diff --git a/src/matmul.jl b/src/matmul.jl index 58bad14..64187c7 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -53,15 +53,19 @@ for T in (Complex{Float32}, Complex{Float64}, Float32, Float64), Base.:(*)(A::$AT, B::StridedMatrix{$T}) = mul!(Matrix{$T}(undef, size(A, 1), size(B, 2)), A, B) end + # define 4-arg ldiv!(C, A, B, a) (C := alpha*inv(A)*B) that is not present in standard LinearAlgrebra, + # redefine 3-arg ldiv!(C, A, B) using 4-arg ldiv!(C, A, B, 1) + # here A is LowerTri/UpperTri etc of a SparseMatrixCSC or adjoint/transpose of it (Symmetric not supported) if w != :Symmetric @eval begin - LinearAlgebra.ldiv!(α::Number, A::$AT, B::StridedVector{$T}, C::StridedVector{$T}) = + LinearAlgebra.ldiv!(C::StridedVector{$T}, A::$AT, B::StridedVector{$T}, α::Number) = cscsv!($tchar, $T(α), describe_and_unwrap(A)..., B, C) - LinearAlgebra.ldiv!(α::Number, A::$AT, B::StridedMatrix{$T}, C::StridedMatrix{$T}) = + LinearAlgebra.ldiv!(C::StridedMatrix{$T}, A::$AT, B::StridedMatrix{$T}, α::Number) = cscsm!($tchar, $T(α), describe_and_unwrap(A)..., B, C) - LinearAlgebra.ldiv!(C::BT, A::$AT, B::BT) where BT <: $BT = ldiv!(one($T), A, B, C) + LinearAlgebra.ldiv!(C::BT, A::$AT, B::BT) where BT <: $BT = ldiv!(C, A, B, one($T)) + # base A\B converts A to dense, so we redefine it to use MKLSparse-enabled ldiv! Base.:(\)(A::$AT, B::StridedVector{$T}) = ldiv!(Vector{$T}(undef, size(A, 1)), A, B) Base.:(\)(A::$AT, B::StridedMatrix{$T}) = ldiv!(Matrix{$T}(undef, size(A, 1), size(B, 2)), A, B) end diff --git a/test/test_BLAS.jl b/test/test_BLAS.jl index 55066ce..c36daae 100644 --- a/test/test_BLAS.jl +++ b/test/test_BLAS.jl @@ -144,7 +144,7 @@ end spAclass = Aclass(spA) α = rand() - @test @blas(ldiv!(α, Aclass(spA), B, similar(B))) ≈ α * (Aclass(A) \ B) + @test @blas(ldiv!(similar(B), Aclass(spA), B, α)) ≈ α * (Aclass(A) \ B) @test @blas(ldiv!(similar(B), Aclass(spA), B)) ≈ Aclass(A) \ B @test @blas(Aclass(spA) \ B) ≈ Aclass(A) \ B end