## Setup

![tiling.png](ax-images/tiling.png)

So far, we have been using global memory by calling the .cuda() method in PyTorch to transfer data and operations to the GPU.

To speed up our kernel, we can use faster memory access, such as shared memory. While this approach provides rapid data access, it limits the number of threads that can access that data concurrently, since shared memory is restricted to the thread block scope. Shared memory is typically around 10× faster than global memory.

We can think of shared memory as a cache where we store data, so we don't need to re-read it from global memory. Specifically, It allows data that is reused by threads within a block to be stored in a much faster, on-chip memory, reducing the need to access slower global memory repeatedly.

In [1]:
import os
os.environ['TORCH_CUDA_ARCH_LIST'] = "9.0"

In [2]:
import os,math,sys,torch,re,numpy as np
from types import SimpleNamespace as ns
from collections import namedtuple

In [3]:
dim3 = namedtuple('dim3', ['x','y','z'], defaults=(1,1))

In [4]:
d = dim3(2,3)
d

dim3(x=2, y=3, z=1)

In [5]:
d.x,d.y

(2, 3)

In [6]:
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

In [7]:
sys.path.insert(0, '..')

In [8]:
from utils import show_img,load_cuda,cuda_begin,cdiv

In [9]:
%load_ext wurlitzer

In [10]:
# os.environ['CUDA_LAUNCH_BLOCKING']='1'
torch.manual_seed(42);

In [11]:
m1 = torch.rand(5120, 256)
m1s = m1[:4] # A sample of the first matrix, as pure python will be too slow to print the whole thing
m2 = torch.rand(256,5120)
m2s = m2[:,:4]

## Reminder

### 2d Python kernel

In [12]:
def blk_kernel2d(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): f(dim3(i1,i0), dim3(j1,j0), threads, *args)

In [13]:
def matmul_bk(blockIdx, threadIdx, blockDim, m, n, out, h, w, k):
    r = blockIdx.y*blockDim.y + threadIdx.y
    c = blockIdx.x*blockDim.x + threadIdx.x
    
    if (r>=h or c>=w): return
    o = 0.
    for i in range(k): o += m[r*k+i] * n[i*w+c]
    out[r*w+c] = o

In [14]:
def matmul_2d(m, n):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(16,16)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d(matmul_bk, blocks, tpb,
                 m.flatten(), n.flatten(), output.flatten(), h, w, k)
    return output

In [15]:
torch.isclose(matmul_2d(m1s, m2s), m1s@m2s).all()

tensor(True)

### CUDA

In [16]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;

    if (r>=h || c>=w) return;
    float o = 0;
    for (int i = 0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [17]:
fname = 'matmul'

In [18]:
def get_sig(fname, src):
    res = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE)
    return res[0]+';' if res else None

In [19]:
cpp_src = get_sig(fname, cuda_src)
cpp_src

'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'

In [20]:
module = load_cuda(cuda_src, cpp_src, [fname])

In [21]:
m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()

In [22]:
module.matmul(m1c,m2c).shape

torch.Size([5120, 5120])

