Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 274 additions & 13 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,35 +388,88 @@ end
@inline Base.IteratorSize(::Type{<:AbstractVectorOfArray}) = Base.HasLength()
@inline Base.first(VA::AbstractVectorOfArray) = first(VA.u)
@inline Base.last(VA::AbstractVectorOfArray) = last(VA.u)
function Base.firstindex(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A}
function Base.firstindex(VA::AbstractVectorOfArray{T, N, A}) where {T, N, A}
N > 1 && Base.depwarn(
"Linear indexing of `AbstractVectorOfArray` is deprecated. Change `A[i]` to `A.u[i]` ",
:firstindex)
return firstindex(VA.u)
end

function Base.lastindex(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A}
N > 1 && Base.depwarn(
function Base.lastindex(VA::AbstractVectorOfArray{T, N, A}) where {T, N, A}
N > 1 && Base.depwarn(
"Linear indexing of `AbstractVectorOfArray` is deprecated. Change `A[i]` to `A.u[i]` ",
:lastindex)
return lastindex(VA.u)
end

# Always return RaggedEnd for type stability. Use dim=0 to indicate a plain index stored in offset.
# _resolve_ragged_index and _column_indices handle the dim=0 case to extract the actual index value.
@inline function Base.lastindex(VA::AbstractVectorOfArray, d::Integer)
if d == ndims(VA)
return RaggedEnd(0, Int(lastindex(VA.u)))
elseif d < ndims(VA)
isempty(VA.u) && return RaggedEnd(0, 0)
return RaggedEnd(Int(d), 0)
else
return RaggedEnd(0, 1)
end
end

Base.getindex(A::AbstractVectorOfArray, I::Int) = A.u[I]
Base.getindex(A::AbstractVectorOfArray, I::AbstractArray{Int}) = A.u[I]
Base.getindex(A::AbstractDiffEqArray, I::Int) = A.u[I]
Base.getindex(A::AbstractDiffEqArray, I::AbstractArray{Int}) = A.u[I]

@deprecate Base.getindex(VA::AbstractVectorOfArray{T,N,A}, I::Int) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false
@deprecate Base.getindex(VA::AbstractVectorOfArray{T, N, A},
I::Int) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false

@deprecate Base.getindex(VA::AbstractVectorOfArray{T,N,A}, I::AbstractArray{Int}) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false
@deprecate Base.getindex(VA::AbstractVectorOfArray{T, N, A},
I::AbstractArray{Int}) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false

@deprecate Base.getindex(VA::AbstractDiffEqArray{T,N,A}, I::AbstractArray{Int}) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false
@deprecate Base.getindex(VA::AbstractDiffEqArray{T, N, A},
I::AbstractArray{Int}) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false

@deprecate Base.getindex(VA::AbstractDiffEqArray{T,N,A}, i::Int) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} VA.u[i] false
@deprecate Base.getindex(VA::AbstractDiffEqArray{T, N, A},
i::Int) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} VA.u[i] false

__parameterless_type(T) = Base.typename(T).wrapper

# `end` support for ragged inner arrays
# Use runtime fields instead of type parameters for type stability
struct RaggedEnd
dim::Int
offset::Int
end
RaggedEnd(dim::Int) = RaggedEnd(dim, 0)

Base.:+(re::RaggedEnd, n::Integer) = RaggedEnd(re.dim, re.offset + Int(n))
Base.:-(re::RaggedEnd, n::Integer) = RaggedEnd(re.dim, re.offset - Int(n))
Base.:+(n::Integer, re::RaggedEnd) = re + n

struct RaggedRange
dim::Int
start::Int
step::Int
offset::Int
end

Base.:(:)(stop::RaggedEnd) = RaggedRange(stop.dim, 1, 1, stop.offset)
function Base.:(:)(start::Integer, stop::RaggedEnd)
RaggedRange(stop.dim, Int(start), 1, stop.offset)
end
function Base.:(:)(start::Integer, step::Integer, stop::RaggedEnd)
RaggedRange(stop.dim, Int(start), Int(step), stop.offset)
end

@inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer)
length(VA.u) <= 1 && return false
first_size = size(VA.u[1], d)
@inbounds for idx in 2:length(VA.u)
size(VA.u[idx], d) == first_size || return true
end
return false
end

Base.@propagate_inbounds function _getindex(
A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int)
A.u[I]
Expand Down Expand Up @@ -487,11 +540,206 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb
return getindex(A, all_variable_symbols(A), args...)
end

@inline _column_indices(VA::AbstractVectorOfArray, idx) = idx === Colon() ?
eachindex(VA.u) : idx
@inline function _column_indices(VA::AbstractVectorOfArray, idx::AbstractArray{Bool})
findall(idx)
end
@inline function _column_indices(VA::AbstractVectorOfArray, idx::RaggedEnd)
# RaggedEnd with dim=0 means it's just a plain index stored in offset
idx.dim == 0 ? idx.offset : idx
end

