Skip to content

Commit

Permalink
Merge pull request #19 from JuliaOpt/bl/vectorized
Browse files Browse the repository at this point in the history
Use broadcast! from @rewrite
  • Loading branch information
blegat committed Nov 30, 2019
2 parents 8a1c8f1 + 4314b92 commit d5a456b
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 92 deletions.
11 changes: 9 additions & 2 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,21 @@ function mutable_broadcast!(op::Function, A::Array, args::Vararg{Any, N}) where
return copyto!(A, mutable_broadcasted(instantiated))
end


"""
broadcast!(op::Function, args...)
Returns the value of `broadcast(op, args...)`, possibly modifying `args[1]`.
"""
function broadcast!(op::Function, args::Vararg{Any, N}) where N
return broadcast_fallback!(broadcast_mutability(args[1], op, args...), op, args...)
# TODO use traits instead
if any(x -> x isa LinearAlgebra.UniformScaling, args)
return broadcast_with_uniform_scaling!(op, args...)
else
return broadcast_fallback!(broadcast_mutability(args[1], op, args...), op, args...)
end
end
function broadcast_with_uniform_scaling!(op::Function, args::Vararg{Any, N}) where N
return op(args...)
end

function broadcast_fallback!(::NotMutable, op::Function, args::Vararg{Any, N}) where N
Expand Down
221 changes: 131 additions & 90 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ struct Zero end
Base.:(+)(zero::Zero, x) = copy_if_mutable(x)
# `add_mul(zero, ...)` redirects to `muladd(..., zero)` which calls `... + zero`.
Base.:(+)(x, zero::Zero) = copy_if_mutable(x)
broadcast!(::Union{typeof(add_mul), typeof(+)}, ::Zero, x) = copy_if_mutable(x)
broadcast!(::typeof(add_mul), ::Zero, x, y) = x * y

using Base.Meta

Expand Down Expand Up @@ -100,48 +102,56 @@ end
# See `JuMP._is_sum`
_is_sum(s::Symbol) = (s == :sum) || (s == :∑) || (s == )

function _parse_generator(x::Expr, aff::Symbol, lcoeffs, rcoeffs, new_var=gensym())
@assert isexpr(x,:call)
@assert length(x.args) > 1
@assert isexpr(x.args[2], :generator) || isexpr(x.args[2], :flatten)
header = x.args[1]
function _parse_generator(vectorized::Bool, inner_factor::Expr, current_sum::Union{Nothing, Symbol}, left_factors, right_factors, new_var=gensym())
@assert isexpr(inner_factor, :call)
@assert length(inner_factor.args) > 1
@assert isexpr(inner_factor.args[2], :generator) || isexpr(inner_factor.args[2], :flatten)
header = inner_factor.args[1]
if _is_sum(header)
_parse_generator_sum(x.args[2], aff, lcoeffs, rcoeffs, new_var)
_parse_generator_sum(vectorized, inner_factor.args[2], current_sum, left_factors, right_factors, new_var)
else
error("Expected sum outside generator expression; got $header")
error("Expected `sum` outside generator expression; got `$header`.")
end
end

function _parse_generator_sum(x::Expr, aff::Symbol, lcoeffs, rcoeffs, new_var)
function _parse_generator_sum(vectorized::Bool, inner_factor::Expr, current_sum::Union{Nothing, Symbol}, left_factors, right_factors, new_var)
# We used to preallocate the expression at the lowest level of the loop.
# When rewriting this some benchmarks revealed that it actually doesn't
# seem to help anymore, so might as well keep the code simple.
code = rewrite_generator(x, t -> _rewrite(t, aff, lcoeffs, rcoeffs, aff)[2])
return :($code; $new_var=$aff)
return _start_summing(current_sum, current_sum -> begin
code = rewrite_generator(inner_factor, t -> _rewrite(vectorized, t, current_sum, left_factors, right_factors, current_sum)[2])
return Expr(:block, code, :($new_var = $current_sum))
end)
end

_is_complex_expr(ex) = isa(ex, Expr) && !isexpr(ex, :ref)
function _is_decomposable_with_factors(ex)
# `.+` and `.-` do not support being decomposed if `left_factors` or
# `right_factors` are not empty. Otherwise, for instance
# `I * (x .+ 1)` would be rewritten into `(I * x) .+ (I * 1)` which is
# incorrect.
return _is_complex_expr(ex) && (
isempty(ex.args) ||
(ex.args[1] != :(.+) && ex.args[1] != :(.-))
)
end

