Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/componentarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ end
# Entry from NamedTuple, Dict, or kwargs
ComponentArray{T}(nt::NamedTuple) where T = ComponentArray(make_carray_args(T, nt)...)
ComponentArray{T}(::NamedTuple{(), Tuple{}}) where T = ComponentArray(T[], (FlatAxis(),))
ComponentArray(nt::NamedTuple) = ComponentArray(make_carray_args(nt)...)
ComponentArray(nt::Union{NamedTuple, AbstractDict}) = ComponentArray(make_carray_args(nt)...)
ComponentArray(::NamedTuple{(), Tuple{}}) = ComponentArray(Any[], (FlatAxis(),))
ComponentArray(d::AbstractDict) = ComponentArray(NamedTuple{Tuple(keys(d))}(values(d)))
ComponentArray{T}(;kwargs...) where T = ComponentArray{T}((;kwargs...))
ComponentArray(;kwargs...) = ComponentArray((;kwargs...))

Expand Down Expand Up @@ -138,17 +137,22 @@ make_carray_args(::NamedTuple{(), Tuple{}}) = (Any[], FlatAxis())
make_carray_args(::Type{T}, ::NamedTuple{(), Tuple{}}) where {T} = (T[], FlatAxis())
function make_carray_args(nt)
data, ax = make_carray_args(Vector, nt)
data = length(data)==1 ? [data[1]] : reduce(vcat, data)
data = length(data)==1 ? [data[1]] : map(identity, data)
return (data, ax)
end
make_carray_args(::Type{T}, nt) where {T} = make_carray_args(Vector{T}, nt)
function make_carray_args(A::Type{<:AbstractArray}, nt)
data, idx = make_idx([], nt, 0)
T = recursive_eltype(nt)
init = _isbitstype(T) ? T[] : []
data, idx = make_idx(init, nt, 0)
return (A(data), Axis(idx))
end

_isbitstype(::Type{<:Union{T, Nothing, Missing}}) where {T} = isbitstype(T)
_isbitstype(T) = isbitstype(T)

# Builds up data vector and returns appropriate AbstractAxis type for each input type
function make_idx(data, nt::NamedTuple, last_val)
function make_idx(data, nt::Union{NamedTuple, AbstractDict}, last_val)
len = recursive_length(nt)
kvs = []
lv = 0
Expand Down Expand Up @@ -325,4 +329,4 @@ julia> sum(prod(ca[k]) for k in valkeys(ca))
k = Val.(keys(idxmap))
return :($k)
end
valkeys(ca::ComponentVector) = valkeys(getaxes(ca)[1])
valkeys(ca::ComponentVector) = valkeys(getaxes(ca)[1])
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,10 @@ recursive_length(a::AbstractArray{T,N}) where {T<:Number,N} = length(a)
recursive_length(a::AbstractArray) = recursive_length.(a) |> sum
recursive_length(nt::NamedTuple) = values(nt) .|> recursive_length |> sum
recursive_length(::Union{Nothing, Missing}) = 1

# Find the highest element type
recursive_eltype(nt::NamedTuple) = mapreduce(recursive_eltype, promote_type, nt)
recursive_eltype(x::Vector) = mapreduce(recursive_eltype, promote_type, x)
recursive_eltype(x::Dict) = mapreduce(recursive_eltype, promote_type, values(x))
recursive_eltype(::AbstractArray{T,N}) where {T<:Number, N}= T
recursive_eltype(x) = typeof(x)