In [1]:
import torch
from functools import partial

import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

In [2]:
@cute.kernel
def naive_elementwise_add_kernel(
    gA: cute.tensor,
    gB: cute.tensor,
    gC: cute.tensor,
):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    bdimx, _, _ = cute.arch.block_dim()

    thread_idx = bidx * bdimx + tidx

    m, n = gA.shape
    ni = thread_idx % n
    mi = thread_idx // n

    a_val = gA[mi, ni]
    b_val = gB[mi, ni]

    gC[mi, ni] = a_val + b_val


@cute.jit
def naive_elementwise_add(
    mA: cute.Tensor,
    mB: cute.Tensor,
    mC: cute.Tensor,
):  
    num_threads_per_block = 256

    m, n = mA.shape

    kernel = naive_elementwise_add_kernel(mA, mB, mC)
    grid = (m * n + num_threads_per_block - 1) // num_threads_per_block
    kernel.launch(grid=grid,
                  block=(num_threads_per_block, 1, 1))

In [3]:
#M, N = 4096, 4096
M, N = 40960, 8192

a = torch.randn(M, N, device="cuda", dtype=torch.float32)
b = torch.randn(M, N, device="cuda", dtype=torch.float32)
c = torch.randn(M, N, device="cuda", dtype=torch.float32)

a_ = from_dlpack(a, assumed_align=16)
b_ = from_dlpack(b, assumed_align=16)
c_ = from_dlpack(c, assumed_align=16)

naive_elementwise_add_ = cute.compile(naive_elementwise_add, a_, b_, c_)
naive_elementwise_add_(a_, b_, c_)

torch.testing.assert_close(c, a + b)

In [4]:
def benchmark(callable, *, num_warmups, num_iters):

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    torch.cuda.synchronize()

    for _ in range(num_warmups):
        callable()

    start_event.record(stream=torch.cuda.current_stream())
    for _ in range(num_iters):
        callable()
    end_event.record(stream=torch.cuda.current_stream())
    torch.cuda.synchronize()

    elapsed_time = start_event.elapsed_time(end_event)
    avg_time = elapsed_time / num_iters
    gflops = a.numel() / (avg_time / 1000) / 1e9

    print(f"Average execution time: {avg_time:.4f}ms")
    print(f"Performance (GFLOPS): {gflops:.4f}GFLOPS")
    print(f"Effective Memory Bandwidth: {(3 * a.numel() * 4) / (avg_time / 1000) / 1e9:.2f} GB/s")

benchmark(partial(naive_elementwise_add_, a_, b_, c_), num_warmups=5, num_iters=100)

Average execution time: 1.1835ms
Performance (GFLOPS): 283.5213GFLOPS
Effective Memory Bandwidth: 3402.26 GB/s


### 向量化的LD/ST

#### cute_divide_test

In [5]:
@cute.jit
def devide_layout_test(
    mA: cute.Tensor
):
    tiler = (2, 4)
    lA = cute.logical_divide(mA, tiler=tiler)
    zA = cute.zipped_divide(mA, tiler=tiler)
    fA = cute.flat_divide(mA, tiler=tiler)
    tA = cute.tiled_divide(mA, tiler=tiler)
    
    print(f"[DSL INFO] Tiled Tensors:")
    print(f"[DSL INFO] Tiler: {tiler}")
    print(f"[DSL INFO] logical_divide  lA = {lA}")    
    print(100*'=')
    print(f"[DSL INFO] zipped_divide  zA = {zA}")
    print(f"[DSL INFO] flatten_divide  fA = {fA}")
    print(f"[DSL INFO] tiled_divide  tA = {tA}")

M, N = 16, 32
a = torch.randn(M, N, device="cuda", dtype=torch.float32)
print(f"Tensor shape {a.shape} Stride: {a.stride()}")
a_ = from_dlpack(a, assumed_align=16)
devide_layout_test(a_)

Tensor shape torch.Size([16, 32]) Stride: (32, 1)
[DSL INFO] Tiled Tensors:
[DSL INFO] Tiler: (2, 4)
[DSL INFO] logical_divide  lA = tensor<ptr<f32, gmem, align<16>> o ((2,8),(4,8)):((32,64),(1,4))>
[DSL INFO] zipped_divide  zA = tensor<ptr<f32, gmem, align<16>> o ((2,4),(8,8)):((32,1),(64,4))>
[DSL INFO] flatten_divide  fA = tensor<ptr<f32, gmem, align<16>> o (2,4,8,8):(32,1,64,4)>
[DSL INFO] tiled_divide  tA = tensor<ptr<f32, gmem, align<16>> o ((2,4),8,8):((32,1),64,4)>


