diff --git a/src/matmul.jl b/src/matmul.jl index 234a3a8..62eb47b 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -146,8 +146,8 @@ end return C end @inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, MKN, ::StaticInt) - _matmul_serial!(C, A, B, α, β, MKN) - return C + _matmul_serial!(C, A, B, α, β, MKN) + return C end """ @@ -164,30 +164,32 @@ If the arrays are small and statically sized, it will dispatch to an inlined mul Otherwise, based on the array's size, whether they are transposed, and whether the columns are already aligned, it decides to not pack at all, to pack only `A`, or to pack both arrays `A` and `B`. """ @inline function _matmul_serial!( - C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN + C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN ) where {T} - M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN - if M * N == 0 - return - elseif K == 0 - matmul_only_β!(C, β) - return - end - pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C); - Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B); - Mc, Kc, Nc = block_sizes(Val(T)); mᵣ, nᵣ = matmul_params(Val(T)); - GC.@preserve Cb Ab Bb begin - if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable - inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N) - return - elseif (nᵣ ≥ N) || dontpack(pA, M, K, Mc, Kc, T) - loopmul!(pC, pA, pB, α, β, M, K, N) - return - else - matmul_st_pack_dispatcher!(pC, pA, pB, α, β, M, K, N) - return - end + ((β ≢ Zero()) && iszero(β)) && return _matmul_serial!(C, A, B, α, Zero(), MKN) + (β isa Bool) && return _matmul_serial!(C, A, B, α, One(), MKN) + M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN + if M * N == 0 + return + elseif K == 0 + matmul_only_β!(C, β) + return + end + pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C); + Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B); + Mc, Kc, Nc = block_sizes(Val(T)); mᵣ, nᵣ = matmul_params(Val(T)); + GC.@preserve Cb Ab Bb begin + if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable + inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N) + return + elseif (nᵣ ≥ N) || dontpack(pA, M, K, Mc, Kc, T) + loopmul!(pC, pA, pB, α, β, M, K, N) + return + else + matmul_st_pack_dispatcher!(pC, pA, pB, α, β, M, K, N) + return end + end end # function function matmul_only_β!(C::AbstractMatrix{T}, β::StaticInt{0}) where T @@ -266,35 +268,37 @@ end # passing MKN directly would let osmeone skip the size check. @inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T} - M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN - if M * N == 0 - return - elseif K == 0 - matmul_only_β!(C, β) - return - end - W = pick_vector_width(T) - pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C); - Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B); - mᵣ, nᵣ = matmul_params(Val(T)) - GC.@preserve Cb Ab Bb begin - if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable - inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N) - return - else - (nᵣ ≥ N) && @goto LOOPMUL - if (Sys.ARCH === :x86_64) || (Sys.ARCH === :i686) - (M*K*N < (StaticInt{4_096}() * W)) && @goto LOOPMUL - else - (M*K*N < (StaticInt{32_000}() * W)) && @goto LOOPMUL - end - __matmul!(pC, pA, pB, α, β, M, K, N, nthread) - return - @label LOOPMUL - loopmul!(pC, pA, pB, α, β, M, K, N) - return - end + ((β ≢ Zero()) && iszero(β)) && return _matmul!(C, A, B, α, Zero(), nthread, MKN) + (β isa Bool) && return _matmul!(C, A, B, α, One(), nthread, MKN) + M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN + if M * N == 0 + return + elseif K == 0 + matmul_only_β!(C, β) + return + end + W = pick_vector_width(T) + pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C); + Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B); + mᵣ, nᵣ = matmul_params(Val(T)) + GC.@preserve Cb Ab Bb begin + if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable + inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N) + return + else + (nᵣ ≥ N) && @goto LOOPMUL + if (Sys.ARCH === :x86_64) || (Sys.ARCH === :i686) + (M*K*N < (StaticInt{4_096}() * W)) && @goto LOOPMUL + else + (M*K*N < (StaticInt{32_000}() * W)) && @goto LOOPMUL + end + __matmul!(pC, pA, pB, α, β, M, K, N, nthread) + return + @label LOOPMUL + loopmul!(pC, pA, pB, α, β, M, K, N) + return end + end end # This funciton is sort of a `pun`. It splits aggressively (it does a lot of "splitin'"), which often means it will split-N. diff --git a/test/matmul_main.jl b/test/matmul_main.jl index 6190aa2..f5152b3 100644 --- a/test/matmul_main.jl +++ b/test/matmul_main.jl @@ -9,3 +9,7 @@ for T ∈ (Float64,Float32,Int64,Int32) @time test_real(T, m_values, k_values, n_values, testset_name_suffix) end +A = rand(2,2); B = rand(2,2); AB = A*B; C = fill(NaN, 2, 2); +@test Octavian.matmul!(C, A, B, true, false) ≈ AB +@test Octavian.matmul!(C, A, B, true, true) ≈ 2AB +