Skip to content

Commit

Permalink
Merge 7ee6d28 into c88b584
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Nov 18, 2019
2 parents c88b584 + 7ee6d28 commit c2a7777
Show file tree
Hide file tree
Showing 8 changed files with 759 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
julia = "1.1"

[extras]
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "OffsetArrays", "SparseArrays"]
6 changes: 6 additions & 0 deletions src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,10 @@ include("Test/Test.jl")
include("bigint.jl")
include("linear_algebra.jl")

isequal_canonical(a, b) = a == b

include("rewrite.jl")

include("dispatch.jl")

end # module
9 changes: 9 additions & 0 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
abstract type AbstractMutable end

# Special-case because the the base version wants to do fill!(::Array{AbstractVariableRef}, zero(GenericAffExpr{Float64,eltype(x)}))
_one_indexed(A) = all(x -> isa(x, Base.OneTo), axes(A))
function LinearAlgebra.diagm(x::AbstractVector{<:AbstractMutable})
@assert _one_indexed(x) # Base.diagm doesn't work for non-one-indexed arrays in general.
ZeroType = promote_operation(zero, eltype(x))
return diagm(0 => copyto!(similar(x, ZeroType), x))
end
7 changes: 7 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ function promote_operation end
function promote_operation(op::Function, args::Vararg{Type, N}) where N
return typeof(op(zero.(args)...))
end
#promote_operation(::typeof(*), ::Type{T}) where {T} = T
#function promote_operation(op::typeof(*), ::Type{Array{T, N}}, ::Type{S}) where {S, T, N}
# return Array{promote_operation(op, T, S), N}
#end
#function promote_operation(op::typeof(*), ::Type{S}, ::Type{Array{T, N}}) where {S, T, N}
# return Array{promote_operation(op, S, T), N}
#end

# Define Traits
abstract type MutableTrait end
Expand Down
230 changes: 230 additions & 0 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# Heavily inspired from `JuMP/src/parse_expr.jl` code.

export @rewrite
macro rewrite(expr)
return rewrite(expr)
end

struct Zero end
# We need to copy `x` as it will be used as might be given by the user and be
# given as first argument of `operate!`.
Base.:(+)(zero::Zero, x) = copy(x)
# `add_mul(zero, ...)` redirects to `muladd(..., zero)` which calls `... + zero`.
Base.:(+)(x, zero::Zero) = copy(x)

using Base.Meta

function _parse_idx_set(arg::Expr)
parse_done, idxvar, idxset = Containers._try_parse_idx_set(arg)
if parse_done
return idxvar, idxset
end
error("Invalid syntax: $arg")
end

# takes a generator statement and returns a properly nested for loop
# with nested filters as specified
function _parse_gen(ex, atleaf)
if isexpr(ex, :flatten)
return _parse_gen(ex.args[1], atleaf)
end
if !isexpr(ex, :generator)
return atleaf(ex)
end
function itrsets(sets)
if isa(sets, Expr)
return sets
elseif length(sets) == 1
return sets[1]
else
return Expr(:block, sets...)
end
end

idxvars = []
if isexpr(ex.args[2], :filter) # if condition
loop = Expr(:for, esc(itrsets(ex.args[2].args[2:end])),
Expr(:if, esc(ex.args[2].args[1]),
_parse_gen(ex.args[1], atleaf)))
for idxset in ex.args[2].args[2:end]
idxvar, s = _parse_idx_set(idxset)
push!(idxvars, idxvar)
end
else
loop = Expr(:for, esc(itrsets(ex.args[2:end])),
_parse_gen(ex.args[1], atleaf))
for idxset in ex.args[2:end]
idxvar, s = _parse_idx_set(idxset)
push!(idxvars, idxvar)
end
end
return loop
end

function _parse_generator(x::Expr, aff::Symbol, lcoeffs, rcoeffs, newaff=gensym())
@assert isexpr(x,:call)
@assert length(x.args) > 1
@assert isexpr(x.args[2],:generator) || isexpr(x.args[2],:flatten)
header = x.args[1]
if _is_sum(header)
_parse_generator_sum(x.args[2], aff, lcoeffs, rcoeffs, newaff)
else
error("Expected sum outside generator expression; got $header")
end
end

function _parse_generator_sum(x::Expr, aff::Symbol, lcoeffs, rcoeffs, newaff)
# We used to preallocate the expression at the lowest level of the loop.
# When rewriting this some benchmarks revealed that it actually doesn't
# seem to help anymore, so might as well keep the code simple.
code = _parse_gen(x, t -> _rewrite(t, aff, lcoeffs, rcoeffs, aff)[2])
return :($code; $newaff=$aff)
end

_is_complex_expr(ex) = isa(ex, Expr) && !isexpr(ex, :ref)

function rewrite(x)
variable = gensym()
new_variable, code = _rewrite_toplevel(x, variable)
return quote
$variable = Zero()
$code
$new_variable
end
end

_rewrite_toplevel(x, variable::Symbol) = _rewrite(x, variable, [], [])

function _is_comparison(ex::Expr)
if isexpr(ex, :comparison)
return true
elseif isexpr(ex, :call)
if ex.args[1] in (:<=, :, :>=, :, :(==))
return true
else
return false
end
else
return false
end
end

# x[i=1] <= 2 is a somewhat common user error. Catch it here.
function _has_assignment_in_ref(ex::Expr)
if isexpr(ex, :ref)
return any(x -> isexpr(x, :(=)), ex.args)
else
return any(_has_assignment_in_ref, ex.args)
end
end
_has_assignment_in_ref(other) = false

