In [1]:
import os,math,sys,torch,re,numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple
import os
import torch
import numpy as np

In [2]:
DEBUG_MODE = True
if DEBUG_MODE: os.environ['NUMBA_ENABLE_CUDASIM'] = '1'

In [3]:
import numba
from numba import cuda
if not DEBUG_MODE:
    from numba.cuda import as_cuda_array as ca

In [4]:
import sys
sys.path.insert(0, '..')
from utils import show_img,load_cuda,cuda_begin,cdiv

In [5]:
import re
def get_sig(fname, src):
    res = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None

In [6]:
cuda_src = cuda_begin + r'''
#include <cuda_runtime.h>
#include <sstream> // Include the stringstream header

std::string get_device_prop() {
    cudaDeviceProp devProp;
    cudaError_t cudaStatus = cudaGetDeviceProperties(&devProp, 0);
    if (cudaStatus != cudaSuccess) {
        return "Failed to get device properties"; // Handle error appropriately
    }

    std::ostringstream stream; // Use ostringstream for formatting the string

    // Extract device properties
    int maxThreads = devProp.maxThreadsPerBlock;
    size_t totalGlobalMem = devProp.totalGlobalMem;
    size_t sharedMemPerBlock = devProp.sharedMemPerBlock;
    int regsPerBlock = devProp.regsPerBlock;
    int warpSize = devProp.warpSize;
    int maxThreadsPerMultiProcessor = devProp.maxThreadsPerMultiProcessor;
    size_t sharedMemPerMultiprocessor = devProp.sharedMemPerMultiprocessor;
    int regsPerMultiprocessor = devProp.regsPerMultiprocessor;

    // Format the string with device properties
    stream << maxThreads << ", " << totalGlobalMem << ", " << sharedMemPerBlock
           << ", " << regsPerBlock << ", " << warpSize << ", " 
           << maxThreadsPerMultiProcessor << ", " << sharedMemPerMultiprocessor 
           << ", " << regsPerMultiprocessor;

    return stream.str(); // Return the formatted string
}
'''

In [7]:
fname = 'get_device_prop'
# cpp_src = get_sig(fname, cuda_src)
cpp_src = "std::string get_device_prop();"
print_module = load_cuda(cuda_src, cpp_src, [fname])
print_module.get_device_prop()

'1024, 25438126080, 49152, 65536, 32, 1536, 102400, 65536'

In [8]:
shared_mem_per_sm = 102400
shared_mem_per_block = 49152

In [9]:
shared_mem_per_sm/shared_mem_per_block

2.0833333333333335

In [10]:
max_threads_per_sm = 1536
approx_block_size = int(max_threads_per_sm / np.floor(shared_mem_per_sm/shared_mem_per_block)); approx_block_size

768

In [11]:
assert approx_block_size * 2 * np.dtype("float32").itemsize <= shared_mem_per_sm, "Max possible shared mem per block exceeded!"

In [12]:
max_threads_per_block = 1024
assert approx_block_size <= max_threads_per_block, f"Have more threads per block {approx_block_size} > {max_threads_per_block}"

In [13]:
warp_size = 32
assert approx_block_size % warp_size == 0, "Block size is not divisible by warp which will cause underutilization!"

In [14]:
# for registers and occupancy: https://docs.nvidia.com/cuda/archive/10.2/cuda-occupancy-calculator/index.html

### Rectangle Tile Matmul

**Note:** Using rectangle tiles either wide or tall is causing issues with synchronization and race conditions if we want to implement something efficient. Otherwise solutions would require a more than needed shared memory usage and/or idle threads.

So for matmul with tall and wide matrices, such as the MNIST example, where A=50,000 x 768 and B= 768 x 10 we will use a block with dim3(x,1,1) and process a single row or single col per block.

- Case 1: `A=50,000 x 768 and B= 768 x 10`, each thread produce an output row.
- Case 2: `A=10 x 768 and B= 768 x 50,000`, each thread produce an output col.

