Skip to content

Commit

Permalink
Use mutate-or-widen implementation of materialize in Base
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Dec 22, 2019
1 parent d1066c2 commit ece75d9
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 24 deletions.
10 changes: 9 additions & 1 deletion src/BangBang.jl
Expand Up @@ -26,7 +26,15 @@ export @!,
singletonof,
splice!!

using Base.Broadcast: materialize!
using Base.Broadcast:
Broadcasted,
broadcasted,
combine_eltypes,
copyto_nonleaf!,
instantiate,
materialize!,
preprocess
using Base: promote_typejoin
using Compat: hasproperty
using InitialValues
using LinearAlgebra
Expand Down
1 change: 0 additions & 1 deletion src/NoBang/linearalgebra.jl
@@ -1,3 +1,2 @@
add(A, B) = A .+ B
mul(::Any, x, y) = x * y
mul(C, A, B, α, β) = A * B * α + C * β
44 changes: 41 additions & 3 deletions src/base.jl
Expand Up @@ -418,8 +418,46 @@ setproperty!!(value, name::Symbol, x) = setproperties!!(value, (; name => x))
"""
materialize!!(dest, x)
"""
@inline materialize!!(dest, x) = may(materialize!, dest, x)
@inline materialize!!(dest, x) = may(_materialize!!, dest, x)
# TODO: maybe instantiate `x` and be aware of `x`'s style

pure(::typeof(materialize!)) = NoBang.materialize
possible(::typeof(materialize!), x, ::Any) = ismutable(x)
@inline _materialize!!(dest, bc::Broadcasted{Style}) where {Style} =
_copyto!!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))

pure(::typeof(_materialize!!)) = NoBang.materialize
possible(::typeof(_materialize!!), ::Any, ::Any) = false
possible(::typeof(_materialize!!), x::AbstractArray, ::Any) = ismutable(x)

@noinline throwdm(axdest, axsrc) =
throw(DimensionMismatch("destination axes $axdest are not compatible with source axes $axsrc"))

# Based on default `copy(bc)` implementation
@inline function _copyto!!(dest, bc::Broadcasted)
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))

ElType = combine_eltypes(bc.f, bc.args)
# Use `copyto!` if we can trust the inference result:
if ElType <: eltype(dest)
return copyto!(dest, bc)
elseif Base.isconcretetype(ElType)
return copyto!(similar(bc, promote_typejoin(eltype(dest), ElType)), bc)
end

bc′ = preprocess(nothing, bc)
iter = eachindex(bc′)
y = iterate(iter)
y === nothing && return dest

# Try to store the first value
I, state = y
@inbounds val = bc′[I]
if typeof(val) <: eltype(dest)
@inbounds dest[I] = val
dest′ = dest
else
dest′ = similar(bc′, typeof(val))
end

# Handle the rest
return copyto_nonleaf!(dest′, bc′, iter, state, 1)
end
17 changes: 1 addition & 16 deletions src/linearalgebra.jl
Expand Up @@ -15,22 +15,7 @@ julia> add!!([1], [2])
3
```
"""
add!!(A, B) = may(add!, A, B)
add!(A, B) = A .+= B

pure(::typeof(add!)) = NoBang.add
_asbb(::typeof(add!)) = add!!
possible(::typeof(add!), A, B) = ismutable(A) && _addeltype(A, B) <: eltype(A)

_addeltype(A, B) = Base.promote_op(+, eltype(A), _eltype(B))
_eltype(x) = eltype(x)
function _eltype(x::Broadcast.Broadcasted)
bc = Broadcast.instantiate(x)
return Base._return_type(getindex, Tuple{typeof(bc), Vararg{Int, ndims(bc)}})
end
# TODO: Implement `materialize!!(dest, bc)` based on `copyto_nonleaf!`
# and use it like this:
# add!!(A, B) = materialize!!(A, instantiate(broadcasted(+, A, B)))
add!!(A, B) = materialize!!(A, instantiate(broadcasted(+, A, B)))

"""
mul!!(C, A, B, [α, β]) -> C′
Expand Down
4 changes: 2 additions & 2 deletions test/test_add.jl
Expand Up @@ -10,7 +10,7 @@ using Base.Broadcast: broadcasted
@test add!!(SVector(1), SVector(2)) === SVector(3)
@test add!!(SVector(1), [2]) == [3] # ok if `SVector(3)`
@test add!!([1], SVector(2)) ==ₜ [3]
@test add!!([1], [0.5]) ==ₜ [1.5]
@test add!!([1], [0.5]) ==Real[1.5]
end

@testset "mutation" begin
Expand All @@ -20,7 +20,7 @@ end

@testset "broadcasted" begin
@test @inferred(add!!(broadcasted(*, [1], [2]), [2])) ==ₜ [4]
@test @inferred(add!!([1], broadcasted(/, [1], [2]))) ==ₜ [1.5]
@test @inferred(add!!([1], broadcasted(/, [1], [2]))) ==Real[1.5]

x = [1]
@test @inferred(add!!(x, broadcasted(*, [1], [2]))) === x ==ₜ [3]
Expand Down
8 changes: 7 additions & 1 deletion test/test_materialize.jl
Expand Up @@ -5,9 +5,15 @@ using BangBang: air

@testset begin
@test materialize!!(nothing, air.([0, 1] .+ 2))::Vector{Int} == [2, 3]
@test materialize!!([NaN, NaN], air.([0, 1] .+ 2))::Vector{Float64} == [2, 3]
@test materialize!!([NaN, NaN], air.([0, 1] .+ 2))::Vector{Real} == [2, 3]
@test materialize!!(nothing, air.((0, 1) .+ 2))::Tuple === (2, 3)
@test materialize!!(nothing, air.(SVector(0, 1) .+ 2))::SVector === SVector(2, 3)
end

@testset "type-unstable materialize!!" begin
unstabletrue = Ref{Any}(true)
unstableone(unused) = unstabletrue[] ? unstabletrue[] : unused
@test materialize!!(nothing, air.([1] .+ unstableone.([0]) ./ [1])) == [2]
end

end # module

0 comments on commit ece75d9

Please sign in to comment.