In [1]:
import torch
from functools import partial
from typing import List

import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
from cutlass.cute import KeepPTX, KeepCUBIN

In [2]:
@cute.kernel
def naive_elementwise_add_kernel(
    gA: cute.Tensor,  # Input tensor A
    gB: cute.Tensor,  # Input tensor B
    gC: cute.Tensor,  # Output tensor C = A + B
):
  tidx, _, _ = cute.arch.thread_idx()
  bidx, _, _ = cute.arch.block_idx()
  bdim, _, _ = cute.arch.block_dim()
  global_tid = (bidx*bdim) + tidx
  m,n = gA.shape 
  X,Y = global_tid //n, (global_tid % n)
  a_val = gA[X,Y]
  b_val = gB[X,Y]
  gC[X,Y] = a_val + b_val
  

In [3]:
@cute.jit 
def naive_elementwise_add_launcher(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):
  n_tpb = 256 
  m,n = mA.shape 
  kernel = naive_elementwise_add_kernel(mA, mB, mC)
  kernel.launch(grid = ((m*n)//n_tpb,1,1), block = (n_tpb,1,1))

In [4]:

M, N = 16384, 8192
a = torch.randn(M, N, device="cuda", dtype=torch.float16)  # Random input A
b = torch.randn(M, N, device="cuda", dtype=torch.float16)  # Random input B
c = torch.zeros(M, N, device="cuda", dtype=torch.float16)  # Output buffer

# Calculate total elements for bandwidth calculations
num_elements = sum([a.numel(), b.numel(), c.numel()])

# Convert PyTorch tensors to CuTe tensors
# -------------------------------------
# from_dlpack creates CuTe tensor views of PyTorch tensors
# assumed_align=16 ensures proper memory alignment for vectorized access (indeed we do need to ACTUALLY have our tensors aligned)
a_ = from_dlpack(a, assumed_align=16)  # CuTe tensor A
b_ = from_dlpack(b, assumed_align=16)  # CuTe tensor B
c_ = from_dlpack(c, assumed_align=16)  # CuTe tensor C

In [5]:
naive_elementwise_add_compiled = cute.compile[KeepPTX](naive_elementwise_add_launcher, a_,b_,c_)
naive_elementwise_add_compiled(a_, b_, c_)
torch.testing.assert_close(c, a+b)

In [13]:
def benchmark(callable, a_, b_, c_, num_elements):
    avg_time_us = cute.testing.benchmark(
        callable,
        kernel_arguments=cute.testing.JitArguments(a_, b_, c_),
        warmup_iterations=5,
        iterations=100,
    )

    # Calculate metrics
    # ----------------
    dtype = a_.element_type

    # Calculate total bytes transferred:
    # - 2 reads (A and B) + 1 write (C)
    # - Each element is dtype.width bits
    bytes_per_element = dtype.width // 8
    total_bytes = num_elements * bytes_per_element

    # Calculate achieved bandwidth
    achieved_bandwidth = total_bytes / (avg_time_us * 1000)  # GB/s

    # Print results
    # ------------
    print(f"Performance Metrics:")
    print(f"-------------------")
    print(f"Kernel execution time: {avg_time_us:.4f} us")
    print(f"Memory throughput: {achieved_bandwidth:.2f} GB/s")

In [14]:
benchmark(naive_elementwise_add_compiled, a_, b_,c_, num_elements)

Performance Metrics:
-------------------
Kernel execution time: 507.6538 us
Memory throughput: 1586.33 GB/s


In [16]:
@cute.kernel
def vectorized_elementwise_kernel(gA:cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): 
  #this time, the tensors are 4 vectorized and dtype is 32 bit for some reason we went from 16 to 32 bit. 
  #also our shapes are now zipped divided as ((1,4),(inner_m,inner_n)) which means that an access ((none), x,y) gives you 
  #the row 4vector chunk when viewing our tensor as tiled by 4 vector chunks.
  tidx, _, _ = cute.arch.thread_idx()
  bidx, _, _ = cute.arch.block_idx()
  bdim, _, _ = cute.arch.block_dim()
  global_tid = (bidx*bdim) + tidx
  inner_m, inner_n = gA.shape[1] #the second mode of the shape tells us the (row vector tiled) shape. 
  inner_x, inner_y = (global_tid//inner_n), (global_tid % inner_n) #indeed we launch totally as many threads as there are elements 
  #where each element is now normalized to (1,4) vector unit atom.
  a_val = gA[(None,(inner_x,inner_y))].load() #I guess we need the .load() to help us emit vector instruction 
  b_val = gB[(None,(inner_x,inner_y))].load()
  gC[(None, (inner_x,inner_y))] = a_val + b_val 
  
  

In [29]:
@cute.jit
def vectorized_elementwise_launcher(mA:cute.Tensor, mB:cute.Tensor, mC:cute.Tensor):
  gA = cute.zipped_divide(mA, (1, 8))
  gB = cute.zipped_divide(mB, (1, 8))
  gC = cute.zipped_divide(mC, (1, 8))
  print("[DSL INFO] Tiled Tensors:")
  print(f"[DSL INFO]   gA = {gA}")
  print(f"[DSL INFO]   gB = {gB}")
  print(f"[DSL INFO]   gC = {gC}")
  n_tpb = 256
  vectorized_elementwise_kernel(gA, gB, gC).launch(
        grid=(cute.size(gC, mode=[1]) // n_tpb, 1, 1),
        block=(n_tpb, 1, 1),
    )
  
a = torch.randn(M, N, device="cuda", dtype=torch.float16)
b = torch.randn(M, N, device="cuda", dtype=torch.float16)
c = torch.zeros(M, N, device="cuda", dtype=torch.float16)

a_ = from_dlpack(a, assumed_align=16)
b_ = from_dlpack(b, assumed_align=16)
c_ = from_dlpack(c, assumed_align=16)

vectorized_elementwise_compiled = cute.compile[KeepPTX](vectorized_elementwise_launcher, a_, b_, c_)
vectorized_elementwise_compiled (a_, b_, c_)

# verify correctness
torch.testing.assert_close(c, a + b)

[DSL INFO] Tiled Tensors:
[DSL INFO]   gA = tensor<ptr<f16, gmem, align<16>> o ((1,8),(16384,1024)):((0,1),(8192,8))>
[DSL INFO]   gB = tensor<ptr<f16, gmem, align<16>> o ((1,8),(16384,1024)):((0,1),(8192,8))>
[DSL INFO]   gC = tensor<ptr<f16, gmem, align<16>> o ((1,8),(16384,1024)):((0,1),(8192,8))>


In [23]:
benchmark(vectorized_elementwise_compiled, a_, b_, c_, num_elements)

Performance Metrics:
-------------------
Kernel execution time: 511.9517 us
Memory throughput: 1573.01 GB/s
