diff --git a/src/matmul.jl b/src/matmul.jl index 73957c8..beef0ff 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -151,6 +151,12 @@ Otherwise, based on the array's size, whether they are transposed, and whether t C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN ) where {T} M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN + if M * N == 0 + return + elseif K == 0 + matmul_only_β!(C, β) + return + end pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C); Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B); Mc, Kc, Nc = block_sizes(T); mᵣ, nᵣ = matmul_params(); @@ -168,6 +174,18 @@ Otherwise, based on the array's size, whether they are transposed, and whether t end end # function +function matmul_only_β!(C::AbstractMatrix{T}, β::StaticInt{0}) where T + @avx for i=1:length(C) + C[i] = zero(T) + end +end + +function matmul_only_β!(C::AbstractMatrix{T}, β) where T + @avx for i=1:length(C) + C[i] = β * C[i] + end +end + function matmul_st_pack_dispatcher!(pC::AbstractStridedPointer{T}, pA, pB, α, β, M, K, N) where {T} Mc, Kc, Nc = block_sizes(T) if (contiguousstride1(pB) ? (Kc * Nc ≥ K * N) : (firstbytestride(pB) ≤ 1600)) @@ -228,6 +246,12 @@ end # passing MKN directly would let osmeone skip the size check. @inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T}#::Union{Nothing,Tuple{Vararg{Integer,3}}}) where {T} M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN + if M * N == 0 + return + elseif K == 0 + matmul_only_β!(C, β) + return + end W = pick_vector_width(T) pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C); Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B); diff --git a/test/_matmul.jl b/test/_matmul.jl index 6df60d3..b186610 100644 --- a/test/_matmul.jl +++ b/test/_matmul.jl @@ -140,3 +140,13 @@ end @test matmul_pack_ab!(similar(AB), A′, B′) ≈ AB end +@time @testset "zero-sized-matrices" begin + @test Octavian.matmul_serial(randn(0,0), randn(0,0)) == zeros(0, 0) + @test Octavian.matmul_serial(randn(2,3), randn(3,0)) == zeros(2, 0) + @test Octavian.matmul_serial(randn(2,0), randn(0,2)) == zeros(2, 2) + @test Octavian.matmul_serial!(ones(2,2),randn(2,0), randn(0,2), 1.0, 2.0) == ones(2, 2) .* 2 + @test Octavian.matmul(randn(0,0), randn(0,0)) == zeros(0, 0) + @test Octavian.matmul(randn(2,3), randn(3,0)) == zeros(2, 0) + @test Octavian.matmul(randn(2,0), randn(0,2)) == zeros(2, 2) + @test Octavian.matmul!(ones(2,2),randn(2,0), randn(0,2), 1.0, 2.0) == ones(2, 2) .* 2 +end