In [23]:
torch.isclose(module.matmul(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [24]:
%%timeit -n 10
module.matmul(m1c,m2c)
torch.cuda.synchronize()

2.6 ms ± 4.69 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


When I removed the call to the kernel itself, it took around 50 µs (0.05 ms) to run, so that's the overhead of the call on my machine.

## Shared mem

### Python

In [25]:
# Simulating shared memory in python
a = torch.zeros(5)        
b,c = a[:3],a[3:] # Slicing a tensor creates a view of the same memory 
#b[0] = 1         # So changing b changes a

In [26]:
b[1] = 2
c[0] = 6
a

tensor([0., 2., 0., 6., 0.])

For now, we have two steps to accelerate our kernel using shared memory:

- Step 1: Load data into shared memory.
- Step 2: Compute the dot product.

In [27]:
def blk_kernel2d_shar(f, blocks, threads, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            shared = torch.zeros(sh_sz) # Now all our threads blocks are going to have access to this shared memory
            f(dim3(i1,i0), threads, shared, *args, **kwargs)

In [28]:
def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:] # Since shared is a contiguous array, we can split it like this

    for ph in range(cdiv(k,tw)): # ph is the index of what tile are we up to -> it indicates how many tiles of width tw have been processed 
                                 # along a row of m or a column of n (since for valid matrix multiplication, once dimension is the same (k index for both) k is divided into segments of width tw).
        
        # ------ we will refactor to make this big for loop into a simple function call - run_threads(fill_shared_tk, blockDim, ph) ----- # 
        idx = ph*tw
        # fill shared
        for tr in range(blockDim.y): # blockDim.y is the number of threads in the y direction, size of the tile in the y direction
            for tc in range(blockDim.x): # blockDim.x is the number of threads in the x direction, size of the tile in the x direction
                r,c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
                # ms and ns are the tiles of the matrix that we're going to be working with
                ms[tr*tw+tc] = m[ tc+idx + r*k] if r<h and idx+tc<k else 0. # if this was a 2d dimensional tensor, we could just do ms[tr,tc] but it's not it's 1d so we have to flatten out our dimensions
                ns[tr*tw+tc] = n[(tr+idx)*w +c] if c<w and idx+tr<k else 0. # Zeros padding for the outside of the matrix
        # ----- #   
        
        # do dotprods from shared
        for tr in range(blockDim.y):
            for tc in range(blockDim.x):
                r,c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
                for i in range(tw):
                    if r*w+c<len(out): out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]

In [29]:
# Exactly the same as before, but now we have a new shared memory runner: blk_kernel2d_shar(), to which we pass the shared memory argument.
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2, # the same of shared memory is tile_width*tile_width*2 (for m_matrix and n_matrix)
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [30]:
m1s.shape, m2.shape

(torch.Size([4, 256]), torch.Size([256, 5120]))

In [31]:
torch.isclose(matmul_2d(m1s, m2s, tw=16), m1s@m2s).all()

tensor(True)

### Python run_threads

NOTE:

- **Block (CUDA concept) vs. Tile (operational concept):**
    > A block is a CUDA programming concept that groups a set of threads; what we do with those threads is entirely up to us. Since we compute the output for each block as a section of the overall result, we refer to that section as a tile.

In [32]:
# Refactor to look like CUDA code: look through all the threads and call some function 
def run_threads(f, blockDim, *args, **kwargs):
    for i0 in range(blockDim.y):
        for i1 in range(blockDim.x): f(i0, i1, *args, **kwargs)

In [33]:
def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    def get_rc(tr, tc): return blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc

    def fill_shared_tk(tr, tc, ph):
        r,c = get_rc(tr, tc)
        ms[tr*tw+tc] = m[ tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.
        ns[tr*tw+tc] = n[(tr + ph*tw)*w +c] if c<w and (ph*tw+tr)<k else 0.

    def dotprod_tk(tr, tc):
        r,c = get_rc(tr, tc)
        for i in range(tw):
            if r*w+c<len(out): out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]

    for ph in range(int(math.ceil(k/tw))):
        run_threads(fill_shared_tk, blockDim, ph)
        run_threads(dotprod_tk, blockDim)

In [34]:
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [35]:
m1s.shape, m2s.shape

(torch.Size([4, 256]), torch.Size([256, 4]))

In [36]:
torch.isclose(matmul_2d(m1s, m2s, tw=16), m1s@m2s).all()

tensor(True)

### Python threads

In [37]:
# We need a way to tell all our threads to wait until they all reach a certain point before continuing.
import threading
from threading import Barrier, Thread
from concurrent.futures import ThreadPoolExecutor

In [38]:
def gs(x):
    print(x)
    # sb.wait()
    print(-x)
    # sb.wait()
    print(x*10)

In [39]:
# This will not provide us the desire behivour, we want to:
# 1- fill our shared memory with the tile of the matrix that we're going to be working with
# 2- and then do the dot products from shared
num = 3
with ThreadPoolExecutor(num) as ex: list(ex.map(lambda i: gs(i), range(1,num+1)))

1
-1
10
2
-2
20
3
-3
30


In [40]:
def g(x, sb):
    print(x)
    sb.wait()
    print(-x)
    sb.wait()
    print(x*10)

In [41]:
num = 3
sb = Barrier(num) # This barrier will wait for num threads to reach it before continuing
                  # Threads will be run in parallel, so the order of the prints will be random
with ThreadPoolExecutor(num) as ex: list(ex.map(lambda i: g(i,sb), range(1,num+1)))

1
2
3
-3
-1
-2
10
20
30


In [42]:
def blk_kernel2d_shar(f, blocks, tpb, sh_sz, *args, **kwargs):
    for i0 in range(blocks.y): # iterate over our 2d grid of blocks
        for i1 in range(blocks.x): 
            shar = torch.zeros(sh_sz) # create our share memory 
            syncb = Barrier(tpb.y*tpb.x) # create a syncronization barrier containing the number of threads in our block
            threads = [Thread(target=f, args=(dim3(i1,i0), dim3(p,o), tpb, shar, syncb, *args), kwargs=kwargs)
                       for o in range(tpb.y) for p in range(tpb.x)] # create a thread for each of our threads in our block
                       # 2 for loops in a list comprehension = cartesian product of the two lists ->  2d grid of threads
                       # o and p will be our new two coordinates for our threads
            for tr in threads: tr.start()
            for tr in threads: tr.join()

In [None]:
def matmul_tiled_bk(blockIdx, threadIdx, blockDim, shared, syncb, m, n, out, h, w, k, tw):
    tc,tr = threadIdx.x,threadIdx.y
    r = blockIdx.y*blockDim.y + tr
    c = blockIdx.x*blockDim.x + tc

    shar_sz = tw*tw
    ms,ns = shared[:shar_sz],shared[shar_sz:]

    p = 0.
    for ph in range(cdiv(k,tw)):
        ms[tr*tw+tc] = m[ tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.
        ns[tr*tw+tc] = n[(tr + ph*tw)*w +c] if c<w and (ph*tw+tr)<k else 0.
        syncb.wait() # wait until all the threads have runned the two previous lines of code
                     # Then we know that all the shared memory has been filled
        # We can go ahead and do the dot product
        for i in range(tw): p += ms[tr*tw+i] * ns[tw*i+tc] # do the dot product
        syncb.wait() # same as before, wait until all the threads have done the dot product
        # if we don't wait here, we might end up overwriting the shared memory before all the threads have done the dot product

    if (r<h and c<w): out[r*w + c] = p

In [44]:
def matmul_2d(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = dim3(tw,tw)
    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))
    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                      m.flatten(), n.flatten(), output.flatten(),
                      h, w, k, tw=tw)
    return output

