# LayerNorm with Zero Centered Gamma: Inference

This notebook shows how to compute a zero centered gamma layernorm forward inference operation using cuDNN.

$$\text{LayerNorm\_Zero\_Centered\_Gamma}(x) = \frac{x-\mu}{\sqrt{\sigma^2 + \epsilon}}\cdot(1+\gamma)+\beta$$

Where $\mu = E[x]$ and $\sigma^2 = Var[x]$ are taken over all inputs in a batch.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/24_layernorm_zero_centered_gamma_inference.ipynb)

## Prerequisites and Setup
This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected.

In [None]:
#get_ipython().system('nvidia-smi')

If running on Colab, you will need to install the cudnn python interface.

In [None]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128')

## Overview

In the following, we will apply layer norm to a tensor of the following shape:

- Batch Size: 4
- Sequence Size: 1024
- Embedding Dimension: 768

Let's define these dimensions as constants:

In [None]:
import cudnn
import torch

torch.manual_seed(1)
print("Running with cudnn backend version:", cudnn.backend_version())

handle = cudnn.create_handle()
assert torch.cuda.is_available()
assert (
    cudnn.backend_version() >= 91000
), "LayerNorm Zero Centered Gamma operation is only supported cuDNN version 9.10.0 or above"

batch, seq_size, embedding_dim = 4, 1024, 768
# Epsilon is a small number to prevent division by 0.
epsilon_value = 1e-3
dtype = torch.float16

## Using Wrapper

In [None]:
# input tensors
x_gpu = torch.randn(
    batch * seq_size,
    embedding_dim,
    1,
    1,
    device="cuda",
    dtype=dtype,
    requires_grad=True,
).to(memory_format=torch.channels_last)
gamma_gpu = torch.randn(
    1, embedding_dim, 1, 1, device="cuda", dtype=dtype, requires_grad=True
).to(memory_format=torch.channels_last)
bias_gpu = torch.randn(
    1, embedding_dim, 1, 1, device="cuda", dtype=dtype, requires_grad=True
).to(memory_format=torch.channels_last)
one_cpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cpu")
eps_cpu = torch.full((1, 1, 1, 1), epsilon_value, dtype=torch.float32, device="cpu")

# forward pass of layernorm using cuDNN graph
with cudnn.Graph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
    inputs=[
        "gamma_plus_one::a",
        "gamma_plus_one::b",
        "ln_fwd::input",
        "ln_fwd::bias",
        "ln_fwd::epsilon",
    ],
    outputs=["ln_fwd::Y"],
) as fwd_graph:
    # Add a pointwise add operation for zero centered gamma + 1
    scale = fwd_graph.add(
        name="gamma_plus_one",
        a=gamma_gpu,
        b=one_cpu,
    )
    # Add a layernorm operation
    out, mean, inv_var = fwd_graph.layernorm(
        name="ln_fwd",
        norm_forward_phase=cudnn.norm_forward_phase.INFERENCE,
        input=x_gpu,
        scale=scale,
        bias=bias_gpu,
        epsilon=eps_cpu,
    )
    assert mean is None, "mean should be None in inference mode"
    assert inv_var is None, "inv_var should be None in inference mode"
    # Enable all outputs, by default outputs are disabled
    out.set_name("output").set_output(True).set_data_type(dtype)

out_gpu = fwd_graph(gamma_gpu, one_cpu, x_gpu, bias_gpu, eps_cpu, handle=handle)

In the above, the layer norm node is created with `norm_forward_phase` set to `INFERENCE`. This mean you explicitly use it for inference, hence the `mean` and `inv_var` outputs are not computed.

You can verify the output with PyTorch:

In [None]:
# PyTorch reference output
out_ref = torch.nn.functional.layer_norm(
    x_gpu,
    [embedding_dim, 1, 1],
    weight=(1 + gamma_gpu).squeeze(0),
    bias=bias_gpu.squeeze(0),
    eps=epsilon_value,
)

torch.testing.assert_close(out_gpu, out_ref, atol=5e-3, rtol=3e-3)

## Using Python Binding APIs

Create input tensor GPU buffers. We use PyTorch to allocate GPU tensors so we can reuse them easily when we calculate reference outputs.

