In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl

In [6]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
B, D = 8, 16

x = torch.randn((B, D), device=device)
norm = nn.RMSNorm(normalized_shape=D, device=device)
y_torch = norm(x)

mean_square = torch.sum(x * x, axis=1) / x.shape[1]
_rstd = torch.rsqrt(mean_square + 1e-5) 

In [7]:
W = norm.weight
W

Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       device='cuda:0', requires_grad=True)

In [None]:
@triton.jit
def rms_norm_fwd_kernel(
    X_ptr,
    X_row_stride,
    Y_ptr,
    Y_row_stride,
    W_ptr,
    RSTD_ptr,
    RSTD_row_stride,
    n_cols,
    eps,
    BLOCK_SIZE: tl.constexpr,
):
    row_id = tl.program_id(0).to(tl.int64)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    X_ptr += row_id * X_row_stride
    Y_ptr += row_id * Y_row_stride
    RSTD_ptr += row_id * RSTD_row_stride

    x = tl.load(
        pointer=X_ptr + col_offsets,
        mask=mask,
        other=0.0,
    )
    w = tl.load(pointer=W_ptr + col_offsets, mask=mask, other=0.0)

    # cast to float32 for computation then cast back to original type
    x_dtype = x.dtype
    x.to(tl.float32)

    mean_square = tl.sum(x * x, axis=0) / n_cols
    rstd = tl.rsqrt(mean_square + eps)

    # cache rms for backward (small compared to X and saves *, sum, /, sqrt)
    tl.store(pointer=RSTD_ptr, value=rstd)

    x = x * rstd
    x = x.to(x_dtype)

    y = x * w
    y = y.to(x_dtype)

    tl.store(
        pointer=Y_ptr + col_offsets,
        value=y,
        mask=mask,
    )


@triton.jit
def rmsnorm_bwd_kernel(
    X_ptr,
    X_row_stride,
    X_dtype: tl.constexpr,
    dX_ptr,
    dX_row_stride,
    dY_ptr,
    dY_row_stride,
    W_ptr,
    dW_ptr,
    dW_row_stride,
    RSTD_ptr,
    RSTD_row_stride,
    n_rows,
    n_cols,
    rows_per_program: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    row_block_id = tl.program_id(0).to(tl.int64)
    row_start = row_block_id * rows_per_program
    row_end = min((row_block_id +1) * rows_per_program, n_rows)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    X_ptr += row_block_id * X_row_stride
    dX_ptr += row_block_id * dX_row_stride
    dY_ptr += row_block_id * X_row_stride
    dW_ptr += row_block_id * dW_row_stride
    RSTD_ptr += row_block_id

    x = tl.load(pointer=X_ptr + col_offsets, mask=mask, other=0.0)
    dy = tl.load(pointer=dY_ptr + col_offsets, mask=mask, other=0.0)
    w = tl.load(pointer=W_ptr + col_offsets, mask=mask, other=0.0)


def rms_norm_forward(
    x: torch.Tensor, w: torch.Tensor, eps: float = 1e-5
) -> torch.Tensor:
    n_rows, n_cols = x.shape
    y = torch.empty_like(x, dtype=x.dtype, device=x.device)
    rstd = torch.empty(n_rows, dtype=torch.float32, device=x.device)

    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    rms_norm_fwd_kernel[(n_rows,)](
        X_ptr=x,
        X_row_stride=x.stride(0),
        Y_ptr=y,
        Y_row_stride=y.stride(0),
        W_ptr=w,
        RSTD_ptr=rstd,
        RSTD_row_stride=rstd.stride(0),
        n_cols=n_cols,
        eps=eps,
        BLOCK_SIZE=BLOCK_SIZE,
    )

    return y, rstd


y, rstd = rms_norm_forward(x, W)

In [11]:
torch.testing.assert_close(y, y_torch, atol=1e-4, rtol=1e-4)
torch.testing.assert_close(rstd, _rstd, atol=1e-4, rtol=1e-4)