In [1]:
import torch
from functools import partial
from typing import List

import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
from cutlass.cute import KeepPTX, KeepCUBIN
import numpy as np

In [88]:
@cute.kernel
def patterned_access(gA:cute.Tensor, gB:cute.Tensor):
  tidx, _, _ = cute.arch.thread_idx()
  bidx, _, _ = cute.arch.block_idx()
  bdimx,_, _ = cute.arch.block_dim()
  lane_id = tidx % 32 
  warp_id = tidx // 32 

  x = (((lane_id*7)+5)%32)*32 + (warp_id) + (bdimx*bidx)
  gB[x] = gA[x]
  
@cute.kernel
def coalesced_access(gA:cute.Tensor, gB:cute.Tensor):
  tidx, _, _ = cute.arch.thread_idx()
  bidx, _, _ = cute.arch.block_idx()
  bdimx,_, _ = cute.arch.block_dim()
  lane_id = tidx % 32 
  warp_id = tidx // 32 
  x = lane_id + (warp_id*32) + (bdimx*bidx)
  gB[x] = gA[x]
  

In [89]:
@cute.jit 
def coalesced_launcher(mA, mB): 
  N= mA.shape[0]
  kernel = coalesced_access(mA,mB)
  tpb = 1024 
  kernel.launch(grid=(N//tpb,1,1), block = (tpb,1,1))
  
@cute.jit 
def patterned_launcher(mA, mB): 
  N = mA.shape[0]
  kernel = patterned_access(mA,mB)
  tpb = 1024 
  kernel.launch(grid=(N//tpb,1,1), block = (tpb,1,1))

In [97]:
num_elems = 4096*4096*16
a,b = torch.randn(num_elems, device = "cuda", dtype=torch.float32) + 0.003, torch.randn(num_elems, device = "cuda", dtype=torch.float32) + 0.003
c,d = torch.zeros(num_elems, device = "cuda", dtype = torch.float32), torch.zeros(num_elems, device = "cuda", dtype = torch.float32)
a_ = from_dlpack(a, assumed_align=128)
b_ = from_dlpack(b, assumed_align=128)
c_ = from_dlpack(c, assumed_align=128)
d_ = from_dlpack(d, assumed_align=128)
patterned_compiled = cute.compile[KeepCUBIN](patterned_launcher,a_,c_)

coalesced_compiled = cute.compile[KeepCUBIN](coalesced_launcher, b_,d_)


In [95]:
def mem_benchmark(callable, a_, b_, num_elements):
    avg_time_us = cute.testing.benchmark(
        callable,
        kernel_arguments=cute.testing.JitArguments(a_, b_),
        warmup_iterations=5,
        iterations=200,
    )

    # Calculate metrics
    # ----------------
    dtype = a_.element_type

    # Calculate total bytes transferred:
    # - 2 reads (A and B) + 1 write (C)
    # - Each element is dtype.width bits
    bytes_per_element = dtype.width // 8
    total_bytes = num_elements * bytes_per_element

    # Calculate achieved bandwidth
    achieved_bandwidth = (2*total_bytes) / (avg_time_us * 1000)  # GB/s

    # Print results
    # ------------
    print(f"Performance Metrics:")
    print(f"-------------------")
    print(f"Kernel execution time: {avg_time_us:.4f} us")
    print(f"Memory throughput: {achieved_bandwidth:.2f} GB/s")

In [96]:
mem_benchmark(patterned_compiled, a_,c_, num_elems)
mem_benchmark(coalesced_compiled, b_, d_, num_elems)

Performance Metrics:
-------------------
Kernel execution time: 3027.2070 us
Memory throughput: 709.39 GB/s
Performance Metrics:
-------------------
Kernel execution time: 1428.8667 us
Memory throughput: 1502.93 GB/s


In [93]:
count = (c== 0).sum().item()
# Returns: Integer (e.g., 502)

In [94]:
count

0