# LayerNorm with Zero Centered Gamma: Forward and Backward

This notebook shows how to compute a zero centered gamma layernorm forward training and backward operation using cuDNN.

$$\text{LayerNorm\_Zero\_Centered\_Gamma}(x) = \frac{x-\mu}{\sqrt{\sigma^2 + \epsilon}}\cdot(1+\gamma)+\beta$$

Where $\mu = E[x]$ and $\sigma^2 = Var[x]$ are taken over all inputs in a batch.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/23_layernorm_zero_centered_gamma_forward_training_and_backward.ipynb)

## Prerequisites and Setup
This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected.

In [None]:
#get_ipython().system('nvidia-smi')

If running on Colab, you will need to install the cudnn python interface.

In [None]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128')

## Overview

In the following, we will apply layer norm to a tensor of the following shape:

- Batch Size: 4
- Sequence Size: 1024
- Embedding Dimension: 768

Let's define these dimensions as constants:

In [None]:
import cudnn
import torch

torch.manual_seed(1)

handle = cudnn.create_handle()
print("Running with cudnn backend version:", cudnn.backend_version())

assert torch.cuda.is_available()
assert (
    cudnn.backend_version() >= 91000
), "LayerNorm Zero Centered Gamma operation is only supported cuDNN version 9.10.0 or above"

batch, seq_size, embedding_dim = 4, 1024, 768
# Epsilon is a small number to prevent division by 0.
epsilon_value = 1e-3
dtype = torch.float16

## Using Wrapper

Forward pass:

In [None]:
# input tensors
x_gpu = torch.randn(
    batch * seq_size,
    embedding_dim,
    1,
    1,
    device="cuda",
    dtype=dtype,
    requires_grad=True,
).to(memory_format=torch.channels_last)
gamma_gpu = torch.randn(
    1, embedding_dim, 1, 1, device="cuda", dtype=dtype, requires_grad=True
).to(memory_format=torch.channels_last)
bias_gpu = torch.randn(
    1, embedding_dim, 1, 1, device="cuda", dtype=dtype, requires_grad=True
).to(memory_format=torch.channels_last)
one_cpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cpu")
eps_cpu = torch.full((1, 1, 1, 1), epsilon_value, dtype=torch.float32, device="cpu")

# forward pass of layernorm using cuDNN graph
with cudnn.Graph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
    inputs=[
        "gamma_plus_one::a",
        "gamma_plus_one::b",
        "ln_fwd::input",
        "ln_fwd::bias",
        "ln_fwd::epsilon",
    ],
    outputs=["ln_fwd::Y", "ln_fwd::MEAN", "ln_fwd::INV_VARIANCE"],
) as fwd_graph:
    # a pointwise add operation for zero centered gamma + 1
    scale = fwd_graph.add(
        name="gamma_plus_one",
        a=gamma_gpu,
        b=one_cpu,
    )
    # Add a layernorm operation
    out, mean, inv_var = fwd_graph.layernorm(
        name="ln_fwd",
        norm_forward_phase=cudnn.norm_forward_phase.TRAINING,
        input=x_gpu,
        scale=scale,
        bias=bias_gpu,
        epsilon=eps_cpu,
    )
    # Enable all outputs, by default outputs are disabled
    # mean and inv_var must be float32 tensors
    out.set_name("output").set_output(True).set_data_type(dtype)
    mean.set_name("mean").set_output(True).set_data_type(torch.float32)
    inv_var.set_name("inv_var").set_output(True).set_data_type(torch.float32)

out_gpu, mean_gpu, inv_var_gpu = fwd_graph(
    gamma_gpu, one_cpu, x_gpu, bias_gpu, eps_cpu, handle=handle
)

Layer norm in cuDNN accepts `scale` and `bias` as input arguments. The `scale` is a factor to multiply to the normalized input. If the scale factor is zero centered, you need a conversion before passing it to the layer norm. The node `gamma_plus_one` defined above is such a conversion.