function rewrite(x)
variable = gensym()
code = rewrite_and_return(x)
return variable, :($variable = $code)
end
function rewrite_and_return(x)
variable = gensym()
output_variable, code = _rewrite_to(x, variable)
output_variable, code = _rewrite(false, x, nothing, [], [])
# We need to use `let` because `rewrite(:(sum(i for i in 1:2))`
return quote
let
$variable = MutableArithmetics.Zero()
$code
$output_variable
end
end
end

_rewrite_to(x, variable::Symbol) = _rewrite(x, variable, [], [])

function _is_comparison(ex::Expr)
if isexpr(ex, :comparison)
# Range comparison `_ <= _ <= _`.
Expand All @@ -168,118 +178,149 @@ function _has_assignment_in_ref(ex::Expr)
end
_has_assignment_in_ref(other) = false

function rewrite_sum(terms, current::Symbol, lcoeffs::Vector, rcoeffs::Vector, output::Symbol, block = Expr(:block))
var = current
function rewrite_sum(vectorized::Bool, terms, current_sum::Union{Nothing, Symbol}, left_factors::Vector, right_factors::Vector, output::Symbol, block = Expr(:block))
var = current_sum
for term in terms[1:(end-1)]
var, code = _rewrite(term, var, lcoeffs, rcoeffs)
var, code = _rewrite(vectorized, term, var, left_factors, right_factors)
push!(block.args, code)
end
new_output, code = _rewrite(terms[end], var, lcoeffs, rcoeffs, output)
new_output, code = _rewrite(vectorized, terms[end], var, left_factors, right_factors, output)
@assert new_output == output
push!(block.args, code)
return output, block
end

function _start_summing(current_sum::Nothing, first_term::Function)
variable = gensym()
return Expr(:block, :($variable = MutableArithmetics.Zero()),
first_term(variable))
end
function _start_summing(current_sum::Symbol, first_term::Function)
return first_term(current_sum)
end

function _write_add_mul(vectorized, current_sum, left_factors, inner_factors, right_factors, new_var::Symbol)
if vectorized
f = :(MutableArithmetics.broadcast!)
else
f = :(MutableArithmetics.operate!)
end
return _start_summing(current_sum, current_sum -> begin
call_expr = Expr(:call, f, :(MutableArithmetics.add_mul), current_sum, left_factors..., inner_factors..., right_factors...)
return :($new_var = $call_expr)
end)
end