In [None]:
@cute.kernel
def vector_elementwise_add_kernel(
    gA: cute.Tensor,
    gB: cute.Tensor,
    gC: cute.Tensor,
):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    bdimx, _, _ = cute.arch.block_dim()\
    
    thread_idx = bidx * bdimx + tidx
#*                                   gA = cute.zipped_divide(mA, (1, 4))
#*
#*        Tile的Shape                         Thread的shape                          Tile的Stride
#*   ((   [  1 ]  [  4 ]   ))            (   [ 4096 ] [ 1024 ]  )  :            ((   [  0 ]  [  1 ]   ))   ( [ 4096 ] [  4 ] )
#*          ↑                                 ↑        ↑
#*        None                               mi       ni
#*
    m, n = gA.shape[1]
    print(f"THread domain m={m} , n={n}")
    ni = thread_idx % n
    mi = thread_idx // n

#*                                   gA = gA[(None, (mi, ni))]
#*                                                    ↓
#*                                       ((   [  1 ]  [  4 ]   ), (   [  0 ]  [  1 ]   ))
#*
    a_val = gA[(None, (mi, ni))].load()
    b_val = gB[(None, (mi, ni))].load()

    print(f"[DSL INFO] sliced gA = {gA[(None, (mi, ni))]}")
    print(f"[DSL INFO] sliced gB = {gB[(None, (mi, ni))]}")

    gC[(None), (mi, ni)] = a_val + b_val


@cute.jit
def vectorized_elementwise_add(
    mA: cute.Tensor,
    mB: cute.Tensor,
    mC: cute.Tensor
):
    threads_per_block = 256

#*                                   gA = cute.zipped_divide(mA, (1, 4))
#*
#*       Tile的Shape                         Thread的shape                          Tile的Stride
#*   ((   [  1 ]  [  4 ]   ))            (   [ 4096 ] [ 1024 ]  )  :            ((   [  0 ]  [  1 ]   ))   ( [ 4096 ] [  4 ] )
#*          ↑                                 ↑        ↑
#*        None                               mi       ni
#*
    gA = cute.zipped_divide(mA, (1, 4))
    gB = cute.zipped_divide(mB, (1, 4))
    gC = cute.zipped_divide(mC, (1, 4))
    print(f"[DSL INFO] Tiled Tensors:")
    print(f"[DSL INFO] gA = {gA}")
    print(f"[DSL INFO] gB = {gB}")
    print(f"[DSL INFO] gC = {gC}")
    
    vector_elementwise_add_kernel(gA, gB, gC).launch(
        grid=(cute.size(gC, mode=[1]) // threads_per_block, 1, 1),
        block=(threads_per_block, 1, 1)
    )

In [11]:
M,N= 40960,4096
a = torch.randn(M, N, device="cuda", dtype=torch.float32)
b = torch.randn(M, N, device="cuda", dtype=torch.float32)
c = torch.zeros(M, N, device="cuda", dtype=torch.float32)

a_ = from_dlpack(a, assumed_align=16)
b_ = from_dlpack(b, assumed_align=16)
c_ = from_dlpack(c, assumed_align=16)

compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)
compiled_func(a_, b_, c_)

# verify correctness
torch.testing.assert_close(c, a + b)

benchmark(partial(compiled_func, a_, b_, c_), num_warmups=5, num_iters=100)

[DSL INFO] Tiled Tensors:
[DSL INFO] gA = tensor<ptr<f32, gmem, align<16>> o ((1,4),(40960,1024)):((0,1),(4096,4))>
[DSL INFO] gB = tensor<ptr<f32, gmem, align<16>> o ((1,4),(40960,1024)):((0,1),(4096,4))>
[DSL INFO] gC = tensor<ptr<f32, gmem, align<16>> o ((1,4),(40960,1024)):((0,1),(4096,4))>
THread domain m=40960 , n=1024
[DSL INFO] sliced gA = tensor<ptr<f32, gmem, align<16>> o ((1,4)):((0,1))>
[DSL INFO] sliced gB = tensor<ptr<f32, gmem, align<16>> o ((1,4)):((0,1))>
Average execution time: 0.4593ms
Performance (GFLOPS): 365.2569GFLOPS
Effective Memory Bandwidth: 4383.08 GB/s
