Skip to content

Commit

Permalink
Merge 67d7824 into 3aceeeb
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Dec 9, 2019
2 parents 3aceeeb + 67d7824 commit 23ceceb
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 2 deletions.
20 changes: 20 additions & 0 deletions src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ add_mul(a, b) = a + b
add_mul(a, b, c) = muladd(b, c, a)
add_mul(a, b, c::Vararg{Any, N}) where {N} = add_mul(a, b *(c...))

"""
iszero!(x)
Return a `Bool` indicating whether `x` is zero, possibly modifying `x`.
## Examples
In MathOptInterface, a `ScalarAffineFunction` may contain duplicate terms.
In `Base.iszero`, duplicate terms need to be merged but the function is left
with duplicates as it cannot be modified. If `iszero!` is called instead,
the function will be canonicalized in addition for checking whether it is zero.
"""
iszero!(x) = iszero(x)

include("interface.jl")
include("shortcuts.jl")
include("broadcast.jl")
Expand All @@ -30,6 +44,12 @@ include("broadcast.jl")
import LinearAlgebra
const Scaling = Union{Number, LinearAlgebra.UniformScaling}
scaling(x::Scaling) = x
function scaling_convert(::Type{LinearAlgebra.UniformScaling{T}}, x::LinearAlgebra.UniformScaling) where T
# `convert(::Type{<:UniformScaling}, ::UniformScaling)` is not defined in LinearAlgebra.
return LinearAlgebra.UniformScaling(convert(T, x.λ))
end
scaling_convert(T::Type, x::LinearAlgebra.UniformScaling) = convert(T, x.λ)
scaling_convert(T::Type, x) = convert(T, x)
include("bigint.jl")
include("linear_algebra.jl")
include("sparse_arrays.jl")
Expand Down
24 changes: 24 additions & 0 deletions src/Test/quadratic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ function quadratic_isequal_canonical_test(w, x, y, z)
end

function quadratic_add_test(w, x, y, z)
w_copy = MA.copy_if_mutable(w)
x_copy = MA.copy_if_mutable(x)
y_copy = MA.copy_if_mutable(y)
z_copy = MA.copy_if_mutable(z)
a = 7
b = 2
c = 1
Expand All @@ -72,12 +76,22 @@ function quadratic_add_test(w, x, y, z)
@test_rewrite y * z - x
end

@test MA.isequal_canonical(w, w_copy)
@test MA.isequal_canonical(x, x_copy)
@test MA.isequal_canonical(y, y_copy)
@test MA.isequal_canonical(z, z_copy)

aff = @inferred a * x + b
@test_rewrite a * x + b
@test aff == aff
aff2 = @inferred c * y + c
@test_rewrite c * y + c

@test MA.isequal_canonical(w, w_copy)
@test MA.isequal_canonical(x, x_copy)
@test MA.isequal_canonical(y, y_copy)
@test MA.isequal_canonical(z, z_copy)

@testset "Affine" begin
unary_test(aff)
add_test(aff, aff)
Expand All @@ -99,6 +113,11 @@ function quadratic_add_test(w, x, y, z)
@test string((x+x)*(x+3)) == string((x+3)*(x+x)) # Issue #288
end

@test MA.isequal_canonical(w, w_copy)
@test MA.isequal_canonical(x, x_copy)
@test MA.isequal_canonical(y, y_copy)
@test MA.isequal_canonical(z, z_copy)

@testset "Quadratic" begin
@test_rewrite 2 * x * x + 1 * y * y + z + 3

Expand All @@ -115,6 +134,11 @@ function quadratic_add_test(w, x, y, z)
q2 = @inferred 8 * x * z + aff2
add_test(q, q2)
end

@test MA.isequal_canonical(w, w_copy)
@test MA.isequal_canonical(x, x_copy)
@test MA.isequal_canonical(y, y_copy)
@test MA.isequal_canonical(z, z_copy)
end

const quadratic_tests = Dict(
Expand Down
22 changes: 22 additions & 0 deletions src/Test/scalar.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
function iszero_test(x)
x_copy = x

@test iszero(x - x)
@test iszero(MA.@rewrite(x - x))
@test MA.iszero!(x - x)
@test MA.iszero!(MA.@rewrite(x - x))

@test iszero(0 * x)
@test iszero(MA.@rewrite(0 * x))
@test MA.iszero!(0 * x)
@test MA.iszero!(MA.@rewrite(0 * x))

@test iszero(x - 2x + x)
@test iszero(MA.@rewrite(x - 2x + x))
@test MA.iszero!(x - 2x + x)
@test MA.iszero!(MA.@rewrite(x - 2x + x))

@test MA.isequal_canonical(x_copy, x)
end

function cube_test(x)
@test_rewrite x^3
@test_rewrite (x + 1)^3
Expand Down Expand Up @@ -37,6 +58,7 @@ end

const scalar_tests = Dict(
"cube" => cube_test,
"iszero" => iszero_test,
"scalar_in_any" => scalar_in_any_test,
"scalar_uniform_scaling" => scalar_uniform_scaling_test
)
Expand Down
5 changes: 3 additions & 2 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ function _dot(x::AbstractArray, y::AbstractArray)
end

# We need a buffer to hold the intermediate multiplication.
mul_buffer = buffer_for(add_mul, eltype(x), eltype(y))

s = zero(promote_operation(add_mul, eltype(x), eltype(x), eltype(y)))
SumType = promote_operation(add_mul, eltype(x), eltype(x), eltype(y))
mul_buffer = buffer_for(add_mul, SumType, eltype(x), eltype(y))
s = zero(SumType)

for (Ix, Iy) in zip(eachindex(x), eachindex(y))
s = @inbounds buffered_operate!(mul_buffer, add_mul, s, x[Ix], y[Iy])
Expand Down
8 changes: 8 additions & 0 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ include("utilities.jl")

struct CustomArray{T, N} <: AbstractArray{T, N} end

import LinearAlgebra

@testset "Scaling convert" begin
@test MA.scaling_convert(LinearAlgebra.UniformScaling{Int}, LinearAlgebra.I) isa LinearAlgebra.UniformScaling
@test MA.scaling_convert(Int, LinearAlgebra.I) === 1
@test MA.scaling_convert(Int, 1) === 1
end

@testset "Errors" begin
@testset "`promote_op` error" begin
AT = CustomArray{Int, 3}
Expand Down

0 comments on commit 23ceceb

Please sign in to comment.