This notebook shows how to compute an RMS norm using the cuDNN python frontend.

$$\text{RMSNorm}(x) = \frac{x}{\sqrt{\mathbb{E}(x^2) + \epsilon}}\cdot\gamma+\beta$$

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/01_matmul_bias.ipynb)

## Prerequisites and Setup
This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected.

In [None]:
# get_ipython().system('nvidia-smi')

If running on Colab, you will need to install the cudnn python interface.

In [None]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')

#### General Setup
The cudnn handle is a per device handle used to initialize cudnn context.



In [None]:
import cudnn
import torch
import sys

torch.manual_seed(1)
handle = cudnn.create_handle()

print("Running with cudnn backend version:", cudnn.backend_version())

assert torch.cuda.is_available()

### RMSNorm Reference Computation

In [None]:
# Reference Model:
class RMSNorm(torch.nn.Module):
    """Root Mean Square Layer Normalization.

    Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
    https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
    """

    def __init__(self, dim: int = -1, eps: float = 1e-5) -> None:
        super().__init__()
        self.eps = eps
        self.dim = dim

    def forward(
        self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None
    ) -> torch.Tensor:
        # NOTE: the original RMSNorm paper implementation is not equivalent
        norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
        print(norm_x.shape)
        inv_var = torch.rsqrt(norm_x + self.eps)
        x_normed = x * inv_var
        x_scaled = weight * x_normed
        if bias is not None:
            x_scaled += bias
        return x_scaled, inv_var

#### Problem Sizes
- Batch Size: 4 
- Sequence Length: 1024
- Hidden Size: 128

In [None]:
batch, seq_length, hidden_size = 4, 1024, 128

input_type = torch.float16

# Epsilon is a small number to prevent division by 0.
epsilon_value = 1e-3

Create input tensor GPU buffers. We use PyTorch to allocate GPU tensors so we can reuse them easily when we calculate reference outputs.

In [None]:
# input tensor memory, initialize them to random numbers
x_gpu = (
    2
    * torch.randn(
        batch * seq_length,
        hidden_size,
        1,
        1,
        dtype=input_type,
        requires_grad=True,
        device="cuda",
    ).to(memory_format=torch.channels_last)
    - 1.25
)

scale_gpu = (
    3
    * torch.randn(
        1, hidden_size, 1, 1, dtype=input_type, requires_grad=True, device="cuda"
    ).to(memory_format=torch.channels_last)
    - 2.75
)
bias_gpu = torch.randn(
    1, hidden_size, 1, 1, dtype=input_type, requires_grad=True, device="cuda"
).to(memory_format=torch.channels_last)

# set epsilon to epsilon_value, allocate on cpu.
epsilon_cpu = torch.full(
    (1, 1, 1, 1), epsilon_value, dtype=torch.float32, requires_grad=False, device="cpu"
)

Compute reference ouputs and allocate output tensor GPU buffers

In [None]:
# we create the reference computation outputs here so we can use .empty_like() to create our output buffers
model = RMSNorm(eps=epsilon_value, dim=(1, 2, 3)).float()
out_expected, inv_var_expected = model(x_gpu, scale_gpu, bias_gpu)

# allocate output tensor memory using PyTorch
# PyTorch has calculated their shapes already, so we can simply use .empty_like()
out_gpu = torch.empty_like(out_expected)
inv_var_gpu = torch.empty_like(inv_var_expected)

#### Create cuDNN graph and tensors

In [None]:
if cudnn.backend_version_string() >= "9.1.0":
    handle = cudnn.create_handle()

    # create cuDNN graph
    graph = cudnn.pygraph(
        handle=handle,
        intermediate_data_type=cudnn.data_type.FLOAT,
        compute_data_type=cudnn.data_type.FLOAT,
    )

    # create tensor handles with the graph API
    x = graph.tensor_like(x_gpu.detach()).set_name("X")
    scale = graph.tensor_like(scale_gpu.detach()).set_name("scale")
    bias = graph.tensor_like(bias_gpu.detach()).set_name("bias")
    epsilon = graph.tensor_like(epsilon_cpu).set_name("epsilon")

    (out, inv_var) = graph.rmsnorm(
        name="rmsnorm",
        input=x,
        norm_forward_phase=cudnn.norm_forward_phase.TRAINING,
        scale=scale,
        bias=bias,
        epsilon=epsilon,
    )

    # enable all outputs
    out.set_name("output").set_output(True).set_data_type(out_expected.dtype)
    inv_var.set_name("inv_var").set_output(True).set_data_type(inv_var_expected.dtype);

