Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ComponentArrays"
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"]
version = "0.12.2"
version = "0.12.4"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
2 changes: 1 addition & 1 deletion src/compat/chainrulescore.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol, Val})
function getproperty_adjoint(Δ)
zero_x = zero(x)
zero_x = ComponentArray(zeros(eltype(Δ), size(x)), getaxes(x))
setproperty!(zero_x, s, Δ)
return (ChainRulesCore.NoTangent(), zero_x, ChainRulesCore.NoTangent())
end
Expand Down
37 changes: 16 additions & 21 deletions src/componentarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,48 +143,45 @@ function make_carray_args(nt)
end
make_carray_args(::Type{T}, nt) where {T} = make_carray_args(Vector{T}, nt)
function make_carray_args(A::Type{<:AbstractArray}, nt)
data, idx = make_idx([], nt, 0)
data, idx = build_data!([], nt, 0)
return (A(data), Axis(idx))
end

# Builds up data vector and returns appropriate AbstractAxis type for each input type
function make_idx(data, nt::NamedTuple, last_val)
function build_data!(data, nt::NamedTuple, last_val)
len = recursive_length(nt)
kvs = []
lv = 0
for (k,v) in zip(keys(nt), values(nt))
(_,val) = make_idx(data, v, lv)
push!(kvs, k => val)
lv = val
kvs = map(nt) do v
lv = build_data!(data, v, lv)[2]
end
return (data, ViewAxis(last_index(last_val) .+ (1:len), (;kvs...)))
return (data, ViewAxis(last_index(last_val) .+ (1:len), kvs))
end
function make_idx(data, pair::Pair, last_val)
data, ax = make_idx(data, pair.second, last_val)
function build_data!(data, pair::Pair, last_val)
data, ax = build_data!(data, pair.second, last_val)
return (data, ViewAxis(last_val:(last_val+len-1), Axis(pair.second)))
end
make_idx(data, x, last_val) = (
build_data!(data, x, last_val) = (
push!(data, x),
ViewAxis(last_index(last_val) + 1)
)
make_idx(data, x::ComponentVector, last_val) = (
pushcat!(data, x),
build_data!(data, x::ComponentVector, last_val) = (
append!(data, x),
ViewAxis(
last_index(last_val) .+ (1:length(x)),
getaxes(x)[1]
)
)
function make_idx(data, x::AbstractArray, last_val)
pushcat!(data, x)
function build_data!(data, x::AbstractArray, last_val)
append!(data, x)
out = last_index(last_val) .+ (1:length(x))
return (data, ViewAxis(out, ShapedAxis(size(x))))
end
function make_idx(data, x::A, last_val) where {A<:AbstractArray{<:Union{NamedTuple, ComponentArray}}}
function build_data!(data, x::A, last_val) where {A<:AbstractArray{<:Union{NamedTuple, ComponentArray}}}
len = recursive_length(x)
if eltype(x) |> isconcretetype
out = ()
for elem in x
(_,out) = make_idx(data, elem, last_val)
(_,out) = build_data!(data, elem, last_val)
end
return (
data,
Expand All @@ -200,7 +197,7 @@ function make_idx(data, x::A, last_val) where {A<:AbstractArray{<:Union{NamedTup
error("Only homogeneous arrays of inner ComponentArrays are allowed.")
end
end
function make_idx(data, x::A, last_val) where {A<:AbstractArray{<:AbstractArray}}
function build_data!(data, x::A, last_val) where {A<:AbstractArray{<:AbstractArray}}
error("ComponentArrays cannot currently contain arrays of arrays as elements. This one contains: \n $x\n")
end

Expand All @@ -209,7 +206,7 @@ end
_maybe_add_field(x, pair) = haskey(x, pair.first) ? _update_field(x, pair) : _add_field(x, pair)
function _add_field(x, pair)
data = copy(getdata(x))
new_data, new_ax = make_idx(data, pair.second, length(data))
new_data, new_ax = build_data!(data, pair.second, length(data))
new_ax = Axis(NamedTuple{tuple(pair.first)}(tuple(new_ax)))
new_ax = merge(getaxes(x)[1], new_ax)
return ComponentArray(new_data, new_ax)
Expand All @@ -220,8 +217,6 @@ function _update_field(x, pair)
return x_copy
end

pushcat!(a, b) = reduce((x1,x2) -> push!(x1,x2), b; init=a)

# Reshape ComponentArrays with ShapedAxis axes
maybe_reshape(data, ::NotShapedOrPartitionedAxis...) = data
function maybe_reshape(data, axs::AbstractAxis...)
Expand Down
5 changes: 5 additions & 0 deletions test/autodiff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ truth = ComponentArray(a = [32, 48], x = 156)
(;c...,).x^2
end[1]
end

# Issue #148
ps = ComponentArray(;bias = rand(4))
out = Zygote.gradient(x -> sum(x.^3 .+ ps.bias), Zygote.seed(rand(4),Val(12)))[1]
@test out isa Vector{<:ForwardDiff.Dual}
end


Expand Down