Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 46 additions & 44 deletions src/tilegym/ops/cutile/group_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
#
# SPDX-License-Identifier: MIT

from types import SimpleNamespace

import cuda.tile as ct
import cuda.tile_experimental as ct_experimental
import torch

from tilegym.backend import register_impl
Expand Down Expand Up @@ -99,6 +102,46 @@ def group_gemm_kernel(
last_problem_end = last_problem_end + num_tiles


def _group_gemm_autotune_configs():
"""
Iterator of autotune configurations for group GEMM kernel.
"""
gpu_capability = torch.cuda.get_device_capability()
if gpu_capability in [(12, 0), (12, 1)]:
yield SimpleNamespace(TILE_M=64, TILE_N=128, TILE_K=128, num_ctas=1, occupancy=1)
yield SimpleNamespace(TILE_M=128, TILE_N=128, TILE_K=128, num_ctas=1, occupancy=1)
yield SimpleNamespace(TILE_M=128, TILE_N=128, TILE_K=64, num_ctas=1, occupancy=1)
else:
yield SimpleNamespace(TILE_M=256, TILE_N=256, TILE_K=64, num_ctas=2, occupancy=1)


def cutile_autotune_group_gemm(stream, group_A, group_B, group_C, transpose_b, device):
"""Autotune group GEMM kernel."""
NUM_SMS = torch.cuda.get_device_properties(device).multi_processor_count

ct_experimental.autotune_launch(
stream,
grid_fn=lambda cfg: (NUM_SMS // cfg.num_ctas * cfg.occupancy, 1, 1),
kernel=group_gemm_kernel,
args_fn=lambda cfg: (
group_A,
group_B,
group_C,
cfg.TILE_M,
cfg.TILE_N,
cfg.TILE_K,
NUM_SMS // cfg.num_ctas * cfg.occupancy,
transpose_b,
),
hints_fn=lambda cfg: {
"num_ctas": cfg.num_ctas,
"occupancy": cfg.occupancy,
},
search_space=_group_gemm_autotune_configs,
)
return group_C


def group_gemm(
group_A,
group_B,
Expand All @@ -115,26 +158,6 @@ def group_gemm(
device = group_A[0].device
dtype = group_A[0].dtype

# Kernel configuration
default_configs = {
"TILE_M": 128,
"TILE_N": 128,
"TILE_K": 64,
"num_ctas": None, # Let compiler auto-pick
}
user_cfg = kwargs.get("kernel_configs")
if user_cfg is None:
kernel_configs = default_configs
else:
kernel_configs = {**default_configs, **user_cfg}
TILE_M = kernel_configs.get("TILE_M")
TILE_N = kernel_configs.get("TILE_N")
TILE_K = kernel_configs.get("TILE_K")
num_ctas = kernel_configs.get("num_ctas", None)
occupancy = kernel_configs.get("occupancy", None)

NUM_SMS = torch.cuda.get_device_properties(device).multi_processor_count

# Create output tensors
group_C = []
for A, B in zip(group_A, group_B):
Expand All @@ -143,30 +166,9 @@ def group_gemm(
C = torch.empty((M, N), device=device, dtype=dtype)
group_C.append(C)

kernel = group_gemm_kernel
# When num_ctas is specified, adjust grid size to account for multiple CTAs per SM
num_ctas_for_grid = num_ctas if num_ctas is not None else 1
grid_size = NUM_SMS // num_ctas_for_grid
grid = (grid_size,)

logger.debug(f"[cuTile] group_gemm launching with grid={grid}, num_ctas={num_ctas}, NUM_SMS={NUM_SMS}")

ct.launch(
torch.cuda.current_stream(),
grid,
kernel,
(
group_A,
group_B,
group_C,
TILE_M,
TILE_N,
TILE_K,
grid_size, # Use adjusted grid size for persistent scheduling stride
transpose_b,
),
)

# Autotune mode
stream = torch.cuda.current_stream()
cutile_autotune_group_gemm(stream, group_A, group_B, group_C, transpose_b, device)
return group_C


Expand Down
126 changes: 126 additions & 0 deletions tests/benchmark/bench_group_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#!/usr/bin/env python3

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: MIT
import torch
import triton
import triton.testing

import tilegym
from tilegym.backend import is_backend_available
from tilegym.backend import register_impl

# Available backends for benchmarking
ALL_BACKENDS = [
("cutile", "CuTile", ("orange", "-")) if is_backend_available("cutile") else None,
("torch", "PyTorch", ("green", "-")),
]


def get_supported_backends(datatype):
"""Filter backends based on datatype support and availability"""
if datatype == torch.float8_e5m2:
return [p for p in ALL_BACKENDS if p is not None and p[0] != "torch"]
else:
return [p for p in ALL_BACKENDS if p is not None]


def reference_group_gemm(group_A: list, group_B: list, transpose_b: bool = False):
"""Reference implementation using PyTorch"""
group_C = []
for i in range(len(group_A)):
A = group_A[i]
B = group_B[i]
if transpose_b:
B = B.transpose(-2, -1)
C = torch.matmul(A, B)
group_C.append(C)
return group_C


register_impl("group_gemm", "torch")(reference_group_gemm)


def create_benchmark_config(datatype, num_groups, transpose_b):
"""Create a benchmark configuration for given datatype and backends"""
available_backends = get_supported_backends(datatype)
if not available_backends:
return None

backends, names, styles = zip(*available_backends)
dtype_name = str(datatype).split(".")[-1]

return triton.testing.Benchmark(
x_names=["M", "N", "K"],
x_vals=[2**i for i in range(10, 14)],
line_arg="backend",
line_vals=list(backends),
line_names=list(names),
styles=list(styles),
xlabel="M/N/K",
ylabel="TFLOPS",
plot_name=f"group-gemm-num_groups{num_groups}-transpose{transpose_b}-{dtype_name}-TFLOPS",
args={
"num_groups": num_groups,
"transpose_b": transpose_b,
"datatype": datatype,
},
)


@triton.testing.perf_report(
[
create_benchmark_config(datatype, num_groups, transpose_b)
for datatype in [torch.float16, torch.float8_e5m2]
for num_groups in [4, 16]
for transpose_b in [False, True]
]
)
def bench_group_gemm(
M,
N,
K,
num_groups,
transpose_b,
backend,
datatype,
device="cuda",
):
# Create input tensors
group_A = []
group_B = []

for i in range(num_groups):
A = torch.rand((M, K), device=device, dtype=torch.half).normal_(std=0.3).to(datatype)
B_shape = (N, K) if transpose_b else (K, N)
B = torch.rand(B_shape, device=device, dtype=torch.half).normal_(std=0.3).to(datatype)

group_A.append(A)
group_B.append(B)

fn = lambda: tilegym.ops.group_gemm(group_A, group_B, transpose_b=transpose_b, backend=backend)

if datatype != torch.float8_e5m2:
# Verify correctness for non-FP8 types because torch doesn't support FP8 matmul
ref = lambda: reference_group_gemm(group_A, group_B, transpose_b=transpose_b)
result = fn()
ref_result = ref()
for i in range(len(result)):
torch.testing.assert_close(result[i], ref_result[i], atol=1e-2, rtol=1e-2)

# Calculate theoretical TFLOPS
# GEMM operation: C = A @ B
# For each matrix: 2 * M * N * K FLOPs (multiply-add operations)
total_flops = num_groups * 2 * M * N * K

ms = triton.testing.do_bench(fn)

# Calculate TFLOPS
tflops = total_flops / (ms * 1e-3) / 1e12

return tflops


if __name__ == "__main__":
bench_group_gemm.run(print_data=True)