From 487bd8372a49bd3dc5f531d6591f3e265376bb1b Mon Sep 17 00:00:00 2001 From: Jonnie Diegelman Date: Sat, 2 Jul 2022 22:54:44 -0400 Subject: [PATCH 1/3] Slightly faster construction --- Project.toml | 2 +- src/componentarray.jl | 38 +++++++++++++++++--------------------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index 14a7fe21..400c08d1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.12.2" +version = "0.12.3" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/componentarray.jl b/src/componentarray.jl index ea09fa84..b45c4651 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -143,48 +143,46 @@ function make_carray_args(nt) end make_carray_args(::Type{T}, nt) where {T} = make_carray_args(Vector{T}, nt) function make_carray_args(A::Type{<:AbstractArray}, nt) - data, idx = make_idx([], nt, 0) + data, idx = build_data!([], nt, 0) return (A(data), Axis(idx)) end # Builds up data vector and returns appropriate AbstractAxis type for each input type -function make_idx(data, nt::NamedTuple, last_val) +function build_data!(data, nt::NamedTuple, last_val) len = recursive_length(nt) - kvs = [] lv = 0 - for (k,v) in zip(keys(nt), values(nt)) - (_,val) = make_idx(data, v, lv) - push!(kvs, k => val) - lv = val + kvs = map(nt) do v + (_, lv) = build_data!(data, v, lv) + lv end - return (data, ViewAxis(last_index(last_val) .+ (1:len), (;kvs...))) + return (data, ViewAxis(last_index(last_val) .+ (1:len), kvs)) end -function make_idx(data, pair::Pair, last_val) - data, ax = make_idx(data, pair.second, last_val) +function build_data!(data, pair::Pair, last_val) + data, ax = build_data!(data, pair.second, last_val) return (data, ViewAxis(last_val:(last_val+len-1), Axis(pair.second))) end -make_idx(data, x, last_val) = ( +build_data!(data, x, last_val) = ( push!(data, x), ViewAxis(last_index(last_val) + 1) ) -make_idx(data, x::ComponentVector, last_val) = ( - pushcat!(data, x), +build_data!(data, x::ComponentVector, last_val) = ( + append!(data, x), ViewAxis( last_index(last_val) .+ (1:length(x)), getaxes(x)[1] ) ) -function make_idx(data, x::AbstractArray, last_val) - pushcat!(data, x) +function build_data!(data, x::AbstractArray, last_val) + append!(data, x) out = last_index(last_val) .+ (1:length(x)) return (data, ViewAxis(out, ShapedAxis(size(x)))) end -function make_idx(data, x::A, last_val) where {A<:AbstractArray{<:Union{NamedTuple, ComponentArray}}} +function build_data!(data, x::A, last_val) where {A<:AbstractArray{<:Union{NamedTuple, ComponentArray}}} len = recursive_length(x) if eltype(x) |> isconcretetype out = () for elem in x - (_,out) = make_idx(data, elem, last_val) + (_,out) = build_data!(data, elem, last_val) end return ( data, @@ -200,7 +198,7 @@ function make_idx(data, x::A, last_val) where {A<:AbstractArray{<:Union{NamedTup error("Only homogeneous arrays of inner ComponentArrays are allowed.") end end -function make_idx(data, x::A, last_val) where {A<:AbstractArray{<:AbstractArray}} +function build_data!(data, x::A, last_val) where {A<:AbstractArray{<:AbstractArray}} error("ComponentArrays cannot currently contain arrays of arrays as elements. This one contains: \n $x\n") end @@ -209,7 +207,7 @@ end _maybe_add_field(x, pair) = haskey(x, pair.first) ? _update_field(x, pair) : _add_field(x, pair) function _add_field(x, pair) data = copy(getdata(x)) - new_data, new_ax = make_idx(data, pair.second, length(data)) + new_data, new_ax = build_data!(data, pair.second, length(data)) new_ax = Axis(NamedTuple{tuple(pair.first)}(tuple(new_ax))) new_ax = merge(getaxes(x)[1], new_ax) return ComponentArray(new_data, new_ax) @@ -220,8 +218,6 @@ function _update_field(x, pair) return x_copy end -pushcat!(a, b) = reduce((x1,x2) -> push!(x1,x2), b; init=a) - # Reshape ComponentArrays with ShapedAxis axes maybe_reshape(data, ::NotShapedOrPartitionedAxis...) = data function maybe_reshape(data, axs::AbstractAxis...) From a28628610bac361647dd94e8675f7e02f7eb4ab6 Mon Sep 17 00:00:00 2001 From: Jonnie Diegelman Date: Sat, 2 Jul 2022 22:54:44 -0400 Subject: [PATCH 2/3] Slightly faster construction --- Project.toml | 2 +- src/componentarray.jl | 37 ++++++++++++++++--------------------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index 14a7fe21..400c08d1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.12.2" +version = "0.12.3" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/componentarray.jl b/src/componentarray.jl index ea09fa84..e6ae6346 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -143,48 +143,45 @@ function make_carray_args(nt) end make_carray_args(::Type{T}, nt) where {T} = make_carray_args(Vector{T}, nt) function make_carray_args(A::Type{<:AbstractArray}, nt) - data, idx = make_idx([], nt, 0) + data, idx = build_data!([], nt, 0) return (A(data), Axis(idx)) end # Builds up data vector and returns appropriate AbstractAxis type for each input type -function make_idx(data, nt::NamedTuple, last_val) +function build_data!(data, nt::NamedTuple, last_val) len = recursive_length(nt) - kvs = [] lv = 0 - for (k,v) in zip(keys(nt), values(nt)) - (_,val) = make_idx(data, v, lv) - push!(kvs, k => val) - lv = val + kvs = map(nt) do v + lv = build_data!(data, v, lv)[2] end - return (data, ViewAxis(last_index(last_val) .+ (1:len), (;kvs...))) + return (data, ViewAxis(last_index(last_val) .+ (1:len), kvs)) end -function make_idx(data, pair::Pair, last_val) - data, ax = make_idx(data, pair.second, last_val) +function build_data!(data, pair::Pair, last_val) + data, ax = build_data!(data, pair.second, last_val) return (data, ViewAxis(last_val:(last_val+len-1), Axis(pair.second))) end -make_idx(data, x, last_val) = ( +build_data!(data, x, last_val) = ( push!(data, x), ViewAxis(last_index(last_val) + 1) ) -make_idx(data, x::ComponentVector, last_val) = ( - pushcat!(data, x), +build_data!(data, x::ComponentVector, last_val) = ( + append!(data, x), ViewAxis( last_index(last_val) .+ (1:length(x)), getaxes(x)[1] ) ) -function make_idx(data, x::AbstractArray, last_val) - pushcat!(data, x) +function build_data!(data, x::AbstractArray, last_val) + append!(data, x) out = last_index(last_val) .+ (1:length(x)) return (data, ViewAxis(out, ShapedAxis(size(x)))) end -function make_idx(data, x::A, last_val) where {A<:AbstractArray{<:Union{NamedTuple, ComponentArray}}} +function build_data!(data, x::A, last_val) where {A<:AbstractArray{<:Union{NamedTuple, ComponentArray}}} len = recursive_length(x) if eltype(x) |> isconcretetype out = () for elem in x - (_,out) = make_idx(data, elem, last_val) + (_,out) = build_data!(data, elem, last_val) end return ( data, @@ -200,7 +197,7 @@ function make_idx(data, x::A, last_val) where {A<:AbstractArray{<:Union{NamedTup error("Only homogeneous arrays of inner ComponentArrays are allowed.") end end -function make_idx(data, x::A, last_val) where {A<:AbstractArray{<:AbstractArray}} +function build_data!(data, x::A, last_val) where {A<:AbstractArray{<:AbstractArray}} error("ComponentArrays cannot currently contain arrays of arrays as elements. This one contains: \n $x\n") end @@ -209,7 +206,7 @@ end _maybe_add_field(x, pair) = haskey(x, pair.first) ? _update_field(x, pair) : _add_field(x, pair) function _add_field(x, pair) data = copy(getdata(x)) - new_data, new_ax = make_idx(data, pair.second, length(data)) + new_data, new_ax = build_data!(data, pair.second, length(data)) new_ax = Axis(NamedTuple{tuple(pair.first)}(tuple(new_ax))) new_ax = merge(getaxes(x)[1], new_ax) return ComponentArray(new_data, new_ax) @@ -220,8 +217,6 @@ function _update_field(x, pair) return x_copy end -pushcat!(a, b) = reduce((x1,x2) -> push!(x1,x2), b; init=a) - # Reshape ComponentArrays with ShapedAxis axes maybe_reshape(data, ::NotShapedOrPartitionedAxis...) = data function maybe_reshape(data, axs::AbstractAxis...) From 999c3353f9d90d08607fdfd86cd3428480bfe715 Mon Sep 17 00:00:00 2001 From: Jonnie Diegelman Date: Sun, 17 Jul 2022 11:25:36 -0400 Subject: [PATCH 3/3] Fixed type promotion in getproperty adjoint. Fixes #148 --- Project.toml | 2 +- src/compat/chainrulescore.jl | 2 +- test/autodiff_tests.jl | 5 +++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 400c08d1..9102fd3b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.12.3" +version = "0.12.4" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/compat/chainrulescore.jl b/src/compat/chainrulescore.jl index be876232..7709ef7d 100644 --- a/src/compat/chainrulescore.jl +++ b/src/compat/chainrulescore.jl @@ -1,6 +1,6 @@ function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol, Val}) function getproperty_adjoint(Δ) - zero_x = zero(x) + zero_x = ComponentArray(zeros(eltype(Δ), size(x)), getaxes(x)) setproperty!(zero_x, s, Δ) return (ChainRulesCore.NoTangent(), zero_x, ChainRulesCore.NoTangent()) end diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index ee7cee27..2f65b070 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -43,6 +43,11 @@ truth = ComponentArray(a = [32, 48], x = 156) (;c...,).x^2 end[1] end + + # Issue #148 + ps = ComponentArray(;bias = rand(4)) + out = Zygote.gradient(x -> sum(x.^3 .+ ps.bias), Zygote.seed(rand(4),Val(12)))[1] + @test out isa Vector{<:ForwardDiff.Dual} end