## Setup

In [None]:
import os,math,sys,torch,re,numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple

In [None]:
dim3 = namedtuple('dim3', ['x','y','z'], defaults=(1,1))

In [None]:
d = dim3(2,3)
d

dim3(x=2, y=3, z=1)

In [None]:
d.x,d.y

(2, 3)

In [None]:
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

In [None]:
sys.path.insert(0, '..')

In [None]:
from utils import show_img,load_cuda,cuda_begin,cdiv

In [None]:
%load_ext wurlitzer

In [None]:
# os.environ['CUDA_LAUNCH_BLOCKING']='1'
torch.manual_seed(42);

In [None]:
m1 = torch.rand(5120, 256)
m1s = m1[:4]
m2 = torch.rand(256,5120)
m2s = m2[:,:4]

## Reminder

### 2d Python kernel

In [None]:
def blk_kernel2d(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): f(dim3(i1,i0), dim3(j1,j0), threads, *args)

In [None]:
def matmul_bk(blockIdx, threadIdx, blockDim, m, n, out, h, w, k):
    r = blockIdx.y*blockDim.y + threadIdx.y
    c = blockIdx.x*blockDim.x + threadIdx.x
    
    if (r>=h or c>=w): return
    o = 0.
    for i in range(k): o += m[r*k+i] * n[i*w+c]
    out[r*w+c] = o

In [None]:
def matmul_2d(m, n):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(16,16)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d(matmul_bk, blocks, tpb,
                 m.flatten(), n.flatten(), output.flatten(), h, w, k)
    return output

In [None]:
torch.isclose(matmul_2d(m1s, m2s), m1s@m2s).all()

tensor(True)

### CUDA

