In [5]:
import torch

In [6]:
from collections import namedtuple

dim3 = namedtuple("dim3", ["x", "y", "z"], defaults=(1, 1))

In [2]:
d = dim3(2, 3)
d

dim3(x=2, y=3, z=1)

In [5]:
d.x, d.y

(2, 3)

In [7]:
torch.manual_seed(42)
m1 = torch.randn((5120, 256))
m2 = torch.randn((256, 5120))

m1s = m1[:4]
m2s = m2[:, :4]

m1s.shape, m2s.shape

(torch.Size([4, 256]), torch.Size([256, 4]))

In [11]:
m1c = m1.contiguous().cuda()
m2c = m2.contiguous().cuda()

In [8]:
m1sc = m1s.contiguous().cuda()
m2sc = m2s.contiguous().cuda()

In [10]:
m1sc @ m2sc

tensor([[ 1.4065e+01,  6.8529e+00, -2.1297e+01,  9.0340e+00],
        [ 1.3069e+00, -1.9484e+01, -6.0172e+00, -5.2774e+00],
        [-5.4064e-03, -9.1612e+00,  9.6259e+00,  3.4818e+01],
        [ 2.0240e+01, -8.6772e+00,  3.6141e+01, -1.3200e+01]], device='cuda:0')

In [14]:
%timeit m1 @ m2

21.5 ms ± 554 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
%timeit (m1c @ m2c).cpu()

75.9 ms ± 559 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [19]:
%timeit m1c @ m2c

1.02 ms ± 13 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


: 

# Python Kernel

In [10]:
import math
import typing as tp


def matmul_2d_loop(
    func: tp.Callable,
    num_blocks: dim3,
    threads_per_block: dim3,
    *args,
):
    for i0 in range(num_blocks.x):
        for i1 in range(num_blocks.y):
            for j0 in range(threads_per_block.x):
                for j1 in range(threads_per_block.y):
                    func(dim3(i0, i1), dim3(j0, j1), threads_per_block, *args)


def matmul_2d_kernel(
    block_idx: dim3,
    thread_idx: dim3,
    block_dim: dim3,
    A: torch.Tensor,
    B: torch.Tensor,
    out: torch.Tensor,
    height: int,
    width: int,
    k: int,
) -> None:
    row = block_idx.y * block_dim.y + thread_idx.y
    col = block_idx.x * block_dim.x + thread_idx.x

    if row >= height or col >= width:
        return

    o = 0.0
    for i in range(k):
        o += A[row * k + i] * B[i * width + col]
    out[row * width + col] = o


