From 3a28e4cf357d7db4d5b6befb496721b39e9fe6d9 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 5 Jun 2020 23:25:00 -0400 Subject: [PATCH 1/5] More robust broadcast --- src/vector_of_array.jl | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 97527124..1a28250b 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -149,30 +149,24 @@ end ## broadcasting -struct VectorOfArrayStyle <: Broadcast.AbstractArrayStyle{Any} end -VectorOfArrayStyle(::Any) = VectorOfArrayStyle() -VectorOfArrayStyle(::Any, ::Any) = VectorOfArrayStyle() - -# promotion rules -#@inline function Broadcast.BroadcastStyle(::VectorOfArrayStyle{AStyle}, ::VectorOfArrayStyle{BStyle}) where {AStyle, BStyle} -# VectorOfArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle())) -#end -Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.BroadcastStyle) = VectorOfArrayStyle() -Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}() - -function Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,S}}) where {T, S} - VectorOfArrayStyle() -end +struct VectorOfArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end # N is only used when voa sees other abstract arrays +VectorOfArrayStyle(::Val{N}) where N = VectorOfArrayStyle{N}() + +# The order is important here. We want to override Base.Broadcast.DefaultArrayStyle to return another Base.Broadcast.DefaultArrayStyle. +Broadcast.BroadcastStyle(::VectorOfArrayStyle{N}, a::Base.Broadcast.DefaultArrayStyle{M}) where {M,N} = Base.Broadcast.DefaultArrayStyle(Val(max(M, N))) +Broadcast.BroadcastStyle(::VectorOfArrayStyle{N}, a::Base.Broadcast.AbstractArrayStyle{M}) where {M,N} = typeof(a)(Val(max(M, N))) +Broadcast.BroadcastStyle(::VectorOfArrayStyle{M}, ::VectorOfArrayStyle{N}) where {M,N} = VectorOfArrayStyle(Val(max(M, N))) +Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,N}}) where {T,N} = VectorOfArrayStyle{N}() -@inline function Base.copy(bc::Broadcast.Broadcasted{VectorOfArrayStyle}) +@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle}) N = narrays(bc) x = unpack_voa(bc, 1) - VectorOfArray(map(1:N) do i + return VectorOfArray(map(1:N) do i copy(unpack_voa(bc, i)) end) end -@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{VectorOfArrayStyle}) +@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{<:VectorOfArrayStyle}) N = narrays(bc) @inbounds for i in 1:N copyto!(dest[i], unpack_voa(bc, i)) @@ -205,7 +199,7 @@ _narrays(args::Tuple{}) = 0 # drop axes because it is easier to recompute @inline unpack_voa(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args)) -@inline unpack_voa(bc::Broadcast.Broadcasted{VectorOfArrayStyle}, i) = Broadcast.Broadcasted(bc.f, unpack_args_voa(i, bc.args)) +@inline unpack_voa(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle}, i) = Broadcast.Broadcasted(bc.f, unpack_args_voa(i, bc.args)) unpack_voa(x,::Any) = x unpack_voa(x::AbstractVectorOfArray, i) = x.u[i] unpack_voa(x::AbstractArray{T,N}, i) where {T,N} = @view x[ntuple(x->Colon(),N-1)...,i] From 8d6155f101bd3f173858cec99dcf55c84fd7a4fe Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 5 Jun 2020 23:34:13 -0400 Subject: [PATCH 2/5] Handle 0 dim broadcast --- src/vector_of_array.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 1a28250b..e7e69128 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -153,6 +153,7 @@ struct VectorOfArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end # N is only VectorOfArrayStyle(::Val{N}) where N = VectorOfArrayStyle{N}() # The order is important here. We want to override Base.Broadcast.DefaultArrayStyle to return another Base.Broadcast.DefaultArrayStyle. +Broadcast.BroadcastStyle(a::VectorOfArrayStyle, ::Base.Broadcast.DefaultArrayStyle{0}) = a Broadcast.BroadcastStyle(::VectorOfArrayStyle{N}, a::Base.Broadcast.DefaultArrayStyle{M}) where {M,N} = Base.Broadcast.DefaultArrayStyle(Val(max(M, N))) Broadcast.BroadcastStyle(::VectorOfArrayStyle{N}, a::Base.Broadcast.AbstractArrayStyle{M}) where {M,N} = typeof(a)(Val(max(M, N))) Broadcast.BroadcastStyle(::VectorOfArrayStyle{M}, ::VectorOfArrayStyle{N}) where {M,N} = VectorOfArrayStyle(Val(max(M, N))) @@ -161,7 +162,7 @@ Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,N}}) where {T,N} = Vec @inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle}) N = narrays(bc) x = unpack_voa(bc, 1) - return VectorOfArray(map(1:N) do i + VectorOfArray(map(1:N) do i copy(unpack_voa(bc, i)) end) end @@ -195,7 +196,7 @@ common_length(a, b) = _narrays(args::AbstractVectorOfArray) = length(args) @inline _narrays(args::Tuple) = common_length(narrays(args[1]), _narrays(Base.tail(args))) _narrays(args::Tuple{Any}) = _narrays(args[1]) -_narrays(args::Tuple{}) = 0 +_narrays(::Any) = 0 # drop axes because it is easier to recompute @inline unpack_voa(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args)) From 549b7544d37841eaa165c1973cc97684162414d9 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 5 Jun 2020 23:35:39 -0400 Subject: [PATCH 3/5] Add tests --- test/basic_indexing.jl | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/test/basic_indexing.jl b/test/basic_indexing.jl index 198b47fb..e527130d 100644 --- a/test/basic_indexing.jl +++ b/test/basic_indexing.jl @@ -4,6 +4,16 @@ using RecursiveArrayTools, Test recs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] testa = cat(recs..., dims=2) testva = VectorOfArray(recs) + +# broadcast with array +X = rand(3, 3) +mulX = testva .* X +ref = mapreduce((x,y)->x.*y, hcat, testva, eachcol(X)) +@test mulX == ref +fill!(mulX, 0) +mulX .= testva .* X +@test mulX == ref + t = [1,2,3] diffeq = DiffEqArray(recs,t) @test Array(testva) == [1 4 7 @@ -73,11 +83,6 @@ testva = VectorOfArray(recs) #TODO: clearly this printed form is nonsense @test testva[:, 1] == recs[1] testva[1:2, 1:2] -# Test broadcast -a = testva .+ rand(3,3) -a.= testva -@test all(a .== testva) - recs = [rand(2,2) for i in 1:5] testva = VectorOfArray(recs) @test Array(testva) isa Array{Float64,3} @@ -97,3 +102,8 @@ w = v .* v x = copy(v) x .= v .* v @test x.u == w.u + +# broadcast with number +w = v .+ 1 +@test w isa VectorOfArray +@test w.u == map(x -> x .+ 1, v.u) From f008461c643f4251d84d4f4f56506ca0b87e3dff Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 5 Jun 2020 23:35:50 -0400 Subject: [PATCH 4/5] Patch release --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8e248236..3ca5b40d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveArrayTools" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" authors = ["Chris Rackauckas "] -version = "2.4.0" +version = "2.4.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From c89d00cf52c80268479b6881daee7dfd6f2d95ad Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 5 Jun 2020 23:39:53 -0400 Subject: [PATCH 5/5] Fix AbstractVectorOfArray printing --- src/vector_of_array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index e7e69128..a96fa907 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -129,8 +129,8 @@ Base.vec(VA::AbstractVectorOfArray) = vec(convert(Array,VA)) # Allocates @inline Statistics.cor(VA::AbstractVectorOfArray;kwargs...) = cor(Array(VA);kwargs...) # make it show just like its data -Base.show(io::IO, x::AbstractVectorOfArray) = show(io, x.u) -Base.show(io::IO, m::MIME"text/plain", x::AbstractVectorOfArray) = show(io, m, x.u) +Base.show(io::IO, x::AbstractVectorOfArray) = Base.print_array(io, x.u) +Base.show(io::IO, m::MIME"text/plain", x::AbstractVectorOfArray) = (println(io, summary(x), ':'); show(io, m, x.u)) Base.summary(A::AbstractVectorOfArray) = string("VectorOfArray{",eltype(A),",",ndims(A),"}") Base.show(io::IO, x::AbstractDiffEqArray) = (print(io,"t: ");show(io, x.t);println(io);print(io,"u: ");show(io, x.u))