diff --git a/src/tilegym/ops/cutile/rms_norm.py b/src/tilegym/ops/cutile/rms_norm.py index e983d74..3ac0145 100644 --- a/src/tilegym/ops/cutile/rms_norm.py +++ b/src/tilegym/ops/cutile/rms_norm.py @@ -2,8 +2,6 @@ # # SPDX-License-Identifier: MIT -import math - import cuda.tile as ct import torch import torch.nn as nn @@ -14,127 +12,6 @@ from .utils import next_power_of_2 -@experimental_kernel -@ct.kernel(occupancy=2) -def rms_norm_backward_kernel( - dx, - dy, - x, - weight, - Rstd, - temp_buffer, - TILE_SIZE: ct.Constant[int], -): - """ - Compute input gradients for RMSNorm backward pass. - - Formula: dx_{m,i} = dy_{m,i} w_i / r_m - x_{m,i} / (N r_m^3) * sum_j dy_{m,j} w_j x_{m,j} - where: - - dy_{m,i} = dy[m,i] (upstream gradient) - - w_i = weight[i] (scale parameter) - - r_m = 1 / rstd[m] (RMS for row m) - - N = number of columns - - See rms_norm_backward_annotated() for detailed derivation. - - Each block handles exactly one row and processes all columns at once. - TILE_SIZE should be >= N (number of columns). - """ - row_idx = ct.bid(0) - M, N = x.shape - - # Load entire row from input and gradient - input_row = ct.load(x, index=(row_idx, 0), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.ZERO) - gradient_row = ct.load(dy, index=(row_idx, 0), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.ZERO) - - # Load reciprocal std (1D tensor [M]) and reshape for broadcasting - inv_std_row = ct.load(Rstd, index=(row_idx,), shape=(1,), padding_mode=ct.PaddingMode.ZERO) - inv_std_row = ct.reshape(inv_std_row, (1, 1)) # Reshape to [1, 1] for broadcasting - - # Load weight vector and reshape for broadcasting - weight_vector = ct.load(weight, index=(0,), shape=(TILE_SIZE,), padding_mode=ct.PaddingMode.ZERO) - weight_vector = ct.reshape(weight_vector, (1, TILE_SIZE)) # Reshape to [1, TILE_SIZE] for broadcasting - - # Compute sum_j dy_{m,j} w_j x_{m,j} for the correction term - - c1 = input_row * gradient_row - c2 = c1 * inv_std_row - - ct.store(temp_buffer, index=(row_idx, 0), tile=ct.astype(c2, temp_buffer.dtype)) - - weighted_gradient_product = c1 * weight_vector - weighted_gradient_sum = ct.sum(weighted_gradient_product, axis=1, keepdims=True) # [1, 1] - - # Compute normalization correction: x_{m,i} / (N r_m^3) * sum_j dy_{m,j} w_j x_{m,j} - # Since inv_std_row = 1/r_m, we have r_m^3 = 1/(inv_std_row^3) - inv_std_cubed = inv_std_row * inv_std_row * inv_std_row # [1, 1] - norm_factor = ct.full((1, 1), N * 1.0, dtype=ct.float32) # [1, 1] - normalization_correction_coeff = input_row * inv_std_cubed / norm_factor # [1, TILE_SIZE] - normalization_correction = normalization_correction_coeff * weighted_gradient_sum # [1, TILE_SIZE] - - # Compute direct term: dy_{m,i} w_i / r_m = gradient_row * weight_vector * inv_std_row - scaled_gradient = gradient_row * weight_vector * inv_std_row # [1, TILE_SIZE] - - # Final dx: direct term minus normalization correction - input_gradient_row = scaled_gradient - normalization_correction # [1, TILE_SIZE] - - # Convert back to the original dtype of dx - input_gradient_row = ct.astype(input_gradient_row, dx.dtype) - - # Store the result back to dx - ct.store(dx, index=(row_idx, 0), tile=input_gradient_row) - - -def rms_norm_backward( - x: torch.Tensor, - dy: torch.Tensor, - weight: torch.Tensor, - rstd: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - x = x.contiguous() - dy = dy.contiguous() - weight = weight.contiguous() - rstd = rstd.contiguous() - - x_shape = x.shape - - # Flatten to [M, N] - x = x.reshape(-1, x.shape[-1]) - dy = dy.reshape(-1, dy.shape[-1]) - - M, N = x.shape - - # Allocate outputs - dx = torch.empty_like(x) - dw = torch.empty_like(weight) # shape (N,) - temp_buffer = torch.empty(x.shape, device=x.device, dtype=torch.float32) - - dx = dx.detach() - dw = dw.detach() - - TILE_SIZE_N = next_power_of_2(N) - - # dx (row-parallel) algorithim - # Also stores dy * x / rms into temp_buffer for each row - grid_dx = (M,) - ct.launch( - torch.cuda.current_stream(), - grid_dx, - rms_norm_backward_kernel, - (dx, dy, x, weight, rstd, temp_buffer, TILE_SIZE_N), - ) - - # Compute dw by summing temp_buffer over the batch dimension - # temp_buffer contains: dy_{b,j} * x_{b,j} / rms_b (shape [M, N]) - # dw_j = sum_b(dy_{b,j} * x_{b,j} / rms_b) * weight_j - # temp_buffer already has dy * x * rstd, so we just sum over row dim (torch performance would be the same as cuTILE) - # Ensure accumulates are done in float32 to avoid precision issues - dw = temp_buffer[:, :N].to(torch.float32).sum(dim=0).to(weight.dtype) - - # Reshape dx back, dw already correct - return dx.view(*x_shape), dw - - @ct.kernel def rms_norm_kernel_gather( x, @@ -184,6 +61,7 @@ def rms_norm_kernel_static_persistent( X, # Input tensor Y, # Output tensor W, # Weight tensor + Rstd, # rstd output (for backward) TILE_SIZE_M: ct.Constant[int], # rows per tile TILE_SIZE_N: ct.Constant[int], # columns per tile eps: ct.Constant[float], # Epsilon value @@ -238,6 +116,9 @@ def rms_norm_kernel_static_persistent( variance_eps = ct.add(variance, eps_tensor) rsqrt_var = ct.rsqrt(variance_eps) + # Store rstd for backward pass + ct.store(Rstd, index=(current_bid,), tile=ct.reshape(rsqrt_var, (TILE_SIZE_M,)), allow_tma=False) + # Step 5: Apply normalization x_normalized = ct.mul(x, rsqrt_var) @@ -265,6 +146,105 @@ def rms_norm_kernel_static_persistent( ) +@experimental_kernel +@ct.kernel(occupancy=1) +def _rms_bwd(dx, dy, x, weight, Rstd, dw_partial, TILE_M: ct.Constant[int], TILE_N: ct.Constant[int]): + """ + Persistent RMSNorm backward — grid-stride loop with fused dw accumulation. + + Each block accumulates its dw contribution into a (grid, TILE_N) partial + sum buffer, avoiding the old M×N temp_buffer allocation. + + Only supports offset=0 (Gemma3 backward is not supported). + """ + bid = ct.bid(0) + M, N = x.shape[0], x.shape[1] + blocks = ct.num_blocks(0) + upper = (M + TILE_M - 1) // TILE_M + + w = ct.astype(ct.load(weight, index=(0,), shape=(TILE_N,), padding_mode=ct.PaddingMode.ZERO), ct.float32) + w = ct.reshape(w, (1, TILE_N)) + rcp = ct.full((TILE_M, 1), 1.0 / N, dtype=ct.float32) + dw_acc = ct.full((1, TILE_N), 0.0, dtype=ct.float32) + + for i in range(bid, upper, blocks): + xt = ct.astype( + ct.load(x, index=(i, 0), shape=(TILE_M, TILE_N), padding_mode=ct.PaddingMode.ZERO, latency=10), + ct.float32, + ) + dyt = ct.astype( + ct.load(dy, index=(i, 0), shape=(TILE_M, TILE_N), padding_mode=ct.PaddingMode.ZERO, latency=10), + ct.float32, + ) + r = ct.reshape( + ct.load(Rstd, index=(i,), shape=(TILE_M,), padding_mode=ct.PaddingMode.ZERO), + (TILE_M, 1), + ) + xhat = xt * r + wdy = dyt * w + c = ct.sum(xhat * wdy, axis=1, keepdims=True) * rcp + ct.store(dx, index=(i, 0), tile=ct.astype((wdy - xhat * c) * r, dx.dtype), allow_tma=False, latency=3) + dw_acc = dw_acc + ct.sum(dyt * xhat, axis=0, keepdims=True) + + ct.store(dw_partial, index=(bid, 0), tile=dw_acc, allow_tma=False) + + +_bwd_cfg: dict = {} # (M, N) → (tile_m, tile_n, grid, N) + + +def _bwd_tiles(M, N): + """Heuristic tile configuration for backward kernel.""" + T = next_power_of_2(N) + if T > 4096: + tm = 1 + elif T <= 2048 or (M >= 8192 and T <= 4096): + tm = 4 + else: + tm = 1 + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + tiles = (M + tm - 1) // tm + g = min(NUM_SMS, tiles) + if tiles <= 64: + g = min(g, 32) + return (tm, T, g, N) + + +def rms_norm_backward( + x: torch.Tensor, + dy: torch.Tensor, + weight: torch.Tensor, + rstd: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Standalone backward pass using persistent CuTile kernel.""" + x = x.contiguous() + dy = dy.contiguous() + weight = weight.contiguous() + rstd = rstd.contiguous() + + x_shape = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + M, N = x.shape + + cfg = _bwd_cfg.get((M, N)) + if cfg is None: + cfg = _bwd_tiles(M, N) + _bwd_cfg[(M, N)] = cfg + tm, T, g, No = cfg + + stream = torch.cuda.current_stream() + + dx = torch.empty_like(x) + dwp = torch.empty((g, T), device=x.device, dtype=torch.float32) + ct.launch(stream, (g,), _rms_bwd, (dx, dy, x, weight, rstd, dwp, tm, T)) + + dw = dwp.sum(0) + if T != No: + dw = dw[:No] + + return dx.view(*x_shape), dw.to(weight.dtype) + + class RMSNorm(torch.autograd.Function): @staticmethod def forward( @@ -319,6 +299,9 @@ def forward( else: static_persistent = False + # Allocate rstd for backward (both paths now store it) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if static_persistent: # Static persistent mode if bias is not None: @@ -341,27 +324,24 @@ def ceil_div(a, b): ceil_div(M, TILE_SIZE_M) * ceil_div(N, TILE_SIZE_N), ) grid = (grid_size,) - kernel_sp = rms_norm_kernel_static_persistent ct.launch( torch.cuda.current_stream(), grid, - kernel_sp, - (x_arg, y, weight, TILE_SIZE_M, TILE_SIZE_N, eps, offset), + rms_norm_kernel_static_persistent, + (x_arg, y, weight, rstd, TILE_SIZE_M, TILE_SIZE_N, eps, offset), ) else: # Standard mode if bias is not None: raise NotImplementedError("Bias is not supported in standard CuTile RMSNorm") - rstd = torch.empty((M,), dtype=torch.float32, device="cuda") MAX_FUSED_SIZE = 4096 // x.element_size() TILE_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(N)) grid = (M,) - kernel = rms_norm_kernel_gather ct.launch( torch.cuda.current_stream(), grid, - kernel, + rms_norm_kernel_gather, ( x_arg, weight, @@ -374,30 +354,30 @@ def ceil_div(a, b): ), ) - # Save variables needed for backward pass - ctx.save_for_backward(x, weight, rstd) - ctx.TILE_SIZE = TILE_SIZE - ctx.eps = eps - ctx.offset = offset + # Always save for backward (both paths now produce rstd) + ctx.save_for_backward(x, weight, rstd) + ctx.TILE_SIZE = next_power_of_2(N) + ctx.eps = eps + ctx.offset = offset return y.view(*x.shape) @staticmethod def backward(ctx, dy): """ - Backward pass for RMSNorm. - Retrieves saved tensors and delegates to rms_norm_backward(). + Persistent backward pass using grid-stride kernel. + Supports backward from both gather and static persistent forward modes. """ # Check if offset was used (backward not supported with non-zero offset) if ctx.offset != 0.0: - raise NotImplementedError("Backward pass not implemented for CuTile RMSNorm with non-zero offset") + raise NotImplementedError( + f"Backward pass not implemented for CuTile RMSNorm with non-zero offset ({ctx.offset})" + ) x, weight, rstd = ctx.saved_tensors - - # Call the standalone backward function dx, dw = rms_norm_backward(x, dy, weight, rstd) - # Return gradients: (x, normalized_shape, weight, eps, bias, static_persistent, offset) + # Gradients: (x, normalized_shape, weight, eps, bias, static_persistent, offset) return dx, None, dw, None, None, None, None @@ -501,18 +481,18 @@ def rms_norm_backward_torch( # Reshape rstd for broadcasting: (M,) -> (M, 1) rstd = rstd.view(M, 1) - # Gradient w.r.t. weight: sum over batch dimension (accumulate in float32) - # Match kernel order: (x * dy) * rstd to match precision behavior - dw = ((x * dy) * rstd).sum(dim=0, dtype=torch.float32) + # Cast to fp32 up front so all intermediates are full precision + x_f = x.float() + dy_f = dy.float() + w_f = weight.float() - # Normalized x (before scaling by weight) - for dx computation - x_norm = x * rstd + # Gradient w.r.t. weight: dw = sum((x * rstd) * dy, dim=0) + x_norm = x_f * rstd + dw = (dy_f * x_norm).sum(dim=0) - # Gradient w.r.t. x (accumulate in float32) - dy_weighted = dy * weight - c1 = (dy_weighted * x_norm).sum( - dim=1, keepdim=True, dtype=torch.float32 - ) # ensure accumulates are done in float32 to avoid precision issues + # Gradient w.r.t. x + dy_weighted = dy_f * w_f + c1 = (dy_weighted * x_norm).sum(dim=1, keepdim=True) dx = rstd * (dy_weighted - x_norm * c1 / N) dx = dx.view(x_shape).to(x.dtype) diff --git a/tests/benchmark/experimental/bench_rmsnorm_backward.py b/tests/benchmark/experimental/bench_rmsnorm_backward.py index 8d8abef..9aa209f 100644 --- a/tests/benchmark/experimental/bench_rmsnorm_backward.py +++ b/tests/benchmark/experimental/bench_rmsnorm_backward.py @@ -15,6 +15,7 @@ from tilegym.backend import is_backend_available from tilegym.ops.cutile.rms_norm import TileRMSNorm +from tilegym.ops.cutile.rms_norm import _bwd_tiles from tilegym.ops.cutile.rms_norm import rms_norm_backward DEVICE = triton.runtime.driver.active.get_active_torch_device() @@ -116,7 +117,11 @@ def run_backward(): dx_bytes = x.numel() * bytes_per_element # Write dx dw_bytes = weight.numel() * bytes_per_element # Write dw - temp_buffer_bytes = x.numel() * 4 * 2 # always write + read float32 + if backend == "cutile": + _, tile_n, grid, _ = _bwd_tiles(M, N) + temp_buffer_bytes = grid * tile_n * 4 * 2 # partial-sum buffer: write + read float32 + else: + temp_buffer_bytes = 0 total_bytes = input_x_bytes + dy_bytes + weight_bytes + rstd_bytes + dx_bytes + dw_bytes + temp_buffer_bytes diff --git a/tests/ops/experimental/test_rmsnorm_backward.py b/tests/ops/experimental/test_rmsnorm_backward.py index 80ea8c3..eeaf508 100644 --- a/tests/ops/experimental/test_rmsnorm_backward.py +++ b/tests/ops/experimental/test_rmsnorm_backward.py @@ -70,3 +70,58 @@ def test_op(self, m, n, dtype, backend, arch): atol=5e-2, multiple_outputs=True, ) + + +class Test_RMSNormAutogradBackward(common.PyTestCase): + @staticmethod + def reference(input, weight, eps): + x_fp32 = input.to(torch.float32) + variance = x_fp32.pow(2).mean(dim=-1, keepdim=True) + x_norm = x_fp32 * torch.rsqrt(variance + eps) + out = x_norm * weight.to(torch.float32) + return out.to(input.dtype) + + _backends = ["cutile"] + + @pytest.mark.parametrize( + "m, n, dtype", + [ + (256, 256, torch.float16), + (4096, 256, torch.bfloat16), + (256, 256, torch.float32), + ], + ) + @pytest.mark.parametrize("static_persistent", [True, False]) + @pytest.mark.parametrize("backend", _backends) + def test_op(self, m, n, dtype, static_persistent, backend, arch): + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + self.setUp() + device = torch.device("cuda") + eps = 1e-5 + + x = torch.rand((m, n), dtype=dtype, device=device).mul_(0.5).add_(-2.3) + w = torch.randn((n,), dtype=dtype, device=device) + dy = torch.randn((m, n), dtype=dtype, device=device) + + x_cutile = x.clone().detach().requires_grad_(True) + w_cutile = w.clone().detach().requires_grad_(True) + y_cutile = tilegym.ops.rms_norm( + x_cutile, + (n,), + w_cutile, + eps, + static_persistent=static_persistent, + ) + y_cutile.backward(dy) + + x_ref = x.clone().detach().requires_grad_(True) + w_ref = w.clone().detach().requires_grad_(True) + y_ref = self.reference(x_ref, w_ref, eps) + y_ref.backward(dy) + + torch.testing.assert_close(x_cutile.grad, x_ref.grad, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(w_cutile.grad, w_ref.grad, rtol=1e-2, atol=1e-2)