def matmul_2d(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    h, k = A.shape
    k2, w = B.shape
    assert k == k2, "Size must match!"

    out = torch.zeros((h, w), dtype=A.dtype, device=A.device)
    threads_per_block = dim3(16, 16)
    num_blocks = dim3(math.ceil(w / threads_per_block.x), math.ceil(h / threads_per_block.y))
    matmul_2d_loop(matmul_2d_kernel, num_blocks, threads_per_block, A.flatten(), B.flatten(), out.flatten(), h, w, k)

    return out

In [14]:
torch.allclose(matmul_2d(m1s, m2s), m1s @ m2s, atol=1e-5)

True

# Python Tiled Kernel with Loops

In [12]:
import math
import typing as tp


def matmul_2d_tiled_loop(
    func: tp.Callable,
    num_blocks: dim3,
    threads_per_block: dim3,
    shared_memory_size: int,
    *args,
    **kwargs,
):
    for i0 in range(num_blocks.x):
        for i1 in range(num_blocks.y):
            shared_memory = torch.zeros(shared_memory_size)
            func(dim3(i0, i1), threads_per_block, shared_memory, *args, **kwargs)


def matmul_2d_tiled_kernel(
    block_idx: dim3,
    block_dim: dim3,
    shared_memory: torch.Tensor,
    A: torch.Tensor,
    B: torch.Tensor,
    out: torch.Tensor,
    height: int,
    width: int,
    k: int,
    tile_width: int,
) -> None:
    shared_memory_size = tile_width * tile_width
    A_shared_memory, B_shared_memory = shared_memory[:shared_memory_size], shared_memory[shared_memory_size:]

    for ph in range(int(math.ceil(k / tile_width))):

        idx = ph * tile_width

        # put data from corresponding parts of the two matrics into shared memory
        for tile_row in range(block_dim.y):
            for tile_col in range(block_dim.x):
                row = block_idx.y * block_dim.y + tile_row
                col = block_idx.x * block_dim.x + tile_col

                A_shared_memory[tile_row * tile_width + tile_col] = (
                    A[row * k + idx + tile_col] if row < height and idx + tile_col < k else 0.0
                )
                B_shared_memory[tile_row * tile_width + tile_col] = (
                    B[(idx + tile_row) * width + col] if idx + tile_row < k and col < width else 0.0
                )

        # compute matmul for the data in tiles.
        for tile_row in range(block_dim.y):
            for tile_col in range(block_dim.x):
                row = block_idx.y * block_dim.y + tile_row
                col = block_idx.x * block_dim.x + tile_col

                for i in range(tile_width):
                    if row * width + col < len(out):
                        out[row * width + col] += (
                            A_shared_memory[tile_row * tile_width + i] * B_shared_memory[i * tile_width + tile_col]
                        )


def matmul_2d_tiled(A: torch.Tensor, B: torch.Tensor, tile_width: int) -> torch.Tensor:
    h, k = A.shape
    k2, w = B.shape
    assert k == k2, "Size must match!"

    out = torch.zeros((h, w), dtype=A.dtype, device=A.device)
    threads_per_block = dim3(tile_width, tile_width)
    num_blocks = dim3(math.ceil(w / threads_per_block.x), math.ceil(h / threads_per_block.y))
    matmul_2d_tiled_loop(
        matmul_2d_tiled_kernel,
        num_blocks,
        threads_per_block,
        tile_width * tile_width * 2,  # tile_width ^2 for both matrices, thus * 2.
        A.flatten(),
        B.flatten(),
        out.flatten(),
        h,
        w,
        k,
        tile_width,
    )

    return out

In [17]:
torch.allclose(matmul_2d_tiled(m1s, m2s, tile_width=16), m1s @ m2s, atol=1e-5)

True

# Python Tiled Kernel with Threads

In [6]:
import threading
from threading import Thread, Barrier
from concurrent.futures import ThreadPoolExecutor

In [1]:
def func(x):
    print(x)
    print(-x)
    print(x * 10)

In [5]:
num = 3
with ThreadPoolExecutor(num) as ex:
    list(ex.map(lambda i: func(i), range(1, num + 1)))

1
-1
10
2
-2
20
3
-3
30


In [6]:
def func_b(x, b):
    print(x)
    b.wait()
    print(-x)
    b.wait()
    print(x * 10)

In [8]:
num = 3
b = Barrier(num)
with ThreadPoolExecutor(num) as ex:
    list(ex.map(lambda i: func_b(i, b), range(1, num + 1)))

1
2
3
-3
-1
-2
30
20
10


In [5]:
import math
import typing as tp
from threading import Thread, Barrier


def matmul_2d_tiled_with_threads_loop(
    func: tp.Callable,
    num_blocks: dim3,
    threads_per_block: dim3,
    shared_memory_size: int,
    *args,
    **kwargs,
):
    for i0 in range(num_blocks.x):
        for i1 in range(num_blocks.y):
            shared_memory = torch.zeros(shared_memory_size)
            syncb = Barrier(threads_per_block.y * threads_per_block.x)
            threads = [
                Thread(
                    target=func,
                    args=(dim3(i0, i1), dim3(p, o), threads_per_block, shared_memory, syncb, *args),
                    kwargs=kwargs,
                )
                for o in range(threads_per_block.y)
                for p in range(threads_per_block.x)
            ]
            for thread in threads:
                thread.start()

            for thread in threads:
                thread.join()


def matmul_2d_tiled_with_threads_kernel(
    block_idx: dim3,
    thread_idx: dim3,
    block_dim: dim3,
    shared_memory: torch.Tensor,
    syncb: Barrier,
    A: torch.Tensor,
    B: torch.Tensor,
    out: torch.Tensor,
    height: int,
    width: int,
    k: int,
    tile_width: int,
) -> None:
    shared_memory_size = tile_width * tile_width
    A_shared_memory, B_shared_memory = shared_memory[:shared_memory_size], shared_memory[shared_memory_size:]

    tile_row = thread_idx.y
    tile_col = thread_idx.x
    row = block_idx.y * block_dim.y + tile_row
    col = block_idx.x * block_dim.x + tile_col

    p = 0.0
    for ph in range(int(math.ceil(k / tile_width))):

        idx = ph * tile_width

        # put data from corresponding parts of the two matrics into shared memory
        A_shared_memory[tile_row * tile_width + tile_col] = (
            A[row * k + idx + tile_col] if row < height and idx + tile_col < k else 0.0
        )
        B_shared_memory[tile_row * tile_width + tile_col] = (
            B[(idx + tile_row) * width + col] if idx + tile_row < k and col < width else 0.0
        )
        syncb.wait()

        # compute matmul for the data in tiles.
        for i in range(tile_width):
            p += A_shared_memory[tile_row * tile_width + i] * B_shared_memory[i * tile_width + tile_col]
        syncb.wait()

    # if row * width + col < len(out):
    if row < height and col < width:
        out[row * width + col] = p


def matmul_2d_tiled_with_threads(A: torch.Tensor, B: torch.Tensor, tile_width: int) -> torch.Tensor:
    h, k = A.shape
    k2, w = B.shape
    assert k == k2, "Size must match!"

    out = torch.zeros((h, w), dtype=A.dtype, device=A.device)
    threads_per_block = dim3(tile_width, tile_width)
    num_blocks = dim3(math.ceil(w / threads_per_block.x), math.ceil(h / threads_per_block.y))
    matmul_2d_tiled_with_threads_loop(
        matmul_2d_tiled_with_threads_kernel,
        num_blocks,
        threads_per_block,
        tile_width * tile_width * 2,  # tile_width ^2 for both matrices, thus * 2.
        A.flatten(),
        B.flatten(),
        out.flatten(),
        h,
        w,
        k,
        tile_width,
    )

    return out

In [7]:
torch.allclose(matmul_2d_tiled_with_threads(m1s, m2s, tile_width=16), m1s @ m2s, atol=1e-5)

True

# CUDA Kernel Dynamic

In [15]:
import os
from pathlib import Path

cuda_source = Path("matmul_tiled_dynamic.cu").read_text()
cpp_source = "torch::Tensor matmul_tiled(torch::Tensor A, torch::Tensor B);"
# You may need to check the line below
os.environ["CUDA_HOME"] = "/public/apps/cuda/12.1"

In [16]:
from torch.utils.cpp_extension import load_inline

matmul_dynamic_module = load_inline(
    name="matmul_tiled_dynamic",
    cpp_sources=cpp_source,
    cuda_sources=cuda_source,
    functions=["matmul_tiled"],
    with_cuda=True,
    extra_cuda_cflags=["-O2"],
    # build_directory='./cuda_build',
)

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [17]:
dir(matmul_dynamic_module)

['__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 'matmul_tiled']

In [11]:
matmul_dynamic_module.matmul_tiled(m1sc, m2sc)

tensor([[ 1.4065e+01,  6.8529e+00, -2.1297e+01,  9.0340e+00],
        [ 1.3069e+00, -1.9484e+01, -6.0172e+00, -5.2774e+00],
        [-5.4064e-03, -9.1612e+00,  9.6259e+00,  3.4818e+01],
        [ 2.0240e+01, -8.6772e+00,  3.6141e+01, -1.3200e+01]], device='cuda:0')

In [12]:
torch.allclose(matmul_dynamic_module.matmul_tiled(m1sc, m2sc).cpu(), m1s @ m2s, atol=1e-5)

True

In [18]:
%timeit matmul_dynamic_module.matmul_tiled(m1c, m2c).cpu()

80.3 ms ± 508 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


# CUDA Kernel Static

In [1]:
import os
from pathlib import Path

cuda_source = Path("matmul_tiled_static.cu").read_text()
cpp_source = "torch::Tensor matmul_tiled(torch::Tensor A, torch::Tensor B);"
# You may need to check the line below
os.environ["CUDA_HOME"] = "/public/apps/cuda/12.1"

In [3]:
from torch.utils.cpp_extension import load_inline

matmul_static_module = load_inline(
    name="matmul_tiled_static",
    cpp_sources=cpp_source,
    cuda_sources=cuda_source,
    functions=["matmul_tiled"],
    with_cuda=True,
    extra_cuda_cflags=["-O2"],
    # build_directory='./cuda_build',
)

In [4]:
dir(matmul_static_module)

['__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 'matmul_tiled']

In [10]:
torch.allclose(matmul_static_module.matmul_tiled(m1sc, m2sc).cpu(), m1s @ m2s, atol=1e-5)

True

In [12]:
%timeit matmul_static_module.matmul_tiled(m1c, m2c).cpu()

78.1 ms ± 241 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
