From 24dea7c32746f1d27c9e45702f95031759a40dce Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 5 Feb 2017 06:25:48 -0600 Subject: [PATCH 1/3] Reductions preserve the AxisArray wrapper (fixes #55) --- src/core.jl | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ test/core.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/src/core.jl b/src/core.jl index ccdc4d8..0f3a9ec 100644 --- a/src/core.jl +++ b/src/core.jl @@ -280,6 +280,55 @@ Base.similar{S}(A::AxisArray, ::Type{S}, ax1::Axis, axs::Axis...) = similar(A, S end end +# These methods allow us to preserve the AxisArray under reductions +# Note that we only extend the following two methods, and then have it +# dispatch to package-local `reduced_indices` and `reduced_indices0` +# methods. This avoids a whole slew of ambiguities. +Base.reduced_indices(A::AxisArray, region) = reduced_indices(axes(A), region) +Base.reduced_indices0(A::AxisArray, region) = reduced_indices0(axes(A), region) + +reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs +reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs +reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) = + reduced_indices(axs, (region,)) +reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) = + reduced_indices0(axs, (region,)) + +reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) = + map((ax,d)->d∈region ? reduced_axis(ax) : ax, axs, ntuple(identity, Val{N})) +reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) = + map((ax,d)->d∈region ? reduced_axis0(ax) : ax, axs, ntuple(identity, Val{N})) + +@inline reduced_indices{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) = + _reduced_indices(reduced_axis, (), region, axs...) +@inline reduced_indices(axs::Tuple{Vararg{Axis}}, region::Axis) = + _reduced_indices(reduced_axis, (), region, axs...) +@inline reduced_indices0{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) = + _reduced_indices(reduced_axis0, (), region, axs...) +@inline reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Axis) = + _reduced_indices(reduced_axis0, (), region, axs...) + +reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple) = + reduced_indices(reduced_indices(axs, region[1]), tail(region)) +reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) = + reduced_indices(reduced_indices(axs, region[1]), tail(region)) +reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple) = + reduced_indices0(reduced_indices0(axs, region[1]), tail(region)) +reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) = + reduced_indices0(reduced_indices0(axs, region[1]), tail(region)) + +@inline _reduced_indices{name}(f, out, chosen::Type{Axis{name}}, ax::Axis{name}, axs...) = + _reduced_indices(f, (out..., f(ax)), chosen, axs...) +@inline _reduced_indices{name}(f, out, chosen::Axis{name}, ax::Axis{name}, axs...) = + _reduced_indices(f, (out..., f(ax)), chosen, axs...) +@inline _reduced_indices(f, out, chosen, ax::Axis, axs...) = + _reduced_indices(f, (out..., ax), chosen, axs...) +_reduced_indices(f, out, chosen) = out + +reduced_axis(ax) = ax(oftype(ax.val, Base.OneTo(1))) +reduced_axis0(ax) = ax(oftype(ax.val, length(ax.val) == 0 ? Base.OneTo(0) : Base.OneTo(1))) + + function Base.permutedims(A::AxisArray, perm) p = permutation(perm, axisnames(A)) AxisArray(permutedims(A.data, p), axes(A)[[p...]]) diff --git a/test/core.jl b/test/core.jl index 5c15565..b29c28a 100644 --- a/test/core.jl +++ b/test/core.jl @@ -225,3 +225,45 @@ map!(*, A2, A, A) @test isa(A2, AxisArray) @test A2.axes == A.axes @test A2.data == A.data .* A.data + +# Reductions (issue #55) +A = AxisArray(collect(reshape(1:15,3,5)), :y, :x) +B = AxisArray(collect(reshape(1:15,3,5)), Axis{:y}(0.1:0.1:0.3), Axis{:x}(10:10:50)) +for C in (A, B) + for op in (sum, minimum) # together, cover both reduced_indices and reduced_indices0 + axv = axisvalues(C) + C1 = @inferred(op(C, 1)) + @test typeof(C1) == typeof(C) + @test axisnames(C1) == (:y,:x) + @test axisvalues(C1) === (oftype(axv[1], Base.OneTo(1)), axv[2]) + C2 = op(C, 2) + @test typeof(C2) == typeof(C) + @test axisnames(C2) == (:y,:x) + @test axisvalues(C2) === (axv[1], oftype(axv[2], Base.OneTo(1))) + C12 = @inferred(op(C, (1,2))) + @test typeof(C12) == typeof(C) + @test axisnames(C12) == (:y,:x) + @test axisvalues(C12) === (oftype(axv[1], Base.OneTo(1)), oftype(axv[2], Base.OneTo(1))) + if op == sum + @test C1 == [6 15 24 33 42] + @test C2 == reshape([35,40,45], 3, 1) + @test C12 == reshape([120], 1, 1) + else + @test C1 == [1 4 7 10 13] + @test C2 == reshape([1,2,3], 3, 1) + @test C12 == reshape([1], 1, 1) + end + C1t = @inferred(op(C, Axis{:y})) + @test C1t == C1 + C2t = @inferred(op(C, Axis{:x})) + @test C2t == C2 + C12t = @inferred(op(C, (Axis{:y},Axis{:x}))) + @test C12t == C12 + C1t = @inferred(op(C, Axis{:y}())) + @test C1t == C1 + C2t = @inferred(op(C, Axis{:x}())) + @test C2t == C2 + C12t = @inferred(op(C, (Axis{:y}(),Axis{:x}()))) + @test C12t == C12 + end +end From d1857f1ef57eb401964294089e69419f06fbf2c4 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 5 Feb 2017 07:07:25 -0600 Subject: [PATCH 2/3] Add version check for reduced_indices --- src/core.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/core.jl b/src/core.jl index 0f3a9ec..41ffb8e 100644 --- a/src/core.jl +++ b/src/core.jl @@ -284,8 +284,13 @@ end # Note that we only extend the following two methods, and then have it # dispatch to package-local `reduced_indices` and `reduced_indices0` # methods. This avoids a whole slew of ambiguities. -Base.reduced_indices(A::AxisArray, region) = reduced_indices(axes(A), region) -Base.reduced_indices0(A::AxisArray, region) = reduced_indices0(axes(A), region) +if VERSION == v"0.5.0" + Base.reduced_dims(A::AxisArray, region) = reduced_indices(axes(A), region) + Base.reduced_dims0(A::AxisArray, region) = reduced_indices0(axes(A), region) +else + Base.reduced_indices(A::AxisArray, region) = reduced_indices(axes(A), region) + Base.reduced_indices0(A::AxisArray, region) = reduced_indices0(axes(A), region) +end reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs From ff9573225154744dc7a92fca9f1d35858acedf82 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 5 Feb 2017 08:17:14 -0600 Subject: [PATCH 3/3] Work around inference hang when inlining is off Works around https://github.com/JuliaLang/julia/issues/20714 --- src/core.jl | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/core.jl b/src/core.jl index 41ffb8e..62f2167 100644 --- a/src/core.jl +++ b/src/core.jl @@ -306,10 +306,10 @@ reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) = @inline reduced_indices{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) = _reduced_indices(reduced_axis, (), region, axs...) -@inline reduced_indices(axs::Tuple{Vararg{Axis}}, region::Axis) = - _reduced_indices(reduced_axis, (), region, axs...) @inline reduced_indices0{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) = _reduced_indices(reduced_axis0, (), region, axs...) +@inline reduced_indices(axs::Tuple{Vararg{Axis}}, region::Axis) = + _reduced_indices(reduced_axis, (), region, axs...) @inline reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Axis) = _reduced_indices(reduced_axis0, (), region, axs...) @@ -322,13 +322,22 @@ reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple) = reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) = reduced_indices0(reduced_indices0(axs, region[1]), tail(region)) -@inline _reduced_indices{name}(f, out, chosen::Type{Axis{name}}, ax::Axis{name}, axs...) = - _reduced_indices(f, (out..., f(ax)), chosen, axs...) -@inline _reduced_indices{name}(f, out, chosen::Axis{name}, ax::Axis{name}, axs...) = +@pure samesym{n1,n2}(::Type{Axis{n1}}, ::Type{Axis{n2}}) = Val{n1==n2}() +samesym{n1,n2,T1,T2}(::Type{Axis{n1,T1}}, ::Type{Axis{n2,T2}}) = samesym(Axis{n1},Axis{n2}) +samesym{n1,n2}(::Type{Axis{n1}}, ::Axis{n2}) = samesym(Axis{n1}, Axis{n2}) +samesym{n1,n2}(::Axis{n1}, ::Type{Axis{n2}}) = samesym(Axis{n1}, Axis{n2}) +samesym{n1,n2}(::Axis{n1}, ::Axis{n2}) = samesym(Axis{n1}, Axis{n2}) + +@inline _reduced_indices{Ax<:Axis}(f, out, chosen::Type{Ax}, ax::Axis, axs...) = + __reduced_indices(f, out, samesym(chosen, ax), chosen, ax, axs) +@inline _reduced_indices(f, out, chosen::Axis, ax::Axis, axs...) = + __reduced_indices(f, out, samesym(chosen, ax), chosen, ax, axs) +_reduced_indices(f, out, chosen) = out + +@inline __reduced_indices(f, out, ::Val{true}, chosen, ax, axs) = _reduced_indices(f, (out..., f(ax)), chosen, axs...) -@inline _reduced_indices(f, out, chosen, ax::Axis, axs...) = +@inline __reduced_indices(f, out, ::Val{false}, chosen, ax, axs) = _reduced_indices(f, (out..., ax), chosen, axs...) -_reduced_indices(f, out, chosen) = out reduced_axis(ax) = ax(oftype(ax.val, Base.OneTo(1))) reduced_axis0(ax) = ax(oftype(ax.val, length(ax.val) == 0 ? Base.OneTo(0) : Base.OneTo(1)))