From e389956b821c3423e32464748c1224fa1baa1259 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Sat, 20 Mar 2021 00:11:09 -0400 Subject: [PATCH 1/2] fix zero shaped matmul --- src/matmul.jl | 15 +++++++++++++++ test/_matmul.jl | 6 ++++++ 2 files changed, 21 insertions(+) diff --git a/src/matmul.jl b/src/matmul.jl index 73957c8..3bb69af 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -151,6 +151,21 @@ Otherwise, based on the array's size, whether they are transposed, and whether t C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN ) where {T} M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN + if M * N == 0 + return + elseif K == 0 + if β isa StaticInt{0} # this is actually type stable + @avx for i=1:length(C) + C[i] = zero(T) + end + return + else + @avx for i=1:length(C) + C[i] = β * C[i] + end + return + end + end pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C); Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B); Mc, Kc, Nc = block_sizes(T); mᵣ, nᵣ = matmul_params(); diff --git a/test/_matmul.jl b/test/_matmul.jl index 6df60d3..fa8087d 100644 --- a/test/_matmul.jl +++ b/test/_matmul.jl @@ -140,3 +140,9 @@ end @test matmul_pack_ab!(similar(AB), A′, B′) ≈ AB end +@time @testset "zero-sized-matrices" begin + @test Octavian.matmul_serial(randn(0,0), randn(0,0)) == zeros(0, 0) + @test Octavian.matmul_serial(randn(2,3), randn(3,0)) == zeros(2, 0) + @test Octavian.matmul_serial(randn(2,0), randn(0,2)) == zeros(2, 2) + @test Octavian.matmul_serial!(ones(2,2),randn(2,0), randn(0,2), 1.0, 2.0) == ones(2, 2) .* 2 +end From c5b05f8c98a0832cadbd2d884217802a07021d60 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Sat, 20 Mar 2021 00:30:01 -0400 Subject: [PATCH 2/2] fix matmul too --- src/matmul.jl | 31 ++++++++++++++++++++----------- test/_matmul.jl | 4 ++++ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index 3bb69af..beef0ff 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -154,17 +154,8 @@ Otherwise, based on the array's size, whether they are transposed, and whether t if M * N == 0 return elseif K == 0 - if β isa StaticInt{0} # this is actually type stable - @avx for i=1:length(C) - C[i] = zero(T) - end - return - else - @avx for i=1:length(C) - C[i] = β * C[i] - end - return - end + matmul_only_β!(C, β) + return end pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C); Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B); @@ -183,6 +174,18 @@ Otherwise, based on the array's size, whether they are transposed, and whether t end end # function +function matmul_only_β!(C::AbstractMatrix{T}, β::StaticInt{0}) where T + @avx for i=1:length(C) + C[i] = zero(T) + end +end + +function matmul_only_β!(C::AbstractMatrix{T}, β) where T + @avx for i=1:length(C) + C[i] = β * C[i] + end +end + function matmul_st_pack_dispatcher!(pC::AbstractStridedPointer{T}, pA, pB, α, β, M, K, N) where {T} Mc, Kc, Nc = block_sizes(T) if (contiguousstride1(pB) ? (Kc * Nc ≥ K * N) : (firstbytestride(pB) ≤ 1600)) @@ -243,6 +246,12 @@ end # passing MKN directly would let osmeone skip the size check. @inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T}#::Union{Nothing,Tuple{Vararg{Integer,3}}}) where {T} M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN + if M * N == 0 + return + elseif K == 0 + matmul_only_β!(C, β) + return + end W = pick_vector_width(T) pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C); Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B); diff --git a/test/_matmul.jl b/test/_matmul.jl index fa8087d..b186610 100644 --- a/test/_matmul.jl +++ b/test/_matmul.jl @@ -145,4 +145,8 @@ end @test Octavian.matmul_serial(randn(2,3), randn(3,0)) == zeros(2, 0) @test Octavian.matmul_serial(randn(2,0), randn(0,2)) == zeros(2, 2) @test Octavian.matmul_serial!(ones(2,2),randn(2,0), randn(0,2), 1.0, 2.0) == ones(2, 2) .* 2 + @test Octavian.matmul(randn(0,0), randn(0,0)) == zeros(0, 0) + @test Octavian.matmul(randn(2,3), randn(3,0)) == zeros(2, 0) + @test Octavian.matmul(randn(2,0), randn(0,2)) == zeros(2, 2) + @test Octavian.matmul!(ones(2,2),randn(2,0), randn(0,2), 1.0, 2.0) == ones(2, 2) .* 2 end