diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index a38c9795..2778bfcc 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -388,35 +388,88 @@ end @inline Base.IteratorSize(::Type{<:AbstractVectorOfArray}) = Base.HasLength() @inline Base.first(VA::AbstractVectorOfArray) = first(VA.u) @inline Base.last(VA::AbstractVectorOfArray) = last(VA.u) -function Base.firstindex(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A} +function Base.firstindex(VA::AbstractVectorOfArray{T, N, A}) where {T, N, A} N > 1 && Base.depwarn( "Linear indexing of `AbstractVectorOfArray` is deprecated. Change `A[i]` to `A.u[i]` ", :firstindex) return firstindex(VA.u) end -function Base.lastindex(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A} - N > 1 && Base.depwarn( +function Base.lastindex(VA::AbstractVectorOfArray{T, N, A}) where {T, N, A} + N > 1 && Base.depwarn( "Linear indexing of `AbstractVectorOfArray` is deprecated. Change `A[i]` to `A.u[i]` ", :lastindex) return lastindex(VA.u) end +# Always return RaggedEnd for type stability. Use dim=0 to indicate a plain index stored in offset. +# _resolve_ragged_index and _column_indices handle the dim=0 case to extract the actual index value. +@inline function Base.lastindex(VA::AbstractVectorOfArray, d::Integer) + if d == ndims(VA) + return RaggedEnd(0, Int(lastindex(VA.u))) + elseif d < ndims(VA) + isempty(VA.u) && return RaggedEnd(0, 0) + return RaggedEnd(Int(d), 0) + else + return RaggedEnd(0, 1) + end +end + Base.getindex(A::AbstractVectorOfArray, I::Int) = A.u[I] Base.getindex(A::AbstractVectorOfArray, I::AbstractArray{Int}) = A.u[I] Base.getindex(A::AbstractDiffEqArray, I::Int) = A.u[I] Base.getindex(A::AbstractDiffEqArray, I::AbstractArray{Int}) = A.u[I] -@deprecate Base.getindex(VA::AbstractVectorOfArray{T,N,A}, I::Int) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false +@deprecate Base.getindex(VA::AbstractVectorOfArray{T, N, A}, + I::Int) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false -@deprecate Base.getindex(VA::AbstractVectorOfArray{T,N,A}, I::AbstractArray{Int}) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false +@deprecate Base.getindex(VA::AbstractVectorOfArray{T, N, A}, + I::AbstractArray{Int}) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false -@deprecate Base.getindex(VA::AbstractDiffEqArray{T,N,A}, I::AbstractArray{Int}) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false +@deprecate Base.getindex(VA::AbstractDiffEqArray{T, N, A}, + I::AbstractArray{Int}) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false -@deprecate Base.getindex(VA::AbstractDiffEqArray{T,N,A}, i::Int) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} VA.u[i] false +@deprecate Base.getindex(VA::AbstractDiffEqArray{T, N, A}, + i::Int) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} VA.u[i] false __parameterless_type(T) = Base.typename(T).wrapper +# `end` support for ragged inner arrays +# Use runtime fields instead of type parameters for type stability +struct RaggedEnd + dim::Int + offset::Int +end +RaggedEnd(dim::Int) = RaggedEnd(dim, 0) + +Base.:+(re::RaggedEnd, n::Integer) = RaggedEnd(re.dim, re.offset + Int(n)) +Base.:-(re::RaggedEnd, n::Integer) = RaggedEnd(re.dim, re.offset - Int(n)) +Base.:+(n::Integer, re::RaggedEnd) = re + n + +struct RaggedRange + dim::Int + start::Int + step::Int + offset::Int +end + +Base.:(:)(stop::RaggedEnd) = RaggedRange(stop.dim, 1, 1, stop.offset) +function Base.:(:)(start::Integer, stop::RaggedEnd) + RaggedRange(stop.dim, Int(start), 1, stop.offset) +end +function Base.:(:)(start::Integer, step::Integer, stop::RaggedEnd) + RaggedRange(stop.dim, Int(start), Int(step), stop.offset) +end + +@inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer) + length(VA.u) <= 1 && return false + first_size = size(VA.u[1], d) + @inbounds for idx in 2:length(VA.u) + size(VA.u[idx], d) == first_size || return true + end + return false +end + Base.@propagate_inbounds function _getindex( A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int) A.u[I] @@ -487,11 +540,206 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb return getindex(A, all_variable_symbols(A), args...) end +@inline _column_indices(VA::AbstractVectorOfArray, idx) = idx === Colon() ? + eachindex(VA.u) : idx +@inline function _column_indices(VA::AbstractVectorOfArray, idx::AbstractArray{Bool}) + findall(idx) +end +@inline function _column_indices(VA::AbstractVectorOfArray, idx::RaggedEnd) + # RaggedEnd with dim=0 means it's just a plain index stored in offset + idx.dim == 0 ? idx.offset : idx +end + +@inline _resolve_ragged_index(idx, ::AbstractVectorOfArray, ::Any) = idx +@inline function _resolve_ragged_index(idx::RaggedEnd, VA::AbstractVectorOfArray, col) + if idx.dim == 0 + # Special case: dim=0 means the offset contains the actual index value + return idx.offset + else + return lastindex(VA.u[col], idx.dim) + idx.offset + end +end +@inline function _resolve_ragged_index(idx::RaggedRange, VA::AbstractVectorOfArray, col) + stop_val = if idx.dim == 0 + # dim == 0 is the sentinel for an already-resolved plain index stored in offset + idx.offset + else + lastindex(VA.u[col], idx.dim) + idx.offset + end + return Base.range(idx.start; step = idx.step, stop = stop_val) +end +@inline function _resolve_ragged_index( + idx::AbstractRange{<:RaggedEnd}, VA::AbstractVectorOfArray, col) + return Base.range(_resolve_ragged_index(first(idx), VA, col); step = step(idx), + stop = _resolve_ragged_index(last(idx), VA, col)) +end +@inline function _resolve_ragged_index(idx::Base.Slice, VA::AbstractVectorOfArray, col) + return Base.Slice(_resolve_ragged_index(idx.indices, VA, col)) +end +@inline function _resolve_ragged_index(idx::CartesianIndex, VA::AbstractVectorOfArray, col) + return CartesianIndex(_resolve_ragged_indices(Tuple(idx), VA, col)...) +end +@inline function _resolve_ragged_index( + idx::AbstractArray{<:RaggedEnd}, VA::AbstractVectorOfArray, col) + return map(i -> _resolve_ragged_index(i, VA, col), idx) +end +@inline function _resolve_ragged_index( + idx::AbstractArray{<:RaggedRange}, VA::AbstractVectorOfArray, col) + return map(i -> _resolve_ragged_index(i, VA, col), idx) +end +@inline function _resolve_ragged_index(idx::AbstractArray, VA::AbstractVectorOfArray, col) + return _has_ragged_end(idx) ? map(i -> _resolve_ragged_index(i, VA, col), idx) : idx +end + +@inline function _resolve_ragged_indices(idxs::Tuple, VA::AbstractVectorOfArray, col) + map(i -> _resolve_ragged_index(i, VA, col), idxs) +end + +@inline function _has_ragged_end(x) + x isa RaggedEnd && return true + x isa RaggedRange && return true + x isa Base.Slice && return _has_ragged_end(x.indices) + x isa CartesianIndex && return _has_ragged_end(Tuple(x)) + x isa AbstractRange && return eltype(x) <: Union{RaggedEnd, RaggedRange} + if x isa AbstractArray + el = eltype(x) + return el <: Union{RaggedEnd, RaggedRange} || + (el === Any && any(_has_ragged_end, x)) + end + x isa Tuple && return any(_has_ragged_end, x) + return false +end +@inline _has_ragged_end(x, xs...) = _has_ragged_end(x) || _has_ragged_end(xs) + +@inline function _ragged_getindex(A::AbstractVectorOfArray, I...) + n = ndims(A) + # Special-case when user provided one fewer index than ndims(A): last index is column selector. + if length(I) == n - 1 + raw_cols = last(I) + # If the raw selector is a RaggedEnd/RaggedRange referring to inner dims, reinterpret as column selector. + cols = if raw_cols isa RaggedEnd && raw_cols.dim != 0 + lastindex(A.u) + raw_cols.offset + elseif raw_cols isa RaggedRange && raw_cols.dim != 0 + stop_val = lastindex(A.u) + raw_cols.offset + Base.range(raw_cols.start; step = raw_cols.step, stop = stop_val) + else + _column_indices(A, raw_cols) + end + prefix = Base.front(I) + if cols isa Int + resolved_prefix = _resolve_ragged_indices(prefix, A, cols) + inner_nd = ndims(A.u[cols]) + n_missing = inner_nd - length(resolved_prefix) + padded = if n_missing > 0 + if all(idx -> idx === Colon(), resolved_prefix) + (resolved_prefix..., ntuple(_ -> Colon(), n_missing)...) + else + (resolved_prefix..., + (lastindex(A.u[cols], length(resolved_prefix) + i) for i in 1:n_missing)...) + end + else + resolved_prefix + end + return A.u[cols][padded...] + else + return VectorOfArray([begin + resolved_prefix = _resolve_ragged_indices(prefix, A, col) + inner_nd = ndims(A.u[col]) + n_missing = inner_nd - length(resolved_prefix) + padded = if n_missing > 0 + if all(idx -> idx === Colon(), resolved_prefix) + (resolved_prefix..., + ntuple(_ -> Colon(), n_missing)...) + else + (resolved_prefix..., + (lastindex(A.u[col], + length(resolved_prefix) + i) for i in 1:n_missing)...) + end + else + resolved_prefix + end + A.u[col][padded...] + end + for col in cols]) + end + end + + # Otherwise, use the full-length interpretation (last index is column selector; missing columns default to Colon()). + if length(I) == n + cols = last(I) + prefix = Base.front(I) + else + cols = Colon() + prefix = I + end + if cols isa Int + if all(idx -> idx === Colon(), prefix) + return A.u[cols] + end + resolved = _resolve_ragged_indices(prefix, A, cols) + inner_nd = ndims(A.u[cols]) + padded = (resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...) + return A.u[cols][padded...] + else + col_idxs = _column_indices(A, cols) + # Resolve sentinel RaggedEnd/RaggedRange (dim==0) for column selection + if col_idxs isa RaggedEnd + col_idxs = _resolve_ragged_index(col_idxs, A, 1) + elseif col_idxs isa RaggedRange + col_idxs = _resolve_ragged_index(col_idxs, A, 1) + end + # If we're selecting whole inner arrays (all leading indices are Colons), + # keep the result as a VectorOfArray to match non-ragged behavior. + if all(idx -> idx === Colon(), prefix) + if col_idxs isa Int + return A.u[col_idxs] + else + return VectorOfArray(A.u[col_idxs]) + end + end + # If col_idxs resolved to a single Int, handle it directly + if col_idxs isa Int + resolved = _resolve_ragged_indices(prefix, A, col_idxs) + inner_nd = ndims(A.u[col_idxs]) + padded = ( + resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...) + return A.u[col_idxs][padded...] + end + vals = map(col_idxs) do col + resolved = _resolve_ragged_indices(prefix, A, col) + inner_nd = ndims(A.u[col]) + padded = ( + resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...) + A.u[col][padded...] + end + return stack(vals) + end +end + +@inline function _checkbounds_ragged(::Type{Bool}, VA::AbstractVectorOfArray, idxs...) + cols = _column_indices(VA, last(idxs)) + prefix = Base.front(idxs) + if cols isa Int + resolved = _resolve_ragged_indices(prefix, VA, cols) + return checkbounds(Bool, VA.u, cols) && checkbounds(Bool, VA.u[cols], resolved...) + else + for col in cols + resolved = _resolve_ragged_indices(prefix, VA, col) + checkbounds(Bool, VA.u, col) || return false + checkbounds(Bool, VA.u[col], resolved...) || return false + end + return true + end +end + Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, args...) symtype = symbolic_type(_arg) elsymtype = symbolic_type(eltype(_arg)) if symtype == NotSymbolic() && elsymtype == NotSymbolic() + if _has_ragged_end(_arg, args...) + return _ragged_getindex(A, _arg, args...) + end if _arg isa Union{Tuple, AbstractArray} && any(x -> symbolic_type(x) != NotSymbolic(), _arg) _getindex(A, symtype, elsymtype, _arg, args...) @@ -523,16 +771,21 @@ Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N} VA.u[I] = v end -Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::Int) = Base.setindex!(VA.u, v, I) -@deprecate Base.setindex!(VA::AbstractVectorOfArray{T,N,A}, v, I::Int) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(VA.u, v, I) false +Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::Int) = Base.setindex!( + VA.u, v, I) +@deprecate Base.setindex!(VA::AbstractVectorOfArray{T, N, A}, v, + I::Int) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!( + VA.u, v, I) false Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, v, ::Colon, I::Colon) where {T, N} VA.u[I] = v end -Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::Colon) = Base.setindex!(VA.u, v, I) -@deprecate Base.setindex!(VA::AbstractVectorOfArray{T,N,A}, v, I::Colon) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!( +Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::Colon) = Base.setindex!( + VA.u, v, I) +@deprecate Base.setindex!(VA::AbstractVectorOfArray{T, N, A}, v, + I::Colon) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!( VA.u, v, I) false Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, v, @@ -540,8 +793,10 @@ Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N} VA.u[I] = v end -Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::AbstractArray{Int}) = Base.setindex!(VA.u, v, I) -@deprecate Base.setindex!(VA::AbstractVectorOfArray{T,N,A}, v, I::AbstractArray{Int}) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!( +Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::AbstractArray{Int}) = Base.setindex!( + VA.u, v, I) +@deprecate Base.setindex!(VA::AbstractVectorOfArray{T, N, A}, v, + I::AbstractArray{Int}) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!( VA, v, :, I) false Base.@propagate_inbounds function Base.setindex!( @@ -710,12 +965,18 @@ Base.ndims(::Type{<:AbstractVectorOfArray{T, N}}) where {T, N} = N function Base.checkbounds( ::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, idxs...) where {T, N} + if _has_ragged_end(idxs...) + return _checkbounds_ragged(Bool, VA, idxs...) + end if length(idxs) == 2 && (idxs[1] == Colon() || idxs[1] == 1) return checkbounds(Bool, VA.u, idxs[2]) end return checkbounds(Bool, VA.u, idxs...) end function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...) + if _has_ragged_end(idx...) + return _checkbounds_ragged(Bool, VA, idx...) + end checkbounds(Bool, VA.u, last(idx)) || return false if last(idx) isa Int return checkbounds(Bool, VA.u[last(idx)], Base.front(idx)...) diff --git a/test/basic_indexing.jl b/test/basic_indexing.jl index 273c2513..85feee65 100644 --- a/test/basic_indexing.jl +++ b/test/basic_indexing.jl @@ -162,6 +162,40 @@ f2 = VectorOfArray([[1.0, 2.0], [3.0]]) @test collect(view(f2, :, 1)) == f2[:, 1] @test collect(view(f2, :, 2)) == f2[:, 2] +# Test `end` with ragged arrays +ragged = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]]) +@test ragged[end, 1] == 2.0 +@test ragged[end, 2] == 5.0 +@test ragged[end, 3] == 9.0 +@test ragged[end - 1, 1] == 1.0 +@test ragged[end - 1, 2] == 4.0 +@test ragged[end - 1, 3] == 8.0 +@test ragged[1:end, 1] == [1.0, 2.0] +@test ragged[1:end, 2] == [3.0, 4.0, 5.0] +@test ragged[1:end, 3] == [6.0, 7.0, 8.0, 9.0] +@test ragged[:, end] == [6.0, 7.0, 8.0, 9.0] +@test ragged[:, 2:end] == VectorOfArray(ragged.u[2:end]) + +ragged2 = VectorOfArray([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0], [7.0, 8.0, 9.0]]) +@test ragged2[end, 1] == 4.0 +@test ragged2[end, 2] == 6.0 +@test ragged2[end, 3] == 9.0 +@test ragged2[end - 1, 1] == 3.0 +@test ragged2[end - 1, 2] == 5.0 +@test ragged2[end - 1, 3] == 8.0 +@test ragged2[end - 2, 1] == 2.0 +@test ragged2[1:end, 1] == [1.0, 2.0, 3.0, 4.0] +@test ragged2[1:end, 2] == [5.0, 6.0] +@test ragged2[1:end, 3] == [7.0, 8.0, 9.0] +@test ragged2[2:end, 1] == [2.0, 3.0, 4.0] +@test ragged2[2:end, 2] == [6.0] +@test ragged2[2:end, 3] == [8.0, 9.0] +@test ragged2[:, end] == [7.0, 8.0, 9.0] +@test ragged2[:, 2:end] == VectorOfArray(ragged2.u[2:end]) +@test ragged2[1:(end - 1), 1] == [1.0, 2.0, 3.0] +@test ragged2[1:(end - 1), 2] == [5.0] +@test ragged2[1:(end - 1), 3] == [7.0, 8.0] + # Broadcasting of heterogeneous arrays (issue #454) u = VectorOfArray([[1.0], [2.0, 3.0]]) @test length(view(u, :, 1)) == 1 @@ -179,6 +213,10 @@ u[1, :, 2] .= [1.0, 2.0, 3.0] # partial column selection by indices u[1, [1, 3], 2] .= [7.0, 9.0] @test u.u[2] == [7.0 2.0 9.0] +# test scalar indexing with end +@test u[1, 1, end] == u.u[end][1, 1] +@test u[1, end, end] == u.u[end][1, end] +@test u[1, 2:end, end] == vec(u.u[end][1, 2:end]) # 3D inner arrays (tensors) with ragged third dimension u = VectorOfArray([zeros(2, 1, n) for n in (2, 3)]) @@ -193,6 +231,10 @@ u[1:2, 1, [1, 3], 2] .= [1.0 3.0; 2.0 4.0] @test u.u[2][2, 1, 1] == 2.0 @test u.u[2][1, 1, 3] == 3.0 @test u.u[2][2, 1, 3] == 4.0 +@test u[:, :, end] == u.u[end] +@test u[:, :, 2:end] == VectorOfArray(u.u[2:end]) +@test u[1, 1, end] == u.u[end][1, 1, end] +@test u[end, 1, end] == u.u[end][end, 1, end] # Test that views can be modified f3 = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0]])