Skip to content

Commit

Permalink
Merge db2bab9 into 879b457
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Dec 3, 2019
2 parents 879b457 + db2bab9 commit 8aa8663
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 5 deletions.
5 changes: 1 addition & 4 deletions src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,11 @@ include("linear_algebra.jl")
include("sparse_arrays.jl")

isequal_canonical(a, b) = a == b
function isequal_canonical(a::Array{T, N}, b::Array{T, N}) where {T, N}
function isequal_canonical(a::AT, b::AT) where AT <: Union{Array, LinearAlgebra.Symmetric}
return all(zip(a, b)) do elements
return isequal_canonical(elements...)
end
end
function isequal_canonical(a::LinearAlgebra.Symmetric, b::LinearAlgebra.Symmetric)
return isequal_canonical(parent(a), parent(b))
end

include("rewrite.jl")
include("dispatch.jl")
Expand Down
9 changes: 9 additions & 0 deletions src/Test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,14 @@ function symmetric_unary_test(x)
end
end

function symmetric_add_test(x)
if x isa AbstractMatrix && size(x, 1) == size(x, 2)
y = LinearAlgebra.Symmetric(x)
add_test(y, y)
add_test(x, y)
end
end

function matrix_uniform_scaling_test(x)
if !(x isa AbstractMatrix && size(x, 1) == size(x, 2))
return
Expand Down Expand Up @@ -347,6 +355,7 @@ const array_tests = Dict(
"broadcast_division" => broadcast_division_test,
"unary" => unary_test,
"symmetric_unary" => symmetric_unary_test,
"symmetric_add" => symmetric_add_test,
"matrix_uniform_scaling" => matrix_uniform_scaling_test,
"symmetric_matrix_uniform_scaling" => symmetric_matrix_uniform_scaling_test
)
Expand Down
5 changes: 5 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Returns the type returned to the call `operate(op, args...)` where the types of
the arguments `args` are `ArgsTypes`.
"""
function promote_operation end
function promote_operation(op::Function, x::Type{<:AbstractArray}, y::Type{<:AbstractArray})
# `zero` is not defined for `AbstractArray` so the fallback would fail with a cryptic MethodError.
# We replace it by a more helpful error here.
error("`promote_operation($op, $x, $y)` not implemented yet, please report this.")
end
# Julia v1.0.x has trouble with inference with the `Vararg` method, see
# https://travis-ci.org/JuliaOpt/JuMP.jl/jobs/617606373
function promote_operation(op::Function, x::Type, y::Type)
Expand Down
18 changes: 17 additions & 1 deletion src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ end
function promote_operation(op::Union{typeof(+), typeof(-)}, ::Type{Matrix{T}}, ::Type{LinearAlgebra.UniformScaling{S}}) where {S, T}
return Matrix{promote_operation(op, S, T)}
end
function promote_operation(op::Union{typeof(+), typeof(-)}, ::Type{Matrix{T}}, ::Type{<:LinearAlgebra.Symmetric{S}}) where {S, T}
return Matrix{promote_operation(op, S, T)}
end

# Only `Scaling`
function mutable_operate!(op::Union{typeof(+), typeof(-)}, A::Matrix, B::LinearAlgebra.UniformScaling)
Expand All @@ -32,23 +35,35 @@ mul_rhs(::typeof(+)) = add_mul
mul_rhs(::typeof(-)) = sub_mul

# `Scaling` and `Array`
function _mutable_operate!(op::Union{typeof(+), typeof(-)}, A::Array{S, N}, B::Array{T, N}, left_factors::Tuple, right_factors::Tuple) where {S, T, N}
function _mutable_operate!(op::Union{typeof(+), typeof(-)}, A::Array{S, N},
B::Union{Array{T, N}, LinearAlgebra.Symmetric{T}},
left_factors::Tuple, right_factors::Tuple) where {S, T, N}
for i in eachindex(A)
A[i] = operate!(mul_rhs(op), A[i], left_factors..., B[i], right_factors...)
end
return A
end

function _check_dims(A, B)
if size(A) != size(B)
throw(DimensionMismatch("Cannot sum matrices of size `$(size(A))` and size `$(size(B))`, the size of the two matrices must be equal."))
end
end

function mutable_operate!(op::Union{typeof(+), typeof(-)}, A::Array{S, N}, B::AbstractArray{T, N}) where {S, T, N}
_check_dims(A, B)
return _mutable_operate!(op, A, B, tuple(), tuple())
end
function mutable_operate!(::typeof(add_mul), A::Array{S, N}, B::AbstractArray{T, N}, α::Vararg{Scaling, M}) where {S, T, N, M}
_check_dims(A, B)
return _mutable_operate!(+, A, B, tuple(), α)
end
function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α::Scaling, B::AbstractArray{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M}
_check_dims(A, B)
return _mutable_operate!(+, A, B, (α,), β)
end
function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α1::Scaling, α2::Scaling, B::AbstractArray{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M}
_check_dims(A, B)
return _mutable_operate!(+, A, B, (α1, α2), β)
end

Expand All @@ -59,6 +74,7 @@ end

# Product

similar_array_type(::Type{LinearAlgebra.Symmetric{T, MT}}, ::Type{S}) where {S, T, MT} = LinearAlgebra.Symmetric{S, similar_array_type(MT, S)}
similar_array_type(::Type{Array{T, N}}, ::Type{S}) where {S, T, N} = Array{S, N}
function promote_operation(op::typeof(*), A::Type{<:AbstractArray{T}}, ::Type{S}) where {S, T}
return similar_array_type(A, promote_operation(op, T, S))
Expand Down
3 changes: 3 additions & 0 deletions src/shortcuts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ mul!(args::Vararg{Any, N}) where {N} = operate!(*, args...)
function promote_operation(::typeof(add_mul), T::Type, x::Type, y::Type)
return promote_operation(+, T, promote_operation(*, x, y))
end
function promote_operation(::typeof(add_mul), x::Type{<:AbstractArray}, y::Type{<:AbstractArray})
return promote_operation(+, x, y)
end
function promote_operation(::typeof(add_mul), T::Type, args::Vararg{Type, N}) where N
return promote_operation(+, T, promote_operation(*, args...))
end
Expand Down
17 changes: 17 additions & 0 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,23 @@ const MA = MutableArithmetics

include("utilities.jl")

struct CustomArray{T, N} <: AbstractArray{T, N} end

@testset "Errors" begin
@testset "`promote_op` error" begin
AT = CustomArray{Int, 3}
err = ErrorException("`promote_operation(+, CustomArray{Int64,3}, CustomArray{Int64,3})` not implemented yet, please report this.")
@test_throws err MA.promote_operation(+, AT, AT)
end

@testset "Dimension mismatch" begin
A = zeros(1, 1)
B = zeros(2, 2)
err = DimensionMismatch("Cannot sum matrices of size `(1, 1)` and size `(2, 2)`, the size of the two matrices must be equal.")
@test_throws err MA.@rewrite A + B
end
end

@testset "Matrix multiplication" begin
@testset "matrix-vector product" begin
A = [1 1 1; 1 1 1; 1 1 1]
Expand Down

0 comments on commit 8aa8663

Please sign in to comment.