#### Build the graph

In [None]:
if cudnn.backend_version_string() >= "9.1.0":
    # Build the graph
    graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])

# To run this block more than once, we need to re-run the previous block to get a new graph.
# The same instance of a graph should not be built twice.

#### Execute the graph

In [None]:
if cudnn.backend_version_string() >= "9.1.0":
    # Mapping of (handles -> memory)
    variant_pack = {
        x: x_gpu.detach(),
        scale: scale_gpu.detach(),
        bias: bias_gpu.detach(),
        epsilon: epsilon_cpu,
        out: out_gpu,
        inv_var: inv_var_gpu,
    }

    workspace = torch.empty(
        graph.get_workspace_size(), device="cuda", dtype=torch.uint8
    )
    graph.execute(variant_pack, workspace)
    torch.cuda.synchronize()

Test cuDNN's output against PyTorch's and check correctness

In [None]:
# reference output
if cudnn.backend_version_string() >= "9.1.0":

    torch.testing.assert_close(out_gpu, out_expected, rtol=5e-3, atol=5e-3)
    torch.testing.assert_close(inv_var_gpu, inv_var_expected, rtol=5e-3, atol=5e-3)

#### RMSNorm Backwards Pass

In [None]:
target = torch.randn_like(out_expected)
criterion = torch.nn.MSELoss()  # TODO: What is this?
loss = criterion(out_expected, target)

out_expected.retain_grad()
x_gpu.retain_grad()
scale_gpu.retain_grad()
bias_gpu.retain_grad()

loss.backward()

In [None]:
if cudnn.backend_version_string() >= "9.1.0":

    bwd_graph = cudnn.pygraph(
        handle=handle,
        intermediate_data_type=cudnn.data_type.FLOAT,
        compute_data_type=cudnn.data_type.FLOAT,
    )

    d_out = bwd_graph.tensor_like(out_expected.grad)

    x_bwd = bwd_graph.tensor_like(x, name="x")
    scale_bwd = bwd_graph.tensor_like(scale, name="scale")
    inv_var_bwd = bwd_graph.tensor_like(inv_var, name="inv_var")

    (d_x, d_scale, d_bias) = bwd_graph.rmsnorm_backward(
        name="d_rmsnorm",
        grad=d_out,
        input=x_bwd,
        scale=scale_bwd,
        inv_variance=inv_var_bwd,
        has_dbias=True,
    )

    d_x.set_output(True).set_data_type(x_gpu.dtype)
    d_scale.set_output(True).set_data_type(x_gpu.dtype)
    d_bias.set_output(True).set_data_type(x_gpu.dtype)

In [None]:
if cudnn.backend_version_string() >= "9.1.0":
    # Build the bwd_graph
    bwd_graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])

    # Create output buffers for gradients
    d_x_gpu = torch.empty_like(x_gpu)
    d_scale_gpu = torch.empty_like(scale_gpu)
    d_bias_gpu = torch.empty_like(bias_gpu)

    workspace = torch.empty(
        bwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8
    )

    bwd_graph.execute(
        {
            x_bwd: x_gpu.detach(),
            scale_bwd: scale_gpu.detach(),
            d_out: out_expected.grad,
            inv_var_bwd: inv_var_gpu.detach(),
            d_x: d_x_gpu,
            d_scale: d_scale_gpu,
            d_bias: d_bias_gpu,
        },
        workspace,
        handle=handle,
    )

Compare results and check correctness

In [None]:
if cudnn.backend_version_string() >= "9.1.0":
    torch.cuda.synchronize()

    torch.testing.assert_close(x_gpu.grad, d_x_gpu, atol=2e-4, rtol=2e-4)
    torch.testing.assert_close(scale_gpu.grad, d_scale_gpu, atol=2e-4, rtol=2e-4)
    torch.testing.assert_close(bias_gpu.grad, d_bias_gpu, atol=2e-4, rtol=2e-4)

Perform Cleanup

In [None]:
if cudnn.backend_version_string() >= "9.1.0":
    cudnn.destroy_handle(handle)