# Setup

In [1]:
import os
import math
import sys
import torch
import re
import numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple

In [2]:
dim3 = namedtuple("dim3", ["x", "y", "z"], defaults=(1, 1))

In [3]:
d = dim3(2, 3)
d

dim3(x=2, y=3, z=1)

In [4]:
d.x, d.y

(2, 3)

In [5]:
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

In [6]:
sys.path.insert(0, "..")
from utils import show_img, load_cuda, cuda_begin, cdiv

In [7]:
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.manual_seed(42);

In [8]:
m1 = torch.rand(5120, 256)
m1s = m1[:4]
m2 = torch.rand(256, 5120)
m2s = m2[:, :4]

# Reminder

### 2d Python kernel

In [9]:
def blk_kernel2d(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x):
                    f(dim3(i1, i0), dim3(j1, j0), threads, *args)

In [10]:
def matmul_bk(blockIdx, threadIdx, blockDim, m, n, out, h, w, k):
    r = blockIdx.y * blockDim.y + threadIdx.y
    c = blockIdx.x * blockDim.x + threadIdx.x

    if r >= h or c >= w:
        return
    o = 0.0
    for i in range(k):
        o += m[r * k + i] * n[i * w + c]
    out[r * w + c] = o

In [11]:
def matmul_2d(m, n):
    h, k = m.shape
    k2, w = n.shape
    assert k == k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(16, 16)
    blocks = dim3(cdiv(w, tpb.x), cdiv(h, tpb.y))
    blk_kernel2d(
        matmul_bk, blocks, tpb, m.flatten(), n.flatten(), output.flatten(), h, w, k
    )
    return output

In [12]:
torch.isclose(matmul_2d(m1s, m2s), m1s @ m2s).all()

tensor(True)

### CUDA

In [13]:
cuda_src = (
    cuda_begin
    + r"""
__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y * blockDim.y + threadIdx.y;
    int c = blockIdx.x * blockDim.x + threadIdx.x;

    if (r >= h || c >= w) return;
    float o = 0;
    for (int i = 0; i<k; i++) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());
    
    dim3 tpb(16, 16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks, tpb>>>(
       m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k
    );
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}

"""
)

In [14]:
fname = "matmul"

In [15]:
def get_sig(fname, src):
    res = re.findall(rf"^(.+\s+{fname}\(.*?\))\s*{{?\s*$", src, re.MULTILINE)
    return res[0] + ";" if res else None

In [16]:
cpp_src = get_sig(fname, cuda_src)
cpp_src

'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'

In [17]:
module = load_cuda(cuda_src, cpp_src, [fname])

In [18]:
m1c, m2c = m1.contiguous().cuda(), m2.contiguous().cuda()

In [19]:
module.matmul(m1c, m2c).shape

torch.Size([5120, 5120])

In [20]:
torch.isclose(module.matmul(m1c, m2c), m1c @ m2c).all()

tensor(False, device='cuda:0')

In [21]:
%%timeit -n 10
module.matmul(m1c, m2c)
torch.cuda.synchronize()

6.25 ms ± 61.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


When I removed the call to the kernel itself, it took around 50μs (0.05 ms) to run, so that's the overhead of the call on my machine.