From 152200735dd92d51167b93bef5b7d5f443c1da76 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 24 Nov 2025 14:06:59 +0100 Subject: [PATCH 1/6] Make matmul work with zero-less eltypes --- src/matmul.jl | 69 ++++++++++++++++++++++++++++++++++++++------------ test/matmul.jl | 25 ++++++++++++++++++ 2 files changed, 78 insertions(+), 16 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index d618bcfe..33372a67 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -599,11 +599,21 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))")) end - _rmul_or_fill!(C, β) + if (!iszero(β) || isempty(A)) # return C*beta + _rmul_or_fill!(C, β) + else # iszero(β) && A and B are non-empty + a1 = firstindex(A, 1) + a2 = firstindex(A, 2) + for j in axes(C, 2), i in axes(C, 1) + z1 = zero(A[i, a2]*A[a1, j] + A[i, a2]*A[a1, j]) + C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + end + end + iszero(α) && return C @inbounds if !conjugate if aat for k ∈ 1:n, j ∈ 1:m - αA_jk = A[j, k] * α + αA_jk = @stable_muladdmul MulAddMul(α, false)(A[j, k]) for i ∈ 1:j C[i, j] += A[i, k] * αA_jk end @@ -614,17 +624,17 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo for k ∈ 2:m temp += A[k, i] * A[k, j] end - C[i, j] += temp * α + C[i, j] += @stable_muladdmul MulAddMul(α, false)(temp) end end else if aat for k ∈ 1:n, j ∈ 1:m - αA_jk_bar = conj(A[j, k]) * α + αA_jk_bar = @stable_muladdmul MulAddMul(α, false)(conj(A[j, k])) for i ∈ 1:j-1 C[i, j] += A[i, k] * αA_jk_bar end - C[j, j] += abs2(A[j, k]) * α + C[j, j] += @stable_muladdmul MulAddMul(α, false)(abs2(A[j, k])) end else for j ∈ 1:n @@ -633,13 +643,13 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo for k ∈ 2:m temp += conj(A[k, i]) * A[k, j] end - C[i, j] += temp * α + C[i, j] += @stable_muladdmul MulAddMul(α, false)(temp) end temp = abs2(A[1, j]) for k ∈ 2:m temp += abs2(A[k, j]) end - C[j, j] += temp * α + C[j, j] += @stable_muladdmul MulAddMul(α, false)(temp) end end end @@ -1132,8 +1142,18 @@ __generic_matmatmul!(C, A, B, alpha, beta, ::Val{true}) = _generic_matmatmul_non __generic_matmatmul!(C, A, B, alpha, beta, ::Val{false}) = _generic_matmatmul_generic!(C, A, B, alpha, beta) function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta) - _rmul_or_fill!(C, beta) - (iszero(alpha) || isempty(A) || isempty(B)) && return C + # _rmul_or_fill!(C, beta) spelled out more carefully to allow for zero-less eltypes + if (!iszero(beta) || isempty(A) || isempty(B)) # return C*beta + _rmul_or_fill!(C, beta) + else # iszero(beta) && A and B are non-empty + a1 = firstindex(A, 2) + b1 = firstindex(B, 1) + for j in axes(C, 2), i in axes(C, 1) + z1 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j]) + C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + end + end + iszero(alpha) && return C @inbounds for n in axes(B, 2), k in axes(B, 1) # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha) Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n]) @@ -1145,20 +1165,37 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta) C end function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta) - _rmul_or_fill!(C, beta) - (iszero(alpha) || isempty(A) || isempty(B)) && return C + if (!iszero(beta) || isempty(A) || isempty(B)) # return C*beta + _rmul_or_fill!(C, beta) + else # iszero(beta) && A and B are non-empty + a1 = firstindex(A, 2) + b1 = firstindex(B, 1) + for j in axes(C, 2), i in axes(C, 1) + z1 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j]) + C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + end + end + iszero(alpha) && return C t = _wrapperop(A) pB = parent(B) pA = parent(A) tmp = similar(C, axes(C, 2)) ci = firstindex(C, 1) ta = t(alpha) - for i in axes(A, 1) - mul!(tmp, pB, view(pA, :, i)) - @views C[ci,:] .+= t.(ta .* tmp) - ci += 1 + if isone(ta) + for i in axes(A, 1) + mul!(tmp, pB, view(pA, :, i)) + @views C[ci,:] .+= t.(tmp) + ci += 1 + end + else + for i in axes(A, 1) + mul!(tmp, pB, view(pA, :, i)) + @views C[ci,:] .+= t.(ta .* tmp) + ci += 1 + end end - C + return C end function _generic_matmatmul_generic!(C, A, B, alpha, beta) if iszero(alpha) || isempty(A) || isempty(B) diff --git a/test/matmul.jl b/test/matmul.jl index 1fcf2009..067f49cb 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -1241,4 +1241,29 @@ end @test C1 ≈ C2 end +@testset "matmul with zero-less types" begin + struct Mod <: Real + val::Int + modulo::Int + Mod(x::Int, y::Int) = new(x % y, y) + end + + Base.:+(x::Mod, y::Mod) = Mod(x.val + y.val, x.modulo) + Base.:*(x::Mod, y::Mod) = Mod(x.val * y.val, x.modulo) + Base.zero(x::Mod) = Mod(0, x.modulo) + + m = Mod.(rand(0:19, 5, 0), 20) + @test_throws MethodError m * copy(m') + for n in (2, 3, 5) + A = rand(0:19, n, n) + M = Mod.(A, 20) + @test M * M == Mod.(A * A, 20) + @test M' * M == Mod.(A' * A, 20) + @test M * M' == Mod.(A * A', 20) + @test M' * M' == Mod.(A' * A', 20) + @test M * M[:, 1] == Mod.(A * A[:, 1], 20) + @test M' * M[:, 1] == Mod.(A' * A[:, 1], 20) + end +end + end # module TestMatmul From c032e87424ef3cd037618ecf2f711445d79a64ab Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 24 Nov 2025 15:09:45 +0100 Subject: [PATCH 2/6] manually hoist out operations --- src/matmul.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index 33372a67..b312d19f 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -604,9 +604,13 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo else # iszero(β) && A and B are non-empty a1 = firstindex(A, 1) a2 = firstindex(A, 2) - for j in axes(C, 2), i in axes(C, 1) - z1 = zero(A[i, a2]*A[a1, j] + A[i, a2]*A[a1, j]) - C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + for j in axes(C, 2) + A_1j = A[a1, j] + for i in axes(C, 1) + A_ij = A[i, a2]*A_1j + z1 = zero(A_ij + A_ij) + C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + end end end iszero(α) && return C From 62036e8f41408eab6820614354921699bd9b9283 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 24 Nov 2025 16:33:18 +0100 Subject: [PATCH 3/6] hoist out more, initialize more carefully in syrk --- src/matmul.jl | 57 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index b312d19f..6a6cc850 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -601,15 +601,24 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo if (!iszero(β) || isempty(A)) # return C*beta _rmul_or_fill!(C, β) - else # iszero(β) && A and B are non-empty - a1 = firstindex(A, 1) - a2 = firstindex(A, 2) - for j in axes(C, 2) - A_1j = A[a1, j] - for i in axes(C, 1) - A_ij = A[i, a2]*A_1j - z1 = zero(A_ij + A_ij) - C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + else # iszero(β) && A is non-empty + if aat + for j ∈ 1:m + A_1j = A[j,1]' + for i ∈ 1:j + A_ij = A[i,1]*A_1j + z1 = zero(A_ij + A_ij) + C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + end + end + else # !aat + for j ∈ 1:n + A_1j = A[1,j] + for i ∈ 1:j + A_ij = A[1,i]'A_1j + z1 = zero(A_ij + A_ij) + C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + end end end end @@ -1152,9 +1161,13 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta) else # iszero(beta) && A and B are non-empty a1 = firstindex(A, 2) b1 = firstindex(B, 1) - for j in axes(C, 2), i in axes(C, 1) - z1 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j]) - C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + for j in axes(C, 2) + B_1j = B[b1, j] + for i in axes(C, 1) + C_ij = A[i, a1] * B_1j + z1 = zero(C_ij + C_ij) + C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + end end end iszero(alpha) && return C @@ -1169,20 +1182,24 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta) C end function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta) + t = _wrapperop(A) + pB = parent(B) + pA = parent(A) if (!iszero(beta) || isempty(A) || isempty(B)) # return C*beta _rmul_or_fill!(C, beta) else # iszero(beta) && A and B are non-empty - a1 = firstindex(A, 2) - b1 = firstindex(B, 1) - for j in axes(C, 2), i in axes(C, 1) - z1 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j]) - C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + a1 = firstindex(pA, 1) + b1 = firstindex(pB, 2) + for j in axes(C, 2) + tB_1j = t(pB[j, b1]) + for i in axes(C, 1) + C_ij = t(pA[a1, i]) * tB_1j + z1 = zero(C_ij + C_ij) + C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + end end end iszero(alpha) && return C - t = _wrapperop(A) - pB = parent(B) - pA = parent(A) tmp = similar(C, axes(C, 2)) ci = firstindex(C, 1) ta = t(alpha) From 497c112e01c22833aae185229fe68f22b58ce4fe Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 24 Nov 2025 22:43:00 +0100 Subject: [PATCH 4/6] simplify initialization in generic syrk --- src/matmul.jl | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index 6a6cc850..36351bf4 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -602,24 +602,10 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo if (!iszero(β) || isempty(A)) # return C*beta _rmul_or_fill!(C, β) else # iszero(β) && A is non-empty - if aat - for j ∈ 1:m - A_1j = A[j,1]' - for i ∈ 1:j - A_ij = A[i,1]*A_1j - z1 = zero(A_ij + A_ij) - C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) - end - end - else # !aat - for j ∈ 1:n - A_1j = A[1,j] - for i ∈ 1:j - A_ij = A[1,i]'A_1j - z1 = zero(A_ij + A_ij) - C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) - end - end + aA_11 = abs2(A[1,1]) + C_ij = zero(aA_11 + aA_11) + for j ∈ 1:m, i ∈ 1:j + C[i,j] = C_ij end end iszero(α) && return C From 9c29751b2a3498afed2d2eb71309aec6de95883e Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 25 Nov 2025 09:52:35 +0100 Subject: [PATCH 5/6] simplify further, use fill! --- src/matmul.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index 36351bf4..49dcb8f8 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -603,10 +603,7 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo _rmul_or_fill!(C, β) else # iszero(β) && A is non-empty aA_11 = abs2(A[1,1]) - C_ij = zero(aA_11 + aA_11) - for j ∈ 1:m, i ∈ 1:j - C[i,j] = C_ij - end + fill!(UpperTriangular(C), zero(aA_11 + aA_11)) end iszero(α) && return C @inbounds if !conjugate From 384b35d025e939808c050be77ab9fb58a80d5f82 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 25 Nov 2025 16:32:21 +0100 Subject: [PATCH 6/6] rely on implicit type promotion in setindex! --- src/matmul.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index 49dcb8f8..311ddfcf 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -1148,8 +1148,7 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta) B_1j = B[b1, j] for i in axes(C, 1) C_ij = A[i, a1] * B_1j - z1 = zero(C_ij + C_ij) - C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + C[i,j] = zero(C_ij + C_ij) end end end @@ -1177,8 +1176,7 @@ function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta) tB_1j = t(pB[j, b1]) for i in axes(C, 1) C_ij = t(pA[a1, i]) * tB_1j - z1 = zero(C_ij + C_ij) - C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1) + C[i,j] = zero(C_ij + C_ij) end end end