diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 8a5dd1f0..97527124 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -147,27 +147,69 @@ end VA.t,VA.u end -# Broadcast - -#add_idxs(x,expr) = expr -#add_idxs{T<:AbstractVectorOfArray}(::Type{T},expr) = :($(expr)[i]) -#add_idxs{T<:AbstractArray}(::Type{Vector{T}},expr) = :($(expr)[i]) -#= -@generated function Base.broadcast!(f,A::AbstractVectorOfArray,B...) - exs = ((add_idxs(B[i],:(B[$i])) for i in eachindex(B))...) - :(for i in eachindex(A) - broadcast!(f,A[i],$(exs...)) - end) -end - -@generated function Base.broadcast(f,B::Union{Number,AbstractVectorOfArray}...) - arr_idx = 0 - for (i,b) in enumerate(B) - if b <: ArrayPartition - arr_idx = i - break +## broadcasting + +struct VectorOfArrayStyle <: Broadcast.AbstractArrayStyle{Any} end +VectorOfArrayStyle(::Any) = VectorOfArrayStyle() +VectorOfArrayStyle(::Any, ::Any) = VectorOfArrayStyle() + +# promotion rules +#@inline function Broadcast.BroadcastStyle(::VectorOfArrayStyle{AStyle}, ::VectorOfArrayStyle{BStyle}) where {AStyle, BStyle} +# VectorOfArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle())) +#end +Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.BroadcastStyle) = VectorOfArrayStyle() +Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}() + +function Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,S}}) where {T, S} + VectorOfArrayStyle() +end + +@inline function Base.copy(bc::Broadcast.Broadcasted{VectorOfArrayStyle}) + N = narrays(bc) + x = unpack_voa(bc, 1) + VectorOfArray(map(1:N) do i + copy(unpack_voa(bc, i)) + end) +end + +@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{VectorOfArrayStyle}) + N = narrays(bc) + @inbounds for i in 1:N + copyto!(dest[i], unpack_voa(bc, i)) end - end - :(A = similar(B[$arr_idx]); broadcast!(f,A,B...); A) + dest end -=# + +## broadcasting utils + +""" + narrays(A...) + +Retrieve number of arrays in the AbstractVectorOfArrays of a broadcast +""" +narrays(A) = 0 +narrays(A::AbstractVectorOfArray) = length(A) +narrays(bc::Broadcast.Broadcasted) = _narrays(bc.args) +narrays(A, Bs...) = common_length(narrays(A), _narrays(Bs)) + +common_length(a, b) = + a == 0 ? b : + (b == 0 ? a : + (a == b ? a : + throw(DimensionMismatch("number of arrays must be equal")))) + +_narrays(args::AbstractVectorOfArray) = length(args) +@inline _narrays(args::Tuple) = common_length(narrays(args[1]), _narrays(Base.tail(args))) +_narrays(args::Tuple{Any}) = _narrays(args[1]) +_narrays(args::Tuple{}) = 0 + +# drop axes because it is easier to recompute +@inline unpack_voa(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args)) +@inline unpack_voa(bc::Broadcast.Broadcasted{VectorOfArrayStyle}, i) = Broadcast.Broadcasted(bc.f, unpack_args_voa(i, bc.args)) +unpack_voa(x,::Any) = x +unpack_voa(x::AbstractVectorOfArray, i) = x.u[i] +unpack_voa(x::AbstractArray{T,N}, i) where {T,N} = @view x[ntuple(x->Colon(),N-1)...,i] + +@inline unpack_args_voa(i, args::Tuple) = (unpack_voa(args[1], i), unpack_args_voa(i, Base.tail(args))...) +unpack_args_voa(i, args::Tuple{Any}) = (unpack_voa(args[1], i),) +unpack_args_voa(::Any, args::Tuple{}) = () diff --git a/test/basic_indexing.jl b/test/basic_indexing.jl index c7a06824..198b47fb 100644 --- a/test/basic_indexing.jl +++ b/test/basic_indexing.jl @@ -1,4 +1,4 @@ -using RecursiveArrayTools +using RecursiveArrayTools, Test # Example Problem recs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] @@ -86,3 +86,14 @@ v = VectorOfArray([zeros(20), zeros(10,10), zeros(3,3,3)]) v[CartesianIndex((2, 3, 2, 3))] = 1 @test v[CartesianIndex((2, 3, 2, 3))] == 1 @test v.u[3][2, 3, 2] == 1 + +v = VectorOfArray([rand(20), rand(10,10), rand(3,3,3)]) +w = v .* v +@test w isa VectorOfArray +@test w[1] isa Vector +@test w[1] == v[1] .* v[1] +@test w[2] == v[2] .* v[2] +@test w[3] == v[3] .* v[3] +x = copy(v) +x .= v .* v +@test x.u == w.u