In [1]:
import torch
from functools import partial
from typing import List

import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
from cutlass.cute import KeepPTX, KeepCUBIN
import numpy as np

In [None]:
@cute.kernel#this is already gmem coalesced 
def naive_gemm(gA:cute.Tensor, gB:cute.Tensor, gC:cute.Tensor):
  M,K = gA.shape
  K,N = gB.shape
  tid_x, tid_y, _ = cute.arch.thread_idx()
  bid_x, bid_y, _ = cute.arch.block_idx()
  bdim_x, bdim_y, _ = cute.arch.block_dim()
  x = tid_x + (bid_x*bdim_x)
  y = tid_y + (bid_y*bdim_y)
  c_val = 0.0
  if (x < N and y < M): 
    for k in range(K): 
      a_val = gA[y,k]
      b_val = gB[k,x]
      c_val += a_val*b_val
    gC[y,x] = c_val
      
      

In [18]:
@cute.jit
def naive_gemm_launcher(mA: cute.Tensor, mB: cute.Tensor, mC:cute.Tensor): 
  tpb_x = 32
  tpb_y = 32 
  M,K = mA.shape
  K,N = mB.shape
  cute.printf(mA.layout)
  kernel = naive_gemm(mA, mB, mC)
  kernel.launch(grid = (N//tpb_x, M//tpb_y,1), block=(tpb_x, tpb_y, 1))

In [19]:
M,K,N = 4096, 4096, 4096 
a = torch.randn(M,K, device = "cuda", dtype=torch.float32)
b = torch.randn(K,N, device = "cuda", dtype=torch.float32)
c = torch.zeros(M,N, device = "cuda", dtype = torch.float32)

a_ = from_dlpack(a)
b_ = from_dlpack(b)
c_ = from_dlpack(c)

In [20]:
naive_gemm_compiled = cute.compile[KeepPTX](naive_gemm_launcher, a_, b_, c_)
naive_gemm_compiled(a_,b_,c_)
torch.testing.assert_close(c, torch.matmul(a, b), atol=1e-3, rtol=1e-3)

(4096,4096):(4096,1)


# okay this is a beautiful verification as we can see when we use cute dl_pack, since torch 
# tensors are row major, our converted cute_layout is also row major. 

In [16]:
def benchmark(callable, a_, b_, c_, M,N,K):
    avg_time_us = cute.testing.benchmark(
        callable,
        kernel_arguments=cute.testing.JitArguments(a_, b_, c_),
        warmup_iterations=5,
        iterations=100,
    )
    
    flop = 2*M*N*K
    
    giga_flop_per_second = flop/(avg_time_us*1000)
    print(f"giga_flops_per_second:{giga_flop_per_second}")
    

In [17]:
benchmark(naive_gemm_compiled, a_,b_,c_, M,N,K)

giga_flops_per_second:7801.126391856618
