In [1]:
import os,math,sys,torch,re,numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple
import torch

In [2]:
# https://docs.python.org/3/library/collections.html#collections.namedtuple
dim3 = namedtuple('dim3', ['x','y','z'], defaults=(1, 1))

In [3]:
dim3(x=3, y=2)

dim3(x=3, y=2, z=1)

In [4]:
from torch.utils.cpp_extension import load_inline

cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CUDA_ERR(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
   if (code != cudaSuccess) 
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      if (abort) exit(code);
   }
}
__host__ __device__ inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a+b-1)/b;}
'''

def load_cuda(cuda_src, cpp_src, funcs, opt=True, verbose=False, name=None):
    "Simple wrapper for torch.utils.cpp_extension.load_inline"
    if name is None: name = funcs[0]
    flags = "-O3 -Xptxas -O3 -Xcompiler -O3" if opt else "-O0 -Xptxas -O0 -Xcompiler -O0"
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
                       extra_cuda_cflags=[flags], verbose=verbose, name=name)

def cdiv(a,b):
    "Int ceiling division of `a` over `b`"
    return (a+b-1)//b

torch.manual_seed(42)

<torch._C.Generator at 0x7fdc58298a70>

In [5]:
m1 = torch.rand(5120, 256)
m2 = torch.rand(256,5120)

m1s = m1[:4]
m2s = m2[:,:4]

In [30]:
m1s.flatten()[:20]

tensor([0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936, 0.9408,
        0.1332, 0.9346, 0.5936, 0.8694, 0.5677, 0.7411, 0.4294, 0.8854, 0.5739,
        0.2666, 0.6274])

In [6]:
def blk_kernel2d(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): f(dim3(i1,i0), dim3(j1,j0), threads, *args)

In [7]:
def matmul_bk(blockIdx, threadIdx, blockDim, m, n, out, h, w, k):
    r = blockIdx.y*blockDim.y + threadIdx.y
    c = blockIdx.x*blockDim.x + threadIdx.x
    
    if (r>=h or c>=w): return
    o = 0.
    for i in range(k):
        o += m[r*k+i] * n[i*w+c]
    out[r*w+c] = o

In [8]:
def matmul_2d(m, n):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(16,16)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d(matmul_bk, blocks, tpb,
                 m.flatten(), n.flatten(), output.flatten(), h, w, k)
    return output

In [9]:
torch.isclose(matmul_2d(m1s, m2s), m1s@m2s).all()

tensor(True)

# CUDA

In [10]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;

    if (r>=h || c>=w) return;
    float o = 0;
    for (int i = 0; i< k; ++i) 
        o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m);
    CHECK_INPUT(n);

    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);

    TORCH_CHECK(k==n.size(0), "Size mismatch");

    auto output = torch::zeros({h, w}, m.options());

    dim3 tpb(16, 16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));

    matmul_k<<<blocks, tpb >>>( m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k );

    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

fname = 'matmul'

def get_sig(fname, src):
    res = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None

In [11]:
cpp_src = get_sig(fname, cuda_src)
cpp_src

'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'

In [12]:
module = load_cuda(cuda_src, cpp_src, [fname])

In [13]:
m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()

In [14]:
module.matmul(m1c,m2c).shape

torch.Size([5120, 5120])

In [15]:
torch.isclose(module.matmul(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [19]:
%%timeit -n 10
module.matmul(m1c,m2c)
torch.cuda.synchronize()

5.97 ms ± 54.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


# Shared Memory

In [26]:
m1s

tensor([[0.8823, 0.9150, 0.3829,  ..., 0.2576, 0.3470, 0.0240],
        [0.7797, 0.1519, 0.7513,  ..., 0.4078, 0.5411, 0.0410],
        [0.6556, 0.1186, 0.1836,  ..., 0.7819, 0.6328, 0.0317],
        [0.1782, 0.9942, 0.6911,  ..., 0.3092, 0.0702, 0.1836]])

In [20]:
def blk_kernel2d_shar(f, blocks, threads, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shared = torch.zeros(sh_sz)
            f(dim3(i1,i0), threads, shared, *args, **kwargs)

In [21]:
def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]
    for ph in range( cdiv(k, tw) ):
        idx = ph * tw
        for tr in range(blockDim.y):
            for tc in range(blockDim.x):
                r, c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
                ms[tr*tw+tc] = m[ tc+idx + r*k] if r<h and idx+tc<k else 0.
                ns[tr*tw+tc] = n[(tr+idx)*w +c] if c<w and idx+tr<k else 0.

        for tr in range(blockDim.y):
            for tc in range(blockDim.x):
                r,c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
                for i in range(tw):
                    if r*w+c<len(out): out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]

In [22]:
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [23]:
m1s.shape, m2.shape

(torch.Size([4, 256]), torch.Size([256, 5120]))

In [24]:
torch.isclose(matmul_2d(m1s, m2s, tw=16), m1s@m2s).all()

tensor(True)

# Threading

In [32]:
import threading
from threading import Barrier, Thread
from concurrent.futures import ThreadPoolExecutor

In [33]:
def g(x, sb):
    print(x)
    sb.wait()
    print(-x)
    sb.wait()
    print(x*10)

num = 3
sb = Barrier(num)
with ThreadPoolExecutor(num) as ex: list(ex.map(lambda i: g(i,sb), range(1,num+1)))

1
2
3
-3
-2
-1
20
10
30


In [35]:
def blk_kernel2d_shar(f, blocks, tpb, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shar = torch.zeros(sh_sz)
            syncb = Barrier(tpb.y*tpb.x)
            threads = [Thread(target=f, args=(dim3(i1,i0), dim3(p,o), tpb, shar, syncb, *args), kwargs=kwargs)
                       for o in range(tpb.y) for p in range(tpb.x)]
            for tr in threads: tr.start()
            for tr in threads: tr.join()

def matmul_tiled_bk(blockIdx, threadIdx, blockDim, shared, syncb, m, n, out, h, w, k, tw):
    tc,tr = threadIdx.x,threadIdx.y
    r = blockIdx.y*blockDim.y + tr
    c = blockIdx.x*blockDim.x + tc

    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    p = 0.
    for ph in range(cdiv(k,tw)):
        ms[tr*tw+tc] = m[ tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.
        ns[tr*tw+tc] = n[(tr + ph*tw)*w +c] if c<w and (ph*tw+tr)<k else 0.
        syncb.wait()
        for i in range(tw): p += ms[tr*tw+i] * ns[tw*i+tc]
        syncb.wait()

    if (r<h and c<w): out[r*w + c] = p

def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [39]:
torch.isclose(matmul_2d(m1s, m2s, tw=16), m1s@m2s).all()

tensor(True)

In [40]:
m1s@m2s

tensor([[56.9767, 60.0248, 60.5974, 63.2288],
        [58.5341, 63.8303, 59.8225, 64.5254],
        [60.0873, 66.2759, 63.8680, 64.7840],
        [61.4536, 61.2220, 61.5788, 66.0657]])

In [46]:
matmul_2d(m1s, m2s, tw=3)

tensor([[56.9767, 60.0248, 60.5974, 63.2288],
        [58.5341, 63.8303, 59.8225, 64.5254],
        [60.0873, 66.2759, 63.8680, 64.7840],
        [61.4536, 61.2220, 61.5788, 66.0657]])

# CUDA dynamic shared

In [80]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k, int tw) {
    int tc=threadIdx.x, tr=threadIdx.y;
    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;

    extern __shared__ float ms[];
    float *ns = &ms[tw*tw];

    float p = 0.0f;
    for (int ph = 0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tr*tw + tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
        ns[tr*tw + tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
        __syncthreads();
        for (int i=0; i<tw; ++i) p += ms[tr*tw + i] * ns[tw*i + tc];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}
'''

