Skip to content

Commit

Permalink
Check eltype in possible(::typeof(mul!), C, A, B) etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Aug 31, 2019
1 parent fbe28fb commit 3295cfa
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
15 changes: 12 additions & 3 deletions src/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@ mul!!(C, A, B) = may(mul!, C, A, B)

pure(::typeof(mul!)) = NoBang.mul
_asbb(::typeof(mul!)) = mul!!
possible(::typeof(mul!), C, ::Any, ::Any) = ismutable(C)
possible(::typeof(mul!), C, A, B) =
ismutable(C) && _matmuleltype(A, B) <: eltype(C)

# Estimate `eltype` of `C`. This is how it's done in LinearAlgebra.jl
# but maybe it's better to use the approach of
# https://github.com/tpapp/AlgebraResultTypes.jl ?
_matprod(x, y) = x * y + x * y
_matmuleltype(A, B) = Base.promote_op(_matprod, eltype(A), eltype(B))

"""
lmul!!(A, B) -> B′
Expand All @@ -14,7 +21,8 @@ lmul!!(A, B) = may(lmul!, A, B)

pure(::typeof(lmul!)) = *
_asbb(::typeof(lmul!)) = lmul!!
possible(::typeof(lmul!), ::Any, B) = ismutable(B)
possible(::typeof(lmul!), A, B) =
ismutable(B) && _matmuleltype(A, B) <: eltype(B)

"""
rmul!!(A, B) -> A′
Expand All @@ -23,4 +31,5 @@ rmul!!(A, B) = may(rmul!, A, B)

pure(::typeof(rmul!)) = *
_asbb(::typeof(rmul!)) = rmul!!
possible(::typeof(rmul!), A, ::Any) = ismutable(A)
possible(::typeof(rmul!), A, B) =
ismutable(A) && _matmuleltype(A, B) <: eltype(A)
5 changes: 5 additions & 0 deletions test/test_lmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ using LinearAlgebra
AB = A * B
@test lmul!!(A, B) === B
@test B == AB

A = LowerTriangular(collect(Float64, reshape(1:4, 2, 2)))
B = ones(Int, 2, 2)
AB = A * B
@test lmul!!(A, B) :: Matrix{Float64} == AB
end

end # module
3 changes: 3 additions & 0 deletions test/test_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ include("preamble.jl")
B = ones(2, 2)
@test mul!!(C, A, B) === C
@test C == A * B

C = zeros(Int, size(A * B))
@test mul!!(C, A, B) :: Matrix{Float64} == A * B
end

end # module
7 changes: 6 additions & 1 deletion test/test_rmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@ using LinearAlgebra
@testset begin
@test rmul!!(1, 2) === 2

A = copy(reshape(1:4, 2, 2))
A = collect(Float64, reshape(1:4, 2, 2))
B = UpperTriangular(ones(2, 2))
AB = A * B
@test rmul!!(A, B) === A
@test A == AB

A = collect(Int, reshape(1:4, 2, 2))
B = UpperTriangular(ones(2, 2))
AB = A * B
@test rmul!!(A, B) :: Matrix{Float64} == AB
end

end # module

0 comments on commit 3295cfa

Please sign in to comment.