Skip to content

[Example] Add example of blackwell matmul v7#91

Merged
yaoyaoding merged 13 commits intomainfrom
blackwell-gemm
Mar 11, 2026
Merged

[Example] Add example of blackwell matmul v7#91
yaoyaoding merged 13 commits intomainfrom
blackwell-gemm

Conversation

@yaoyaoding
Copy link
Member

This example of matmul uses the cluster launch control feature of Blackwell GPU to implement the persistent block optimization.

       m      n      k   name  latency (ms)       tflops
0   4096   4096   4096  torch      0.144560   950.739843
1   4096   4096   4096  tilus      0.164880   833.569587
2   4096   4096  14336  torch      0.445376  1080.067951
3   4096   4096  14336  tilus      0.482304   997.371638
4   8192   8192   8192  torch      0.891872  1232.813275
5   8192   8192   8192  tilus      0.981024  1120.779544
6  10240  10240  10240  torch      1.697824  1264.844676
7  10240  10240  10240  tilus      1.891312  1135.446528
import os

import pandas
import tilus
import torch
from tilus import RegisterTensor, SharedTensor, float16, float32, int32, uint32
from tilus.utils import benchmark_func, cdiv

class Pipeline(tilus.Class):
    def __init__(
        self,
        num_stages: int,
        producer_arrive_count: int = 1,
        consumer_arrive_count: int = 1,
    ):
        self.num_stages: int = num_stages
        self.empty_barriers = self.mbarrier.alloc(
            [consumer_arrive_count for _ in range(num_stages)]
        )
        self.full_barriers = self.mbarrier.alloc(
            [producer_arrive_count for _ in range(num_stages)]
        )
        self.producer_stage: int32 = 0
        self.consumer_stage: int32 = 0
        self.producer_phase: uint32 = self.mbarrier.producer_initial_phase
        self.consumer_phase: uint32 = self.mbarrier.consumer_initial_phase

    def producer_acquire(self, scope: str = "cta"):
        self.mbarrier.wait(
            barrier=self.empty_barriers[self.producer_stage],
            phase=self.producer_phase,
            scope=scope,
        )

    def producer_barrier(self) -> RegisterTensor:
        return self.full_barriers[self.producer_stage]

    def producer_advance(self):
        self.producer_stage = (self.producer_stage + 1) % self.num_stages
        self.producer_phase = self.producer_phase ^ (self.producer_stage == 0)

    def consumer_acquire(self, scope: str = "cta"):
        self.mbarrier.wait(
            barrier=self.full_barriers[self.consumer_stage],
            phase=self.consumer_phase,
            scope=scope,
        )

    def consumer_barrier(self) -> RegisterTensor:
        return self.empty_barriers[self.consumer_stage]

    def consumer_advance(self):
        self.consumer_stage = (self.consumer_stage + 1) % self.num_stages
        self.consumer_phase = self.consumer_phase ^ (self.consumer_stage == 0)


