# Layer Norm with Pointwise Add

This notebook shows how to compute forward pointwise add + layer normalization with intermediate output.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/25_layernorm_forward_training_and_backward_with_relu_bitmask.ipynb)

## Prerequisites and Setup
This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected.

In [None]:
#get_ipython().system('nvidia-smi')

If running on Colab, you will need to install the cudnn python interface.

In [None]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128')

## Overview

In the following, we will apply layer norm to a tensor of the following shape:

- Batch Size: 4
- Sequence Size: 1024
- Embedding Dimension: 768

Let's define these dimensions as constants:

In [None]:
import cudnn
import torch

torch.manual_seed(1)
handle = cudnn.create_handle()

print("Running with cudnn backend version:", cudnn.backend_version())

assert torch.cuda.is_available()
assert (
    cudnn.backend_version() >= 91400
), "LayerNorm pointwise fusion with intermediate output is only supported cuDNN version 9.14.0 or above"

batch, seq_size, embedding_dim = 4, 1024, 768
dtype = torch.float32

# Epsilon is a small number to prevent division by 0.
epsilon_value = 1e-3

## Using Wrapper

#### Add and LayerNorm with Intermediate bfloat16 Output

First, we define the input tensors

In [None]:
# allocate random input tensors
x_gpu = torch.randn(
    batch * seq_size, embedding_dim, 1, 1, device="cuda", dtype=dtype
).to(memory_format=torch.channels_last)
add_gpu = torch.randn(
    batch * seq_size, embedding_dim, 1, 1, device="cuda", dtype=dtype
).to(memory_format=torch.channels_last)
scale_gpu = torch.randn(1, embedding_dim, 1, 1, device="cuda", dtype=dtype).to(
    memory_format=torch.channels_last
)
bias_gpu = torch.randn(1, embedding_dim, 1, 1, device="cuda", dtype=dtype).to(
    memory_format=torch.channels_last
)
epsilon_cpu = torch.full((1, 1, 1, 1), epsilon_value, dtype=torch.float32, device="cpu")

Next, create the graph for the forward pass.

In [None]:
with cudnn.Graph(
    io_data_type=cudnn.data_type.FLOAT,
    intermediate_data_type=cudnn.data_type.BFLOAT16,
    compute_data_type=cudnn.data_type.FLOAT,
) as fwd_graph:
    # pointwise add operation: x + b
    added_x = fwd_graph.add(
        name="Pointwise add",
        a=x_gpu,
        b=add_gpu,
    )
    # layernorm forward pass
    out, mean, inv_var = fwd_graph.layernorm(
        name="LN",
        norm_forward_phase=cudnn.norm_forward_phase.TRAINING,
        input=added_x,
        scale=scale_gpu,
        bias=bias_gpu,
        epsilon=epsilon_cpu,
    )
    # mark the output tensors
    added_x.set_name("added_x").set_output(True).set_data_type(cudnn.data_type.BFLOAT16)
    out.set_name("output").set_output(True).set_data_type(cudnn.data_type.FLOAT)
    mean.set_name("mean").set_output(True).set_data_type(cudnn.data_type.FLOAT)
    inv_var.set_name("inv_var").set_output(True).set_data_type(cudnn.data_type.FLOAT)

Then, execute the graph and compare the output to the reference output from PyTorch:

In [None]:
# allocated output tensors
added_x_gpu = torch.empty(
    batch * seq_size, embedding_dim, 1, 1, dtype=torch.bfloat16, device="cuda"
)
out_gpu = torch.empty_like(x_gpu)
mean_gpu = torch.empty(batch * seq_size, 1, 1, 1, dtype=torch.float32, device="cuda")
inv_var_gpu = torch.empty(batch * seq_size, 1, 1, 1, dtype=torch.float32, device="cuda")

# execute the graph
output = fwd_graph(
    {
        # input tensors
        "Pointwise add::a": x_gpu,
        "Pointwise add::b": add_gpu,
        "LN::scale": scale_gpu,
        "LN::bias": bias_gpu,
        "LN::epsilon": epsilon_cpu,
        # output tensors
        "added_x": added_x_gpu,
        "output": out_gpu,
        "mean": mean_gpu,
        "inv_var": inv_var_gpu,
    },
    handle=handle,
)

# PyTorch reference forward operation with intermediate bfloat16 output
added_x_ref = torch.add(x_gpu, add_gpu).to(torch.bfloat16)
out_ref = torch.nn.functional.layer_norm(
    added_x_ref.to(torch.float32),
    [embedding_dim, 1, 1],
    weight=scale_gpu.squeeze(0),
    bias=bias_gpu.squeeze(0),
    eps=epsilon_value,
)
mean_ref = added_x_ref.to(torch.float32).mean(dim=(1, 2, 3), keepdim=True)
inv_var_ref = torch.rsqrt(
    torch.var(added_x_ref.to(torch.float32), dim=(1, 2, 3), keepdim=True)
    + epsilon_value
)

# compare to reference output
torch.testing.assert_close(added_x_gpu, added_x_ref, rtol=5e-3, atol=5e-3)
torch.testing.assert_close(out_gpu, out_ref, rtol=5e-3, atol=5e-3)
torch.testing.assert_close(inv_var_gpu, inv_var_ref, rtol=5e-3, atol=5e-3)
torch.testing.assert_close(mean_gpu, mean_ref, rtol=5e-3, atol=5e-3)