# Adaptive LayerNorm: Forward and Backward

This notebook shows how the forward training and backward pass of an adaptive layer norm operation can be done using cuDNN.

$$\text{Adaptive\_LayerNorm}(x) = \frac{x-\mu}{\sqrt{\sigma^2 + \epsilon}}\cdot\gamma+\beta$$

Where $\mu = E[x]$ and $\sigma^2 = Var[x]$ are taken over all inputs in a batch, $\gamma$ and $\beta$ are learnable parameters and varies for each input in a batch. This is in contrast to the layer norm where $\gamma$ and $\beta$ are shared across all inputs in a batch.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/26_adaptive_layernorm_forward_training_and_backward.ipynb)

## Prerequisites and Setup
This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected.

In [None]:
# get_ipython().system('nvidia-smi')

If running on Colab, you will need to install the cudnn python interface.

In [None]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128')

## Overview

In the following, we will apply adaptive layer norm to a tensor of the following shape:

- Batch Size: 4
- Sequence Size: 1024
- Embedding Dimension: 768

Let's define these dimensions as constants:

In [None]:
import cudnn
import torch

torch.manual_seed(1)
print("Running with cudnn backend version:", cudnn.backend_version())

handle = cudnn.create_handle()

assert torch.cuda.is_available()

batch, seq_size, embedding_dim = 4, 1024, 768
# Epsilon is a small number to prevent division by 0.
epsilon_value = 1e-3
dtype = torch.float16

## Using Wrapper

#### Forward Pass

In [None]:
# input tensors
x_gpu = torch.randn(
    batch, seq_size, embedding_dim, device="cuda", dtype=dtype, requires_grad=True
)
scale_gpu = torch.randn(
    batch, 1, embedding_dim, device="cuda", dtype=dtype, requires_grad=True
)
bias_gpu = torch.randn(
    batch, 1, embedding_dim, device="cuda", dtype=dtype, requires_grad=True
)
eps_cpu = torch.full((1, 1, 1), epsilon_value, dtype=torch.float32, device="cpu")

# forward pass of layernorm using cuDNN graph
with cudnn.Graph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
    inputs=["adaln::input", "adaln::scale", "adaln::bias", "adaln::epsilon"],
    outputs=["adaln::Y", "adaln::MEAN", "adaln::INV_VARIANCE"],
) as fwd_graph:
    out, mean, inv_var = fwd_graph.adalayernorm(
        name="adaln",
        norm_forward_phase=cudnn.norm_forward_phase.TRAINING,
        input=x_gpu,
        scale=scale_gpu,
        bias=bias_gpu,
        epsilon=eps_cpu,
    )
    out.set_name("output").set_output(True).set_data_type(dtype)
    mean.set_name("mean").set_output(True).set_data_type(cudnn.data_type.FLOAT)
    inv_var.set_name("inv_var").set_output(True).set_data_type(cudnn.data_type.FLOAT)

out_gpu, mean_gpu, inv_var_gpu = fwd_graph(
    x_gpu.detach(), scale_gpu.detach(), bias_gpu.detach(), eps_cpu, handle=handle
)

# PyTorch reference output
out_ref = torch.nn.functional.layer_norm(x_gpu, (embedding_dim,), eps=epsilon_value)
out_ref = out_ref * scale_gpu + bias_gpu
mean_ref = x_gpu.float().mean(dim=2, keepdim=True)
inv_var_ref = torch.rsqrt(torch.var(x_gpu.float(), dim=2, keepdim=True) + epsilon_value)