@tilus.autotune("block_m", [256])
@tilus.autotune("block_n, e_block_n", [[64, 16], [128, 16], [256, 16], [256, 32]])
@tilus.autotune("block_k", [16, 32, 64])
@tilus.autotune("tma_stages", [3, 4, 5, 6])
@tilus.autotune("mma_stages", [1, 2])
class BlackwellMatmulV7(tilus.Script):

    def __init__(
        self,
        block_m: int,
        block_n: int,
        block_k: int,
        tma_stages: int,
        mma_stages: int,
        e_block_n: int,
    ):
        super().__init__()
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k
        self.e_block_n = e_block_n
        self.tma_stages = tma_stages
        self.mma_stages = mma_stages
        self.clc_stages = 1

    def query_clc_response(self, s_clc_response: SharedTensor, pipe: Pipeline):
        pipe.consumer_acquire(scope="cluster")
        response = s_clc_response[pipe.consumer_stage]
        is_valid, new_blockIdx = self.clc.query_response(response)
        self.fence.async_view(space="shared")
        self.mbarrier.arrive_and_expect_tx_remote(
            pipe.consumer_barrier(), transaction_bytes=0, target_rank=0
        )
        pipe.consumer_advance()
        return is_valid, new_blockIdx

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        """
        Each CTA provides its own slice of A, B, and D.
        CTA0 = CTA with last bit of cluster rank = 0
        CTA1 = CTA with last bit of cluster rank = 1

                            Input B (K, N)
                          ┌───────┬───────┐
                          │  b0   │  b1   │
                          │(K,N/2)│(K,N/2)│
                          │[CTA0] │[CTA1] │
                          └───────┴───────┘
        ┌──────────────┐  ┌───────────────┐
        │  a0 (M/2, K) │  │  d0 (M/2, N)  │
        │  [CTA0]      │  │  [CTA0]       │
        ├──────────────┤  ├───────────────┤
        │  a1 (M/2, K) │  │  d1 (M/2, N)  │
        │  [CTA1]      │  │  [CTA1]       │
        └──────────────┘  └───────────────┘
         Input A (M, K)     Output D (M, N)
        """
        self.attrs.blocks = [cdiv(m_size, self.block_m) * 2, cdiv(n_size, self.block_n)]
        self.attrs.cluster_blocks = [2, 1]
        self.attrs.warps = 8

        block_m = self.block_m
        block_n = self.block_n
        block_k = self.block_k
        e_block_n = self.e_block_n
        tma_stages = self.tma_stages
        mma_stages = self.mma_stages
        clc_stages = self.clc_stages

        g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
        g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])

        s_a = self.shared_tensor(dtype=float16, shape=[tma_stages, block_m // 2, block_k])
        s_b = self.shared_tensor(dtype=float16, shape=[tma_stages, block_n // 2, block_k])
        t_acc = self.tcgen05.alloc(
            dtype=float32, shape=[mma_stages, block_m // 2, block_n], cta_group=2
        )

        s_clc_response = self.shared_tensor(dtype=int32, shape=[clc_stages, 4])

        tma_pipe = Pipeline(tma_stages)
        mma_pipe = Pipeline(
            mma_stages, consumer_arrive_count=128
        )  # 4 warps (epilogue warps)
        clc_pipe = Pipeline(
            clc_stages, consumer_arrive_count=224 * 2
        )  # 7 warps * 2 blocks

        cta_rank = self.cluster.blockRank

        self.cluster_sync()

        with self.single_warp(0):  # tma worker (gmem -> smem)
            offset_m_a = self.blockIdx.x * (block_m // 2)
            offset_n_b = self.blockIdx.y * block_n + cta_rank * (block_n // 2)
            while True:
                for offset_k in self.range(0, k_size, block_k, unroll=tma_stages):
                    tma_pipe.producer_acquire()
                    mbarrier = tma_pipe.producer_barrier()
                    if cta_rank == 0:
                        with self.single_thread():
                            # the mbarrier on CTA0 will track the completion of both CTAs' loading
                            transaction_bytes = (s_a[0].nbytes + s_b[0].nbytes) * 2
                            self.mbarrier.arrive_and_expect_tx(
                                mbarrier, transaction_bytes
                            )
                    else:
                        # get the mbarrier address in the CTA0 to signal
                        mbarrier = self.cluster.map_shared_addr(mbarrier, target_rank=0)
                    with self.single_thread():
                        self.tma.global_to_shared(
                            src=g_a,
                            dst=s_a[tma_pipe.producer_stage],
                            offsets=[offset_m_a, offset_k],
                            mbarrier=mbarrier,
                            cta_group=2,
                        )
                        self.tma.global_to_shared(
                            src=g_b,
                            dst=s_b[tma_pipe.producer_stage],
                            offsets=[offset_n_b, offset_k],
                            mbarrier=mbarrier,
                            cta_group=2,
                        )
                    tma_pipe.producer_advance()

                is_valid, new_blockIdx = self.query_clc_response(s_clc_response, clc_pipe)
                if not is_valid:
                    break
                offset_m_a = (new_blockIdx.x + cta_rank) * (block_m // 2)
                offset_n_b = new_blockIdx.y * block_n + cta_rank * (block_n // 2)

        with self.single_warp(1):  # mma worker (smem -> tmem)
            while True:
                with self.single_thread():
                    if cta_rank == 0:
                        mma_pipe.producer_acquire()
                        for offset_k in self.range(0, k_size, block_k, unroll=mma_stages):
                            tma_pipe.consumer_acquire()
                            self.tcgen05.mma(
                                s_a[tma_pipe.consumer_stage],
                                s_b[tma_pipe.consumer_stage].transpose(),
                                t_acc[mma_pipe.producer_stage],
                                enable_input_d=offset_k != 0,
                                cta_group=2,
                            )
                            self.tcgen05.commit(
                                mbarrier=tma_pipe.consumer_barrier(),
                                cta_group=2,
                                multicast_mask=0b11,
                            )
                            tma_pipe.consumer_advance()
                        self.tcgen05.commit(
                            mbarrier=mma_pipe.producer_barrier(),
                            cta_group=2,
                            multicast_mask=0b11,
                        )
                        mma_pipe.producer_advance()

                is_valid, new_blockIdx = self.query_clc_response(s_clc_response, clc_pipe)
                if not is_valid:
                    break

        with self.single_warp(2):  # scheduler
            while True:
                if cta_rank == 0:
                    clc_pipe.producer_acquire(
                        scope="cluster"
                    )  # peer cta will arrive this barrier, need 'cluster'scoped acquire
                    self.mbarrier.arrive_and_expect_tx_multicast(
                        clc_pipe.producer_barrier(),
                        transaction_bytes=16,
                        multicast_mask=0b11,
                    )
                    with self.single_thread():
                        self.clc.try_cancel(
                            s_clc_response[clc_pipe.producer_stage],
                            mbarrier=clc_pipe.producer_barrier(),
                            multicast=True,
                        )
                    clc_pipe.producer_advance()

                is_valid, new_blockIdx = self.query_clc_response(s_clc_response, clc_pipe)
                if not is_valid:
                    break

        with self.warp_group(warp_begin=4, num_warps=4):  # epilogue (tmem -> gmem)
            s_c = self.shared_tensor(dtype=float16, shape=[block_m // 2, self.e_block_n])
            offset_m_c = self.blockIdx.x * (block_m // 2)
            offset_n_c = self.blockIdx.y * block_n
            while True:
                mma_pipe.consumer_acquire()

                for e_offset_n in range(0, block_n, e_block_n):
                    t_acc_slice = self.tcgen05.slice(
                        t_acc[mma_pipe.consumer_stage],
                        offsets=[0, e_offset_n],
                        shape=[block_m // 2, e_block_n],
                        dims=[0, 1],
                    )
                    r_acc = self.tcgen05.load(t_acc_slice)
                    self.tcgen05.wait_load()
                    self.store_shared(s_c, r_acc.to(float16))
                    self.fence.async_view(space="shared")
                    self.sync()
                    with self.single_thread():
                        self.tma.shared_to_global(
                            s_c,
                            g_c,
                            offsets=[offset_m_c, offset_n_c + e_offset_n],
                            dims=[0, 1],
                        )
                        self.tma.commit_group()
                        self.tma.wait_group(n=0)
                    self.sync()

                self.mbarrier.arrive(mma_pipe.consumer_barrier())
                mma_pipe.consumer_advance()

                is_valid, new_blockIdx = self.query_clc_response(s_clc_response, clc_pipe)
                if not is_valid:
                    break
                offset_m_c = (new_blockIdx.x + cta_rank) * (block_m // 2)
                offset_n_c = new_blockIdx.y * block_n

        # all allocated tensor memory must be deallocated
        self.sync()
        self.tcgen05.dealloc(t_acc)


def main(bench=True):
    matmul = BlackwellMatmulV7()

    headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
    rows: list = []

    for m_size, n_size, k_size in [
        [4096, 4096, 4096],
        [4096, 4096, 14336],
        [8192, 8192, 8192],
        [10240, 10240, 10240],
    ]:
        print(f"Running with m_size={m_size}, n_size={n_size}, k_size={k_size}")
        a = torch.randint(0, 2, size=(m_size, k_size), dtype=torch.float16, device="cuda")
        b = torch.randint(0, 2, size=(n_size, k_size), dtype=torch.float16, device="cuda")
        c_actual = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda")
        c_expected = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda")

        matmul(m_size, n_size, k_size, a, b, c_actual)
        torch.cuda.synchronize()

        torch.matmul(a, b.T, out=c_expected)
        torch.cuda.synchronize()

        torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)

        # benchmark
        if bench:
            for name, func in [
                ("torch", lambda: torch.matmul(a, b.T, out=c_expected)),
                ("tilus", lambda: matmul(m_size, n_size, k_size, a, b, c_actual)),
            ]:
                latency = benchmark_func(func, warmup=5, repeat=20)
                tflops = 2 * m_size * n_size * k_size / latency * 1e-9
                rows.append([m_size, n_size, k_size, name, latency, tflops])

    if bench:
        df = pandas.DataFrame(rows, columns=headers)
        print(df)


if __name__ == "__main__":
    main(bench=True)

@yaoyaoding yaoyaoding force-pushed the blackwell-gemm branch 2 times, most recently from 4fe2772 to ddc79f8 Compare March 11, 2026 23:12
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
@yaoyaoding yaoyaoding mentioned this pull request Mar 11, 2026
17 tasks
@yaoyaoding yaoyaoding merged commit 24ec3f6 into main Mar 11, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant