Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce compile time for generic matmatmul #52038

Merged
merged 16 commits into from
Nov 14, 2023
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,11 @@ adjoint(A::Adjoint) = A.parent
transpose(A::Transpose) = A.parent
adjoint(A::Transpose{<:Real}) = A.parent
transpose(A::Adjoint{<:Real}) = A.parent
adjoint(A::Transpose{<:Any,<:Adjoint}) = transpose(A.parent.parent)
transpose(A::Adjoint{<:Any,<:Transpose}) = adjoint(A.parent.parent)
# disambiguation
adjoint(A::Transpose{<:Real,<:Adjoint}) = transpose(A.parent.parent)
transpose(A::Adjoint{<:Real,<:Transpose}) = A.parent

# printing
function Base.showarg(io::IO, v::Adjoint, toplevel)
Expand Down Expand Up @@ -395,11 +400,16 @@ map(f, avs::AdjointAbsVec...) = adjoint(map((xs...) -> adjoint(f(adjoint.(xs)...
map(f, tvs::TransposeAbsVec...) = transpose(map((xs...) -> transpose(f(transpose.(xs)...)), parent.(tvs)...))
quasiparentt(x) = parent(x); quasiparentt(x::Number) = x # to handle numbers in the defs below
quasiparenta(x) = parent(x); quasiparenta(x::Number) = conj(x) # to handle numbers in the defs below
quasiparentc(x) = parent(parent(x)); quasiparentc(x::Number) = conj(x) # to handle numbers in the defs below
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...))
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...))
# Hack to preserve behavior after #32122; this needs to be done with a broadcast style instead to support dotted fusion
Broadcast.broadcast_preserving_zero_d(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...))
Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...))
Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,Transpose{<:Any,<:AdjointAbsVec}}...) =
transpose(adjoint(broadcast((xs...) -> adjoint(transpose(f(conj.(xs)...))), quasiparentc.(tvs)...)))
Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,Adjoint{<:Any,<:TransposeAbsVec}}...) =
adjoint(transpose(broadcast((xs...) -> transpose(adjoint(f(conj.(xs)...))), quasiparentc.(tvs)...)))
# TODO unify and allow mixed combinations with a broadcast style


Expand Down
256 changes: 62 additions & 194 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ julia> lmul!(F.Q, B)
lmul!(A, B)

# THE one big BLAS dispatch
@inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat}
if all(in(('N', 'T', 'C')), (tA, tB))
if tA == 'T' && tB == 'N' && A === B
Expand All @@ -364,16 +364,16 @@ lmul!(A, B)
return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C)
end
end
return _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end

# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
@inline function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasReal}
if all(in(('N', 'T', 'C')), (tA, tB))
gemm_wrapper!(C, tA, tB, A, B, _add)
else
_generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end
end

Expand Down Expand Up @@ -563,11 +563,11 @@ function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
if all(in(('N', 'T', 'C')), (tA, tB))
gemm_wrapper!(C, tA, tB, A, B)
else
_generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end
end

function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
Expand Down Expand Up @@ -604,10 +604,10 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
stride(C, 2) >= size(C, 1))
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
end
_generic_matmatmul!(C, tA, tB, A, B, _add)
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end

function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasReal}
mA, nA = lapack_size(tA, A)
Expand Down Expand Up @@ -647,7 +647,7 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
return C
end
_generic_matmatmul!(C, tA, tB, A, B, _add)
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end

# blas.jl defines matmul for floats; other integer and mixed precision
Expand Down Expand Up @@ -764,197 +764,65 @@ end

const tilebufsize = 10800 # Approximately 32k/3

function generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul)
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
mC, nC = size(C)

if iszero(_add.alpha)
return _rmul_or_fill!(C, _add.beta)
end
if mA == nA == mB == nB == mC == nC == 2
return matmul2x2!(C, tA, tB, A, B, _add)
end
if mA == nA == mB == nB == mC == nC == 3
return matmul3x3!(C, tA, tB, A, B, _add)
end
A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA)
B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB)
_generic_matmatmul!(C, tA, tB, A, B, _add)
end
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) =
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)

