Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 75 additions & 17 deletions src/tilegym/ops/cutile/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,46 @@
from .utils import next_power_of_2


@ct.kernel
def rms_norm_kernel_multi_wave_cached(
x,
w,
out,
Rstd,
N: ct.Constant[int],
eps: ct.Constant[float],
offset: ct.Constant[float],
TILE_SIZE: ct.Constant[int],
):
"""
Multi-wave RMSNorm kernel that caches inputs in registers (single tile).

Formula: y = norm(x) * (offset + w)
For Llama: offset=0.0, For Gemma3: offset=1.0
"""
row = ct.bid(0)
_rms = ct.full((TILE_SIZE,), 0.0, dtype=ct.float32)
offsets = ct.arange(TILE_SIZE, dtype=ct.int32)
check_bound = TILE_SIZE != N

# cache inputs in registers
xj = ct.gather(x, (row, offsets), check_bounds=check_bound, latency=1)
xj = ct.astype(xj, ct.float32)
_rms += xj * xj

# Calculate RMS Norm
rms = ct.rsqrt(ct.sum(_rms, axis=0, keepdims=False) / N + eps)
ct.scatter(Rstd, row, rms)

wj = ct.gather(w, offsets, check_bounds=check_bound, latency=1)
wj = ct.astype(wj, ct.float32)

# Apply offset: y = x_normalized * (offset + w)
yj = xj * rms * (offset + wj)
yj = ct.astype(yj, x.dtype)
ct.scatter(out, (row, offsets), yj, latency=1)