In [48]:
@cuda.jit
def matmul_flat_tile_numba(m, n, out, tw):

    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tx,ty = tid.x,tid.y

    h,k  = m.shape
    k2,w = n.shape
    
    ms = cuda.shared.array(0, dtype=np.dtype("float32"))
    
    r = cbi.x

    for ph in range(math.ceil(k/tw)):
                
        # fill shared mem
        idx = ph*tw
        ms[tx] = m[r, idx+tx] if r<h and idx+tx<k else 0.
        cuda.syncthreads()

        for c in range(w):        
            # dot-product and accumulate
            p = 0
            for i in range(tw):
                 p += ms[i] * n[idx+tx,c]
            out[r,c] = p
        cuda.syncthreads()

In [49]:
def matmul_2d_numba(m, n, tw):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    out = torch.zeros(h, w, dtype=m.dtype).cuda()
    
    dtype_size = np.dtype("float32").itemsize
    dyn_shared_mem_size = tw * dtype_size
    
    blocks = (h,)
    tpb = (tw,)
    
    if DEBUG_MODE:
        matmul_flat_tile_numba[blocks, tpb, 0, dyn_shared_mem_size](m, n, out, tw) 
    else:
        matmul_flat_tile_numba[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw) 
    return out

In [50]:
m1s = torch.arange(12).view(3,4).float().contiguous().cuda()
m2 = torch.arange(24).view(4,6).float().contiguous().cuda()
m1s,m2

(tensor([[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]], device='cuda:0'),
 tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [12., 13., 14., 15., 16., 17.],
         [18., 19., 20., 21., 22., 23.]], device='cuda:0'))

In [51]:
matmul_2d_numba(m1s,m2,tw=2)

tensor([[ 60.,  65.,  70.,  75.,  80.,  85.],
        [234., 247., 260., 273., 286., 299.],
        [252., 399., 420., 315., 336., 357.]], device='cuda:0')

In [52]:
m1s@m2

tensor([[ 84.,  90.,  96., 102., 108., 114.],
        [228., 250., 272., 294., 316., 338.],
        [372., 410., 448., 486., 524., 562.]], device='cuda:0')

In [44]:
torch.allclose(matmul_2d_numba(m1s,m2,tw=4), m1s@m2)

True

### CUDA

In [57]:
A = torch.randn(50_000, 768).contiguous().cuda()
B = torch.randn(768, 10).contiguous().cuda()

In [7]:
cuda_src = cuda_begin + r'''
template<int tw>
__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k) {
    int tc=threadIdx.x, tr=threadIdx.y;
    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;
    extern __shared__ float ms[];
    float *ns = &ms[tw*tw];

    float p = 0.0f;
    for (int ph = 0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tr*tw + tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
        ns[tr*tw + tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
        __syncthreads();
        for (int i=0; i<tw; ++i) p += ms[tr*tw + i] * ns[tw*i + tc];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}
'''