In [None]:
# input tensors
x_gpu = torch.randn(
    batch * seq_size, embedding_dim, 1, 1, device="cuda", dtype=dtype
).to(memory_format=torch.channels_last)
gamma_gpu = torch.randn(1, embedding_dim, 1, 1, device="cuda", dtype=dtype).to(
    memory_format=torch.channels_last
)
bias_gpu = torch.randn(1, embedding_dim, 1, 1, device="cuda", dtype=dtype).to(
    memory_format=torch.channels_last
)
one_cpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cpu")
eps_cpu = torch.full((1, 1, 1, 1), epsilon_value, dtype=torch.float32, device="cpu")

Then we can create the graph for forward pass:

In [None]:
from enum import Enum


class UID(Enum):
    SCALE0 = 1
    X = 2
    BIAS = 3
    OUT = 5
    ONE = 8
    EPSILON = 9

In [None]:
# Create the cuDNN graph.
graph = cudnn.pygraph(
    handle=handle,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

# Create tensor handles with the graph API, assign UIDs.
x = graph.tensor_like(x_gpu.detach()).set_name("X").set_uid(UID.X.value)
gamma = (
    graph.tensor_like(gamma_gpu.detach()).set_name("scale0").set_uid(UID.SCALE0.value)
)
one = graph.tensor_like(one_cpu).set_name("one").set_uid(UID.ONE.value)
bias = graph.tensor_like(bias_gpu.detach()).set_name("bias").set_uid(UID.BIAS.value)
epsilon = graph.tensor_like(eps_cpu).set_name("epsilon").set_uid(UID.EPSILON.value)

# A node for pointwise add operation: zero centered gamma + 1
scale = graph.add(name="gamma_plus_one", a=gamma, b=one)

# A node for layernorm operation
out, mean, inv_var = graph.layernorm(
    name="layernorm",
    norm_forward_phase=cudnn.norm_forward_phase.INFERENCE,
    input=x,
    scale=scale,
    bias=bias,
    epsilon=epsilon,
)

# Mark output tensors
out.set_name("output").set_output(True).set_data_type(dtype).set_uid(UID.OUT.value)

# Build the graph
graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])

Here we assign UIDs for tensors. UIDs are a unique identifier that will allow us to provide a mapping from tensors created from cuDNN graph api calls, such as `graph.tensor_like()`, to the underlying device memory that will be used to store these tensors. Virtual tensors don't require explicit memory allocated for them, but non-vritual tensors like inputs or outputs will need to have UIDs assigned to them. 

Alternatively, one can use handles directly in the mapping, however using UIDs can be more convinient for caching of cuDNN graphs.

For each of our inputs {X, Scale, Bias, Epsilon} and our outputs {Out, _, _}, we allocate a UID.

After validating and building a cuDNN graph,  we can now execute it. To do this, we have to provide input and output buffers. We do this by using the previously allocated UIDs to associate between tensor handles generated from the graph API, and their underlying memory. 

The desired input values need to be stored in these buffers before the `graph.execute` call. Because we have done a reference computation, we can simply reuse the buffers we have allocated via PyTorch.

Note that the EPISLON UID expects a cpu buffer,

In [None]:
# Mapping of (UIDs -> memory)
variant_pack = {
    UID.X.value: x_gpu,
    UID.SCALE0.value: gamma_gpu,
    UID.BIAS.value: bias_gpu,
    UID.EPSILON.value: eps_cpu,
    UID.OUT.value: out_gpu,
    UID.ONE.value: one_cpu,
}

workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
graph.execute(variant_pack, workspace)
torch.cuda.synchronize()

Test cuDNN's output against PyTorch's and check correctness

In [None]:
# PyTorch reference output
out_ref = torch.nn.functional.layer_norm(
    x_gpu,
    [embedding_dim, 1, 1],
    weight=(1 + gamma_gpu).squeeze(0),
    bias=bias_gpu.squeeze(0),
    eps=epsilon_value,
)

# compare to reference output
torch.testing.assert_close(out_gpu, out_ref, atol=5e-3, rtol=3e-3)

Perform Cleanup

In [None]:
cudnn.destroy_handle(handle)