In [81]:
cuda_src += r'''
torch::Tensor matmul_dyn(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    /*
    // Commented out section demonstrating basic idea of dynamic size calculation
    cudaDeviceProp devProp;
    CUDA_ERR(cudaGetDeviceProperties(&devProp, 0));
    int maxThreads = devProp.maxThreadsPerBlock;
    size_t requiredSize = static_cast<size_t>(maxThreads) * 2 * sizeof(float);
    size_t size = min(devProp.sharedMemPerBlock, requiredSize);
    int TW = std::sqrt(maxThreads);
    */

    // We just set size fixed for now
    int TW = 16;
    size_t size = TW*TW * 2 * sizeof(float);
    dim3 tpb(TW,TW);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks,tpb,size>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k, TW);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [82]:
fname = 'matmul_dyn'
cpp_src = get_sig(fname, cuda_src)
cpp_src

'torch::Tensor matmul_dyn(torch::Tensor m, torch::Tensor n);'

In [83]:
module = load_cuda(cuda_src, cpp_src, [fname], opt=True)

In [84]:
torch.isclose(module.matmul_dyn(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [85]:
%%timeit -n 10
module.matmul_dyn(m1c,m2c)
torch.cuda.synchronize()

6.65 ms ± 358 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [76]:
cuda_src = cuda_begin + r'''
constexpr int tw = 16;

__global__ void matmul_ks(float *m, float *n, float *out, int h, int w, int k) {
    __shared__ float ms[tw][tw], ns[tw][tw];
    int tc=threadIdx.x, tr=threadIdx.y;
    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;

    float p=0.0f;
    for (int ph=0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tr][tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
        ns[tr][tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
        __syncthreads();
        for (int i=0; i<tw; ++i) p += ms[tr][i] * ns[i][tc];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}

torch::Tensor matmul_static(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());
    dim3 tpb(tw,tw);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_ks<<<blocks,tpb>>>(m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [77]:
fname = 'matmul_static'
cpp_src = get_sig(fname, cuda_src)
module = load_cuda(cuda_src, cpp_src, [fname])
torch.isclose(module.matmul_static(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [78]:
%%timeit -n 10
module.matmul_static(m1c,m2c)
torch.cuda.synchronize()

4.74 ms ± 167 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
