Skip to content

Commit

Permalink
fix 4-arg ldiv!() declaration
Browse files Browse the repository at this point in the history
to be compatible with 3-arg ldiv!()
in base Julia;
add tests for 3- and 4-arg ldiv!()
  • Loading branch information
alyst committed May 6, 2023
1 parent c972d04 commit d8bce50
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
10 changes: 7 additions & 3 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,19 @@ for T in (Complex{Float32}, Complex{Float64}, Float32, Float64),
Base.:(*)(A::$AT, B::StridedMatrix{$T}) = mul!(Matrix{$T}(undef, size(A, 1), size(B, 2)), A, B)
end

# define 4-arg ldiv!(C, A, B, a) (C := alpha*inv(A)*B) that is not present in standard LinearAlgrebra,
# redefine 3-arg ldiv!(C, A, B) using 4-arg ldiv!(C, A, B, 1)
# here A is LowerTri/UpperTri etc of a SparseMatrixCSC or adjoint/transpose of it (Symmetric not supported)
if w != :Symmetric
@eval begin
LinearAlgebra.ldiv!(α::Number, A::$AT, B::StridedVector{$T}, C::StridedVector{$T}) =
LinearAlgebra.ldiv!(C::StridedVector{$T}, A::$AT, B::StridedVector{$T}, α::Number) =
cscsv!($tchar, $T(α), describe_and_unwrap(A)..., B, C)
LinearAlgebra.ldiv!(α::Number, A::$AT, B::StridedMatrix{$T}, C::StridedMatrix{$T}) =
LinearAlgebra.ldiv!(C::StridedMatrix{$T}, A::$AT, B::StridedMatrix{$T}, α::Number) =
cscsm!($tchar, $T(α), describe_and_unwrap(A)..., B, C)

LinearAlgebra.ldiv!(C::BT, A::$AT, B::BT) where BT <: $BT = ldiv!(one($T), A, B, C)
LinearAlgebra.ldiv!(C::BT, A::$AT, B::BT) where BT <: $BT = ldiv!(C, A, B, one($T))

# base A\B converts A to dense, so we redefine it to use MKLSparse-enabled ldiv!
Base.:(\)(A::$AT, B::StridedVector{$T}) = ldiv!(Vector{$T}(undef, size(A, 1)), A, B)
Base.:(\)(A::$AT, B::StridedMatrix{$T}) = ldiv!(Matrix{$T}(undef, size(A, 1), size(B, 2)), A, B)
end
Expand Down
2 changes: 1 addition & 1 deletion test/test_BLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ end
B = Bdim == 2 ? rand(T, n, n) : rand(T, n)
spAclass = Aclass(convert_to_class(spA))

@test @blas(ldiv!(0,5, spAclass, B, similar(B))) 0.5 * (Array(spAclass) \ B)
@test @blas(ldiv!(similar(B), spAclass, B, 0.5)) 0.5 * (Array(spAclass) \ B)
@test @blas(ldiv!(similar(B), spAclass, B)) Array(spAclass) \ B
@test @blas(spAclass \ B) Array(spAclass) \ B
end
Expand Down

0 comments on commit d8bce50

Please sign in to comment.