Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 56 additions & 52 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ end
return C
end
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, MKN, ::StaticInt)
_matmul_serial!(C, A, B, α, β, MKN)
return C
_matmul_serial!(C, A, B, α, β, MKN)
return C
end

"""
Expand All @@ -164,30 +164,32 @@ If the arrays are small and statically sized, it will dispatch to an inlined mul
Otherwise, based on the array's size, whether they are transposed, and whether the columns are already aligned, it decides to not pack at all, to pack only `A`, or to pack both arrays `A` and `B`.
"""
@inline function _matmul_serial!(
C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN
C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN
) where {T}
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
if M * N == 0
return
elseif K == 0
matmul_only_β!(C, β)
return
end
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
Mc, Kc, Nc = block_sizes(Val(T)); mᵣ, nᵣ = matmul_params(Val(T));
GC.@preserve Cb Ab Bb begin
if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable
inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N)
return
elseif (nᵣ ≥ N) || dontpack(pA, M, K, Mc, Kc, T)
loopmul!(pC, pA, pB, α, β, M, K, N)
return
else
matmul_st_pack_dispatcher!(pC, pA, pB, α, β, M, K, N)
return
end
((β ≢ Zero()) && iszero(β)) && return _matmul_serial!(C, A, B, α, Zero(), MKN)
(β isa Bool) && return _matmul_serial!(C, A, B, α, One(), MKN)
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
if M * N == 0
return
elseif K == 0
matmul_only_β!(C, β)
return
end
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
Mc, Kc, Nc = block_sizes(Val(T)); mᵣ, nᵣ = matmul_params(Val(T));
GC.@preserve Cb Ab Bb begin
if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable
inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N)
return
elseif (nᵣ ≥ N) || dontpack(pA, M, K, Mc, Kc, T)
loopmul!(pC, pA, pB, α, β, M, K, N)
return
else
matmul_st_pack_dispatcher!(pC, pA, pB, α, β, M, K, N)
return
end
end
end # function

function matmul_only_β!(C::AbstractMatrix{T}, β::StaticInt{0}) where T
Expand Down Expand Up @@ -266,35 +268,37 @@ end

# passing MKN directly would let osmeone skip the size check.
@inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T}
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
if M * N == 0
return
elseif K == 0
matmul_only_β!(C, β)
return
end
W = pick_vector_width(T)
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
mᵣ, nᵣ = matmul_params(Val(T))
GC.@preserve Cb Ab Bb begin
if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable
inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N)
return
else
(nᵣ ≥ N) && @goto LOOPMUL
if (Sys.ARCH === :x86_64) || (Sys.ARCH === :i686)
(M*K*N < (StaticInt{4_096}() * W)) && @goto LOOPMUL
else
(M*K*N < (StaticInt{32_000}() * W)) && @goto LOOPMUL
end
__matmul!(pC, pA, pB, α, β, M, K, N, nthread)
return
@label LOOPMUL
loopmul!(pC, pA, pB, α, β, M, K, N)
return
end
((β ≢ Zero()) && iszero(β)) && return _matmul!(C, A, B, α, Zero(), nthread, MKN)
(β isa Bool) && return _matmul!(C, A, B, α, One(), nthread, MKN)
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
if M * N == 0
return
elseif K == 0
matmul_only_β!(C, β)
return
end
W = pick_vector_width(T)
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
mᵣ, nᵣ = matmul_params(Val(T))
GC.@preserve Cb Ab Bb begin
if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable
inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N)
return
else
(nᵣ ≥ N) && @goto LOOPMUL
if (Sys.ARCH === :x86_64) || (Sys.ARCH === :i686)
(M*K*N < (StaticInt{4_096}() * W)) && @goto LOOPMUL
else
(M*K*N < (StaticInt{32_000}() * W)) && @goto LOOPMUL
end
__matmul!(pC, pA, pB, α, β, M, K, N, nthread)
return
@label LOOPMUL
loopmul!(pC, pA, pB, α, β, M, K, N)
return
end
end
end

# This funciton is sort of a `pun`. It splits aggressively (it does a lot of "splitin'"), which often means it will split-N.
Expand Down
4 changes: 4 additions & 0 deletions test/matmul_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ for T ∈ (Float64,Float32,Int64,Int32)
@time test_real(T, m_values, k_values, n_values, testset_name_suffix)
end

A = rand(2,2); B = rand(2,2); AB = A*B; C = fill(NaN, 2, 2);
@test Octavian.matmul!(C, A, B, true, false) ≈ AB
@test Octavian.matmul!(C, A, B, true, true) ≈ 2AB