# LayerNorm

In [2]:
import torch
import torch.nn.functional as F
import triton
import triton.language as tl


@triton.jit
def layer_norm_fwd_kernel(
    X, Y, W, B, Mean, Rstd, stride, N, eps, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    row = pid

    # 将 X 和 Y 都定位到指定当前线程块对应的那一行数据上
    X += row * stride
    Y += row * stride

    # 计算 X 中第 row 行数组的均值
    # 依次计算出每个 BLOCK_SIZE 区间内的 mean，然后相加在一起
    block_mean = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
    for blk_start in range(0, N, BLOCK_SIZE):
        offsets = blk_start + tl.arange(0, BLOCK_SIZE)
        x = tl.load(X + offsets, mask=offsets < N, other=0).to(tl.float32)
        block_mean += x
    mean = tl.sum(block_mean) / N
    tl.store(Mean + row, mean)

    # 计算 X 中第 row 行数组的方差
    block_var = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
    for blk_start in range(0, N, BLOCK_SIZE):
        offsets = blk_start + tl.arange(0, BLOCK_SIZE)
        x = tl.load(X + offsets, mask=offsets < N, other=0).to(tl.float32)
        x = tl.where(offsets < N, x - mean, 0.0)
        block_var += x * x
    var = tl.sum(block_var) / N
    rstd = 1 / tl.sqrt(var + eps)
    tl.store(Rstd + row, rstd)

    # 对 x 应用规一化
    for blk_start in range(0, N, BLOCK_SIZE):
        offsets = blk_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < N
        x = tl.load(X + offsets, mask=mask).to(tl.float32)
        w = tl.load(W + offsets, mask=mask)
        b = tl.load(B + offsets, mask=mask)
        y = (x - mean) * rstd * w + b
        tl.store(Y + offsets, y, mask=mask)


def layer_norm_fwd(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps=1e-5):
    assert x.is_cuda and weight.is_cuda and bias.is_cuda

    output = torch.zeros_like(x, device="cuda")

    feature_dim = x.size(-1)
    x = x.view(-1, feature_dim)
    n = x.size(0)
    mean = torch.zeros(n, dtype=x.dtype, device="cuda")
    rstd = torch.zeros(n, dtype=x.dtype, device="cuda")
    grid = (n,)
    layer_norm_fwd_kernel[grid](
        x,
        output,
        weight,
        bias,
        mean,
        rstd,
        stride=x.stride(0),
        N=feature_dim,
        eps=eps,
        BLOCK_SIZE=64,
    )
    return output


def main_test():
    batch_size = 2
    seq_len = 3
    input_size = 5
    x = torch.randn(batch_size, seq_len, input_size, device="cuda")
    weight = torch.randn((input_size,), device="cuda")
    bias = torch.zeros((input_size,), device="cuda")

    layer_norm_output_torch = F.layer_norm(x, (input_size,), weight=weight, bias=bias)
    layer_norm_output_triton = layer_norm_fwd(x, weight, bias)

    print(torch.allclose(layer_norm_output_torch, layer_norm_output_triton, atol=1e-6))



main_test()


True