"""
_rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Symbol=gensym())
_rewrite(vectorized::Bool, inner_factor, current_sum::Union{Nothing, Symbol}, left_factors::Vector, right_factors::Vector, new_var::Symbol=gensym())
Return `new_var, code` such that `code` is equivalent to
```julia
new_var = aff + prod(lcoefs) * x * prod(rcoeffs)
new_var = prod(left_factors) * inner_factor * prod(right_factors)
```
if `current_sum` is `nothing`,
```julia
new_var = current_sum + prod(left_factors) * inner_factor * prod(right_factors)
```
if `current_sum` is a `Symbol` and `vectorized` is `false` and
```julia
new_var = current_sum .+ prod(left_factors) * inner_factor * prod(right_factors)
```
otherwise.
"""
function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Symbol=gensym())
if isexpr(x, :call)
if x.args[1] == :+
return rewrite_sum(x.args[2:end], aff, lcoeffs, rcoeffs, new_var)
elseif x.args[1] == :-
function _rewrite(vectorized::Bool, inner_factor, current_sum::Union{Symbol, Nothing}, left_factors::Vector, right_factors::Vector, new_var::Symbol=gensym())
if isexpr(inner_factor, :call)
# We need to verfify that `left_factors` and `right_factors` are empty for broadcast, see `_is_decomposable_with_factors`.
# We also need to verify that `current_sum` is `nothing` otherwise we are unsure that the elements in the containers have been copied, e.g., in
# `I + (x .+ 1)`, the offdiagonal entries of `I + x` as the same as `x` so we cannot do `broadcast!(add_mul, I + x, 1)`.
if inner_factor.args[1] == :+ || inner_factor.args[1] == :- ||
(current_sum === nothing && isempty(left_factors) && isempty(right_factors) && (inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-)))
block = Expr(:block)
if length(x.args) > 2 # not unary subtraction
aff_, code = _rewrite(x.args[2], aff, lcoeffs, rcoeffs)
if length(inner_factor.args) > 2 # not unary addition or subtraction
next_sum, code = _rewrite(vectorized, inner_factor.args[2], current_sum, left_factors, right_factors)
push!(block.args, code)
start = 3
else
aff_ = aff
next_sum = current_sum
start = 2
end
return rewrite_sum(x.args[start:end], aff_, vcat(-1, lcoeffs), rcoeffs, new_var, block)
elseif x.args[1] == :*
vectorized = vectorized || inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-)
if inner_factor.args[1] == :- || inner_factor.args[1] == :(.-)
left_factors = vcat(-1, left_factors)
end
return rewrite_sum(vectorized, inner_factor.args[start:end], next_sum, left_factors, right_factors, new_var, block)
elseif inner_factor.args[1] == :* # FIXME && !vectorized ?
# we might need to recurse on multiple arguments, e.g.,
# (x+y)*(x+y)
n_expr = mapreduce(_is_complex_expr, +, x.args)
if n_expr == 1 # special case, only recurse on one argument and don't create temporary objects
which_idx = 0
for i in 2:length(x.args)
if _is_complex_expr(x.args[i])
which_idx = i
end
# special case, only recurse on one argument and don't create temporary objects
if isone(mapreduce(_is_complex_expr, +, inner_factor.args)) &&
isone(mapreduce(_is_decomposable_with_factors, +, inner_factor.args))
# `findfirst` return the index in `2:...` so we need to add `1`.
which_idx = 1 + findfirst(2:length(inner_factor.args)) do i
_is_decomposable_with_factors(inner_factor.args[i])
end
return _rewrite(
x.args[which_idx], aff,
vcat(lcoeffs, [esc(x.args[i]) for i in 2:(which_idx - 1)]),
vcat(rcoeffs, [esc(x.args[i]) for i in (which_idx + 1):length(x.args)]),
vectorized,
inner_factor.args[which_idx], current_sum,
vcat(left_factors, [esc(inner_factor.args[i]) for i in 2:(which_idx - 1)]),
vcat(right_factors, [esc(inner_factor.args[i]) for i in (which_idx + 1):length(inner_factor.args)]),
new_var)
else
blk = Expr(:block)
for i in 2:length(x.args)
if _is_complex_expr(x.args[i])
s = gensym()
new_var_, parsed = _rewrite_to(x.args[i], s)
push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed))
x.args[i] = new_var_
for i in 2:length(inner_factor.args)
if _is_complex_expr(inner_factor.args[i])
new_var_, parsed = rewrite(inner_factor.args[i])
push!(blk.args, parsed)
inner_factor.args[i] = new_var_
else
x.args[i] = esc(x.args[i])
inner_factor.args[i] = esc(inner_factor.args[i])
end
end
callexpr = Expr(:call, :(MutableArithmetics.add_mul!), aff,
lcoeffs..., x.args[2:end]..., rcoeffs...)
push!(blk.args, :($new_var = $callexpr))
push!(blk.args, _write_add_mul(
vectorized, current_sum, left_factors,
inner_factor.args[2:end], right_factors, new_var
))
return new_var, blk
end
elseif x.args[1] == :^ && _is_complex_expr(x.args[2])
MulType = :(MA.promote_operation(*, typeof($(x.args[2])), typeof($(x.args[2]))))
if x.args[3] == 2
blk = Expr(:block)
s = gensym()
new_var_, parsed = _rewrite_to(x.args[2], s)
push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed))
push!(blk.args, :($new_var = MutableArithmetics.add_mul!(
$aff, $(Expr(:call, :*, lcoeffs..., new_var_, new_var_,
rcoeffs...)))))
return new_var, blk
elseif x.args[3] == 1
return _rewrite(:(convert($MulType, $(x.args[2]))), aff, lcoeffs, rcoeffs, new_var)
elseif x.args[3] == 0
return _rewrite(:(one($MulType)), aff, lcoeffs, rcoeffs, new_var)
elseif inner_factor.args[1] == :^ && _is_complex_expr(inner_factor.args[2]) # FIXME && !vectorized ?
MulType = :(MA.promote_operation(*, typeof($(inner_factor.args[2])), typeof($(inner_factor.args[2]))))
if inner_factor.args[3] == 2
new_var_, parsed = rewrite(inner_factor.args[2])
square_expr = _write_add_mul(
vectorized, current_sum, left_factors,
(new_var_, new_var_), right_factors, new_var
)
return new_var, Expr(:block, parsed, square_expr)
elseif inner_factor.args[3] == 1
return _rewrite(vectorized, :(convert($MulType, $(inner_factor.args[2]))), current_sum, left_factors, right_factors, new_var)
elseif inner_factor.args[3] == 0
return _rewrite(vectorized, :(one($MulType)), current_sum, left_factors, right_factors, new_var)
else
blk = Expr(:block)
s = gensym()
new_var_, parsed = _rewrite_to(x.args[2], s)
push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed))
push!(blk.args, :($new_var = MutableArithmetics.add_mul!(
$aff, $(Expr(:call, :*, lcoeffs...,
Expr(:call, :^, new_var_, esc(x.args[3])),
rcoeffs...)))))
return new_var, blk
new_var_, parsed = rewrite(inner_factor.args[2])
power_expr = _write_add_mul(
vectorized, current_sum, left_factors,
(Expr(:call, :^, new_var_, esc(inner_factor.args[3])),),
right_factors, new_var
)
return new_var, Expr(:block, parsed, power_expr)
end
elseif x.args[1] == :/
@assert length(x.args) == 3
numerator = x.args[2]
denom = x.args[3]
return _rewrite(numerator, aff, lcoeffs, vcat(esc(:(1 / $denom)), rcoeffs), new_var)
elseif length(x.args) >= 2 && (isexpr(x.args[2], :generator) || isexpr(x.args[2], :flatten))
return new_var, _parse_generator(x, aff, lcoeffs, rcoeffs, new_var)
elseif inner_factor.args[1] == :/ # FIXME && !vectorized ?
@assert length(inner_factor.args) == 3
numerator = inner_factor.args[2]
denom = inner_factor.args[3]
return _rewrite(vectorized, numerator, current_sum, left_factors, vcat(esc(:(1 / $denom)), right_factors), new_var)
elseif length(inner_factor.args) >= 2 && (isexpr(inner_factor.args[2], :generator) || isexpr(inner_factor.args[2], :flatten))
return new_var, _parse_generator(vectorized, inner_factor, current_sum, left_factors, right_factors, new_var)
end
elseif isexpr(x, :curly)
Base.error("The curly syntax (sum{},prod{},norm2{}) is no longer supported. Expression: `$x`.")
elseif isexpr(inner_factor, :curly)
Base.error("The curly syntax (sum{},prod{},norm2{}) is no longer supported. Expression: `$inner_factor`.")
end
if isa(x, Expr) && _is_comparison(x)
error("Unexpected comparison in expression `$x`.")
if isa(inner_factor, Expr) && _is_comparison(inner_factor)
error("Unexpected comparison in expression `$inner_factor`.")
end
if isa(x, Expr) && _has_assignment_in_ref(x)
error("Unexpected assignment in expression `$x`.")
if isa(inner_factor, Expr) && _has_assignment_in_ref(inner_factor)
error("Unexpected assignment in expression `$inner_factor`.")
end
# at the lowest level
callexpr = Expr(:call, :(MutableArithmetics.add_mul!), aff, lcoeffs..., esc(x), rcoeffs...)
return new_var, :($new_var = $callexpr)
return new_var, _write_add_mul(vectorized, current_sum, left_factors, (esc(inner_factor),), right_factors, new_var)
end
4 changes: 4 additions & 0 deletions test/dummy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ struct DummyBigInt <: MA.AbstractMutable
data::BigInt
end
DummyBigInt(J::UniformScaling) = DummyBigInt(J.λ)

# Broadcast
Base.ndims(::Type{DummyBigInt}) = 0
Base.broadcastable(x::DummyBigInt) = Ref(x)

Base.promote_rule(::Type{DummyBigInt}, ::Type{<:Union{Integer, UniformScaling{<:Integer}}}) = DummyBigInt
# `copy` on BigInt returns the same instance anyway
Base.copy(x::DummyBigInt) = x
Expand Down
2 changes: 2 additions & 0 deletions test/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ end
include("dummy.jl")

function error_test(x, y, z)
err = ErrorException("Expected `sum` outside generator expression; got `prod`.")
@test_macro_throws err MA.@rewrite(prod(i for i in 1:2))
err = ErrorException("Unexpected assignment in expression `y[j=1]`.")
@test_macro_throws err MA.@rewrite y[j = 1]
err = ErrorException("Unexpected assignment in expression `x[i=1]`.")
Expand Down

0 comments on commit d5a456b

Please sign in to comment.