In this nb, I play around with my numba-cuda version of matmul_2. It runs at ~1.2s, but should run at ~200ms.

**Hypothesis:** I made simple code changes (compared to cuda-c version) which should not change the runtime, but do.<br/>
**Result:** At least for the 4 code changes identified, that's not the case; ie the runtime is still slow.

In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING']='1' # disabling this doesn't change cuda runtime

import numpy as np
from numba import cuda
from util import cdiv, measure_runtime

In [2]:
@cuda.jit()
def original(a,b,c,m,n,k,bs):
    # we defined blocks of size bs*bs
    x = cuda.blockIdx.x * bs + (cuda.threadIdx.x // bs)
    y = cuda.blockIdx.y * bs + (cuda.threadIdx.x % bs)
    if x>=m or y>=n: return 
    tmp = 0
    for i in range(k): tmp += a[x,i] * b[i,y]
    c[x, y] = tmp

In [3]:
measure_runtime(
    original,
    nthreads=32*32,nblocks_fn=lambda outp_shape,nthreads: cdiv(outp_shape, (32,32)),
    kernel_args=[32]
);

Measuring runtime of kernel original for m,n,k = 4092,4092,4092, averaging over 3 runs.


STAGE:2024-05-05 14:43:59 130288:130288 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-05-05 14:44:03 130288:130288 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-05 14:44:03 130288:130288 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


1.256s


In [4]:
# fixed bs
@cuda.jit()
def variant2(a,b,c,m,n,k):
    bs = 32
    # we defined blocks of size bs*bs
    x = cuda.blockIdx.x * bs + (cuda.threadIdx.x // bs)
    y = cuda.blockIdx.y * bs + (cuda.threadIdx.x % bs)
    if x>=m or y>=n: return 
    tmp = 0
    for i in range(k): tmp += a[x,i] * b[i,y]
    c[x, y] = tmp

In [5]:
measure_runtime(
    variant2,
    nthreads=32*32,nblocks_fn=lambda outp_shape,nthreads: cdiv(outp_shape, (32,32)),
);

Measuring runtime of kernel variant2 for m,n,k = 4092,4092,4092, averaging over 3 runs.


STAGE:2024-05-05 14:44:06 130288:130288 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-05-05 14:44:10 130288:130288 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-05 14:44:10 130288:130288 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


1.261s


In [6]:
# no premature return
@cuda.jit()
def variant3(a,b,c,m,n,k,bs):
    # we defined blocks of size bs*bs
    x = cuda.blockIdx.x * bs + (cuda.threadIdx.x // bs)
    y = cuda.blockIdx.y * bs + (cuda.threadIdx.x % bs)
    if x<m and y<n:
        tmp = 0
        for i in range(k): tmp += a[x,i] * b[i,y]
        c[x, y] = tmp

In [7]:
measure_runtime(
    variant3,
    nthreads=32*32,nblocks_fn=lambda outp_shape,nthreads: cdiv(outp_shape, (32,32)),
    kernel_args=[32]
);

Measuring runtime of kernel variant3 for m,n,k = 4092,4092,4092, averaging over 3 runs.


STAGE:2024-05-05 14:44:14 130288:130288 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-05-05 14:44:17 130288:130288 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-05 14:44:17 130288:130288 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


1.263s


In [8]:
# init tmp as float
@cuda.jit()
def variant4(a,b,c,m,n,k,bs):
    # we defined blocks of size bs*bs
    x = cuda.blockIdx.x * bs + (cuda.threadIdx.x // bs)
    y = cuda.blockIdx.y * bs + (cuda.threadIdx.x % bs)
    if x>=m or y>=n: return 
    tmp = 0.0
    for i in range(k): tmp += a[x,i] * b[i,y]
    c[x, y] = tmp

In [9]:
measure_runtime(
    variant4,
    nthreads=32*32,nblocks_fn=lambda outp_shape,nthreads: cdiv(outp_shape, (32,32)),
    kernel_args=[32]
);

Measuring runtime of kernel variant4 for m,n,k = 4092,4092,4092, averaging over 3 runs.


STAGE:2024-05-05 14:44:21 130288:130288 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-05-05 14:44:24 130288:130288 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-05 14:44:24 130288:130288 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


1.265s


In [10]:
# with signature explicitly provided
from numba import float32, int32
sig = (
    float32[:, :], float32[:, :], float32[:, :], # matrices a,b,c
    int32, int32, int32, # sizes m,n,k
    int32 # bs
)

@cuda.jit(sig)
def variant5(a,b,c,m,n,k,bs):
    # we defined blocks of size bs*bs
    x = cuda.blockIdx.x * bs + (cuda.threadIdx.x // bs)
    y = cuda.blockIdx.y * bs + (cuda.threadIdx.x % bs)
    if x>=m or y>=n: return 
    tmp = 0
    for i in range(k): tmp += a[x,i] * b[i,y]
    c[x, y] = tmp

In [11]:
measure_runtime(
    variant5,
    nthreads=32*32,nblocks_fn=lambda outp_shape,nthreads: cdiv(outp_shape, (32,32)),
    kernel_args=[32]
);

Measuring runtime of kernel variant5 for m,n,k = 4092,4092,4092, averaging over 3 runs.


STAGE:2024-05-05 14:44:28 130288:130288 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-05-05 14:44:31 130288:130288 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-05 14:44:31 130288:130288 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


1.239s


In [12]:
# with types
# as seen in https://twitter.com/haseox94/status/1752130508182708417
@cuda.jit()
def variant6(a,b,c,m:int,n:int,k:int,bs:int):
    # we defined blocks of size bs*bs
    x = cuda.blockIdx.x * bs + (cuda.threadIdx.x // bs)
    y = cuda.blockIdx.y * bs + (cuda.threadIdx.x % bs)
    if x>=m or y>=n: return 
    tmp: float = 0.0
    for i in range(k): tmp += a[x,i] * b[i,y]
    c[x, y] = tmp

In [13]:
measure_runtime(
    variant6,
    nthreads=32*32,nblocks_fn=lambda outp_shape,nthreads: cdiv(outp_shape, (32,32)),
    kernel_args=[32]
);

Measuring runtime of kernel variant6 for m,n,k = 4092,4092,4092, averaging over 3 runs.


STAGE:2024-05-05 14:44:35 130288:130288 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-05-05 14:44:39 130288:130288 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-05 14:44:39 130288:130288 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


1.268s


✅ From `check_types.ipynb` I know that `tmp` is cast to fp64, and explicitly defining it as fp32 prevents this.<br/>
Let's measure the kernel without fp64-casting.

In [14]:
@cuda.jit()
def variant7(a,b,c,m,n,k,bs):
    # we defined blocks of size bs*bs
    x = cuda.blockIdx.x * bs + (cuda.threadIdx.x // bs)
    y = cuda.blockIdx.y * bs + (cuda.threadIdx.x % bs)
    if x>=m or y>=n: return 
    tmp = np.float32(0)  # Initialize tmp explicitly as fp32
    for i in range(k): tmp += a[x,i] * b[i,y]
    c[x, y] = tmp

In [15]:
measure_runtime(
    variant7,
    nthreads=32*32,nblocks_fn=lambda outp_shape,nthreads: cdiv(outp_shape, (32,32)),
    kernel_args=[32]
);

Measuring runtime of kernel variant7 for m,n,k = 4092,4092,4092, averaging over 3 runs.


STAGE:2024-05-05 14:44:41 130288:130288 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-05-05 14:44:42 130288:130288 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-05 14:44:42 130288:130288 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


0.566s


Still, the runtime should be ~200ms

In [31]:
# without 2d indexing
@cuda.jit()
def variant8(a,b,c,m,n,k):
    bs = 32
    # we defined blocks of size bs*bs
    x = cuda.blockIdx.x * bs + (cuda.threadIdx.x // bs)
    y = cuda.blockIdx.y * bs + (cuda.threadIdx.x % bs)
    if x>=m or y>=n: return 
    tmp = np.float32(0)  # Initialize tmp explicitly as fp32
    for i in range(k):
        tmp += a[x*k+i] * b[i*n+y]
    c[x*n+y] = tmp

In [32]:
from torch.profiler import profile, record_function, ProfilerActivity
from torch.profiler import schedule as profiler_schedule

from util import to_d, to_h
dtype = 'float32'

m,n,k=4092,4092,4092
nthreads=(32*32,)
nblocks = cdiv((4092,4092), (32,32))

wait,warmup,runs=1,1,3

a = to_d(np.ones((m*k,), dtype=dtype))
b = to_d(np.ones((k*n,), dtype=dtype))

with profile(activities=[ProfilerActivity.CUDA], schedule=profiler_schedule(wait=wait, warmup=warmup, active=runs)) as p:
    for _ in range(wait+warmup+runs):
        c = to_d(np.empty((m*n,), dtype=dtype))    
        variant8[nblocks, nthreads](a,b,c,m,n,k)
        p.step()

STAGE:2024-05-05 14:51:02 130288:130288 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-05-05 14:51:04 130288:130288 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-05 14:51:04 130288:130288 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


In [33]:
from torch import allclose, tensor

allclose(tensor(a).reshape(m,k)@tensor(b).reshape(k,n), tensor(c).reshape(m,n))

True

In [37]:
from util import cuda_mean_runtime
cuda_mean_runtime(p, variant8.__name__, do_print=True);

0.587s