In [45]:
torch.isclose(matmul_2d(m1s, m2s, tw=8), m1s@m2s).all()

tensor(True)

### CUDA dynamic shared

Code auto-generated by ChatGPT 4, using the following prompt:

> Convert the following python code to CUDA C, keeping formatting and variable names the same where possible. You can remove `blockIdx, threadIdx, blockDim, shared` from the argument list, since they're already provided by CUDA. Change `syncb.wait()` to `__syncthreads`. Use `extern __shared__ float shared[]` to create the `shared` array. Use the C ternary operator to replace the Python equivalent where appropriate. If the Python code uses any non-standard functions, you can assume the same functions are also available to the translated C code with the same name and signature.

The generated code worked first time, although we did some minor cleanups afterwards (e.g. renaming `shared` to `ms`).

In [46]:
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k, int tw) {
    int tc=threadIdx.x, tr=threadIdx.y;
    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;

    extern __shared__ float ms[]; // to use shared memory, we need to declare it as an array and specify its type
                                  // ms will be a pointer to the start of the shared memory, CUDA is going to create 'extern __shared__'
    float *ns = &ms[tw*tw];       // Now we need a pointer to the second half of the shared memory for ns
                                  // So we need to calculate the address of the start of ns taking into account they are contiguous in memory

    float p = 0.0f;
    for (int ph = 0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tr*tw + tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
        ns[tr*tw + tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
        __syncthreads(); // identical to the barrier in python, it will wait until all the threads have reached this point
        for (int i=0; i<tw; ++i) p += ms[tr*tw + i] * ns[tw*i + tc];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}
'''

`cudaDeviceProp devProp;`
`CUDA_ERR(cudaGetDeviceProperties(&devProp, 0));`

So, because the function expects a pointer (a memory address) to a structure, we need to pass it the address of our structure using the `&` operator. This operator retrieves the address of the structure, which is stored on the stack as a local variable. In contrast, the `*` operator is used to dereference a pointer—that is, to access the value stored at a given memory address. Since our structure is just declared (not yet filled with data) and we want to provide its address rather than its (empty) content, we must use &.

In [47]:
cuda_src += r'''
torch::Tensor matmul_dyn(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    /*
    // Commented out section demonstrating basic idea of dynamic size calculation
    cudaDeviceProp devProp; // creates a variable devProp of type cudaDeviceProp on the stack
                            // The stack is a region of memory automatically managed by our program, where local variables are allocated when a function is called and freed when the function returns.
                            // stack (runtime - automatically management) - local variables (and variables inside function without using dynamic allocation), function parameters and return addresses 
                            //heap (compilation - explicit management) 'new' (C++) 'malloc' (C) 
    CUDA_ERR(cudaGetDeviceProperties(&devProp, 0));
    int maxThreads = devProp.maxThreadsPerBlock;
    size_t requiredSize = static_cast<size_t>(maxThreads) * 2 * sizeof(float);
    size_t size = min(devProp.sharedMemPerBlock, requiredSize); // this ensures that the allocation does not exceed what the device supports
    int TW = std::sqrt(maxThreads); // to define a 2D grid of threads (with dimensions TW x TW) 
    */

    // We just set size fixed for now
    int TW = 16;
    size_t size = TW*TW * 2 * sizeof(float);
    dim3 tpb(TW,TW);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks,tpb,size>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k, TW);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [48]:
fname = 'matmul_dyn'

In [49]:
cpp_src = get_sig(fname, cuda_src)

In [50]:
module = load_cuda(cuda_src, cpp_src, [fname], opt=True)

In [51]:
torch.isclose(module.matmul_dyn(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [52]:
%%timeit -n 10
module.matmul_dyn(m1c,m2c)
torch.cuda.synchronize()

2.11 ms ± 1.94 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### CUDA static shared

In [53]:
cuda_src = cuda_begin + r'''
constexpr int tw = 16; // Now we have decided at compile time what our tile width is 

__global__ void matmul_ks(float *m, float *n, float *out, int h, int w, int k) {
    __shared__ float ms[tw][tw], ns[tw][tw]; // then this is not dynamic anymore, it's a fixed size
    int tc=threadIdx.x, tr=threadIdx.y;
    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;

    float p=0.0f;
    for (int ph=0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tr][tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
        ns[tr][tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
        __syncthreads();
        for (int i=0; i<tw; ++i) p += ms[tr][i] * ns[i][tc];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}

torch::Tensor matmul_static(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());
    dim3 tpb(tw,tw);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_ks<<<blocks,tpb>>>(m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [54]:
fname = 'matmul_static'
cpp_src = get_sig(fname, cuda_src)
module = load_cuda(cuda_src, cpp_src, [fname])
torch.isclose(module.matmul_static(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [55]:
%%timeit -n 10
module.matmul_static(m1c,m2c)
torch.cuda.synchronize()

1.67 ms ± 2.57 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


The presenter has not clarified why dynamic shared memory is slower than static shared memory in this case. It remains a mystery and is not yet documented.

## Numba

We can achieve the same functionality using a different library called Numba to write CUDA code in Python. Numba offers several handy features:

- It allows us to write CUDA-like code with fast compilation times.
- It does not require flattening tensors.
- [Numba Simulator](https://numba.pydata.org/numba-doc/dev/cuda/simulator.html): Setting `NUMBA_ENABLE_CUDASIM=1` enables debugging with a simulated CPU version.

In [56]:
from numba import cuda
from numba.cuda import as_cuda_array as ca

In [57]:
@cuda.jit
def matmul_k_numba(m, n, out, tw):
    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tc,tr = tid.x,tid.y
    r,c = cbi.y * cbd.y + tr, cbi.x * cbd.x + tc
    h,k  = m.shape
    k2,w = n.shape

    shar = cuda.shared.array(0, dtype=np.float32)
    ms,ns = shar[:tw*tw],shar[tw*tw:2*tw*tw]

    p = np.float32(0.0)
    for ph in range(math.ceil(k/tw)):
        idx = ph*tw
        ms[tr*tw+tc] = m[r, tc+idx] if r<h and idx+tc<k else 0.
        ns[tr*tw+tc] = n[tr+idx, c] if c<w and idx+tr<k else 0.
        cuda.syncthreads()
        for i in range(tw): p += ms[tr*tw+i] * ns[i*tw+tc]
        cuda.syncthreads()
    if r < h and c < w: out[r, c] = p

In [58]:
def matmul_2d_numba(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    out = torch.zeros(h, w, dtype=m.dtype, device=m.device)
    dyn_shared_mem_size = 2 * tw * tw * 4
    tpb = tw,tw
    blocks = cdiv(w,tpb[0]), cdiv(h,tpb[1])
    matmul_k_numba[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw) 
    return out

In [59]:
torch.isclose(matmul_2d_numba(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [60]:
%%timeit -n 10
matmul_2d_numba(m1c,m2c)
torch.cuda.synchronize()

9.52 ms ± 15.5 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Extra: Optimised Dynamic CUDA with Template

In [61]:
cuda_src = cuda_begin + r'''
template<int tw>
__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k) {
    int tc=threadIdx.x, tr=threadIdx.y;
    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;
    extern __shared__ float ms[];
    float *ns = &ms[tw*tw];

    float p = 0.0f;
    for (int ph = 0; ph < cdiv(k,tw); ++ph) {
        int idx = ph*tw;
        ms[tr*tw + tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
        ns[tr*tw + tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
        __syncthreads();
        for (int i=0; i<tw; ++i) p += ms[tr*tw + i] * ns[tw*i + tc];
        __syncthreads();
    }
    if (r<h && c<w) out[r*w + c] = p;
}
'''

In [62]:
cuda_src += r'''
torch::Tensor matmul_dyn1(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h=m.size(0), w=n.size(1), k=m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());
    int TW = 16; // TODO: Calculate this dynamically
    size_t size = TW*TW*2 * sizeof(float) + 1;
    dim3 tpb(TW,TW);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));

    // lambda function to call the templated function
    auto f = [&](auto kf) { kf<<<blocks, tpb, size>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    };
    switch(TW) {
        case 8: f(matmul_k<8>); break;
        case 16: f(matmul_k<16>); break;
        case 32: f(matmul_k<32>); break;
        default: break;
    }
    
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

In [63]:
%%time
fname = 'matmul_dyn1'
cpp_src = get_sig(fname, cuda_src)
module = load_cuda(cuda_src, cpp_src, [fname], opt=True)
func = getattr(module, fname)

CPU times: user 7.07 ms, sys: 0 ns, total: 7.07 ms
Wall time: 32.2 s


In [64]:
torch.isclose(func(m1c,m2c), m1c@m2c).all()

tensor(True, device='cuda:0')

In [65]:
%%timeit -n 10
func(m1c,m2c)
torch.cuda.synchronize()

1.67 ms ± 2.28 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