# output is assigned to newaff
function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symbol=gensym())
if isexpr(x, :call)
if x.args[1] == :+
b = Expr(:block)
aff_ = aff
for arg in x.args[2:(end-1)]
aff_, code = _rewrite(arg, aff_, lcoeffs, rcoeffs)
push!(b.args, code)
end
newaff, code = _rewrite(x.args[end], aff_, lcoeffs, rcoeffs, newaff)
push!(b.args, code)
return newaff, b
elseif x.args[1] == :-
if length(x.args) == 2 # unary subtraction
return _rewrite(x.args[2], aff, vcat(-1.0, lcoeffs), rcoeffs, newaff)
else # a - b - c ...
b = Expr(:block)
aff_, code = _rewrite(x.args[2], aff, lcoeffs, rcoeffs)
push!(b.args, code)
for arg in x.args[3:(end-1)]
aff_,code = _rewrite(arg, aff_, vcat(-1.0, lcoeffs), rcoeffs)
push!(b.args, code)
end
newaff,code = _rewrite(x.args[end], aff_, vcat(-1.0, lcoeffs), rcoeffs, newaff)
push!(b.args, code)
return newaff, b
end
elseif x.args[1] == :*
# we might need to recurse on multiple arguments, e.g.,
# (x+y)*(x+y)
n_expr = mapreduce(_is_complex_expr, +, x.args)
if n_expr == 1 # special case, only recurse on one argument and don't create temporary objects
which_idx = 0
for i in 2:length(x.args)
if _is_complex_expr(x.args[i])
which_idx = i
end
end
return _rewrite(
x.args[which_idx], aff,
vcat(lcoeffs, [esc(x.args[i]) for i in 2:(which_idx - 1)]),
vcat(rcoeffs, [esc(x.args[i]) for i in (which_idx + 1):length(x.args)]),
newaff)
else
blk = Expr(:block)
for i in 2:length(x.args)
if _is_complex_expr(x.args[i])
s = gensym()
newaff_, parsed = _rewrite_toplevel(x.args[i], s)
push!(blk.args, :($s = 0.0; $parsed))
x.args[i] = newaff_
else
x.args[i] = esc(x.args[i])
end
end
callexpr = Expr(:call, :operate!, add_mul, aff,
lcoeffs..., x.args[2:end]..., rcoeffs...)
push!(blk.args, :($newaff = $callexpr))
return newaff, blk
end
elseif x.args[1] == :^ && _is_complex_expr(x.args[2])
MulType = :(MA.promote_operation(*, typeof($(x.args[2])), typeof($(x.args[2]))))
if x.args[3] == 2
blk = Expr(:block)
s = gensym()
newaff_, parsed = _rewrite_toplevel(x.args[2], s)
push!(blk.args, :($s = Zero(); $parsed))
push!(blk.args, :($newaff = operate!(add_mul,
$aff, $(Expr(:call, :*, lcoeffs..., newaff_, newaff_,
rcoeffs...)))))
return newaff, blk
elseif x.args[3] == 1
return _rewrite(:(convert($MulType, $(x.args[2]))), aff, lcoeffs, rcoeffs)
elseif x.args[3] == 0
return _rewrite(:(one($MulType)), aff, lcoeffs, rcoeffs)
else
blk = Expr(:block)
s = gensym()
newaff_, parsed = _rewrite_toplevel(x.args[2], s)
push!(blk.args, :($s = Zero(); $parsed))
push!(blk.args, :($newaff = _destructive_add_with_reorder!(
$aff, $(Expr(:call, :*, lcoeffs...,
Expr(:call, :^, newaff_, esc(x.args[3])),
rcoeffs...)))))
return newaff, blk
end
elseif x.args[1] == :/
@assert length(x.args) == 3
numerator = x.args[2]
denom = x.args[3]
return _rewrite(numerator, aff, lcoeffs, vcat(esc(:(1 / $denom)), rcoeffs), newaff)
elseif length(x.args) >= 2 && (isexpr(x.args[2], :generator) || isexpr(x.args[2], :flatten))
return newaff, _parse_generator(x,aff,lcoeffs,rcoeffs,newaff)
end
elseif isexpr(x, :curly)
_error_curly(x)
end
if isa(x, Expr) && _is_comparison(x)
error("Unexpected comparison in expression $x.")
end
if isa(x, Expr) && _has_assignment_in_ref(x)
@warn "Unexpected assignment in expression $x. This will" *
" become a syntax error in a future release."
end
# at the lowest level
callexpr = Expr(:call, :operate!, add_mul, aff, lcoeffs..., esc(x), rcoeffs...)
return newaff, :($newaff = $callexpr)
end
2 changes: 2 additions & 0 deletions src/shortcuts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ mul!(args::Vararg{Any, N}) where {N} = operate!(*, args...)
Return `a + *(args...)`. Note that `add_mul(a, b, c) = muladd(b, c, a)`.
"""
function add_mul end
add_mul(a, b) = a + b
add_mul(a, b, c) = muladd(b, c, a)
add_mul(a, b, c::Vararg{Any, N}) where {N} = add_mul(a, b *(c...))

function promote_operation(::typeof(add_mul), T::Type, args::Vararg{Type, N}) where N
return promote_operation(+, T, promote_operation(*, args...))
Expand Down
Loading

0 comments on commit c2a7777

Please sign in to comment.