Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RecursiveArrayTools"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "2.4.0"
version = "2.4.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
33 changes: 14 additions & 19 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ Base.vec(VA::AbstractVectorOfArray) = vec(convert(Array,VA)) # Allocates
@inline Statistics.cor(VA::AbstractVectorOfArray;kwargs...) = cor(Array(VA);kwargs...)

# make it show just like its data
Base.show(io::IO, x::AbstractVectorOfArray) = show(io, x.u)
Base.show(io::IO, m::MIME"text/plain", x::AbstractVectorOfArray) = show(io, m, x.u)
Base.show(io::IO, x::AbstractVectorOfArray) = Base.print_array(io, x.u)
Base.show(io::IO, m::MIME"text/plain", x::AbstractVectorOfArray) = (println(io, summary(x), ':'); show(io, m, x.u))
Base.summary(A::AbstractVectorOfArray) = string("VectorOfArray{",eltype(A),",",ndims(A),"}")

Base.show(io::IO, x::AbstractDiffEqArray) = (print(io,"t: ");show(io, x.t);println(io);print(io,"u: ");show(io, x.u))
Expand All @@ -149,30 +149,25 @@ end

## broadcasting

struct VectorOfArrayStyle <: Broadcast.AbstractArrayStyle{Any} end
VectorOfArrayStyle(::Any) = VectorOfArrayStyle()
VectorOfArrayStyle(::Any, ::Any) = VectorOfArrayStyle()
struct VectorOfArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end # N is only used when voa sees other abstract arrays
VectorOfArrayStyle(::Val{N}) where N = VectorOfArrayStyle{N}()

# promotion rules
#@inline function Broadcast.BroadcastStyle(::VectorOfArrayStyle{AStyle}, ::VectorOfArrayStyle{BStyle}) where {AStyle, BStyle}
# VectorOfArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
#end
Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.BroadcastStyle) = VectorOfArrayStyle()
Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}()
# The order is important here. We want to override Base.Broadcast.DefaultArrayStyle to return another Base.Broadcast.DefaultArrayStyle.
Broadcast.BroadcastStyle(a::VectorOfArrayStyle, ::Base.Broadcast.DefaultArrayStyle{0}) = a
Broadcast.BroadcastStyle(::VectorOfArrayStyle{N}, a::Base.Broadcast.DefaultArrayStyle{M}) where {M,N} = Base.Broadcast.DefaultArrayStyle(Val(max(M, N)))
Broadcast.BroadcastStyle(::VectorOfArrayStyle{N}, a::Base.Broadcast.AbstractArrayStyle{M}) where {M,N} = typeof(a)(Val(max(M, N)))
Broadcast.BroadcastStyle(::VectorOfArrayStyle{M}, ::VectorOfArrayStyle{N}) where {M,N} = VectorOfArrayStyle(Val(max(M, N)))
Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,N}}) where {T,N} = VectorOfArrayStyle{N}()

function Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,S}}) where {T, S}
VectorOfArrayStyle()
end

@inline function Base.copy(bc::Broadcast.Broadcasted{VectorOfArrayStyle})
@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
N = narrays(bc)
x = unpack_voa(bc, 1)
VectorOfArray(map(1:N) do i
copy(unpack_voa(bc, i))
end)
end

@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{VectorOfArrayStyle})
@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
N = narrays(bc)
@inbounds for i in 1:N
copyto!(dest[i], unpack_voa(bc, i))
Expand Down Expand Up @@ -201,11 +196,11 @@ common_length(a, b) =
_narrays(args::AbstractVectorOfArray) = length(args)
@inline _narrays(args::Tuple) = common_length(narrays(args[1]), _narrays(Base.tail(args)))
_narrays(args::Tuple{Any}) = _narrays(args[1])
_narrays(args::Tuple{}) = 0
_narrays(::Any) = 0

# drop axes because it is easier to recompute
@inline unpack_voa(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args))
@inline unpack_voa(bc::Broadcast.Broadcasted{VectorOfArrayStyle}, i) = Broadcast.Broadcasted(bc.f, unpack_args_voa(i, bc.args))
@inline unpack_voa(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle}, i) = Broadcast.Broadcasted(bc.f, unpack_args_voa(i, bc.args))
unpack_voa(x,::Any) = x
unpack_voa(x::AbstractVectorOfArray, i) = x.u[i]
unpack_voa(x::AbstractArray{T,N}, i) where {T,N} = @view x[ntuple(x->Colon(),N-1)...,i]
Expand Down
20 changes: 15 additions & 5 deletions test/basic_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ using RecursiveArrayTools, Test
recs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
testa = cat(recs..., dims=2)
testva = VectorOfArray(recs)

# broadcast with array
X = rand(3, 3)
mulX = testva .* X
ref = mapreduce((x,y)->x.*y, hcat, testva, eachcol(X))
@test mulX == ref
fill!(mulX, 0)
mulX .= testva .* X
@test mulX == ref

t = [1,2,3]
diffeq = DiffEqArray(recs,t)
@test Array(testva) == [1 4 7
Expand Down Expand Up @@ -73,11 +83,6 @@ testva = VectorOfArray(recs) #TODO: clearly this printed form is nonsense
@test testva[:, 1] == recs[1]
testva[1:2, 1:2]

# Test broadcast
a = testva .+ rand(3,3)
a.= testva
@test all(a .== testva)

recs = [rand(2,2) for i in 1:5]
testva = VectorOfArray(recs)
@test Array(testva) isa Array{Float64,3}
Expand All @@ -97,3 +102,8 @@ w = v .* v
x = copy(v)
x .= v .* v
@test x.u == w.u

# broadcast with number
w = v .+ 1
@test w isa VectorOfArray
@test w.u == map(x -> x .+ 1, v.u)