In [1]:
import torch
import os
from torch.utils.cpp_extension import load_inline

In [2]:
os.environ['CUDA_LAUNCH_BLOCKING']='1'

In [3]:
!pip install wurlitzer ninja

Collecting wurlitzer
  Downloading wurlitzer-3.1.1-py3-none-any.whl.metadata (2.5 kB)
Downloading wurlitzer-3.1.1-py3-none-any.whl (8.6 kB)
Installing collected packages: wurlitzer
Successfully installed wurlitzer-3.1.1


In [4]:
matmul_cuda_src = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

inline unsigned int cdiv(unsigned int a, unsigned int b) {
    return (a + b - 1) / b;
}

__global__ void matmul_kernel(float* a, float* b, int n, int p, float* out, int elements) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    int cumulative {0};
    if (i < elements) {
        for (int j = 0; j < n; ++j) {
            cumulative += a[(i / p) * n + j] * b[j*p + i%p];
        }
    }
    out[i] = cumulative;
}

torch::Tensor matmul(torch::Tensor a, torch::Tensor b) {
    CHECK_INPUT(a);
    CHECK_INPUT(b);
    int m = a.size(0);
    int n = a.size(1);
    int p = b.size(1);

    torch::Tensor out = torch::empty({m, p}, a.options());
    int threads {256};
    matmul_kernel<<<cdiv(m*p, threads), threads>>>(
        a.data_ptr<float>(), b.data_ptr<float>(), n, p, 
        out.data_ptr<float>(), m*p);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return out;
}
'''

In [5]:
matmul_cpp_src = "torch::Tensor matmul(torch::Tensor a, torch::Tensor b);"

In [6]:
# Compile the CUDA code
matmul_module = load_inline(cuda_sources=[matmul_cuda_src], cpp_sources=[matmul_cpp_src], 
                            functions=['matmul'], name="inline_ext")

In [7]:
a = torch.randn((50000, 100)).contiguous().cuda()
b = torch.rand((100, 1200)).contiguous().cuda()

## My Implementation

In [8]:
%%timeit 
_ = matmul_module.matmul(a, b).cpu()

231 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## PyTorch's Implementation
- Utilises L1 caching for better memory efficiency

In [9]:
%%timeit
_ = (a @ b).cpu()

194 ms ± 1.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
