Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support and use broadcast with mapreduce.
  • Loading branch information
maleadt committed May 7, 2020
commit 6e7560a1daea34da1e4359cb2a404a648087b26d
44 changes: 32 additions & 12 deletions src/host/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# map-reduce

const AbstractArrayOrBroadcasted = Union{AbstractArray,Broadcast.Broadcasted}

# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional
# argument `init` value to avoid eager initialization of `R` (if set to something).
mapreducedim!(f, op, R::AbstractGPUArray, As::AbstractArray...; init=nothing) = error("Not implemented") # COV_EXCL_LINE
mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArrayOrBroadcasted;
init=nothing) = error("Not implemented") # COV_EXCL_LINE
# resolve ambiguities
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::Broadcast.Broadcasted) = mapreducedim!(f, op, R, A)

neutral_element(op, T) =
error("""GPUArrays.jl needs to know the neutral element for your operator `$op`.
Expand All @@ -18,11 +23,30 @@ neutral_element(::typeof(Base.mul_prod), T) = one(T)
neutral_element(::typeof(Base.min), T) = typemax(T)
neutral_element(::typeof(Base.max), T) = typemin(T)

function Base.mapreduce(f, op, As::AbstractGPUArray...; dims=:, init=nothing)
# resolve ambiguities
Base.mapreduce(f, op, A::AbstractGPUArray, As::AbstractArrayOrBroadcasted...;
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::AbstractArrayOrBroadcasted...;
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)

function _mapreduce(f, op, As...; dims, init)
# mapreduce should apply `f` like `map` does, consuming elements like iterators.
bc = if allequal(size.(As)...)
Broadcast.instantiate(Broadcast.broadcasted(f, As...))
else
# TODO: can we avoid the reshape + view?
indices = LinearIndices.(As)
common_length = minimum(length.(indices))
Bs = map(As) do A
view(reshape(A, length(A)), 1:common_length)
end
Broadcast.instantiate(Broadcast.broadcasted(f, Bs...))
end

# figure out the destination container type by looking at the initializer element,
# or by relying on inference to reason through the map and reduce functions.
if init === nothing
ET = Base.promote_op(f, map(eltype, As)...)
ET = Broadcast.combine_eltypes(bc.f, bc.args)
ET = Base.promote_op(op, ET, ET)
(ET === Union{} || ET === Any) &&
error("mapreduce cannot figure the output element type, please pass an explicit init value")
Expand All @@ -32,14 +56,10 @@ function Base.mapreduce(f, op, As::AbstractGPUArray...; dims=:, init=nothing)
ET = typeof(init)
end

# TODO: Broadcast-semantics after JuliaLang-julia#31020
A = first(As)
all(B -> size(A) == size(B), As) || throw(DimensionMismatch("dimensions of containers must be identical"))

sz = size(A)
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], ndims(A))
R = similar(A, ET, red)
mapreducedim!(f, op, R, As...; init=init)
sz = size(bc)
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], length(sz))
R = similar(bc, ET, red)
mapreducedim!(identity, op, R, bc; init=init)

if dims==Colon()
@allowscalar R[]
Expand All @@ -57,7 +77,7 @@ Base.count(pred::Function, A::AbstractGPUArray) = mapreduce(pred, +, A; init = 0

Base.:(==)(A::AbstractGPUArray, B::AbstractGPUArray) = Bool(mapreduce(==, &, A, B))

# avoid calling into `initarray!``
# avoid calling into `initarray!`
Base.sum!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(Base.add_sum, R, A)
Base.prod!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(Base.mul_prod, R, A)
Base.maximum!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(max, R, A)
Expand Down
5 changes: 3 additions & 2 deletions src/reference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,12 @@ Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
reshape(reinterpret(T, A.data), size)

function GPUArrays.mapreducedim!(f, op, R::JLArray, As::AbstractArray...; init=nothing)
function GPUArrays.mapreducedim!(f, op, R::JLArray, A::Union{AbstractArray,Broadcast.Broadcasted};
init=nothing)
if init !== nothing
fill!(R, init)
end
@allowscalar Base.reducedim!(op, R.data, map(f, As...))
@allowscalar Base.reducedim!(op, R.data, map(f, A))
end

end
9 changes: 9 additions & 0 deletions test/testsuite/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ function test_mapreduce(AT)
ET <: Complex || @test compare(minimum, AT,rand(range, dims))
end
end
@testset "broadcasting behavior" begin
@test compare((x,y)->mapreduce(+, +, x, y), AT,
rand(range, 1), rand(range, 2, 2))
@test compare(AT, rand(range, 1), rand(range, 2, 2)) do x, y
bc = Broadcast.instantiate(Broadcast.broadcasted(*, x, y))
reduce(+, bc)
end

end
end
end
@testset "any all ==" begin
Expand Down