# Python CPU

In [1]:
import torch

In [18]:
torch.set_printoptions(precision=2, sci_mode=False)

In [7]:
def matmul_py(A: torch.Tensor, B: torch.Tensor):
    assert len(A.shape) == 2 and len(B.shape) == 2

    ar, ac = A.shape
    br, bc = B.shape
    assert ac == br

    C = torch.zeros((ar, bc), device=A.device)
    for i in range(ar):
        for j in range(bc):
            for k in range(ac):
                C[i][j] += A[i][k] * B[k][j]
    return C

In [3]:
A = torch.tensor([[1, 0], [0, 1]], dtype=torch.float32)
B = torch.tensor([[1, 0], [0, 1]], dtype=torch.float32)

In [8]:
matmul_py(A, B)

tensor([[1., 0.],
        [0., 1.]])

In [9]:
assert torch.allclose(A @ B, matmul_py(A, B))

In [2]:
torch.manual_seed(42)
A = torch.randn((20, 300))
B = torch.randn((300, 5))

In [14]:
%timeit _ = matmul_py(A, B)

474 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
%timeit _ = A @ B

5.1 µs ± 67.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [18]:
A_cuda = A.contiguous().cuda()
B_cuda = B.contiguous().cuda()

In [21]:
%timeit _ = A_cuda @ B_cuda

11.7 µs ± 65.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


# Python Kernel

In [10]:
from types import SimpleNamespace as ns

In [11]:
import math


def matmul_bk_loops(func, num_blocks, threads_per_block, *args):
    for i0 in range(num_blocks.y):
        for i1 in range(num_blocks.x):
            for j0 in range(threads_per_block.y):
                for j1 in range(threads_per_block.x):
                    func(ns(x=i1, y=i0), ns(x=j1, y=j0), threads_per_block, *args)


def matmul_bk(block_idx, thread_idx, block_dim, m, n, output, h, w, k) -> None:
    r = block_idx.y * block_dim.y + thread_idx.y
    c = block_idx.x * block_dim.x + thread_idx.x
    if r >= h or c >= w:
        return
    o = 0.0
    for i in range(k):
        o += m[r * k + i] * n[i * w + c]
    output[r * w + c] = o


def matmul_2d(m, n):
    h, k = m.shape
    k2, w = n.shape
    assert k == k2

    output = torch.zeros((h, w), dtype=m.dtype)
    threads_per_block = ns(x=16, y=16)
    num_blocks = ns(x=math.ceil(w / threads_per_block.x), y=math.ceil(h / threads_per_block.y))
    matmul_bk_loops(matmul_bk, num_blocks, threads_per_block, m.flatten(), n.flatten(), output.flatten(), h, w, k)
    return output

In [8]:
torch.allclose(A @ B, matmul_2d(A, B), atol=1e-5)

True

# CUDA Kernel

In [2]:
import os
from pathlib import Path

cuda_source = Path("matrix_multiplication.cu").read_text()
cpp_source = "torch::Tensor matmul(torch::Tensor m, torch::Tensor n);"
# You may need to check the line below
os.environ["CUDA_HOME"] = "/public/apps/cuda/12.1"

Run `python -m compile_cuda_kernel` to build cuda kernel.

In [3]:
from torch.utils.cpp_extension import load_inline

module = load_inline(
    name="matmul",
    cpp_sources=cpp_source,
    cuda_sources=cuda_source,
    functions=["matmul"],
    with_cuda=True,
    extra_cuda_cflags=["-O2"],
    # build_directory='./cuda_build',
)

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [4]:
torch.manual_seed(42)
A = torch.randn((20, 300))
B = torch.randn((300, 5))

In [5]:
A.shape

torch.Size([20, 300])

In [6]:
A_cuda = A.contiguous().cuda()
B_cuda = B.contiguous().cuda()

In [7]:
torch.allclose(module.matmul(A_cuda, B_cuda).cpu(), A @ B, atol=1e-5)

True

# Benchmark

Our implementation is not far from PyTorch when dealing with large matrices.

### Small Matrices

In [27]:
torch.manual_seed(42)
A = torch.randn((20, 300))
B = torch.randn((300, 5))
print(A.shape[0] * A.shape[1] * B.shape[1])

A_cuda = A.contiguous().cuda()
B_cuda = B.contiguous().cuda()

30000


In [28]:
%%timeit
result = matmul_py(A, B)
result.shape

443 ms ± 5.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [29]:
%%timeit
result = module.matmul(A_cuda, B_cuda).cpu()
result.shape

39.5 µs ± 374 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [30]:
%%timeit
result = A @ B
result.shape

5.35 µs ± 85.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### Large Matrices

In [31]:
torch.manual_seed(42)
A = torch.randn((320, 5000))
B = torch.randn((5000, 640))
print(A.shape[0] * A.shape[1] * B.shape[1])

1024000000


In [32]:
A_cuda = A.contiguous().cuda()
B_cuda = B.contiguous().cuda()

In [33]:
%%timeit
result = module.matmul(A_cuda, B_cuda).cpu()
result.shape


1.44 ms ± 4.36 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [34]:
%%timeit
result = A @ B
result.shape

1.28 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