In [None]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;

    if (r>=h || c>=w) return;
    float o = 0;
    for (int i = 0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [None]:
fname = 'matmul'

In [None]:
def get_sig(fname, src):
    res = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None

In [None]:
cpp_src = get_sig(fname, cuda_src)
cpp_src

'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'

In [None]:
module = load_cuda(cuda_src, cpp_src, [fname])

In [None]:
m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()

In [None]:
module.matmul(m1c,m2c).shape

In [None]:
torch.isclose(module.matmul(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [None]:
%%timeit -n 10
module.matmul(m1c,m2c)
torch.cuda.synchronize()

6.03 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


When I removed the call to the kernel itself, it took around 50 µs (0.05 ms) to run, so that's the overhead of the call on my machine.

## Shared mem

### Python

In [None]:
a = torch.zeros(5)
b,c = a[:3],a[3:]

In [None]:
b[1] = 2
c[0] = 6
a

tensor([0., 2., 0., 6., 0.])

In [None]:
def blk_kernel2d_shar(f, blocks, threads, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shared = torch.zeros(sh_sz)
            f(dim3(i1,i0), threads, shared, *args, **kwargs)

In [None]:
def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    for ph in range(cdiv(k,tw)):
        idx = ph*tw
        # fill shared
        for tr in range(blockDim.y):
            for tc in range(blockDim.x):
                r,c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
                ms[tr*tw+tc] = m[ tc+idx + r*k] if r<h and idx+tc<k else 0.
                ns[tr*tw+tc] = n[(tr+idx)*w +c] if c<w and idx+tr<k else 0.

        # do dotprods from shared
        for tr in range(blockDim.y):
            for tc in range(blockDim.x):
                r,c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
                for i in range(tw):
                    if r*w+c<len(out): out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]

In [None]:
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [None]:
m1s.shape, m2.shape

(torch.Size([4, 256]), torch.Size([256, 5120]))

In [None]:
torch.isclose(matmul_2d(m1s, m2s, tw=16), m1s@m2s).all()

tensor(True)

### Python run_threads

In [None]:
def run_threads(f, blockDim, *args, **kwargs):
    for i0 in range(blockDim.y):
        for i1 in range(blockDim.x): f(i0, i1, *args, **kwargs)

In [None]:
def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    def get_rc(tr, tc): return blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc

    def fill_shared_tk(tr, tc, ph):
        r,c = get_rc(tr, tc)
        ms[tr*tw+tc] = m[ tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.
        ns[tr*tw+tc] = n[(tr + ph*tw)*w +c] if c<w and (ph*tw+tr)<k else 0.

    def dotprod_tk(tr, tc):
        r,c = get_rc(tr, tc)
        for i in range(tw):
            if r*w+c<len(out): out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]

    for ph in range(int(math.ceil(k/tw))):
        run_threads(fill_shared_tk, blockDim, ph)
        run_threads(dotprod_tk, blockDim)

In [None]:
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [None]:
m1s.shape, m2s.shape

(torch.Size([4, 256]), torch.Size([256, 4]))

In [None]:
torch.isclose(matmul_2d(m1s, m2s, tw=16), m1s@m2s).all()

tensor(True)

### Python threads

In [None]:
import threading
from threading import Barrier, Thread
from concurrent.futures import ThreadPoolExecutor

In [None]:
def g(x, sb):
    print(x)
    sb.wait()
    print(-x)
    sb.wait()
    print(x*10)

In [None]:
num = 3
sb = Barrier(num)
with ThreadPoolExecutor(num) as ex: list(ex.map(lambda i: g(i,sb), range(1,num+1)))

1
2
3
-3
-2
-1
30
10
20


In [None]:
def blk_kernel2d_shar(f, blocks, tpb, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shar = torch.zeros(sh_sz)
            syncb = Barrier(tpb.y*tpb.x)
            threads = [Thread(target=f, args=(dim3(i1,i0), dim3(p,o), tpb, shar, syncb, *args), kwargs=kwargs)
                       for o in range(tpb.y) for p in range(tpb.x)]
            for tr in threads: tr.start()
            for tr in threads: tr.join()

In [None]:
def matmul_tiled_bk(blockIdx, threadIdx, blockDim, shared, syncb, m, n, out, h, w, k, tw):
    tc,tr = threadIdx.x,threadIdx.y
    r = blockIdx.y*blockDim.y + tr
    c = blockIdx.x*blockDim.x + tc

    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    p = 0.
    for ph in range(cdiv(k,tw)):
        ms[tr*tw+tc] = m[ tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.
        ns[tr*tw+tc] = n[(tr + ph*tw)*w +c] if c<w and (ph*tw+tr)<k else 0.
        syncb.wait()
        for i in range(tw): p += ms[tr*tw+i] * ns[tw*i+tc]
        syncb.wait()

    if (r<h and c<w): out[r*w + c] = p

In [None]:
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [None]:
torch.isclose(matmul_2d(m1s, m2s, tw=8), m1s@m2s).all()

tensor(True)

### CUDA dynamic shared

Code auto-generated by ChatGPT 4, using the following prompt:

> Convert the following python code to CUDA C, keeping formatting and variable names the same where possible. You can remove `blockIdx, threadIdx, blockDim, shared` from the argument list, since they're already provided by CUDA. Change `syncb.wait()` to `__syncthreads`. Use `extern __shared__ float shared[]` to create the `shared` array. Use the C ternary operator to replace the Python equivalent where appropriate. If the Python code uses any non-standard functions, you can assume the same functions are also available to the translated C code with the same name and signature.

The generated code worked first time, although we did some minor cleanups afterwards (e.g. renaming `shared` to `ms`).

In [None]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k, int tw) {
    int tc=threadIdx.x, tr=threadIdx.y;
    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;

    extern __shared__ float ms[];
    float *ns = &ms[tw*tw];

    float p = 0.0f;
    for (int ph = 0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tr*tw + tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
        ns[tr*tw + tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
        __syncthreads();
        for (int i=0; i<tw; ++i) p += ms[tr*tw + i] * ns[tw*i + tc];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}
'''

In [None]:
cuda_src += r'''
torch::Tensor matmul_dyn(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    /*
    cudaDeviceProp devProp;
    CUDA_ERR(cudaGetDeviceProperties(&devProp, 0));
    int maxThreads = devProp.maxThreadsPerBlock;
    size_t requiredSize = static_cast<size_t>(maxThreads) * 2 * sizeof(float);
    size_t size = min(devProp.sharedMemPerBlock, requiredSize);
    int TW = std::sqrt(maxThreads);
    */
    int TW = 16;
    size_t size = TW*TW * 2 * sizeof(float);

    dim3 tpb(TW,TW);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks,tpb,size>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k, TW);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [None]:
fname = 'matmul_dyn'

In [None]:
cpp_src = get_sig(fname, cuda_src)

In [None]:
module = load_cuda(cuda_src, cpp_src, [fname], opt=True)

In [None]:
torch.isclose(module.matmul_dyn(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [None]:
%%timeit -n 10
module.matmul_dyn(m1c,m2c)
torch.cuda.synchronize()

6.56 ms ± 372 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### CUDA static shared

In [None]:
cuda_src = cuda_begin + r'''
constexpr int tw = 16;

__global__ void matmul_ks(float *m, float *n, float *out, int h, int w, int k) {
    __shared__ float ms[tw][tw], ns[tw][tw];
    int tc=threadIdx.x, tr=threadIdx.y;
    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;

    float p=0.0f;
    for (int ph=0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tr][tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
        ns[tr][tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
        __syncthreads();
        for (int i=0; i<tw; ++i) p += ms[tr][i] * ns[i][tc];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}

torch::Tensor matmul_static(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());
    dim3 tpb(tw,tw);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_ks<<<blocks,tpb>>>(m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [None]:
fname = 'matmul_static'
cpp_src = get_sig(fname, cuda_src)
module = load_cuda(cuda_src, cpp_src, [fname])
torch.isclose(module.matmul_static(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [None]:
%%timeit -n 10
module.matmul_static(m1c,m2c)
torch.cuda.synchronize()

4.52 ms ± 253 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Numba

In [None]:
os.environ['NUMBA_ENABLE_CUDASIM'] = '1'

In [None]:
from numba import cuda
# from numba.cuda import as_cuda_array as ca

In [None]:
@cuda.jit
def matmul_k_numba(m, n, out, tw):
    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tc,tr = tid.x,tid.y
    r,c = cbi.y * cbd.y + tr, cbi.x * cbd.x + tc
    h,k  = m.shape
    k2,w = n.shape

    shar = cuda.shared.array(0, dtype=np.float32)
    ms,ns = shar[:tw*tw],shar[tw*tw:2*tw*tw]

    p = np.float32(0.0)
    for ph in range(math.ceil(k/tw)):
        idx = ph*tw
        ms[tr*tw+tc] = m[r, tc+idx] if r<h and idx+tc<k else 0.
        ns[tr*tw+tc] = n[tr+idx, c] if c<w and idx+tr<k else 0.
        cuda.syncthreads()
        for i in range(tw): p += ms[tr*tw+i] * ns[i*tw+tc]
        cuda.syncthreads()
    if r < h and c < w: out[r, c] = p

In [None]:
def matmul_2d_numba(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    out = torch.zeros(h, w, dtype=m.dtype, device=m.device)
    dyn_shared_mem_size = 2 * tw * tw * 4
    tpb = tw,tw
    blocks = cdiv(w,tpb[0]), cdiv(h,tpb[1])
    matmul_k_numba[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw) 
    return out

In [None]:
m1c.shape, m2c.shape

In [None]:
torch.isclose(matmul_2d_numba(m1c,m2c), m1c@m2c).all()

In [None]:
%%timeit -n 10
matmul_2d_numba(m1c,m2c)
torch.cuda.synchronize()

#### Rectangle Tiling

Works with `NUMBA_ENABLE_CUDASIM`, but not with CUDA mode.

In [None]:
@cuda.jit
def matmul_k_numba(m, n, out, tw_w, tw_h):
    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tx,ty = tid.x,tid.y

    h,k  = m.shape
    k2,w = n.shape
    
    shar = cuda.shared.array(0, dtype=np.dtype("float32"))
    ms,ns = shar[:tw_w*tw_h],shar[tw_w*tw_h:]
    
    for ph in range(math.ceil(k/tw_w)):
                
        # fill shared mem
        r,c = cbi.y * cbd.y + ty, cbi.x * cbd.y + ty
        idx = ph*tw_w
        
        ms[ty*tw_w+tx] = m[r, idx+tx] if r<h and idx+tx<k else 0.
        ns[ty*tw_w+tx] = n[idx+tx, c] if c<w and idx+tx<k else 0.
        cuda.syncthreads()
        
        for col_i in range(tw_h): 
            r_out, c_out = cbi.y*cbd.y + ty, cbi.x*cbd.y + col_i
            msIndex = ty*tw_w    + tx
            nsIndex = col_i*tw_w + tx
            if r_out < h and c_out < w:
                out[r_out, c_out] += ms[msIndex] * ns[nsIndex]
        cuda.syncthreads()

In [None]:
def matmul_2d_numba(m, n, tw_w=3, tw_h=2):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    out = torch.zeros(h, w, dtype=m.dtype).cuda()
    
    dtype_size = np.dtype("float32").itemsize
    dyn_shared_mem_size = 2 * tw_w * tw_h * dtype_size
    
    tpb = tw_w,tw_h
    blocks = cdiv(w,tpb[1]), cdiv(h,tpb[1])
    print(blocks)
    # matmul_k_numba[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw_w, tw_h) 
    matmul_k_numba[blocks, tpb, 0, dyn_shared_mem_size](m, n, out, tw_w, tw_h) # cudasim
    return out

In [None]:
m1s = torch.arange(12).view(3,4).float()
m2 = torch.arange(24).view(4,6).float()
m1s,m2

(tensor([[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]]),
 tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [12., 13., 14., 15., 16., 17.],
         [18., 19., 20., 21., 22., 23.]]))

In [None]:
torch.isclose(matmul_2d_numba(m1s.cuda(),m2.cuda(), tw_w=2, tw_h=3).cpu(), m1s@m2).all()

(2, 1)


tensor(True)

In [None]:
torch.isclose(matmul_2d_numba(m1s.cuda(),m2.cuda(), tw_w=3, tw_h=2).cpu(), m1s@m2).all()

(3, 2)


tensor(True)