@ct.kernel
def rms_norm_kernel_gather(
x,
Expand Down Expand Up @@ -265,7 +305,7 @@ def forward(
weight,
eps,
bias=None,
static_persistent=None,
mode=None,
offset=0.0,
):
"""
Expand All @@ -277,7 +317,7 @@ def forward(
weight: Weight tensor of shape [N]
eps: Epsilon value for numerical stability
bias: Bias tensor of shape [N], default is None
static_persistent: Whether to use static persistent kernel, default is False
mode: Kernel selection mode (None, "static_persistent", "multi_wave_reload", "multi_wave_cached")
offset: Offset to add to weight (default 0.0 for Llama, 1.0 for Gemma3)

Returns:
Expand All @@ -303,17 +343,17 @@ def forward(

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count

if static_persistent is None:
if mode is None:
if M > NUM_SMS * 2:
# Heuristic for static persistent mode: if we need run over 2 waves, use static persistent mode
static_persistent = True
mode = "static_persistent"
else:
static_persistent = False
mode = "multi_wave_reload"

# Allocate rstd for backward (both paths now store it)
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)

if static_persistent:
if mode == "static_persistent":
# Static persistent mode
if bias is not None:
raise NotImplementedError("Bias is not supported in static persistent CuTile RMSNorm")
Expand Down Expand Up @@ -358,8 +398,21 @@ def ceil_div(a, b):
rms_norm_kernel_static_persistent,
(x_arg, y, weight, rstd, TILE_SIZE_M, TILE_SIZE_N, eps, offset),
)
else:
# Standard mode
elif mode == "multi_wave_cached":
# Multi-wave cached mode (single tile, inputs cached in registers)
if bias is not None:
raise NotImplementedError("Bias is not supported in multi_wave_cached CuTile RMSNorm")

TILE_SIZE = next_power_of_2(N)
grid = (M,)
ct.launch(
torch.cuda.current_stream(),
grid,
rms_norm_kernel_multi_wave_cached,
(x_arg, weight, y, rstd, N, eps, offset, TILE_SIZE),
)
elif mode == "multi_wave_reload":
# Standard multi-wave reload mode
if bias is not None:
raise NotImplementedError("Bias is not supported in standard CuTile RMSNorm")

Expand All @@ -381,6 +434,11 @@ def ceil_div(a, b):
TILE_SIZE,
),
)
else:
raise ValueError(
f"Unknown mode '{mode}'. Supported modes: None, 'static_persistent', "
f"'multi_wave_reload', 'multi_wave_cached'"
)

# Always save for backward (both paths now produce rstd)
ctx.save_for_backward(x, weight, rstd)
Expand All @@ -405,12 +463,12 @@ def backward(ctx, dy):
x, weight, rstd = ctx.saved_tensors
dx, dw = rms_norm_backward(x, dy, weight, rstd)

# Gradients: (x, normalized_shape, weight, eps, bias, static_persistent, offset)
# Gradients: (x, normalized_shape, weight, eps, bias, mode, offset)
return dx, None, dw, None, None, None, None


@register_impl("rms_norm", backend="cutile")
def rms_norm(input, normalized_shape, weight, eps, bias=None, static_persistent=None, offset=0.0, **kwargs):
def rms_norm(input, normalized_shape, weight, eps, bias=None, mode=None, offset=0.0, **kwargs):
"""
Root mean square normalization implemented using CUDA Tile

Expand All @@ -420,14 +478,14 @@ def rms_norm(input, normalized_shape, weight, eps, bias=None, static_persistent=
weight: Tensor of shape (N,)
eps: Small constant added to variance calculation
bias: Bias tensor of shape (N,), default is None (not supported in cutile)
static_persistent: Whether to use static persistent kernel, default is False
mode: Kernel selection mode (None, "static_persistent", "multi_wave_reload", "multi_wave_cached")
offset: Offset to add to weight (default 0.0 for Llama, 1.0 for Gemma3)
**kwargs: Additional arguments for backend-specific configurations

Returns:
Normalized tensor with same shape as input
"""
return RMSNorm.apply(input, normalized_shape, weight, eps, bias, static_persistent, offset)
return RMSNorm.apply(input, normalized_shape, weight, eps, bias, mode, offset)


class TileRMSNorm(nn.Module):
Expand All @@ -446,21 +504,21 @@ def __init__(self, hidden_size, eps=1e-6, offset=0.0):
self.hidden_size = hidden_size
self.offset = offset

def forward(self, hidden_states, static_persistent=None):
def forward(self, hidden_states, mode=None):
"""
Forward pass with optional static_persistent override
Forward pass with optional mode override

Args:
hidden_states: Input tensor
static_persistent: Default is None, which means use heuristic to
decide whether to use static persistent mode for better performance
mode: Default is None, which means use heuristic to
decide which kernel mode to use for better performance
"""
return rms_norm(
hidden_states,
None,
self.weight,
self.variance_epsilon,
static_persistent=static_persistent,
mode=mode,
offset=self.offset,
)

Expand Down
4 changes: 2 additions & 2 deletions src/tilegym/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def rms_norm(
weight: torch.Tensor,
eps: float,
bias: Optional[torch.Tensor] = None,
static_persistent: bool = False,
mode: Optional[str] = None,
**kwargs: Any,
):
"""
Expand All @@ -150,7 +150,7 @@ def rms_norm(
weight: Tensor of shape (N,)
eps: small scaler to be added to variance calculation prior to division.
bias: Bias tensor of shape (N,), default is None
static_persistent: Whether to use static persistent kernel, default is False
mode: Kernel selection mode (None, "static_persistent", "multi_wave_reload", "multi_wave_cached")
**kwargs: Additional arguments for backend-specific configurations
"""
raise NotImplementedError(f"rms_norm is not implemented for {get_current_backend()}")
Expand Down
2 changes: 1 addition & 1 deletion tests/benchmark/bench_mix_triton_cutile.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def mixed_fn():
# Step 1: Triton vector add
added = triton_vector_add(x, y)
# Step 2: CuTile rmsnorm
return tilegym.ops.cutile.rms_norm(added, w_shape, weight, eps, static_persistent=True)
return tilegym.ops.cutile.rms_norm(added, w_shape, weight, eps, mode="static_persistent")

fn = mixed_fn

Expand Down
70 changes: 40 additions & 30 deletions tests/benchmark/bench_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#
# SPDX-License-Identifier: MIT

from typing import Any
from typing import Optional

import torch
import torch.nn.functional as F
import triton
Expand All @@ -15,11 +18,12 @@

def reference_rms_norm(
input: torch.Tensor,
normalized_shape: tuple,
normalized_shape: Any,
weight: torch.Tensor,
eps: float,
bias: torch.Tensor = None, # Unused - kept for interface compatibility
static_persistent: bool = False, # Unused - kept for interface compatibility
bias: Optional[torch.Tensor] = None,
mode: Optional[str] = None,
**kwargs: Any,
):
"""Fused PyTorch RMSNorm baseline using F.rms_norm.

Expand All @@ -36,52 +40,55 @@ def reference_rms_norm(
register_impl("rms_norm", "torch")(reference_rms_norm)


# Available backends with their display names and plot styles
ALL_BACKENDS = [
("cutile", "CuTile", ("blue", "-")) if is_backend_available("cutile") else None,
("torch", "PyTorch", ("green", "-")),
# Available configs with their display names and plot styles
# (backend, mode, display label, plot style)
ALL_CONFIGS = [
("cutile", "static_persistent", "CuTile static persistent", ("purple", "-")),
("cutile", "multi_wave_reload", "CuTile multi wave reload", ("blue", "-")),
("cutile", "multi_wave_cached", "CuTile multi wave cached", ("red", "--")),
("torch", None, "PyTorch", ("green", "-")),
]


def get_supported_backends():
"""Filter backends based on availability"""
return [p for p in ALL_BACKENDS if p is not None]
def get_supported_configs():
cutile_available = is_backend_available("cutile")
if cutile_available:
return ALL_CONFIGS
return [c for c in ALL_CONFIGS if c[0] == "torch"]


M_DEFAULT = 4096

def create_benchmark_config(dtype, static_persistent=True):

def create_benchmark_config(dtype):
"""Create a benchmark configuration for given parameters"""
available_backends = get_supported_backends()
if not available_backends:
supported = get_supported_configs()
if not supported:
return None

backends, names, styles = zip(*available_backends)
# Use index as line_val key to pass both backend and mode
labels = [c[2] for c in supported]
styles = [c[3] for c in supported]
dtype_name = str(dtype).split(".")[-1] # e.g., 'float16' from 'torch.float16'

return triton.testing.Benchmark(
x_names=["N"],
x_vals=[2**i for i in range(10, 15)], # Hidden size from 1024 to 16384
line_arg="backend",
line_vals=list(backends),
line_names=list(names),
styles=list(styles),
line_arg="config_idx",
line_vals=list(range(len(supported))),
line_names=labels,
styles=styles,
ylabel="GB/s",
plot_name=f"rmsnorm-performance-{dtype_name}-persistent-{static_persistent}-GBps",
plot_name=f"rmsnorm-{dtype_name}-M{M_DEFAULT}",
args={
"dtype": dtype,
"static_persistent": static_persistent,
"M": 4096,
"M": M_DEFAULT,
}, # Fixed batch*seq_len
)


@triton.testing.perf_report(
[
create_benchmark_config(dtype, static_persistent)
for dtype in [torch.float16, torch.bfloat16]
for static_persistent in [True, False]
]
)
def bench_rmsnorm(N, backend, dtype, static_persistent, M, device=DEVICE):
@triton.testing.perf_report([create_benchmark_config(dtype) for dtype in [torch.float16, torch.bfloat16]])
def bench_rmsnorm(N, config_idx, dtype, M, device=DEVICE):
eps = 1e-5

# Create input tensors
Expand All @@ -91,7 +98,10 @@ def bench_rmsnorm(N, backend, dtype, static_persistent, M, device=DEVICE):
x = torch.rand(x_shape, dtype=dtype, device=device, requires_grad=False).mul_(0.5).add_(-2.3)
weight = torch.randn(w_shape, dtype=dtype, device=device, requires_grad=False)

fn = lambda: tilegym.ops.rms_norm(x, w_shape, weight, eps, static_persistent=static_persistent, backend=backend)
supported = get_supported_configs()
backend, mode, _, _ = supported[config_idx]

fn = lambda: tilegym.ops.rms_norm(x, w_shape, weight, eps, mode=mode, backend=backend)
ref = lambda: reference_rms_norm(x, w_shape, weight, eps)
torch.testing.assert_close(fn(), ref(), atol=5e-2, rtol=0.0)

Expand Down
6 changes: 3 additions & 3 deletions tests/ops/experimental/test_rmsnorm_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def reference(input, weight, eps):
(256, 256, torch.float32),
],
)
@pytest.mark.parametrize("static_persistent", [True, False])
@pytest.mark.parametrize("mode", [None, "static_persistent", "multi_wave_reload", "multi_wave_cached"])
@pytest.mark.parametrize("backend", _backends)
def test_op(self, m, n, dtype, static_persistent, backend, arch):
def test_op(self, m, n, dtype, mode, backend, arch):
if tilegym.is_backend_available(backend):
tilegym.set_backend(backend)
else:
Expand All @@ -114,7 +114,7 @@ def test_op(self, m, n, dtype, static_persistent, backend, arch):
(n,),
w_cutile,
eps,
static_persistent=static_persistent,
mode=mode,
)
y_cutile.backward(dy)

Expand Down
9 changes: 5 additions & 4 deletions tests/ops/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,22 @@ def reference(input, normalized_shape, weight, eps):
(256, 256, torch.float32),
],
)
@pytest.mark.parametrize("static_persistent", [True, False])
@pytest.mark.parametrize("mode", [None, "static_persistent", "multi_wave_reload", "multi_wave_cached"])
@pytest.mark.parametrize("backend", _backends)
@markif(
lambda arch, m, n: arch in ["sm120", "sm121"] and m == 31072 and n == 4096,
mark=pytest.mark.slow,
)
def test_op(self, m, n, dtype, static_persistent, backend, arch):
def test_op(self, m, n, dtype, mode, backend, arch):
if tilegym.is_backend_available(backend):
tilegym.set_backend(backend)
else:
pytest.skip(f"Backend {backend} is not available")

# skip static_persistent tests when n > 16384 to avoid excessive memory usage
# Avoid tileiras hangs on RTX PRO 6000 which has 100 KB shared memory per SM
if static_persistent and n > 16384:
# mode=None can also select static_persistent via heuristic when M > NUM_SMS * 2
if mode in ("static_persistent", None) and n > 16384:
pytest.skip("Skipping static_persistent test for large n to avoid excessive memory usage")

self.setUp()
Expand All @@ -81,7 +82,7 @@ def test_op(self, m, n, dtype, static_persistent, backend, arch):
"eps": eps,
},
extra_test_kwargs={
"static_persistent": static_persistent,
"mode": mode,
},
rtol=0.0,
atol=5e-2,
Expand Down
Loading