From 3295cfa6502d81e8fd20fc3c3350189b129c1e0f Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Sat, 31 Aug 2019 11:24:32 -0700 Subject: [PATCH] Check eltype in possible(::typeof(mul!), C, A, B) etc. --- src/linearalgebra.jl | 15 ++++++++++++--- test/test_lmul.jl | 5 +++++ test/test_mul.jl | 3 +++ test/test_rmul.jl | 7 ++++++- 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index 80fae59f..fa24f374 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -5,7 +5,14 @@ mul!!(C, A, B) = may(mul!, C, A, B) pure(::typeof(mul!)) = NoBang.mul _asbb(::typeof(mul!)) = mul!! -possible(::typeof(mul!), C, ::Any, ::Any) = ismutable(C) +possible(::typeof(mul!), C, A, B) = + ismutable(C) && _matmuleltype(A, B) <: eltype(C) + +# Estimate `eltype` of `C`. This is how it's done in LinearAlgebra.jl +# but maybe it's better to use the approach of +# https://github.com/tpapp/AlgebraResultTypes.jl ? +_matprod(x, y) = x * y + x * y +_matmuleltype(A, B) = Base.promote_op(_matprod, eltype(A), eltype(B)) """ lmul!!(A, B) -> B′ @@ -14,7 +21,8 @@ lmul!!(A, B) = may(lmul!, A, B) pure(::typeof(lmul!)) = * _asbb(::typeof(lmul!)) = lmul!! -possible(::typeof(lmul!), ::Any, B) = ismutable(B) +possible(::typeof(lmul!), A, B) = + ismutable(B) && _matmuleltype(A, B) <: eltype(B) """ rmul!!(A, B) -> A′ @@ -23,4 +31,5 @@ rmul!!(A, B) = may(rmul!, A, B) pure(::typeof(rmul!)) = * _asbb(::typeof(rmul!)) = rmul!! -possible(::typeof(rmul!), A, ::Any) = ismutable(A) +possible(::typeof(rmul!), A, B) = + ismutable(A) && _matmuleltype(A, B) <: eltype(A) diff --git a/test/test_lmul.jl b/test/test_lmul.jl index b43ae220..a1349e7e 100644 --- a/test/test_lmul.jl +++ b/test/test_lmul.jl @@ -11,6 +11,11 @@ using LinearAlgebra AB = A * B @test lmul!!(A, B) === B @test B == AB + + A = LowerTriangular(collect(Float64, reshape(1:4, 2, 2))) + B = ones(Int, 2, 2) + AB = A * B + @test lmul!!(A, B) :: Matrix{Float64} == AB end end # module diff --git a/test/test_mul.jl b/test/test_mul.jl index cb90ad1b..f4a0fa75 100644 --- a/test/test_mul.jl +++ b/test/test_mul.jl @@ -12,6 +12,9 @@ include("preamble.jl") B = ones(2, 2) @test mul!!(C, A, B) === C @test C == A * B + + C = zeros(Int, size(A * B)) + @test mul!!(C, A, B) :: Matrix{Float64} == A * B end end # module diff --git a/test/test_rmul.jl b/test/test_rmul.jl index abcebf5c..474392da 100644 --- a/test/test_rmul.jl +++ b/test/test_rmul.jl @@ -6,11 +6,16 @@ using LinearAlgebra @testset begin @test rmul!!(1, 2) === 2 - A = copy(reshape(1:4, 2, 2)) + A = collect(Float64, reshape(1:4, 2, 2)) B = UpperTriangular(ones(2, 2)) AB = A * B @test rmul!!(A, B) === A @test A == AB + + A = collect(Int, reshape(1:4, 2, 2)) + B = UpperTriangular(ones(2, 2)) + AB = A * B + @test rmul!!(A, B) :: Matrix{Float64} == AB end end # module