function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
_add::MulAddMul) where {T,S,R}
@assert tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C')
require_one_based_indexing(C, A, B)

mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if mB != nA
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), matrix B has dimensions ($mB,$nB)"))
end
if size(C,1) != mA || size(C,2) != nB
throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs ($mA,$nB)"))
end

if iszero(_add.alpha) || isempty(A) || isempty(B)
return _rmul_or_fill!(C, _add.beta)
end

tile_size = 0
if isbitstype(R) && isbitstype(T) && isbitstype(S) && (tA == 'N' || tB != 'N')
tile_size = floor(Int, sqrt(tilebufsize / max(sizeof(R), sizeof(S), sizeof(T), 1)))
end
@inbounds begin
if tile_size > 0
sz = (tile_size, tile_size)
Atile = Array{T}(undef, sz)
Btile = Array{S}(undef, sz)

z1 = zero(A[1, 1]*B[1, 1] + A[1, 1]*B[1, 1])
z = convert(promote_type(typeof(z1), R), z1)

if mA < tile_size && nA < tile_size && nB < tile_size
copy_transpose!(Atile, 1:nA, 1:mA, tA, A, 1:mA, 1:nA)
copyto!(Btile, 1:mB, 1:nB, tB, B, 1:mB, 1:nB)
for j = 1:nB
boff = (j-1)*tile_size
for i = 1:mA
aoff = (i-1)*tile_size
s = z
for k = 1:nA
s += Atile[aoff+k] * Btile[boff+k]
end
_modify!(_add, s, C, (i,j))
end
end
else
Ctile = Array{R}(undef, sz)
for jb = 1:tile_size:nB
jlim = min(jb+tile_size-1,nB)
jlen = jlim-jb+1
for ib = 1:tile_size:mA
ilim = min(ib+tile_size-1,mA)
ilen = ilim-ib+1
fill!(Ctile, z)
for kb = 1:tile_size:nA
klim = min(kb+tile_size-1,mB)
klen = klim-kb+1
copy_transpose!(Atile, 1:klen, 1:ilen, tA, A, ib:ilim, kb:klim)
copyto!(Btile, 1:klen, 1:jlen, tB, B, kb:klim, jb:jlim)
for j=1:jlen
bcoff = (j-1)*tile_size
for i = 1:ilen
aoff = (i-1)*tile_size
s = z
for k = 1:klen
s += Atile[aoff+k] * Btile[bcoff+k]
end
Ctile[bcoff+i] += s
end
end
end
if isone(_add.alpha) && iszero(_add.beta)
copyto!(C, ib:ilim, jb:jlim, Ctile, 1:ilen, 1:jlen)
else
C[ib:ilim, jb:jlim] .= @views _add.(Ctile[1:ilen, 1:jlen], C[ib:ilim, jb:jlim])
end
end
AxM = axes(A, 1)
AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
BxK = axes(B, 1)
BxN = axes(B, 2)
CxM = axes(C, 1)
CxN = axes(C, 2)
if AxM != CxM
throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix C has axes ($CxM,$CxN)"))
end
if AxK != BxK
throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix B has axes ($BxK,$CxN)"))
end
if BxN != CxN
throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
end
if isbitstype(R) && sizeof(R) ≤ 16 && !(A isa Adjoint || A isa Transpose)
_rmul_or_fill!(C, _add.beta)
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
@inbounds for n in BxN, k in BxK
Balpha = B[k,n]*_add.alpha
@simd for m in AxM
C[m,n] = muladd(A[m,k], Balpha, C[m,n])
end
end
elseif isbitstype(R) && sizeof(R) ≤ 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose))
_rmul_or_fill!(C, _add.beta)
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
t = wrapperop(A)
pB = parent(B)
pA = parent(A)
tmp = similar(C, CxN)
ci = first(CxM)
ta = t(_add.alpha)
for i in AxM
mul!(tmp, pB, view(pA, :, i))
C[ci,:] .+= t.(ta .* tmp)
ci += 1
end
else
# Multiplication for non-plain-data uses the naive algorithm
if tA == 'N'
if tB == 'N'
for i = 1:mA, j = 1:nB
z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j])
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[i, k]*B[k, j]
end
_modify!(_add, Ctmp, C, (i,j))
end
elseif tB == 'T'
for i = 1:mA, j = 1:nB
z2 = zero(A[i, 1]*transpose(B[j, 1]) + A[i, 1]*transpose(B[j, 1]))
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[i, k] * transpose(B[j, k])
end
_modify!(_add, Ctmp, C, (i,j))
end
else
for i = 1:mA, j = 1:nB
z2 = zero(A[i, 1]*B[j, 1]' + A[i, 1]*B[j, 1]')
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[i, k]*B[j, k]'
end
_modify!(_add, Ctmp, C, (i,j))
end
end
elseif tA == 'T'
if tB == 'N'
for i = 1:mA, j = 1:nB
z2 = zero(transpose(A[1, i])*B[1, j] + transpose(A[1, i])*B[1, j])
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += transpose(A[k, i]) * B[k, j]
end
_modify!(_add, Ctmp, C, (i,j))
end
elseif tB == 'T'
for i = 1:mA, j = 1:nB
z2 = zero(transpose(A[1, i])*transpose(B[j, 1]) + transpose(A[1, i])*transpose(B[j, 1]))
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += transpose(A[k, i]) * transpose(B[j, k])
end
_modify!(_add, Ctmp, C, (i,j))
end
else
for i = 1:mA, j = 1:nB
z2 = zero(transpose(A[1, i])*B[j, 1]' + transpose(A[1, i])*B[j, 1]')
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += transpose(A[k, i]) * adjoint(B[j, k])
end
_modify!(_add, Ctmp, C, (i,j))
end
end
else
if tB == 'N'
for i = 1:mA, j = 1:nB
z2 = zero(A[1, i]'*B[1, j] + A[1, i]'*B[1, j])
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[k, i]'B[k, j]
end
_modify!(_add, Ctmp, C, (i,j))
end
elseif tB == 'T'
for i = 1:mA, j = 1:nB
z2 = zero(A[1, i]'*transpose(B[j, 1]) + A[1, i]'*transpose(B[j, 1]))
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += adjoint(A[k, i]) * transpose(B[j, k])
end
_modify!(_add, Ctmp, C, (i,j))
end
else
for i = 1:mA, j = 1:nB
z2 = zero(A[1, i]'*B[j, 1]' + A[1, i]'*B[j, 1]')
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[k, i]'B[j, k]'
end
_modify!(_add, Ctmp, C, (i,j))
end
if iszero(_add.alpha) || isempty(A) || isempty(B)
return _rmul_or_fill!(C, _add.beta)
end
a1 = first(AxK)
b1 = first(BxK)
@inbounds for i in AxM, j in BxN
z2 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j])
Ctmp = convert(promote_type(R, typeof(z2)), z2)
@simd for k in AxK
Ctmp = muladd(A[i, k], B[k, j], Ctmp)
end
_modify!(_add, Ctmp, C, (i,j))
end
end
end # @inbounds
C
return C
end


Expand All @@ -963,7 +831,7 @@ function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
end

function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
Base.@constprop :aggressive function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
if !(size(A) == size(B) == size(C) == (2,2))
Expand Down Expand Up @@ -1030,7 +898,7 @@ function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
end

function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
Base.@constprop :aggressive function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
if !(size(A) == size(B) == size(C) == (3,3))
Expand Down
Loading