Skip to content

Commit

Permalink
fix 4-arg ldiv!() declaration
Browse files Browse the repository at this point in the history
to be compatible with 3-arg ldiv!()
in base Julia;
add tests for 3- and 4-arg ldiv!()
  • Loading branch information
alyst committed May 6, 2023
1 parent 096d1da commit b32fed7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
10 changes: 7 additions & 3 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,19 @@ for T in (Complex{Float32}, Complex{Float64}, Float32, Float64),
Base.:(*)(A::$AT, B::StridedMatrix{$T}) = mul!(Matrix{$T}(undef, size(A, 1), size(B, 2)), A, B)
end

# define 4-arg ldiv!(C, A, B, a) (C := alpha*inv(A)*B) that is not present in standard LinearAlgrebra,
# redefine 3-arg ldiv!(C, A, B) using 4-arg ldiv!(C, A, B, 1)
# here A is LowerTri/UpperTri etc of a SparseMatrixCSC or adjoint/transpose of it (Symmetric not supported)
if w != :Symmetric
@eval begin
LinearAlgebra.ldiv!(α::Number, A::$AT, B::StridedVector{$T}, C::StridedVector{$T}) =
LinearAlgebra.ldiv!(C::StridedVector{$T}, A::$AT, B::StridedVector{$T}, α::Number) =
cscsv!($tchar, $T(α), describe_and_unwrap(A)..., B, C)
LinearAlgebra.ldiv!(α::Number, A::$AT, B::StridedMatrix{$T}, C::StridedMatrix{$T}) =
LinearAlgebra.ldiv!(C::StridedMatrix{$T}, A::$AT, B::StridedMatrix{$T}, α::Number) =
cscsm!($tchar, $T(α), describe_and_unwrap(A)..., B, C)

LinearAlgebra.ldiv!(C::BT, A::$AT, B::BT) where BT <: $BT = ldiv!(one($T), A, B, C)
LinearAlgebra.ldiv!(C::BT, A::$AT, B::BT) where BT <: $BT = ldiv!(C, A, B, one($T))

# base A\B converts A to dense, so we redefine it to use MKLSparse-enabled ldiv!
Base.:(\)(A::$AT, B::StridedVector{$T}) = ldiv!(Vector{$T}(undef, size(A, 1)), A, B)
Base.:(\)(A::$AT, B::StridedMatrix{$T}) = ldiv!(Matrix{$T}(undef, size(A, 1), size(B, 2)), A, B)
end
Expand Down
2 changes: 1 addition & 1 deletion test/test_BLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ end
spAclass = Aclass(spA)
α = rand()

@test @blas(ldiv!(α, Aclass(spA), B, similar(B))) α * (Aclass(A) \ B)
@test @blas(ldiv!(similar(B), Aclass(spA), B, α)) α * (Aclass(A) \ B)
@test @blas(ldiv!(similar(B), Aclass(spA), B)) Aclass(A) \ B
@test @blas(Aclass(spA) \ B) Aclass(A) \ B
end
Expand Down

0 comments on commit b32fed7

Please sign in to comment.