Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fastpath for generic mul! #51812

Closed
wants to merge 3 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,11 @@ For custom matrix and vector types, it is recommended to implement
if possible.
"""
@inline function mul!(C, A, B)
return mul!(C, A, B, true, false)
if eltype(C) <: BlasFloat && eltype(C) === eltype(A) === eltype(B)
return mul!(C, A, B, true, false)
else
return _generic_matmatmul!(C, A, B)
end
end

"""
Expand All @@ -260,7 +264,13 @@ julia> C
730.0 740.0
```
"""
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
@inline function mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number)
if (!((eltype(C) <: BlasFloat && eltype(C) === eltype(A) === eltype(B)))) ||
(!(C isa StridedVecOrMat && A isa StridedVecOrMat && B isa StridedVecOrMat)) &&
α == true && (β == true || β == false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, then this condition applies, for instance, to CUDA, CUSparse and sparse matrices (and/or their transposes and adjoints), right? That would massively slow down SparseArrays.jl and even break a few GPUArrays-related packages, which specifically overload LinearAlgebra.generic_matmatmul! (with the character arguments).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make it Array-only.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, I might have misread the logic. Suppose I want to multiply non-transposed dense GPUArrays with BlasFloat elements, then this returns ... true? And directs away from generic_matmatmul!? Too many conditions for my little brain.

Copy link
Contributor Author

@chriselrod chriselrod Nov 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check is supposed to be for BLAS compatibility of types.
BLAS requires both of
1.The eltypes match and they're BlasFloat
2. They're all strided

The idea is, if either of these are the case, call the new overload.
Otherwise, call the BLAS dispatcher.

But because packages like CUDA are relying on extending the non-Julian and non-exported generic_matmatmul!, it's suddenly a breaking change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they do because that significantly reduced load times of those packages, both SparseArrays and all of the GPUArrays-related packages. Otherwise people had to overload tons of mul! methods, for all combinations of plain/transpose/adjoint factors, possibly dispatching on type parameters. JuliaGPU/CUDA.jl#1904

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, ideally there would have been another way to avoid ``mul!`.

I went with requiring Matrix for the first two arguments and Array for the third.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, ideally there would have been another way to avoid mul!.

Which one? Since this is all internal, it can be changed to whatever yields the correct dispatch to the underlying *BLAS methods and doesn't increase load times. Method insertion at package loading was the big big issue.

β == false && fill!(C, zero(eltype(C)))
return _generic_matmatmuladd!(C, A, B)
end
generic_matmatmul!(
C,
wrapper_char(A),
Expand All @@ -269,6 +279,33 @@ julia> C
_unwrap(B),
MulAddMul(α, β)
)
end

function _generic_matmatmuladd!(C, A, B)
AxM = axes(A, 1)
AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
BxK = axes(B, 1)
BxN = axes(B, 2)
CxM = axes(C, 1)
CxN = axes(C, 2)
if AxM != CxM
throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix C has axes ($CxM,$CxN)"))
end
if AxK != BxK
throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix B has axes ($BxK,$CxN)"))
end
if BxN != CxN
throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
end
for n = BxN, k = BxK, m = AxM
C[m,n] = muladd(A[m,k], B[k,n], C[m,n])
end
return C
end
function _generic_matmatmul!(C, A, B)
_generic_matmatmuladd!(fill!(C, zero(eltype(C))), A, B)
end


"""
rmul!(A, B)
Expand Down