In [48]:
import torch
from torch import tensor
import torchvision as tv
import torchvision.transforms.functional as tvf
from torchvision import io
from torch.utils.cpp_extension import load_inline

T = torch.Tensor

In [49]:
# Creating dummy data

def create_data(*, seed: int = 1234, n_dim: int = 1024):
    torch.manual_seed(seed)

    A = torch.randn(n_dim, n_dim)
    B = torch.randn(n_dim, n_dim)
    C = torch.randn(n_dim, n_dim)

    alpha = 3
    beta = 1.5

    return A, B, C, alpha, beta

In [50]:
import os
os.environ['CUDA_LAUNCH_BLOCKING']='1'


# wurlitzer
# Capture C-level stdout/stderr pipes in Python
# More here: https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/

# ninja for build
%pip install -q wurlitzer ninja

%load_ext wurlitzer

The wurlitzer extension is already loaded. To reload it, use:
  %reload_ext wurlitzer


In [51]:
def load_cuda(cuda_src, cpp_src, funcs, opt=False, verbose=False):
    if isinstance(cuda_src, str):
        cuda_src = [cuda_src]
    if isinstance(cpp_src, str):
        cpp_src = [cpp_src]

    return load_inline(
        cuda_sources=cuda_src,
        cpp_sources=cpp_src,
        functions=funcs,
        extra_cuda_cflags=["-O2"] if opt else [],
        verbose=verbose,
        name="inline_ext",
    )

In [52]:
# Utility functions defined by Jeremy

cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor");
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous");
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x);

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}

#define get_item(A, n, i, j) ((A)[ (i) * (n) + (j) ])
#define cdiv(a, b) ((a + b - 1) / (b))  // Implementing ceiling division
'''

In [53]:
cuda_src_naive = r'''

__global__ void sgemm_kernel(float* matrix_a, float* matrix_b, float* matrix_c, float alpha, float beta, int dim) {
    int row = blockIdx.y*blockDim.y + threadIdx.y;
    int col = blockIdx.x*blockDim.x + threadIdx.x;

    if (row >= dim || col >= dim) return;
    float tmp = 0;
    for (int i = 0; i<dim; ++i) {
        tmp += (matrix_a[row*dim + i] * matrix_b[dim*i + col]);
    }
    matrix_c[row*dim + col] = alpha * tmp + beta * matrix_c[row*dim + col];

}


torch::Tensor sgemm(torch::Tensor matrix_a, torch::Tensor matrix_b, torch::Tensor matrix_c, float alpha, float beta) {
    CHECK_INPUT(matrix_a); CHECK_INPUT(matrix_b); CHECK_INPUT(matrix_c);
    int dim = matrix_a.size(0);

    dim3 tpb(16,16);
    dim3 blocks(cdiv(dim, tpb.x), cdiv(dim, tpb.y));

    sgemm_kernel<<<blocks, tpb>>>(
        matrix_a.data_ptr<float>(),
        matrix_b.data_ptr<float>(),
        matrix_c.data_ptr<float>(),
        alpha,
        beta,
        dim
        );
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return matrix_c;
}
'''

cpp_src_naive = r'''
    torch::Tensor sgemm(torch::Tensor matrix_a, torch::Tensor matrix_b, torch::Tensor matrix_c, float alpha, float beta);
'''


In [54]:
cuda_src_naive_reverse = r'''

__global__ void sgemm_kernel_reverse(float* matrix_a, float* matrix_b, float* matrix_c, float alpha, float beta, int dim) {
    int col = blockIdx.y*blockDim.y + threadIdx.y;
    int row = blockIdx.x*blockDim.x + threadIdx.x;

    if (row >= dim || col >= dim) return;
    float tmp = 0;
    for (int i = 0; i<dim; ++i) {
        tmp += (matrix_a[row*dim + i] * matrix_b[dim*i + col]);
    }
    matrix_c[row*dim + col] = alpha * tmp + beta * matrix_c[row*dim + col];
}


torch::Tensor sgemm_reverse(torch::Tensor matrix_a, torch::Tensor matrix_b, torch::Tensor matrix_c, float alpha, float beta) {
    CHECK_INPUT(matrix_a); CHECK_INPUT(matrix_b); CHECK_INPUT(matrix_c);
    int dim = matrix_a.size(0);

    dim3 tpb(16,16);
    dim3 blocks(cdiv(dim, tpb.x), cdiv(dim, tpb.y));

    sgemm_kernel_reverse<<<blocks, tpb>>>(
        matrix_a.data_ptr<float>(),
        matrix_b.data_ptr<float>(),
        matrix_c.data_ptr<float>(),
        alpha,
        beta,
        dim
        );
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return matrix_c;
}
'''

cpp_src_naive_reverse = r'''
    torch::Tensor sgemm_reverse(torch::Tensor matrix_a, torch::Tensor matrix_b, torch::Tensor matrix_c, float alpha, float beta);