In [8]:
cuda_src += r'''
torch::Tensor matmul_sq_tile(torch::Tensor m, torch::Tensor n, int TW) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());
    
    //int TW = 8; // TODO: Calculate this dynamically
    
    size_t size = TW*TW*2 * sizeof(float) + 1;
    dim3 tpb(TW,TW);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));

    auto f = [&](auto kf) { kf<<<blocks, tpb, size>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    };
    switch(TW) {
        case 8: f(matmul_k<8>); break;
        case 16: f(matmul_k<16>); break;
        case 32: f(matmul_k<32>); break;
        default: break;
    }
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [None]:
fname = 'matmul_sq_tile'
cpp_src = get_sig(fname, cuda_src)
module = load_cuda(cuda_src, cpp_src, [fname])

In [15]:
%%timeit -n 10
module.matmul_sq_tile(A,B,8)
torch.cuda.synchronize()

699 µs ± 12.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
%%timeit -n 10
module.matmul_sq_tile(A,B,16)
torch.cuda.synchronize()

489 µs ± 8.62 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [17]:
%%timeit -n 10
module.matmul_sq_tile(A,B,32)
torch.cuda.synchronize()

1.03 ms ± 36.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [23]:
cuda_src = cuda_begin + r'''
template<int tw>
__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k) {
    int tc=threadIdx.x, tr=threadIdx.y;
    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;
    extern __shared__ float ms[];
    float *ns = &ms[tw*tw];

    float p = 0.0f;
    for (int ph = 0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tr*tw + tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
        ns[tr*tw + tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
        __syncthreads();
        for (int i=0; i<tw; ++i) p += ms[tr*tw + i] * ns[tw*i + tc];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}
'''

**Note:** This is much slower !!! Because shared memory is repeatedly filled for ms, instead of using it multiple times. fix `matmul_flat_tile_numba`

In [None]:
@cuda.jit
def matmul_flat_tile_numba(m, n, out, tw):

    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tx,ty = tid.x,tid.y

    h,k  = m.shape
    k2,w = n.shape
    
    shar = cuda.shared.array(0, dtype=np.dtype("float32"))
    ms,ns = shar[:tw],shar[tw:]
    
    r = cbi.x

    for c in range(w):
        p = 0
        for ph in range(math.ceil(k/tw)):
                    
            # fill shared mem
            idx = ph*tw
            
            # transposed and aligned for dot-product
            ms[tx] = m[r, idx+tx] if r<h and idx+tx<k else 0.
            ns[tx] = n[idx+tx, c] if c<w and idx+tx<k else 0.
            cuda.syncthreads()

            # dot-product and accumulate
            for i in range(tw):
                p += ms[i] * ns[i]
            cuda.syncthreads()
            
        if r < h and c < w: out[r,c] = p

In [None]:
@cuda.jit
def matmul_flat_tile_numba(m, n, out, tw):

    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tx,ty = tid.x,tid.y

    h,k  = m.shape
    k2,w = n.shape
    
    ms = cuda.shared.array(0, dtype=np.dtype("float32"))
    
    r = cbi.x

    for ph in range(math.ceil(k/tw)):
                
        # fill shared mem
        idx = ph*tw
        ms[tx] = m[r, idx+tx] if r<h and idx+tx<k else 0.
        cuda.syncthreads()

        for c in range(w):        
            # dot-product and accumulate
            p = 0
            for i in range(tw):
                 p += ms[i] * n[idx+tx,c]
            out[r,c] = p
        cuda.syncthreads()

In [53]:
cuda_src = cuda_begin + r'''
template<int tw>
__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k) {
    
    int tx=threadIdx.x, ty=threadIdx.y, r=blockIdx.x;
    
    extern __shared__ float ms[];

    for (int ph = 0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tx] = r<h && idx+tx<k ? m[r*k + idx + tx] : 0.0f;
        __syncthreads();
        for (int c=0; c<w; ++c){
            for (int i=0; i<tw; ++i) out[r*w + c] += ms[i] * n[(idx+tx)*w + c];
        }
        __syncthreads();
    }
    
}
'''

In [54]:
cuda_src += r'''
torch::Tensor matmul_flat_tile(torch::Tensor m, torch::Tensor n, int TW) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    size_t size = TW * sizeof(float) + 1;
    dim3 tpb(TW);
    dim3 blocks(h);

    auto f = [&](auto kf) { kf<<<blocks, tpb, size>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    };
    switch(TW) {
        case 256: f(matmul_k<256>); break;
        case 512: f(matmul_k<512>); break;
        case 768: f(matmul_k<768>); break;
        default: break;
    }
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [55]:
fname = 'matmul_flat_tile'
cpp_src = get_sig(fname, cuda_src)
flat_module = load_cuda(cuda_src, cpp_src, [fname])

In [None]:
assert torch.allclose(flat_module.matmul_flat_tile(A,B,768), A@B, atol=1e-4)

In [60]:
flat_module.matmul_flat_tile(A,B,768)

tensor([[  -7.4253,   18.9958,  -31.8680,  ...,   -9.9729,   60.7632,
           43.9456],
        [  12.7517,   -4.6914,   -7.9195,  ...,   23.2486, -181.1063,
            6.4379],
        [  -4.8458,   -2.0506,   17.1899,  ...,  -37.1760,   44.4268,
           19.3395],
        ...,
        [  -7.7356,   15.3391,   25.3848,  ...,    7.1080,    7.9670,
           -9.5791],
        [   8.8504,  -18.9655,    9.4784,  ...,  123.9788,  -38.5527,
           11.5962],
        [   2.3148,    4.7446,    7.9802,  ...,   -6.2798,    2.5232,
          -11.4697]], device='cuda:0')

In [None]:
%%timeit -n 10
flat_module.matmul_flat_tile(A,B,768)
torch.cuda.synchronize()