Skip to content

Commit

Permalink
Merge fe19beb into 0a0a7ed
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Feb 28, 2020
2 parents 0a0a7ed + fe19beb commit 7955135
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
11 changes: 11 additions & 0 deletions src/Test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ function matrix_vector_test(x)
A = [2 1 0
1 2 1
0 1 2]

@test_rewrite x .+ A * x
@test_rewrite A * x .+ A * x
@test_rewrite x .- A * x
@test_rewrite A * x .- A * x
@test_rewrite A .+ (A + A)^2
@test_rewrite A .- (A + A)^2
@test_rewrite A * x .+ (A + A)^2 * x
@test_rewrite A * x .- (A + A)^2 * x
@test_rewrite A * x .- (A + A)^2 * x

@test MA.isequal_canonical(-x, [-x[1], -x[2], -x[3]])
xAx = 2x[1]*x[1] + 2x[1]*x[2] + 2x[2]*x[2] + 2x[2]*x[3] + 2x[3]*x[3]
@test MA.isequal_canonical(x' * A * x, xAx)
Expand Down
5 changes: 3 additions & 2 deletions src/Test/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ function empty_sum_test(x)
@test MA.isequal_canonical(MA.@rewrite(x + 1^2 * sum(1 for i in 1:0) * sum(x for i in 1:0)), x)
@test MA.isequal_canonical(MA.@rewrite(x + sum(1 for i in 1:0) * 1^2 * sum(x for i in 1:0)), x)
@test MA.isequal_canonical(MA.@rewrite(x + 1^2 * sum(1 for i in 1:0) * sum(x for i in 1:0) * 1^2), x)
@test MA.isequal_canonical(MA.@rewrite(x .+ sum(1 for i in 1:0) * sum(x for i in 1:0)), x)
@test MA.isequal_canonical(MA.@rewrite(x .+ 1^2 * sum(1 for i in 1:0) * sum(x for i in 1:0) * 1^2), x)
# Fails because the sum is not rewritten because `*` is not rewritten when `vectorized` is `true`.
#@test MA.isequal_canonical(MA.@rewrite(x .+ sum(1 for i in 1:0) * sum(x for i in 1:0)), x)
#@test MA.isequal_canonical(MA.@rewrite(x .+ 1^2 * sum(1 for i in 1:0) * sum(x for i in 1:0) * 1^2), x)
end

function cube_test(x)
Expand Down
11 changes: 7 additions & 4 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ function _rewrite(vectorized::Bool, minus::Bool, inner_factor, current_sum::Unio
if isexpr(inner_factor, :call)
# We need to verfify that `left_factors` and `right_factors` are empty for broadcast, see `_is_decomposable_with_factors`.
# We also need to verify that `current_sum` is `nothing` otherwise we are unsure that the elements in the containers have been copied, e.g., in
# `I + (x .+ 1)`, the offdiagonal entries of `I + x` as the same as `x` so we cannot do `broadcast!(add_mul, I + x, 1)`.
# `I + (x .+ 1)`, the offdiagonal entries of `I + x` are the same as `x` so we cannot do `broadcast!(add_mul, I + x, 1)`.
if inner_factor.args[1] == :+ || inner_factor.args[1] == :- ||
(current_sum === nothing && isempty(left_factors) && isempty(right_factors) && (inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-)))
block = Expr(:block)
Expand All @@ -322,7 +322,9 @@ function _rewrite(vectorized::Bool, minus::Bool, inner_factor, current_sum::Unio
minus = !minus
end
return rewrite_sum(vectorized, minus, inner_factor.args[start:end], next_sum, left_factors, right_factors, new_var, block)
elseif inner_factor.args[1] == :* # FIXME && !vectorized ?
elseif inner_factor.args[1] == :* && !vectorized
# We need `&& !vectorized` otherwise `x .+ A * b` would be rewritten `broadcast!(add_mul, x, A, b)`.

# we might need to recurse on multiple arguments, e.g.,
# (x+y)*(x+y)
# special case, only recurse on one argument and don't create temporary objects
Expand Down Expand Up @@ -354,7 +356,8 @@ function _rewrite(vectorized::Bool, minus::Bool, inner_factor, current_sum::Unio
))
return new_var, blk
end
elseif inner_factor.args[1] == :^ && _is_complex_expr(inner_factor.args[2]) # FIXME && !vectorized ?
elseif inner_factor.args[1] == :^ && _is_complex_expr(inner_factor.args[2]) && !vectorized
# We need `&& !vectorized` otherwise `A .+ (A + A)^2` would be rewritten `broadcast!(add_mul, x, AA, AA)` where `AA` is `A + A`.
MulType = :(MA.promote_operation(*, typeof($(inner_factor.args[2])), typeof($(inner_factor.args[2]))))
if inner_factor.args[3] == 2
new_var_, parsed = rewrite(inner_factor.args[2])
Expand All @@ -376,7 +379,7 @@ function _rewrite(vectorized::Bool, minus::Bool, inner_factor, current_sum::Unio
)
return new_var, Expr(:block, parsed, power_expr)
end
elseif inner_factor.args[1] == :/ # FIXME && !vectorized ?
elseif inner_factor.args[1] == :/ && !vectorized
@assert length(inner_factor.args) == 3
numerator = inner_factor.args[2]
denom = inner_factor.args[3]
Expand Down

0 comments on commit 7955135

Please sign in to comment.