@inline _resolve_ragged_index(idx, ::AbstractVectorOfArray, ::Any) = idx
@inline function _resolve_ragged_index(idx::RaggedEnd, VA::AbstractVectorOfArray, col)
if idx.dim == 0
# Special case: dim=0 means the offset contains the actual index value
return idx.offset
else
return lastindex(VA.u[col], idx.dim) + idx.offset
end
end
@inline function _resolve_ragged_index(idx::RaggedRange, VA::AbstractVectorOfArray, col)
stop_val = if idx.dim == 0
# dim == 0 is the sentinel for an already-resolved plain index stored in offset
idx.offset
else
lastindex(VA.u[col], idx.dim) + idx.offset
end
return Base.range(idx.start; step = idx.step, stop = stop_val)
end
@inline function _resolve_ragged_index(
idx::AbstractRange{<:RaggedEnd}, VA::AbstractVectorOfArray, col)
return Base.range(_resolve_ragged_index(first(idx), VA, col); step = step(idx),
stop = _resolve_ragged_index(last(idx), VA, col))
end
@inline function _resolve_ragged_index(idx::Base.Slice, VA::AbstractVectorOfArray, col)
return Base.Slice(_resolve_ragged_index(idx.indices, VA, col))
end
@inline function _resolve_ragged_index(idx::CartesianIndex, VA::AbstractVectorOfArray, col)
return CartesianIndex(_resolve_ragged_indices(Tuple(idx), VA, col)...)
end
@inline function _resolve_ragged_index(
idx::AbstractArray{<:RaggedEnd}, VA::AbstractVectorOfArray, col)
return map(i -> _resolve_ragged_index(i, VA, col), idx)
end
@inline function _resolve_ragged_index(
idx::AbstractArray{<:RaggedRange}, VA::AbstractVectorOfArray, col)
return map(i -> _resolve_ragged_index(i, VA, col), idx)
end
@inline function _resolve_ragged_index(idx::AbstractArray, VA::AbstractVectorOfArray, col)
return _has_ragged_end(idx) ? map(i -> _resolve_ragged_index(i, VA, col), idx) : idx
end

@inline function _resolve_ragged_indices(idxs::Tuple, VA::AbstractVectorOfArray, col)
map(i -> _resolve_ragged_index(i, VA, col), idxs)
end

@inline function _has_ragged_end(x)
x isa RaggedEnd && return true
x isa RaggedRange && return true
x isa Base.Slice && return _has_ragged_end(x.indices)
x isa CartesianIndex && return _has_ragged_end(Tuple(x))
x isa AbstractRange && return eltype(x) <: Union{RaggedEnd, RaggedRange}
if x isa AbstractArray
el = eltype(x)
return el <: Union{RaggedEnd, RaggedRange} ||
(el === Any && any(_has_ragged_end, x))
end
x isa Tuple && return any(_has_ragged_end, x)
return false
end
@inline _has_ragged_end(x, xs...) = _has_ragged_end(x) || _has_ragged_end(xs)

@inline function _ragged_getindex(A::AbstractVectorOfArray, I...)
n = ndims(A)
# Special-case when user provided one fewer index than ndims(A): last index is column selector.
if length(I) == n - 1
raw_cols = last(I)
# If the raw selector is a RaggedEnd/RaggedRange referring to inner dims, reinterpret as column selector.
cols = if raw_cols isa RaggedEnd && raw_cols.dim != 0
lastindex(A.u) + raw_cols.offset
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
stop_val = lastindex(A.u) + raw_cols.offset
Base.range(raw_cols.start; step = raw_cols.step, stop = stop_val)
else
_column_indices(A, raw_cols)
end
prefix = Base.front(I)
if cols isa Int
resolved_prefix = _resolve_ragged_indices(prefix, A, cols)
inner_nd = ndims(A.u[cols])
n_missing = inner_nd - length(resolved_prefix)
padded = if n_missing > 0
if all(idx -> idx === Colon(), resolved_prefix)
(resolved_prefix..., ntuple(_ -> Colon(), n_missing)...)
else
(resolved_prefix...,
(lastindex(A.u[cols], length(resolved_prefix) + i) for i in 1:n_missing)...)
end
else
resolved_prefix
end
return A.u[cols][padded...]
else
return VectorOfArray([begin
resolved_prefix = _resolve_ragged_indices(prefix, A, col)
inner_nd = ndims(A.u[col])
n_missing = inner_nd - length(resolved_prefix)
padded = if n_missing > 0
if all(idx -> idx === Colon(), resolved_prefix)
(resolved_prefix...,
ntuple(_ -> Colon(), n_missing)...)
else
(resolved_prefix...,
(lastindex(A.u[col],
length(resolved_prefix) + i) for i in 1:n_missing)...)
end
else
resolved_prefix
end
A.u[col][padded...]
end
for col in cols])
end
end