'''


In [55]:
cuda_src_with_shared_memory = r'''

const int blockdim = 32;

__global__ void sgemm_kernel_with_shared_memory(float* matrix_a, float* matrix_b, float* matrix_c, float alpha, float beta, int dim) {

    __shared__ float As[blockdim][blockdim];
    __shared__ float Bs[blockdim][blockdim];

    int tx = threadIdx.x;
    int ty = threadIdx.y;

    int row = blockIdx.y * blockdim + ty;
    int col = blockIdx.x * blockdim + tx;

    float tmp = 0.0f;

    int n_tiles = cdiv(dim, blockdim);

    for (int tile = 0; tile < n_tiles; ++tile) {
        int a_col = tile * blockdim + tx;
        int b_row = tile * blockdim + ty;

        // Loading data on shared memory
        As[ty][tx] = (row < dim && a_col < dim) ? get_item(matrix_a, dim, row, a_col) : 0.0f;
        Bs[ty][tx] = (col < dim && b_row < dim) ? get_item(matrix_b, dim, b_row, col) : 0.0f;
        __syncthreads();

        // Perform computation using shared memory
        for (int i = 0; i < blockdim; ++i) {
            tmp += As[ty][i] * Bs[i][tx];
        }
        __syncthreads();
    }

    // Store result back in matrix_c
    if (row < dim && col < dim) {
        matrix_c[row * dim + col] = alpha * tmp + beta * matrix_c[row * dim + col];
    }
}

torch::Tensor sgemm_with_shared_memory(torch::Tensor matrix_a, torch::Tensor matrix_b, torch::Tensor matrix_c, float alpha, float beta) {
    CHECK_INPUT(matrix_a); CHECK_INPUT(matrix_b); CHECK_INPUT(matrix_c);

    int dim = matrix_a.size(0);

    dim3 tpb(blockdim, blockdim);
    dim3 blocks(cdiv(dim, blockdim), cdiv(dim, blockdim));

    sgemm_kernel_with_shared_memory<<<blocks, tpb>>>(
        matrix_a.data_ptr<float>(),
        matrix_b.data_ptr<float>(),
        matrix_c.data_ptr<float>(),
        alpha,
        beta,
        dim
    );
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return matrix_c;
}
'''

cpp_src_with_shared_memory = r'''
    torch::Tensor sgemm_with_shared_memory(torch::Tensor matrix_a, torch::Tensor matrix_b, torch::Tensor matrix_c, float alpha, float beta);
'''


In [None]:
cuda_src = cuda_begin + cuda_src_naive + cuda_src_naive_reverse + cuda_src_with_shared_memory
cpp_src = cpp_src_naive + cpp_src_naive_reverse + cpp_src_with_shared_memory
module = load_cuda(cuda_src, cpp_src, ['sgemm', 'sgemm_reverse', 'sgemm_with_shared_memory'], verbose=True)

Using /root/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
The input conditions for extension module inline_ext have changed. Bumping to version 4 and re-building as inline_ext_v4...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py311_cu124/inline_ext/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module inline_ext_v4...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


In [None]:
A, B, C, alpha, beta = create_data()
a_gpu = A.contiguous().cuda()
b_gpu = B.contiguous().cuda()
c_gpu = C.contiguous().cuda()
pt_out = alpha * torch.mm(A, B) + beta * C

In [None]:
%timeit -n 100 out = module.sgemm_with_shared_memory(a_gpu, b_gpu, c_gpu, alpha, beta)

In [None]:
A, B, C, alpha, beta = create_data()
a_gpu = A.contiguous().cuda()
b_gpu = B.contiguous().cuda()
c_gpu = C.contiguous().cuda()
pt_out = alpha * torch.mm(A, B) + beta * C

In [None]:
%timeit -n 100 out = module.sgemm(a_gpu, b_gpu, c_gpu, alpha, beta)

In [None]:
A, B, C, alpha, beta = create_data()
a_gpu = A.contiguous().cuda()
b_gpu = B.contiguous().cuda()
c_gpu = C.contiguous().cuda()
pt_out = alpha * torch.mm(A, B) + beta * C

In [None]:
%timeit -n 100 out = module.sgemm_reverse(a_gpu, b_gpu, c_gpu, alpha, beta)