torch.testing.assert_close(out_gpu, out_ref, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(mean_gpu, mean_ref, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(inv_var_gpu, inv_var_ref, atol=5e-3, rtol=3e-3)

Comparing this with the case of [layer norm](20_layernorm_forward.ipynb), you can see that the arguments to the operation are the same, except that the `scale` and `bias` arguments are in a different dimension size corresponding to the `input` tensor.

#### Backward pass

In [None]:
# Compute gradients: Ask PyTorch not to discard the grads after use so that we can read it twice
# out_ref.grad will be used in the cudnn graph, x_gpu.grad, scale_gpu.grad, and bias_gpu.grad will
# be used to compare with the cudnn graph output.
target = torch.randn_like(out_ref)
criterion = torch.nn.MSELoss()
loss = criterion(out_ref, target)

out_ref.retain_grad()
x_gpu.retain_grad()
scale_gpu.retain_grad()
bias_gpu.retain_grad()

loss.backward()

# Backward pass
with cudnn.Graph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
    inputs=[
        "Dadaln::grad",
        "Dadaln::input",
        "Dadaln::scale",
        "Dadaln::mean",
        "Dadaln::inv_variance",
    ],
    outputs=["Dadaln::DX", "Dadaln::DSCALE", "Dadaln::DBIAS"],
) as bwd_graph:
    dx, dscale, dbias = bwd_graph.adalayernorm_backward(
        name="Dadaln",
        grad=out_ref.grad,
        input=x_gpu,
        scale=scale_gpu,
        mean=mean_gpu,
        inv_variance=inv_var_gpu,
    )
    dx.set_output(True).set_data_type(dtype)
    dscale.set_output(True).set_data_type(dtype)
    dbias.set_output(True).set_data_type(dtype)

dx_gpu, dscale_gpu, dbias_gpu = bwd_graph(
    out_ref.grad,
    x_gpu.detach(),
    scale_gpu.detach(),
    mean_gpu.detach(),
    inv_var_gpu.detach(),
    handle=handle,
)

torch.testing.assert_close(x_gpu.grad, dx_gpu, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(scale_gpu.grad, dscale_gpu, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(bias_gpu.grad, dbias_gpu, atol=5e-3, rtol=3e-3)

## Using Python Binding APIs

#### Adaptive LayerNorm Forward Pass

Create input tensor GPU buffers. We use PyTorch to allocate GPU tensors so we can reuse them easily when we calculate reference outputs.

In [None]:
# input tensors
x_gpu = torch.randn(
    batch, seq_size, embedding_dim, device="cuda", dtype=dtype, requires_grad=True
)
scale_gpu = torch.randn(
    batch, 1, embedding_dim, device="cuda", dtype=dtype, requires_grad=True
)
bias_gpu = torch.randn(
    batch, 1, embedding_dim, device="cuda", dtype=dtype, requires_grad=True
)
eps_cpu = torch.full((1, 1, 1), epsilon_value, dtype=torch.float32, device="cpu")

Then create the graph for the forward pass.

In [None]:
# Create the cuDNN graph
fwd_graph = cudnn.pygraph(
    handle=handle,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

# Create tensor handles with the fwd_graph
x = fwd_graph.tensor_like(x_gpu.detach()).set_name("X")
scale = fwd_graph.tensor_like(scale_gpu.detach()).set_name("scale")
bias = fwd_graph.tensor_like(bias_gpu.detach()).set_name("bias")
epsilon = fwd_graph.tensor_like(eps_cpu).set_name("epsilon")

# Add a layernorm operation
out, mean, inv_var = fwd_graph.adalayernorm(
    name="ADALN",
    norm_forward_phase=cudnn.norm_forward_phase.TRAINING,
    input=x,
    scale=scale,
    bias=bias,
    epsilon=epsilon,
)

# Enable all outputs
out.set_name("output").set_output(True).set_data_type(dtype)
mean.set_name("mean").set_output(True).set_data_type(cudnn.data_type.FLOAT)
inv_var.set_name("inv_var").set_output(True).set_data_type(cudnn.data_type.FLOAT)

# Build the fwd_graph
fwd_graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])

Execute the forward graph.
Instead of mapping UIDs to memory (as in [20_layernorm.ipynb](20_layernorm.ipynb)), we can directly map handles to memory. This is simpler but slightly slower to execute.

In [None]:
# Mapping of (handles -> memory)
variant_pack = {
    x: x_gpu.detach(),
    scale: scale_gpu.detach(),
    bias: bias_gpu.detach(),
    epsilon: eps_cpu,
    out: out_gpu,
    mean: mean_gpu,
    inv_var: inv_var_gpu,
}

workspace = torch.empty(
    fwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8
)
fwd_graph.execute(variant_pack, workspace)
torch.cuda.synchronize()

Test cuDNN's output against PyTorch's and check correctness

In [None]:
# PyTorch reference output
out_ref = torch.nn.functional.layer_norm(x_gpu, (embedding_dim,), eps=epsilon_value)
out_ref = out_ref * scale_gpu + bias_gpu
mean_ref = x_gpu.float().mean(dim=2, keepdim=True)
inv_var_ref = torch.rsqrt(torch.var(x_gpu.float(), dim=2, keepdim=True) + epsilon_value)

torch.testing.assert_close(out_gpu, out_ref, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(mean_gpu, mean_ref, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(inv_var_gpu, inv_var_ref, atol=5e-3, rtol=3e-3)

#### Adaptive LayerNorm Backward Pass

First, let's compute the references values for backward pass:

In [None]:
# Reference backward operation using PyTorch
target = torch.randn_like(out_ref)  # random as ground truth
criterion = torch.nn.MSELoss()
loss = criterion(out_ref, target)

# keep grads for comparison
out_ref.retain_grad()
x_gpu.retain_grad()
scale_gpu.retain_grad()
bias_gpu.retain_grad()

loss.backward()

Build the backward graph

In [None]:
bwd_graph = cudnn.pygraph(
    handle=handle,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

# Create tensors associated with the backwards graph. DO NOT reuse tensor handles from the forward graph.
d_out = bwd_graph.tensor(
    name="d_out", dim=x_gpu.size(), stride=x_gpu.stride(), data_type=x_gpu.dtype
)
x_bwd = bwd_graph.tensor_like(x, name="x")
scale_bwd = bwd_graph.tensor_like(scale, name="scale")
mean_bwd = bwd_graph.tensor_like(mean, name="mean")
inv_var_bwd = bwd_graph.tensor_like(inv_var, name="inv_var")

# Add the adaptive layernorm backward operation
d_x, d_scale, d_bias = bwd_graph.adalayernorm_backward(
    name="DADALN",
    grad=d_out,
    input=x_bwd,
    scale=scale_bwd,
    mean=mean_bwd,
    inv_variance=inv_var_bwd,
)

# Enable outputs.
d_x.set_output(True).set_data_type(x_gpu.dtype)
d_scale.set_output(True).set_data_type(x_gpu.dtype)
d_bias.set_output(True).set_data_type(x_gpu.dtype)

# Build the bwd_graph
bwd_graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])

Execute the graph and check correctness against PyTorch

In [None]:
# Create output buffers for gradients
d_x_gpu = torch.empty_like(x_gpu)
d_scale_gpu = torch.empty_like(scale_gpu)
d_bias_gpu = torch.empty_like(bias_gpu)

# For the inputs of the backwards graph (x_bwd, d_out, scale_bwd, mean_bwd, inv_var_bwd), we use the
# outputs of the forwards graph. For d_out we use pytorches autograd .grad functionality.
variant_pack = {
    x_bwd: x_gpu.detach(),
    scale_bwd: scale_gpu.detach(),
    d_out: out_ref.grad,
    mean_bwd: mean_gpu.detach(),
    inv_var_bwd: inv_var_gpu.detach(),
    d_x: d_x_gpu,
    d_scale: d_scale_gpu,
    d_bias: d_bias_gpu,
}
workspace = torch.empty(
    bwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8
)
bwd_graph.execute(variant_pack, workspace, handle=handle)
torch.cuda.synchronize()

Compare results and check correctness

In [None]:
# compare to reference output
torch.testing.assert_close(x_gpu.grad, d_x_gpu, atol=2e-4, rtol=2e-4)
torch.testing.assert_close(scale_gpu.grad, d_scale_gpu, atol=2e-4, rtol=2e-4)
torch.testing.assert_close(bias_gpu.grad, d_bias_gpu, atol=2e-4, rtol=2e-4)

Perform Cleanup

In [None]:
cudnn.destroy_handle(handle)