From 24dea7c32746f1d27c9e45702f95031759a40dce Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 5 Feb 2017 06:25:48 -0600 Subject: [PATCH] Reductions preserve the AxisArray wrapper (fixes #55) --- src/core.jl | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ test/core.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/src/core.jl b/src/core.jl index ccdc4d8..0f3a9ec 100644 --- a/src/core.jl +++ b/src/core.jl @@ -280,6 +280,55 @@ Base.similar{S}(A::AxisArray, ::Type{S}, ax1::Axis, axs::Axis...) = similar(A, S end end +# These methods allow us to preserve the AxisArray under reductions +# Note that we only extend the following two methods, and then have it +# dispatch to package-local `reduced_indices` and `reduced_indices0` +# methods. This avoids a whole slew of ambiguities. +Base.reduced_indices(A::AxisArray, region) = reduced_indices(axes(A), region) +Base.reduced_indices0(A::AxisArray, region) = reduced_indices0(axes(A), region) + +reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs +reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs +reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) = + reduced_indices(axs, (region,)) +reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) = + reduced_indices0(axs, (region,)) + +reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) = + map((ax,d)->d∈region ? reduced_axis(ax) : ax, axs, ntuple(identity, Val{N})) +reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) = + map((ax,d)->d∈region ? reduced_axis0(ax) : ax, axs, ntuple(identity, Val{N})) + +@inline reduced_indices{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) = + _reduced_indices(reduced_axis, (), region, axs...) +@inline reduced_indices(axs::Tuple{Vararg{Axis}}, region::Axis) = + _reduced_indices(reduced_axis, (), region, axs...) +@inline reduced_indices0{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) = + _reduced_indices(reduced_axis0, (), region, axs...) +@inline reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Axis) = + _reduced_indices(reduced_axis0, (), region, axs...) + +reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple) = + reduced_indices(reduced_indices(axs, region[1]), tail(region)) +reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) = + reduced_indices(reduced_indices(axs, region[1]), tail(region)) +reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple) = + reduced_indices0(reduced_indices0(axs, region[1]), tail(region)) +reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) = + reduced_indices0(reduced_indices0(axs, region[1]), tail(region)) + +@inline _reduced_indices{name}(f, out, chosen::Type{Axis{name}}, ax::Axis{name}, axs...) = + _reduced_indices(f, (out..., f(ax)), chosen, axs...) +@inline _reduced_indices{name}(f, out, chosen::Axis{name}, ax::Axis{name}, axs...) = + _reduced_indices(f, (out..., f(ax)), chosen, axs...) +@inline _reduced_indices(f, out, chosen, ax::Axis, axs...) = + _reduced_indices(f, (out..., ax), chosen, axs...) +_reduced_indices(f, out, chosen) = out + +reduced_axis(ax) = ax(oftype(ax.val, Base.OneTo(1))) +reduced_axis0(ax) = ax(oftype(ax.val, length(ax.val) == 0 ? Base.OneTo(0) : Base.OneTo(1))) + + function Base.permutedims(A::AxisArray, perm) p = permutation(perm, axisnames(A)) AxisArray(permutedims(A.data, p), axes(A)[[p...]]) diff --git a/test/core.jl b/test/core.jl index 5c15565..b29c28a 100644 --- a/test/core.jl +++ b/test/core.jl @@ -225,3 +225,45 @@ map!(*, A2, A, A) @test isa(A2, AxisArray) @test A2.axes == A.axes @test A2.data == A.data .* A.data + +# Reductions (issue #55) +A = AxisArray(collect(reshape(1:15,3,5)), :y, :x) +B = AxisArray(collect(reshape(1:15,3,5)), Axis{:y}(0.1:0.1:0.3), Axis{:x}(10:10:50)) +for C in (A, B) + for op in (sum, minimum) # together, cover both reduced_indices and reduced_indices0 + axv = axisvalues(C) + C1 = @inferred(op(C, 1)) + @test typeof(C1) == typeof(C) + @test axisnames(C1) == (:y,:x) + @test axisvalues(C1) === (oftype(axv[1], Base.OneTo(1)), axv[2]) + C2 = op(C, 2) + @test typeof(C2) == typeof(C) + @test axisnames(C2) == (:y,:x) + @test axisvalues(C2) === (axv[1], oftype(axv[2], Base.OneTo(1))) + C12 = @inferred(op(C, (1,2))) + @test typeof(C12) == typeof(C) + @test axisnames(C12) == (:y,:x) + @test axisvalues(C12) === (oftype(axv[1], Base.OneTo(1)), oftype(axv[2], Base.OneTo(1))) + if op == sum + @test C1 == [6 15 24 33 42] + @test C2 == reshape([35,40,45], 3, 1) + @test C12 == reshape([120], 1, 1) + else + @test C1 == [1 4 7 10 13] + @test C2 == reshape([1,2,3], 3, 1) + @test C12 == reshape([1], 1, 1) + end + C1t = @inferred(op(C, Axis{:y})) + @test C1t == C1 + C2t = @inferred(op(C, Axis{:x})) + @test C2t == C2 + C12t = @inferred(op(C, (Axis{:y},Axis{:x}))) + @test C12t == C12 + C1t = @inferred(op(C, Axis{:y}())) + @test C1t == C1 + C2t = @inferred(op(C, Axis{:x}())) + @test C2t == C2 + C12t = @inferred(op(C, (Axis{:y}(),Axis{:x}()))) + @test C12t == C12 + end +end