<a href="https://colab.research.google.com/github/anshulsawant/llm-systems/blob/main/cuda.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Getting started with CUDA

## Setup

In [143]:
import torch, os, math
import torchvision as tv
import torchvision.transforms.functional as tvf
from torchvision import io
import matplotlib.pyplot as plt
from torch.utils.cpp_extension import load_inline

In [144]:
!pip install wurlitzer ninja



### Python Block Kernel

1. **Streaming Multiprocessors (SMs):** In NVIDIA GPUs, SMs are the fundamental units of execution. Each SM can execute multiple threads concurrently.
2. **Thread Blocks:** A thread block is a group of threads that can cooperate among themselves through shared memory and synchronization. All threads in a block are executed on the same SM. This means they can share resources such as shared memory and can synchronize their execution with each other.
3. **Shared Memory:** Shared memory is a small memory space on the GPU that is shared among the threads in a block. It is much faster than global memory (the main GPU memory), but it is also limited in size. Threads in the same block can use shared memory to share data with each other efficiently.

- The RTX 3090, based on the Ampere architecture, has 82 SMs.
- Each SM in GA10x GPUs contain 128 CUDA Cores, four third-generation Tensor Cores, a 256 KB Register File, and 128 KB of L1/Shared Memory
- In CUDA, all threads in a block have the potential to run concurrently. However, the actual concurrency depends on the number of CUDA cores per SM and the resources required by the threads.

### CUDA Setup

In [145]:
## This is slow but good for dev.
os.environ['CUDA_LAUNCH_BLOCKING']='1'

In [146]:
%load_ext wurlitzer

The wurlitzer extension is already loaded. To reload it, use:
  %reload_ext wurlitzer


In [147]:
def load_cuda(cuda_src, cpp_src, funcs, opt=False, verbose=False):
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
                       extra_cuda_cflags=["-O2"] if opt else [], verbose=verbose, name="inline_ext")

In [148]:
cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
'''

<img src="attachment:4590626e-3f24-4381-a14b-50162f737579.png" width="500">

## Matmul

In [149]:
from torch import tensor

### Python matmul

In [150]:
m1 = torch.randn(5, 784)
m2 = torch.randn(784, 10)
m1.shape,m2.shape

(torch.Size([5, 784]), torch.Size([784, 10]))

In [151]:
import numpy as np
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

### 2d Python kernel

In [152]:
from types import SimpleNamespace as ns

In [153]:
def blk_kernel2d(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): f(ns(x=i1,y=i0), ns(x=j1,y=j0), threads, *args)

In [154]:
def matmul_bk(blockidx, threadidx, blockdim, m, n, out, h, w, k):
    r = blockidx.y*blockdim.y + threadidx.y
    c = blockidx.x*blockdim.x + threadidx.x

    if (r>=h or c>=w): return
    o = 0.
    for i in range(k): o += m[r*k+i] * n[i*w+c]
    out[r*w+c] = o

In [155]:
def matmul_2d(m, n):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = ns(x=16,y=16)
    blocks = ns(x=math.ceil(w/tpb.x), y=math.ceil(h/tpb.y))
    blk_kernel2d(matmul_bk, blocks, tpb,
                 m.flatten(), n.flatten(), output.flatten(), h, w, k)
    return output

In [156]:
res = matmul_2d(m1, m2)
torch.isclose(t1, res).all()

tensor(False)

### Broadcasting

In [157]:
def matmul(a,b):
    (ar,ac),(br,bc) = a.shape,b.shape
    c = torch.zeros(ar, bc)
    for i in range(ar): c[i] = (a[i,:,None] * b).sum(dim=0)
    return c

In [158]:
torch.isclose(t1,matmul(m1, m2)).all()

tensor(False)

In [159]:
%time _=matmul(m1, m2)

CPU times: user 1.82 ms, sys: 0 ns, total: 1.82 ms
Wall time: 1.57 ms


In [160]:
m1 = x_train
tr = matmul(m1, m2)
tr.shape

torch.Size([50000, 10])

In [161]:
%time _=matmul(m1, m2)

CPU times: user 1.26 s, sys: 5.21 ms, total: 1.27 s
Wall time: 1.27 s


In [162]:
ar,ac = m1.shape
br,bc = m2.shape
ar*bc*ac

392000000

In [176]:
def index_to_position(index, strides, num_dims):
    '''
     Converts a multidimensional tensor index into a single-dimensional position in storage
     based on strides.
     Args:
        index: index tuple of ints
        strides: tensor strides
        num_dims: number of dimensions in the tensor, e.g. shape/strides of [2, 3, 4] has 3 dimensions

     Returns:
        int - position in storage
    '''
    position = 0;
    for i in range(num_dims):
        position += index[i] * strides[i];
    return position;

def to_index(ordinal, shape, out_index, num_dims):
    '''
     Convert an ordinal to an index in the shape. Should ensure that enumerating position 0 ... size of
     a tensor produces every index exactly once. It may not be the inverse of index_to_position.
     Args:
        ordinal: ordinal position to convert
        shape: tensor shape
        out_index: return index corresponding to position
        num_dims: number of dimensions in the tensor

     Returns:
        None (Fills in out_index)
    '''
    cur_ord = ordinal;
    for i in reversed(range(num_dims)):
        sh = shape[i];
        out_index[i] = cur_ord % sh;
        cur_ord /= sh;

def broadcast_index(big_index, big_shape, shape, out_index, num_dims_big, num_dims):
    '''
     Convert a big_index into big_shape to a smaller out_index into shape following broadcasting rules.
     In this case it may be larger or with more dimensions than the shape given.
     Additional dimensions may need to be mapped to 0 or removed.

     Args:
        big_index: multidimensional index of bigger tensor
        big_shape: tensor shape of bigger tensor
        nums_big_dims: number of dimensions in bigger tensor
        out_index: multidimensional index of smaller tensor
        shape: tensor shape of smaller tensor
        num_dims: number of dimensions in smaller tensor

     Returns:
        None (Fills in out_index)
    '''
    for i in range(num_dims):
        if shape[i] > 1:
            out_index[i] = big_index[i + (num_dims_big - num_dims)]
        else:
            out_index[i] = 0


### CUDA matmul

In [163]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;

    if (r>=h || c>=w) return;
    float o = 0;
    for (int i = 0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [164]:
cpp_src = "torch::Tensor matmul(torch::Tensor m, torch::Tensor n);"

In [165]:
module = load_cuda(cuda_src, cpp_src, ['matmul'])

In [166]:
m1c,m2c = m1.contiguous().cuda(), m2.contiguous().cuda()

In [167]:
torch.isclose(tr,module.matmul(m1c, m2c).cpu(), atol=1e-5).all()

tensor(False)

In [168]:
%%time
res=module.matmul(m1c, m2c).cpu()
res.shape

CPU times: user 6.29 ms, sys: 26 µs, total: 6.32 ms
Wall time: 5.74 ms


torch.Size([50000, 10])

### Pytorch

In [169]:
torch.isclose(tr,(m1c@m2c).cpu(), atol=1e-5).all()

tensor(False)

In [170]:
%timeit -n 10 _=(m1c@m2c).cpu()

2.09 ms ± 69.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