Now, let's compare the output from cuDNN with PyTorch:

In [None]:
# PyTorch reference output
out_ref = torch.nn.functional.layer_norm(
    x_gpu,
    [embedding_dim, 1, 1],
    weight=(1 + gamma_gpu).squeeze(0),
    bias=bias_gpu.squeeze(0),
    eps=epsilon_value,
)
mean_ref = x_gpu.float().mean(dim=(1, 2, 3), keepdim=True)
inv_var_ref = torch.rsqrt(
    torch.var(x_gpu.float(), dim=(1, 2, 3), keepdim=True) + epsilon_value
)

torch.testing.assert_close(out_gpu, out_ref, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(mean_gpu, mean_ref, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(inv_var_gpu, inv_var_ref, atol=5e-3, rtol=3e-3)

Based on the above output, let's implement the backward pass:

In [None]:
# Compute gradients: Ask PyTorch not to discard the grads after use so that we can read it twice
# out_ref.grad will be used in the cudnn graph, x_gpu.grad, scale_gpu.grad, and bias_gpu.grad will
# be used to compare with the cudnn graph output.
target = torch.randn_like(out_ref)
criterion = torch.nn.MSELoss()
loss = criterion(out_ref, target)

out_ref.retain_grad()
x_gpu.retain_grad()
gamma_gpu.retain_grad()
bias_gpu.retain_grad()

loss.backward()

# Backward pass
with cudnn.Graph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
    inputs=[
        "gamma_plus_one::a",
        "gamma_plus_one::b",
        "ln_bwd::grad",
        "ln_bwd::input",
        "ln_bwd::mean",
        "ln_bwd::inv_variance",
    ],
    outputs=["ln_bwd::DX", "ln_bwd::DSCALE", "ln_bwd::DBIAS"],
) as bwd_graph:
    scale_bwd = bwd_graph.add(
        name="gamma_plus_one",
        a=gamma_gpu,
        b=one_cpu,
    )
    dx, dscale, dbias = bwd_graph.layernorm_backward(
        name="ln_bwd",
        grad=out_ref.grad,
        input=x_gpu,
        scale=scale_bwd,
        mean=mean_gpu,
        inv_variance=inv_var_gpu,
    )
    dx.set_output(True).set_data_type(dtype)
    dscale.set_output(True).set_data_type(dtype)
    dbias.set_output(True).set_data_type(dtype)

dx_gpu, dscale_gpu, dbias_gpu = bwd_graph(
    gamma_gpu, one_cpu, out_ref.grad, x_gpu, mean_gpu, inv_var_gpu, handle=handle
)

torch.testing.assert_close(x_gpu.grad, dx_gpu, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(gamma_gpu.grad, dscale_gpu, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(bias_gpu.grad, dbias_gpu, atol=5e-3, rtol=3e-3)

## Using Python Binding APIs

#### Forward pass

Next, we need to create GPU buffers as input, the `gamma_gpu` tensor is the zero-centered scale factor:

In [None]:
gamma_gpu = torch.randn(
    1, embedding_dim, 1, 1, device="cuda", dtype=dtype, requires_grad=True
).to(memory_format=torch.channels_last)
bias_gpu = torch.randn(
    1, embedding_dim, 1, 1, device="cuda", dtype=dtype, requires_grad=True
).to(memory_format=torch.channels_last)
x_gpu = torch.randn(
    batch * seq_size,
    embedding_dim,
    1,
    1,
    device="cuda",
    dtype=dtype,
    requires_grad=True,
).to(memory_format=torch.channels_last)
one_cpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cpu")
eps_cpu = torch.full((1, 1, 1, 1), epsilon_value, dtype=torch.float32, device="cpu")

Then we can create the graph for forward pass:

In [None]:
from enum import Enum


class UID(Enum):
    SCALE0 = 1
    X = 2
    BIAS = 3
    OUT = 5
    MEAN = 6
    INV_VAR = 7
    ONE = 8
    EPSILON = 9

In [None]:
# Create the cuDNN graph.
graph = cudnn.pygraph(
    handle=handle,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

# Create tensor handles with the graph API, assign UIDs.
x = graph.tensor_like(x_gpu.detach()).set_name("X").set_uid(UID.X.value)
gamma = (
    graph.tensor_like(gamma_gpu.detach()).set_name("scale0").set_uid(UID.SCALE0.value)
)
one = graph.tensor_like(one_cpu).set_name("one").set_uid(UID.ONE.value)
bias = graph.tensor_like(bias_gpu.detach()).set_name("bias").set_uid(UID.BIAS.value)
epsilon = graph.tensor_like(eps_cpu).set_name("epsilon").set_uid(UID.EPSILON.value)

# A node for pointwise add operation: zero centered gamma + 1
scale = graph.add(name="gamma_plus_one", a=gamma, b=one)

# A node for layernorm operation
out, mean, inv_var = graph.layernorm(
    name="layernorm",
    norm_forward_phase=cudnn.norm_forward_phase.TRAINING,
    input=x,
    scale=scale,
    bias=bias,
    epsilon=epsilon,
)

# Mark output tensors
out.set_name("output").set_output(True).set_data_type(dtype).set_uid(UID.OUT.value)
mean.set_name("mean").set_output(True).set_data_type(torch.float32).set_uid(
    UID.MEAN.value
)
inv_var.set_name("inv_var").set_output(True).set_data_type(torch.float32).set_uid(
    UID.INV_VAR.value
)

# Build the graph
graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])

Here we assign UIDs for tensors. UIDs are a unique identifier that will allow us to provide a mapping from tensors created from cuDNN graph api calls, such as `graph.tensor_like()`, to the underlying device memory that will be used to store these tensors. Virtual tensors don't require explicit memory allocated for them, but non-vritual tensors like inputs or outputs will need to have UIDs assigned to them. 

Alternatively, one can use handles directly in the mapping, however using UIDs can be more convinient for caching of cuDNN graphs.

For each of our inputs {X, Scale, Bias, Epsilon} and our outputs {Out, Mean, Inverse Variance}, we allocate a UID. 

After validating and building a cuDNN graph,  we can now execute it. To do this, we have to provide input and output buffers. We do this by using the previously allocated UIDs to associate between tensor handles generated from the graph API, and their underlying memory. 

The desired input values need to be stored in these buffers before the `graph.execute` call. Because we have done a reference computation, we can simply reuse the buffers we have allocated via PyTorch.

Note that the EPISLON UID expects a cpu buffer,

In [None]:
out_gpu = torch.empty(batch * seq_size, embedding_dim, 1, 1, device="cuda", dtype=dtype)
mean_gpu = torch.empty(batch * seq_size, 1, 1, 1, device="cuda", dtype=torch.float32)
inv_var_gpu = torch.empty(batch * seq_size, 1, 1, 1, device="cuda", dtype=torch.float32)

# Mapping of (UIDs -> memory)
variant_pack = {
    UID.X.value: x_gpu,
    UID.SCALE0.value: gamma_gpu,
    UID.BIAS.value: bias_gpu,
    UID.EPSILON.value: eps_cpu,
    UID.OUT.value: out_gpu,
    UID.MEAN.value: mean_gpu,
    UID.INV_VAR.value: inv_var_gpu,
    UID.ONE.value: one_cpu,
}

workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
graph.execute(variant_pack, workspace)
torch.cuda.synchronize()

Test cuDNN's output against PyTorch's and check correctness

In [None]:
# PyTorch reference output
out_ref = torch.nn.functional.layer_norm(
    x_gpu,
    [embedding_dim, 1, 1],
    weight=(1 + gamma_gpu).squeeze(0),
    bias=bias_gpu.squeeze(0),
    eps=epsilon_value,
)
mean_ref = x_gpu.float().mean(dim=(1, 2, 3), keepdim=True)
inv_var_ref = torch.rsqrt(
    torch.var(x_gpu.float(), dim=(1, 2, 3), keepdim=True) + epsilon_value
)

# compare to reference output
torch.testing.assert_close(out_gpu, out_ref, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(mean_gpu, mean_ref, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(inv_var_gpu, inv_var_ref, atol=5e-3, rtol=3e-3)

#### Backward Pass

Let's use random values as groundtruth to calculate the loss and run backward pass in PyTorch:

In [None]:
# Reference backward operation using PyTorch
target = torch.randn_like(out_ref)
criterion = torch.nn.MSELoss()
loss = criterion(out_ref, target)

out_ref.retain_grad()
x_gpu.retain_grad()
gamma_gpu.retain_grad()
bias_gpu.retain_grad()

loss.backward()

Then we can create the backward graph

In [None]:
bwd_graph = cudnn.pygraph(
    handle=handle,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

# Create tensors associated with the backwards graph.
# DO NOT reuse tensor handles from the forward graph because tensors are not shared across graphs.
d_out = bwd_graph.tensor(
    name="d_out", dim=x_gpu.size(), stride=x_gpu.stride(), data_type=x_gpu.dtype
)
x_bwd = bwd_graph.tensor_like(x, name="x")
gamma_bwd = bwd_graph.tensor_like(gamma, name="gamma")
one_bwd = graph.tensor_like(one_cpu).set_name("one")
mean_bwd = bwd_graph.tensor_like(mean, name="mean")
inv_var_bwd = bwd_graph.tensor_like(inv_var, name="inv_var")

# A node for pointwise add operation: zero centered gamma + 1
scale_bwd = bwd_graph.add(name="gamma_bwd_plus_one", a=gamma_bwd, b=one_bwd)

# A node for the layernorm backward operation
d_x, d_scale, d_bias = bwd_graph.layernorm_backward(
    name="DLN",
    grad=d_out,
    input=x_bwd,
    scale=scale_bwd,
    mean=mean_bwd,
    inv_variance=inv_var_bwd,
)

# Enable outputs.
d_x.set_output(True).set_data_type(x_gpu.dtype)
d_scale.set_output(True).set_data_type(x_gpu.dtype)
d_bias.set_output(True).set_data_type(x_gpu.dtype)

# Build the bwd_graph
bwd_graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])

Execute the graph and check correctness against PyTorch

In [None]:
# Create output buffers for gradients
d_x_gpu = torch.empty_like(x_gpu)
d_scale_gpu = torch.empty_like(gamma_gpu)
d_bias_gpu = torch.empty_like(bias_gpu)

# For the inputs of the backwards graph (x_bwd, d_out, scale_bwd, mean_bwd, inv_var_bwd), we use the outputs of the forwards graph. For d_out we use pytorches autograd .grad functionality.
variant_pack = {
    x_bwd: x_gpu.detach(),
    gamma_bwd: gamma_gpu.detach(),
    d_out: out_ref.grad,
    mean_bwd: mean_gpu.detach(),
    inv_var_bwd: inv_var_gpu.detach(),
    d_x: d_x_gpu,
    d_scale: d_scale_gpu,
    d_bias: d_bias_gpu,
    one_bwd: one_cpu,
}
workspace = torch.empty(
    bwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8
)
bwd_graph.execute(variant_pack, workspace, handle=handle)

Compare results and check correctness

In [None]:
torch.cuda.synchronize()

# compare to reference output
torch.testing.assert_close(x_gpu.grad, d_x_gpu, atol=2e-4, rtol=2e-4)
torch.testing.assert_close(gamma_gpu.grad, d_scale_gpu, atol=2e-4, rtol=2e-4)
torch.testing.assert_close(bias_gpu.grad, d_bias_gpu, atol=2e-4, rtol=2e-4)

Perform Cleanup

In [None]:
cudnn.destroy_handle(handle)