diff --git a/src/BangBang.jl b/src/BangBang.jl index a714f0ad..0709c547 100644 --- a/src/BangBang.jl +++ b/src/BangBang.jl @@ -26,7 +26,15 @@ export @!, singletonof, splice!! -using Base.Broadcast: materialize! +using Base.Broadcast: + Broadcasted, + broadcasted, + combine_eltypes, + copyto_nonleaf!, + instantiate, + materialize!, + preprocess +using Base: promote_typejoin using Compat: hasproperty using InitialValues using LinearAlgebra diff --git a/src/NoBang/linearalgebra.jl b/src/NoBang/linearalgebra.jl index dbb17465..800adaa5 100644 --- a/src/NoBang/linearalgebra.jl +++ b/src/NoBang/linearalgebra.jl @@ -1,3 +1,2 @@ -add(A, B) = A .+ B mul(::Any, x, y) = x * y mul(C, A, B, α, β) = A * B * α + C * β diff --git a/src/base.jl b/src/base.jl index 535e0721..aeba6225 100644 --- a/src/base.jl +++ b/src/base.jl @@ -418,8 +418,46 @@ setproperty!!(value, name::Symbol, x) = setproperties!!(value, (; name => x)) """ materialize!!(dest, x) """ -@inline materialize!!(dest, x) = may(materialize!, dest, x) +@inline materialize!!(dest, x) = may(_materialize!!, dest, x) # TODO: maybe instantiate `x` and be aware of `x`'s style -pure(::typeof(materialize!)) = NoBang.materialize -possible(::typeof(materialize!), x, ::Any) = ismutable(x) +@inline _materialize!!(dest, bc::Broadcasted{Style}) where {Style} = + _copyto!!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest)))) + +pure(::typeof(_materialize!!)) = NoBang.materialize +possible(::typeof(_materialize!!), ::Any, ::Any) = false +possible(::typeof(_materialize!!), x::AbstractArray, ::Any) = ismutable(x) + +@noinline throwdm(axdest, axsrc) = + throw(DimensionMismatch("destination axes $axdest are not compatible with source axes $axsrc")) + +# Based on default `copy(bc)` implementation +@inline function _copyto!!(dest, bc::Broadcasted) + axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) + + ElType = combine_eltypes(bc.f, bc.args) + # Use `copyto!` if we can trust the inference result: + if ElType <: eltype(dest) + return copyto!(dest, bc) + elseif Base.isconcretetype(ElType) + return copyto!(similar(bc, promote_typejoin(eltype(dest), ElType)), bc) + end + + bc′ = preprocess(nothing, bc) + iter = eachindex(bc′) + y = iterate(iter) + y === nothing && return dest + + # Try to store the first value + I, state = y + @inbounds val = bc′[I] + if typeof(val) <: eltype(dest) + @inbounds dest[I] = val + dest′ = dest + else + dest′ = similar(bc′, typeof(val)) + end + + # Handle the rest + return copyto_nonleaf!(dest′, bc′, iter, state, 1) +end diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index 38aab89f..9eaacb0b 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -15,22 +15,7 @@ julia> add!!([1], [2]) 3 ``` """ -add!!(A, B) = may(add!, A, B) -add!(A, B) = A .+= B - -pure(::typeof(add!)) = NoBang.add -_asbb(::typeof(add!)) = add!! -possible(::typeof(add!), A, B) = ismutable(A) && _addeltype(A, B) <: eltype(A) - -_addeltype(A, B) = Base.promote_op(+, eltype(A), _eltype(B)) -_eltype(x) = eltype(x) -function _eltype(x::Broadcast.Broadcasted) - bc = Broadcast.instantiate(x) - return Base._return_type(getindex, Tuple{typeof(bc), Vararg{Int, ndims(bc)}}) -end -# TODO: Implement `materialize!!(dest, bc)` based on `copyto_nonleaf!` -# and use it like this: -# add!!(A, B) = materialize!!(A, instantiate(broadcasted(+, A, B))) +add!!(A, B) = materialize!!(A, instantiate(broadcasted(+, A, B))) """ mul!!(C, A, B, [α, β]) -> C′ diff --git a/test/test_add.jl b/test/test_add.jl index fee56e50..1fc3878b 100644 --- a/test/test_add.jl +++ b/test/test_add.jl @@ -10,7 +10,7 @@ using Base.Broadcast: broadcasted @test add!!(SVector(1), SVector(2)) === SVector(3) @test add!!(SVector(1), [2]) == [3] # ok if `SVector(3)` @test add!!([1], SVector(2)) ==ₜ [3] - @test add!!([1], [0.5]) ==ₜ [1.5] + @test add!!([1], [0.5]) ==ₜ Real[1.5] end @testset "mutation" begin @@ -20,7 +20,7 @@ end @testset "broadcasted" begin @test @inferred(add!!(broadcasted(*, [1], [2]), [2])) ==ₜ [4] - @test @inferred(add!!([1], broadcasted(/, [1], [2]))) ==ₜ [1.5] + @test @inferred(add!!([1], broadcasted(/, [1], [2]))) ==ₜ Real[1.5] x = [1] @test @inferred(add!!(x, broadcasted(*, [1], [2]))) === x ==ₜ [3] diff --git a/test/test_materialize.jl b/test/test_materialize.jl index da2a05e3..56ede065 100644 --- a/test/test_materialize.jl +++ b/test/test_materialize.jl @@ -5,9 +5,15 @@ using BangBang: air @testset begin @test materialize!!(nothing, air.([0, 1] .+ 2))::Vector{Int} == [2, 3] - @test materialize!!([NaN, NaN], air.([0, 1] .+ 2))::Vector{Float64} == [2, 3] + @test materialize!!([NaN, NaN], air.([0, 1] .+ 2))::Vector{Real} == [2, 3] @test materialize!!(nothing, air.((0, 1) .+ 2))::Tuple === (2, 3) @test materialize!!(nothing, air.(SVector(0, 1) .+ 2))::SVector === SVector(2, 3) end +@testset "type-unstable materialize!!" begin + unstabletrue = Ref{Any}(true) + unstableone(unused) = unstabletrue[] ? unstabletrue[] : unused + @test materialize!!(nothing, air.([1] .+ unstableone.([0]) ./ [1])) == [2] +end + end # module