From 6a6c4e451c24740f84d78631045eb273d8228c26 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 5 Feb 2017 06:25:48 -0600 Subject: [PATCH] Reductions preserve the AxisArray wrapper (fixes #55) --- src/core.jl | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ test/core.jl | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/src/core.jl b/src/core.jl index f6e670f..c4cbd84 100644 --- a/src/core.jl +++ b/src/core.jl @@ -280,6 +280,55 @@ Base.similar{S}(A::AxisArray, ::Type{S}, ax1::Axis, axs::Axis...) = similar(A, S end end +# These methods allow us to preserve the AxisArray under reductions +# Note that we only extend the following two methods, and then have it +# dispatch to package-local `reduced_indices` and `reduced_indices0` +# methods. This avoids a whole slew of ambiguities. +Base.reduced_indices(A::AxisArray, region) = reduced_indices(axes(A), region) +Base.reduced_indices0(A::AxisArray, region) = reduced_indices0(axes(A), region) + +reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs +reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs +reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) = + reduced_indices(axs, (region,)) +reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) = + reduced_indices0(axs, (region,)) + +reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) = + ntuple(d->d∈region ? reduced_axis(axs[d]) : axs[d], Val{N}) +reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) = + ntuple(d->d∈region ? reduced_axis0(axs[d]) : axs[d], Val{N}) + +@inline reduced_indices{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) = + _reduced_indices(reduced_axis, (), region, axs...) +@inline reduced_indices(axs::Tuple{Vararg{Axis}}, region::Axis) = + _reduced_indices(reduced_axis, (), region, axs...) +@inline reduced_indices0{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) = + _reduced_indices(reduced_axis0, (), region, axs...) +@inline reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Axis) = + _reduced_indices(reduced_axis0, (), region, axs...) + +reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple) = + reduced_indices(reduced_indices(axs, region[1]), tail(region)) +reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) = + reduced_indices(reduced_indices(axs, region[1]), tail(region)) +reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple) = + reduced_indices0(reduced_indices0(axs, region[1]), tail(region)) +reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) = + reduced_indices0(reduced_indices0(axs, region[1]), tail(region)) + +@inline _reduced_indices{name}(f, out, chosen::Type{Axis{name}}, ax::Axis{name}, axs...) = + _reduced_indices(f, (out..., f(ax)), chosen, axs...) +@inline _reduced_indices{name}(f, out, chosen::Axis{name}, ax::Axis{name}, axs...) = + _reduced_indices(f, (out..., f(ax)), chosen, axs...) +@inline _reduced_indices(f, out, chosen, ax::Axis, axs...) = + _reduced_indices(f, (out..., ax), chosen, axs...) +_reduced_indices(f, out, chosen) = out + +reduced_axis(ax) = ax(oftype(ax.val, Base.OneTo(1))) +reduced_axis0(ax) = ax(oftype(ax.val, length(ax.val) == 0 ? Base.OneTo(0) : Base.OneTo(1))) + + function Base.permutedims(A::AxisArray, perm) p = permutation(perm, axisnames(A)) AxisArray(permutedims(A.data, p), axes(A)[[p...]]) diff --git a/test/core.jl b/test/core.jl index c608412..94eba03 100644 --- a/test/core.jl +++ b/test/core.jl @@ -192,3 +192,47 @@ A[0] = 12 A = AxisArray(OffsetArrays.OffsetArray(rand(4,5), -1:2, 5:9), :x, :y) @test indices(A) == (-1:2, 5:9) @test linearindices(A) == 1:20 + +# Reductions (issue #55) +A = AxisArray(collect(reshape(1:15,3,5)), :y, :x) +B = AxisArray(collect(reshape(1:15,3,5)), Axis{:y}(0.1:0.1:0.3), Axis{:x}(10:10:50)) +for C in (A, B) + for op in (sum, minimum) # together, cover both reduced_indices and reduced_indices0 + axv = axisvalues(C) + # C1 = @inferred(sum(C, 1)) + C1 = op(C, 1) + @test typeof(C1) == typeof(C) + @test axisnames(C1) == (:y,:x) + @test axisvalues(C1) === (oftype(axv[1], Base.OneTo(1)), axv[2]) + C2 = op(C, 2) + @test typeof(C2) == typeof(C) + @test axisnames(C2) == (:y,:x) + @test axisvalues(C2) === (axv[1], oftype(axv[2], Base.OneTo(1))) + # C12 = @inferred(sum(C, (1,2))) + C12 = op(C, (1,2)) + @test typeof(C12) == typeof(C) + @test axisnames(C12) == (:y,:x) + @test axisvalues(C12) === (oftype(axv[1], Base.OneTo(1)), oftype(axv[2], Base.OneTo(1))) + if op == sum + @test C1 == [6 15 24 33 42] + @test C2 == reshape([35,40,45], 3, 1) + @test C12 == reshape([120], 1, 1) + else + @test C1 == [1 4 7 10 13] + @test C2 == reshape([1,2,3], 3, 1) + @test C12 == reshape([1], 1, 1) + end + C1t = @inferred(op(C, Axis{:y})) + @test C1t == C1 + C2t = @inferred(op(C, Axis{:x})) + @test C2t == C2 + C12t = @inferred(op(C, (Axis{:y},Axis{:x}))) + @test C12t == C12 + C1t = @inferred(op(C, Axis{:y}())) + @test C1t == C1 + C2t = @inferred(op(C, Axis{:x}())) + @test C2t == C2 + C12t = @inferred(op(C, (Axis{:y}(),Axis{:x}()))) + @test C12t == C12 + end +end