# Otherwise, use the full-length interpretation (last index is column selector; missing columns default to Colon()).
if length(I) == n
cols = last(I)
prefix = Base.front(I)
else
cols = Colon()
prefix = I
end
if cols isa Int
if all(idx -> idx === Colon(), prefix)
return A.u[cols]
end
resolved = _resolve_ragged_indices(prefix, A, cols)
inner_nd = ndims(A.u[cols])
padded = (resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
return A.u[cols][padded...]
else
col_idxs = _column_indices(A, cols)
# Resolve sentinel RaggedEnd/RaggedRange (dim==0) for column selection
if col_idxs isa RaggedEnd
col_idxs = _resolve_ragged_index(col_idxs, A, 1)
elseif col_idxs isa RaggedRange
col_idxs = _resolve_ragged_index(col_idxs, A, 1)
end
# If we're selecting whole inner arrays (all leading indices are Colons),
# keep the result as a VectorOfArray to match non-ragged behavior.
if all(idx -> idx === Colon(), prefix)
if col_idxs isa Int
return A.u[col_idxs]
else
return VectorOfArray(A.u[col_idxs])
end
end
# If col_idxs resolved to a single Int, handle it directly
if col_idxs isa Int
resolved = _resolve_ragged_indices(prefix, A, col_idxs)
inner_nd = ndims(A.u[col_idxs])
padded = (
resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
return A.u[col_idxs][padded...]
end
vals = map(col_idxs) do col
resolved = _resolve_ragged_indices(prefix, A, col)
inner_nd = ndims(A.u[col])
padded = (
resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
A.u[col][padded...]
end
return stack(vals)
end
end

@inline function _checkbounds_ragged(::Type{Bool}, VA::AbstractVectorOfArray, idxs...)
cols = _column_indices(VA, last(idxs))
prefix = Base.front(idxs)
if cols isa Int
resolved = _resolve_ragged_indices(prefix, VA, cols)
return checkbounds(Bool, VA.u, cols) && checkbounds(Bool, VA.u[cols], resolved...)
else
for col in cols
resolved = _resolve_ragged_indices(prefix, VA, col)
checkbounds(Bool, VA.u, col) || return false
checkbounds(Bool, VA.u[col], resolved...) || return false
end
return true
end
end

Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, args...)
symtype = symbolic_type(_arg)
elsymtype = symbolic_type(eltype(_arg))

if symtype == NotSymbolic() && elsymtype == NotSymbolic()
if _has_ragged_end(_arg, args...)
return _ragged_getindex(A, _arg, args...)
end
if _arg isa Union{Tuple, AbstractArray} &&
any(x -> symbolic_type(x) != NotSymbolic(), _arg)
_getindex(A, symtype, elsymtype, _arg, args...)
Expand Down Expand Up @@ -523,25 +771,32 @@ Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}
VA.u[I] = v
end

Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::Int) = Base.setindex!(VA.u, v, I)
@deprecate Base.setindex!(VA::AbstractVectorOfArray{T,N,A}, v, I::Int) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(VA.u, v, I) false
Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::Int) = Base.setindex!(
VA.u, v, I)
@deprecate Base.setindex!(VA::AbstractVectorOfArray{T, N, A}, v,
I::Int) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(
VA.u, v, I) false

Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, v,
::Colon, I::Colon) where {T, N}
VA.u[I] = v
end

Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::Colon) = Base.setindex!(VA.u, v, I)
@deprecate Base.setindex!(VA::AbstractVectorOfArray{T,N,A}, v, I::Colon) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(
Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::Colon) = Base.setindex!(
VA.u, v, I)
@deprecate Base.setindex!(VA::AbstractVectorOfArray{T, N, A}, v,
I::Colon) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(
VA.u, v, I) false

Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, v,
::Colon, I::AbstractArray{Int}) where {T, N}
VA.u[I] = v
end

Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::AbstractArray{Int}) = Base.setindex!(VA.u, v, I)
@deprecate Base.setindex!(VA::AbstractVectorOfArray{T,N,A}, v, I::AbstractArray{Int}) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(
Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::AbstractArray{Int}) = Base.setindex!(
VA.u, v, I)
@deprecate Base.setindex!(VA::AbstractVectorOfArray{T, N, A}, v,
I::AbstractArray{Int}) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(
VA, v, :, I) false

Base.@propagate_inbounds function Base.setindex!(
Expand Down Expand Up @@ -710,12 +965,18 @@ Base.ndims(::Type{<:AbstractVectorOfArray{T, N}}) where {T, N} = N
function Base.checkbounds(
::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:AbstractVector{T}},
idxs...) where {T, N}
if _has_ragged_end(idxs...)
return _checkbounds_ragged(Bool, VA, idxs...)
end
if length(idxs) == 2 && (idxs[1] == Colon() || idxs[1] == 1)
return checkbounds(Bool, VA.u, idxs[2])
end
return checkbounds(Bool, VA.u, idxs...)
end
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
if _has_ragged_end(idx...)
return _checkbounds_ragged(Bool, VA, idx...)
end
checkbounds(Bool, VA.u, last(idx)) || return false
if last(idx) isa Int
return checkbounds(Bool, VA.u[last(idx)], Base.front(idx)...)
Expand Down
Loading
Loading