From 299787dfab651504c2939f445da5b681c1eb9c7d Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 20 Apr 2026 08:38:02 -0700 Subject: [PATCH 1/2] feat(rmsnorm): add multi_wave_cached kernel and mode-based kernel selection Add rms_norm_kernel_multi_wave_cached, a single-tile RMSNorm kernel that caches inputs in registers to avoid reloading from memory. Replace the boolean static_persistent parameter with a mode parameter for explicit kernel selection: - None: heuristic selection based on tensor shape (default) - "static_persistent": rms_norm_kernel_static_persistent - "multi_wave_reload": rms_norm_kernel_multi_wave_reload - "multi_wave_cached": rms_norm_kernel_multi_wave_cached Rename kernels for consistency: - rms_norm_kernel_gather -> rms_norm_kernel_multi_wave_reload - rms_norm_kernel_gather_regs_cached -> rms_norm_kernel_multi_wave_cached Update benchmark to compare all kernel modes side-by-side per dtype. --- src/tilegym/ops/cutile/rms_norm.py | 92 +++++++++++++++---- src/tilegym/ops/ops.py | 4 +- tests/benchmark/bench_mix_triton_cutile.py | 2 +- tests/benchmark/bench_rmsnorm.py | 62 +++++++------ .../ops/experimental/test_rmsnorm_backward.py | 6 +- tests/ops/test_rms_norm.py | 9 +- 6 files changed, 120 insertions(+), 55 deletions(-) diff --git a/src/tilegym/ops/cutile/rms_norm.py b/src/tilegym/ops/cutile/rms_norm.py index 8916f65b..11d206d2 100644 --- a/src/tilegym/ops/cutile/rms_norm.py +++ b/src/tilegym/ops/cutile/rms_norm.py @@ -12,6 +12,46 @@ from .utils import next_power_of_2 +@ct.kernel +def rms_norm_kernel_multi_wave_cached( + x, + w, + out, + Rstd, + N: ct.Constant[int], + eps: ct.Constant[float], + offset: ct.Constant[float], + TILE_SIZE: ct.Constant[int], +): + """ + Multi-wave RMSNorm kernel that caches inputs in registers (single tile). + + Formula: y = norm(x) * (offset + w) + For Llama: offset=0.0, For Gemma3: offset=1.0 + """ + row = ct.bid(0) + _rms = ct.full((TILE_SIZE,), 0.0, dtype=ct.float32) + offsets = ct.arange(TILE_SIZE, dtype=ct.int32) + check_bound = TILE_SIZE != N + + # cache inputs in registers + xj = ct.gather(x, (row, offsets), check_bounds=check_bound, latency=1) + xj = ct.astype(xj, ct.float32) + _rms += xj * xj + + # Calculate RMS Norm + rms = ct.rsqrt(ct.sum(_rms, axis=0, keepdims=False) / N + eps) + ct.scatter(Rstd, row, rms) + + wj = ct.gather(w, offsets, check_bounds=check_bound, latency=1) + wj = ct.astype(wj, ct.float32) + + # Apply offset: y = x_normalized * (offset + w) + yj = xj * rms * (offset + wj) + yj = ct.astype(yj, x.dtype) + ct.scatter(out, (row, offsets), yj, latency=1) + + @ct.kernel def rms_norm_kernel_gather( x, @@ -256,7 +296,7 @@ def forward( weight, eps, bias=None, - static_persistent=None, + mode=None, offset=0.0, ): """ @@ -268,7 +308,7 @@ def forward( weight: Weight tensor of shape [N] eps: Epsilon value for numerical stability bias: Bias tensor of shape [N], default is None - static_persistent: Whether to use static persistent kernel, default is False + mode: Kernel selection mode (None, "static_persistent", "multi_wave_reload", "multi_wave_cached") offset: Offset to add to weight (default 0.0 for Llama, 1.0 for Gemma3) Returns: @@ -294,17 +334,17 @@ def forward( NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - if static_persistent is None: + if mode is None: if M > NUM_SMS * 2: # Heuristic for static persistent mode: if we need run over 2 waves, use static persistent mode - static_persistent = True + mode = "static_persistent" else: - static_persistent = False + mode = "multi_wave_reload" # Allocate rstd for backward (both paths now store it) rstd = torch.empty((M,), dtype=torch.float32, device=x.device) - if static_persistent: + if mode == "static_persistent": # Static persistent mode if bias is not None: raise NotImplementedError("Bias is not supported in static persistent CuTile RMSNorm") @@ -349,8 +389,21 @@ def ceil_div(a, b): rms_norm_kernel_static_persistent, (x_arg, y, weight, rstd, TILE_SIZE_M, TILE_SIZE_N, eps, offset), ) - else: - # Standard mode + elif mode == "multi_wave_cached": + # Multi-wave cached mode (single tile, inputs cached in registers) + if bias is not None: + raise NotImplementedError("Bias is not supported in multi_wave_cached CuTile RMSNorm") + + TILE_SIZE = next_power_of_2(N) + grid = (M,) + ct.launch( + torch.cuda.current_stream(), + grid, + rms_norm_kernel_multi_wave_cached, + (x_arg, weight, y, rstd, N, eps, offset, TILE_SIZE), + ) + elif mode == "multi_wave_reload": + # Standard multi-wave reload mode if bias is not None: raise NotImplementedError("Bias is not supported in standard CuTile RMSNorm") @@ -372,6 +425,11 @@ def ceil_div(a, b): TILE_SIZE, ), ) + else: + raise ValueError( + f"Unknown mode '{mode}'. Supported modes: None, 'static_persistent', " + f"'multi_wave_reload', 'multi_wave_cached'" + ) # Always save for backward (both paths now produce rstd) ctx.save_for_backward(x, weight, rstd) @@ -396,12 +454,12 @@ def backward(ctx, dy): x, weight, rstd = ctx.saved_tensors dx, dw = rms_norm_backward(x, dy, weight, rstd) - # Gradients: (x, normalized_shape, weight, eps, bias, static_persistent, offset) + # Gradients: (x, normalized_shape, weight, eps, bias, mode, offset) return dx, None, dw, None, None, None, None @register_impl("rms_norm", backend="cutile") -def rms_norm(input, normalized_shape, weight, eps, bias=None, static_persistent=None, offset=0.0, **kwargs): +def rms_norm(input, normalized_shape, weight, eps, bias=None, mode=None, offset=0.0, **kwargs): """ Root mean square normalization implemented using CUDA Tile @@ -411,14 +469,14 @@ def rms_norm(input, normalized_shape, weight, eps, bias=None, static_persistent= weight: Tensor of shape (N,) eps: Small constant added to variance calculation bias: Bias tensor of shape (N,), default is None (not supported in cutile) - static_persistent: Whether to use static persistent kernel, default is False + mode: Kernel selection mode (None, "static_persistent", "multi_wave_reload", "multi_wave_cached") offset: Offset to add to weight (default 0.0 for Llama, 1.0 for Gemma3) **kwargs: Additional arguments for backend-specific configurations Returns: Normalized tensor with same shape as input """ - return RMSNorm.apply(input, normalized_shape, weight, eps, bias, static_persistent, offset) + return RMSNorm.apply(input, normalized_shape, weight, eps, bias, mode, offset) class TileRMSNorm(nn.Module): @@ -437,21 +495,21 @@ def __init__(self, hidden_size, eps=1e-6, offset=0.0): self.hidden_size = hidden_size self.offset = offset - def forward(self, hidden_states, static_persistent=None): + def forward(self, hidden_states, mode=None): """ - Forward pass with optional static_persistent override + Forward pass with optional mode override Args: hidden_states: Input tensor - static_persistent: Default is None, which means use heuristic to - decide whether to use static persistent mode for better performance + mode: Default is None, which means use heuristic to + decide which kernel mode to use for better performance """ return rms_norm( hidden_states, None, self.weight, self.variance_epsilon, - static_persistent=static_persistent, + mode=mode, offset=self.offset, ) diff --git a/src/tilegym/ops/ops.py b/src/tilegym/ops/ops.py index 8411e444..7539296a 100644 --- a/src/tilegym/ops/ops.py +++ b/src/tilegym/ops/ops.py @@ -138,7 +138,7 @@ def rms_norm( weight: torch.Tensor, eps: float, bias: Optional[torch.Tensor] = None, - static_persistent: bool = False, + mode: Optional[str] = None, **kwargs: Any, ): """ @@ -150,7 +150,7 @@ def rms_norm( weight: Tensor of shape (N,) eps: small scaler to be added to variance calculation prior to division. bias: Bias tensor of shape (N,), default is None - static_persistent: Whether to use static persistent kernel, default is False + mode: Kernel selection mode (None, "static_persistent", "multi_wave_reload", "multi_wave_cached") **kwargs: Additional arguments for backend-specific configurations """ raise NotImplementedError(f"rms_norm is not implemented for {get_current_backend()}") diff --git a/tests/benchmark/bench_mix_triton_cutile.py b/tests/benchmark/bench_mix_triton_cutile.py index 9fa3e1ed..56d7825d 100644 --- a/tests/benchmark/bench_mix_triton_cutile.py +++ b/tests/benchmark/bench_mix_triton_cutile.py @@ -131,7 +131,7 @@ def mixed_fn(): # Step 1: Triton vector add added = triton_vector_add(x, y) # Step 2: CuTile rmsnorm - return tilegym.ops.cutile.rms_norm(added, w_shape, weight, eps, static_persistent=True) + return tilegym.ops.cutile.rms_norm(added, w_shape, weight, eps, mode="static_persistent") fn = mixed_fn diff --git a/tests/benchmark/bench_rmsnorm.py b/tests/benchmark/bench_rmsnorm.py index 53431475..8b6bb118 100644 --- a/tests/benchmark/bench_rmsnorm.py +++ b/tests/benchmark/bench_rmsnorm.py @@ -19,7 +19,7 @@ def reference_rms_norm( weight: torch.Tensor, eps: float, bias: torch.Tensor = None, # Unused - kept for interface compatibility - static_persistent: bool = False, # Unused - kept for interface compatibility + **kwargs, # Unused - kept for interface compatibility ): """Fused PyTorch RMSNorm baseline using F.rms_norm. @@ -36,52 +36,55 @@ def reference_rms_norm( register_impl("rms_norm", "torch")(reference_rms_norm) -# Available backends with their display names and plot styles -ALL_BACKENDS = [ - ("cutile", "CuTile", ("blue", "-")) if is_backend_available("cutile") else None, - ("torch", "PyTorch", ("green", "-")), +# Available configs with their display names and plot styles +# (backend, mode, display label, plot style) +ALL_CONFIGS = [ + ("cutile", "static_persistent", "CuTile static persistent", ("purple", "-")), + ("cutile", "multi_wave_reload", "CuTile multi wave reload", ("blue", "-")), + ("cutile", "multi_wave_cached", "CuTile multi wave cached", ("red", "--")), + ("torch", None, "PyTorch", ("green", "-")), ] -def get_supported_backends(): - """Filter backends based on availability""" - return [p for p in ALL_BACKENDS if p is not None] +def get_supported_configs(): + cutile_available = is_backend_available("cutile") + if cutile_available: + return ALL_CONFIGS + return [c for c in ALL_CONFIGS if c[0] == "torch"] -def create_benchmark_config(dtype, static_persistent=True): +M_DEFAULT = 4096 + + +def create_benchmark_config(dtype): """Create a benchmark configuration for given parameters""" - available_backends = get_supported_backends() - if not available_backends: + supported = get_supported_configs() + if not supported: return None - backends, names, styles = zip(*available_backends) + # Use index as line_val key to pass both backend and mode + labels = [c[2] for c in supported] + styles = [c[3] for c in supported] dtype_name = str(dtype).split(".")[-1] # e.g., 'float16' from 'torch.float16' return triton.testing.Benchmark( x_names=["N"], x_vals=[2**i for i in range(10, 15)], # Hidden size from 1024 to 16384 - line_arg="backend", - line_vals=list(backends), - line_names=list(names), - styles=list(styles), + line_arg="config_idx", + line_vals=list(range(len(supported))), + line_names=labels, + styles=styles, ylabel="GB/s", - plot_name=f"rmsnorm-performance-{dtype_name}-persistent-{static_persistent}-GBps", + plot_name=f"rmsnorm-{dtype_name}-M{M_DEFAULT}", args={ "dtype": dtype, - "static_persistent": static_persistent, - "M": 4096, + "M": M_DEFAULT, }, # Fixed batch*seq_len ) -@triton.testing.perf_report( - [ - create_benchmark_config(dtype, static_persistent) - for dtype in [torch.float16, torch.bfloat16] - for static_persistent in [True, False] - ] -) -def bench_rmsnorm(N, backend, dtype, static_persistent, M, device=DEVICE): +@triton.testing.perf_report([create_benchmark_config(dtype) for dtype in [torch.float16, torch.bfloat16]]) +def bench_rmsnorm(N, config_idx, dtype, M, device=DEVICE): eps = 1e-5 # Create input tensors @@ -91,7 +94,10 @@ def bench_rmsnorm(N, backend, dtype, static_persistent, M, device=DEVICE): x = torch.rand(x_shape, dtype=dtype, device=device, requires_grad=False).mul_(0.5).add_(-2.3) weight = torch.randn(w_shape, dtype=dtype, device=device, requires_grad=False) - fn = lambda: tilegym.ops.rms_norm(x, w_shape, weight, eps, static_persistent=static_persistent, backend=backend) + supported = get_supported_configs() + backend, mode, _, _ = supported[config_idx] + + fn = lambda: tilegym.ops.rms_norm(x, w_shape, weight, eps, mode=mode, backend=backend) ref = lambda: reference_rms_norm(x, w_shape, weight, eps) torch.testing.assert_close(fn(), ref(), atol=5e-2, rtol=0.0) diff --git a/tests/ops/experimental/test_rmsnorm_backward.py b/tests/ops/experimental/test_rmsnorm_backward.py index eeaf508b..f727756d 100644 --- a/tests/ops/experimental/test_rmsnorm_backward.py +++ b/tests/ops/experimental/test_rmsnorm_backward.py @@ -91,9 +91,9 @@ def reference(input, weight, eps): (256, 256, torch.float32), ], ) - @pytest.mark.parametrize("static_persistent", [True, False]) + @pytest.mark.parametrize("mode", [None, "static_persistent", "multi_wave_reload", "multi_wave_cached"]) @pytest.mark.parametrize("backend", _backends) - def test_op(self, m, n, dtype, static_persistent, backend, arch): + def test_op(self, m, n, dtype, mode, backend, arch): if tilegym.is_backend_available(backend): tilegym.set_backend(backend) else: @@ -114,7 +114,7 @@ def test_op(self, m, n, dtype, static_persistent, backend, arch): (n,), w_cutile, eps, - static_persistent=static_persistent, + mode=mode, ) y_cutile.backward(dy) diff --git a/tests/ops/test_rms_norm.py b/tests/ops/test_rms_norm.py index 0c9458b1..41088f64 100644 --- a/tests/ops/test_rms_norm.py +++ b/tests/ops/test_rms_norm.py @@ -42,13 +42,13 @@ def reference(input, normalized_shape, weight, eps): (256, 256, torch.float32), ], ) - @pytest.mark.parametrize("static_persistent", [True, False]) + @pytest.mark.parametrize("mode", [None, "static_persistent", "multi_wave_reload", "multi_wave_cached"]) @pytest.mark.parametrize("backend", _backends) @markif( lambda arch, m, n: arch in ["sm120", "sm121"] and m == 31072 and n == 4096, mark=pytest.mark.slow, ) - def test_op(self, m, n, dtype, static_persistent, backend, arch): + def test_op(self, m, n, dtype, mode, backend, arch): if tilegym.is_backend_available(backend): tilegym.set_backend(backend) else: @@ -56,7 +56,8 @@ def test_op(self, m, n, dtype, static_persistent, backend, arch): # skip static_persistent tests when n > 16384 to avoid excessive memory usage # Avoid tileiras hangs on RTX PRO 6000 which has 100 KB shared memory per SM - if static_persistent and n > 16384: + # mode=None can also select static_persistent via heuristic when M > NUM_SMS * 2 + if mode in ("static_persistent", None) and n > 16384: pytest.skip("Skipping static_persistent test for large n to avoid excessive memory usage") self.setUp() @@ -81,7 +82,7 @@ def test_op(self, m, n, dtype, static_persistent, backend, arch): "eps": eps, }, extra_test_kwargs={ - "static_persistent": static_persistent, + "mode": mode, }, rtol=0.0, atol=5e-2, From 757a5914c6565aaf0914eaa42f9c0fe41554a5f7 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Thu, 23 Apr 2026 06:01:38 -0700 Subject: [PATCH 2/2] ensure same signature --- tests/benchmark/bench_rmsnorm.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/benchmark/bench_rmsnorm.py b/tests/benchmark/bench_rmsnorm.py index 8b6bb118..30b35e91 100644 --- a/tests/benchmark/bench_rmsnorm.py +++ b/tests/benchmark/bench_rmsnorm.py @@ -2,6 +2,9 @@ # # SPDX-License-Identifier: MIT +from typing import Any +from typing import Optional + import torch import torch.nn.functional as F import triton @@ -15,11 +18,12 @@ def reference_rms_norm( input: torch.Tensor, - normalized_shape: tuple, + normalized_shape: Any, weight: torch.Tensor, eps: float, - bias: torch.Tensor = None, # Unused - kept for interface compatibility - **kwargs, # Unused - kept for interface compatibility + bias: Optional[torch.Tensor] = None, + mode: Optional[str] = None, + **kwargs: Any, ): """Fused PyTorch RMSNorm baseline using F.rms_norm.