Skip to content

Commit

Permalink
Merge pull request #56 from JuliaArrays/teh/reductions
Browse files Browse the repository at this point in the history
Reductions preserve the AxisArray wrapper (fixes #55)
  • Loading branch information
timholy committed Mar 25, 2017
2 parents c56dc40 + ff95732 commit 8f8272d
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
63 changes: 63 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,69 @@ Base.similar{S}(A::AxisArray, ::Type{S}, ax1::Axis, axs::Axis...) = similar(A, S
end
end

# These methods allow us to preserve the AxisArray under reductions
# Note that we only extend the following two methods, and then have it
# dispatch to package-local `reduced_indices` and `reduced_indices0`
# methods. This avoids a whole slew of ambiguities.
if VERSION == v"0.5.0"
Base.reduced_dims(A::AxisArray, region) = reduced_indices(axes(A), region)
Base.reduced_dims0(A::AxisArray, region) = reduced_indices0(axes(A), region)
else
Base.reduced_indices(A::AxisArray, region) = reduced_indices(axes(A), region)
Base.reduced_indices0(A::AxisArray, region) = reduced_indices0(axes(A), region)
end

reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs
reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) =
reduced_indices(axs, (region,))
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) =
reduced_indices0(axs, (region,))

reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) =
map((ax,d)->dregion ? reduced_axis(ax) : ax, axs, ntuple(identity, Val{N}))
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) =
map((ax,d)->dregion ? reduced_axis0(ax) : ax, axs, ntuple(identity, Val{N}))

@inline reduced_indices{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) =
_reduced_indices(reduced_axis, (), region, axs...)
@inline reduced_indices0{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) =
_reduced_indices(reduced_axis0, (), region, axs...)
@inline reduced_indices(axs::Tuple{Vararg{Axis}}, region::Axis) =
_reduced_indices(reduced_axis, (), region, axs...)
@inline reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Axis) =
_reduced_indices(reduced_axis0, (), region, axs...)

reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple) =
reduced_indices(reduced_indices(axs, region[1]), tail(region))
reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) =
reduced_indices(reduced_indices(axs, region[1]), tail(region))
reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple) =
reduced_indices0(reduced_indices0(axs, region[1]), tail(region))
reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) =
reduced_indices0(reduced_indices0(axs, region[1]), tail(region))

@pure samesym{n1,n2}(::Type{Axis{n1}}, ::Type{Axis{n2}}) = Val{n1==n2}()
samesym{n1,n2,T1,T2}(::Type{Axis{n1,T1}}, ::Type{Axis{n2,T2}}) = samesym(Axis{n1},Axis{n2})
samesym{n1,n2}(::Type{Axis{n1}}, ::Axis{n2}) = samesym(Axis{n1}, Axis{n2})
samesym{n1,n2}(::Axis{n1}, ::Type{Axis{n2}}) = samesym(Axis{n1}, Axis{n2})
samesym{n1,n2}(::Axis{n1}, ::Axis{n2}) = samesym(Axis{n1}, Axis{n2})

@inline _reduced_indices{Ax<:Axis}(f, out, chosen::Type{Ax}, ax::Axis, axs...) =
__reduced_indices(f, out, samesym(chosen, ax), chosen, ax, axs)
@inline _reduced_indices(f, out, chosen::Axis, ax::Axis, axs...) =
__reduced_indices(f, out, samesym(chosen, ax), chosen, ax, axs)
_reduced_indices(f, out, chosen) = out

@inline __reduced_indices(f, out, ::Val{true}, chosen, ax, axs) =
_reduced_indices(f, (out..., f(ax)), chosen, axs...)
@inline __reduced_indices(f, out, ::Val{false}, chosen, ax, axs) =
_reduced_indices(f, (out..., ax), chosen, axs...)

reduced_axis(ax) = ax(oftype(ax.val, Base.OneTo(1)))
reduced_axis0(ax) = ax(oftype(ax.val, length(ax.val) == 0 ? Base.OneTo(0) : Base.OneTo(1)))


function Base.permutedims(A::AxisArray, perm)
p = permutation(perm, axisnames(A))
AxisArray(permutedims(A.data, p), axes(A)[[p...]])
Expand Down
42 changes: 42 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,45 @@ map!(*, A2, A, A)
@test isa(A2, AxisArray)
@test A2.axes == A.axes
@test A2.data == A.data .* A.data

# Reductions (issue #55)
A = AxisArray(collect(reshape(1:15,3,5)), :y, :x)
B = AxisArray(collect(reshape(1:15,3,5)), Axis{:y}(0.1:0.1:0.3), Axis{:x}(10:10:50))
for C in (A, B)
for op in (sum, minimum) # together, cover both reduced_indices and reduced_indices0
axv = axisvalues(C)
C1 = @inferred(op(C, 1))
@test typeof(C1) == typeof(C)
@test axisnames(C1) == (:y,:x)
@test axisvalues(C1) === (oftype(axv[1], Base.OneTo(1)), axv[2])
C2 = op(C, 2)
@test typeof(C2) == typeof(C)
@test axisnames(C2) == (:y,:x)
@test axisvalues(C2) === (axv[1], oftype(axv[2], Base.OneTo(1)))
C12 = @inferred(op(C, (1,2)))
@test typeof(C12) == typeof(C)
@test axisnames(C12) == (:y,:x)
@test axisvalues(C12) === (oftype(axv[1], Base.OneTo(1)), oftype(axv[2], Base.OneTo(1)))
if op == sum
@test C1 == [6 15 24 33 42]
@test C2 == reshape([35,40,45], 3, 1)
@test C12 == reshape([120], 1, 1)
else
@test C1 == [1 4 7 10 13]
@test C2 == reshape([1,2,3], 3, 1)
@test C12 == reshape([1], 1, 1)
end
C1t = @inferred(op(C, Axis{:y}))
@test C1t == C1
C2t = @inferred(op(C, Axis{:x}))
@test C2t == C2
C12t = @inferred(op(C, (Axis{:y},Axis{:x})))
@test C12t == C12
C1t = @inferred(op(C, Axis{:y}()))
@test C1t == C1
C2t = @inferred(op(C, Axis{:x}()))
@test C2t == C2
C12t = @inferred(op(C, (Axis{:y}(),Axis{:x}())))
@test C12t == C12
end
end

0 comments on commit 8f8272d

Please sign in to comment.