Skip to content

Commit

Permalink
Reductions preserve the AxisArray wrapper (fixes #55)
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Feb 21, 2017
1 parent 4d2f8cc commit 6a6c4e4
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
49 changes: 49 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,55 @@ Base.similar{S}(A::AxisArray, ::Type{S}, ax1::Axis, axs::Axis...) = similar(A, S
end
end

# These methods allow us to preserve the AxisArray under reductions
# Note that we only extend the following two methods, and then have it
# dispatch to package-local `reduced_indices` and `reduced_indices0`
# methods. This avoids a whole slew of ambiguities.
Base.reduced_indices(A::AxisArray, region) = reduced_indices(axes(A), region)
Base.reduced_indices0(A::AxisArray, region) = reduced_indices0(axes(A), region)

reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs
reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) =
reduced_indices(axs, (region,))
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) =
reduced_indices0(axs, (region,))

reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) =
ntuple(d->dregion ? reduced_axis(axs[d]) : axs[d], Val{N})
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) =
ntuple(d->dregion ? reduced_axis0(axs[d]) : axs[d], Val{N})

@inline reduced_indices{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) =
_reduced_indices(reduced_axis, (), region, axs...)
@inline reduced_indices(axs::Tuple{Vararg{Axis}}, region::Axis) =
_reduced_indices(reduced_axis, (), region, axs...)
@inline reduced_indices0{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) =
_reduced_indices(reduced_axis0, (), region, axs...)
@inline reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Axis) =
_reduced_indices(reduced_axis0, (), region, axs...)

reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple) =
reduced_indices(reduced_indices(axs, region[1]), tail(region))
reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) =
reduced_indices(reduced_indices(axs, region[1]), tail(region))
reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple) =
reduced_indices0(reduced_indices0(axs, region[1]), tail(region))
reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) =
reduced_indices0(reduced_indices0(axs, region[1]), tail(region))

@inline _reduced_indices{name}(f, out, chosen::Type{Axis{name}}, ax::Axis{name}, axs...) =
_reduced_indices(f, (out..., f(ax)), chosen, axs...)
@inline _reduced_indices{name}(f, out, chosen::Axis{name}, ax::Axis{name}, axs...) =
_reduced_indices(f, (out..., f(ax)), chosen, axs...)
@inline _reduced_indices(f, out, chosen, ax::Axis, axs...) =
_reduced_indices(f, (out..., ax), chosen, axs...)
_reduced_indices(f, out, chosen) = out

reduced_axis(ax) = ax(oftype(ax.val, Base.OneTo(1)))
reduced_axis0(ax) = ax(oftype(ax.val, length(ax.val) == 0 ? Base.OneTo(0) : Base.OneTo(1)))


function Base.permutedims(A::AxisArray, perm)
p = permutation(perm, axisnames(A))
AxisArray(permutedims(A.data, p), axes(A)[[p...]])
Expand Down
44 changes: 44 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,47 @@ A[0] = 12
A = AxisArray(OffsetArrays.OffsetArray(rand(4,5), -1:2, 5:9), :x, :y)
@test indices(A) == (-1:2, 5:9)
@test linearindices(A) == 1:20

# Reductions (issue #55)
A = AxisArray(collect(reshape(1:15,3,5)), :y, :x)
B = AxisArray(collect(reshape(1:15,3,5)), Axis{:y}(0.1:0.1:0.3), Axis{:x}(10:10:50))
for C in (A, B)
for op in (sum, minimum) # together, cover both reduced_indices and reduced_indices0
axv = axisvalues(C)
# C1 = @inferred(sum(C, 1))
C1 = op(C, 1)
@test typeof(C1) == typeof(C)
@test axisnames(C1) == (:y,:x)
@test axisvalues(C1) === (oftype(axv[1], Base.OneTo(1)), axv[2])
C2 = op(C, 2)
@test typeof(C2) == typeof(C)
@test axisnames(C2) == (:y,:x)
@test axisvalues(C2) === (axv[1], oftype(axv[2], Base.OneTo(1)))
# C12 = @inferred(sum(C, (1,2)))
C12 = op(C, (1,2))
@test typeof(C12) == typeof(C)
@test axisnames(C12) == (:y,:x)
@test axisvalues(C12) === (oftype(axv[1], Base.OneTo(1)), oftype(axv[2], Base.OneTo(1)))
if op == sum
@test C1 == [6 15 24 33 42]
@test C2 == reshape([35,40,45], 3, 1)
@test C12 == reshape([120], 1, 1)
else
@test C1 == [1 4 7 10 13]
@test C2 == reshape([1,2,3], 3, 1)
@test C12 == reshape([1], 1, 1)
end
C1t = @inferred(op(C, Axis{:y}))
@test C1t == C1
C2t = @inferred(op(C, Axis{:x}))
@test C2t == C2
C12t = @inferred(op(C, (Axis{:y},Axis{:x})))
@test C12t == C12
C1t = @inferred(op(C, Axis{:y}()))
@test C1t == C1
C2t = @inferred(op(C, Axis{:x}()))
@test C2t == C2
C12t = @inferred(op(C, (Axis{:y}(),Axis{:x}())))
@test C12t == C12
end
end

0 comments on commit 6a6c4e4

Please sign in to comment.