# BatchNorm Operation

This notebook shows how to compute a batchnorm forward operation using cuDNN.

$$\text{BatchNorm}(x) = \frac{x-\mu}{\sqrt{\sigma^2 + \epsilon}}\cdot\gamma+\beta$$

Where $\mu = E[x]$ and $\sigma^2 = Var[x]$ are taken over all inputs in a channel.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/29_batchnorm.ipynb)

## Prerequisites and Setup
This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected.

In [None]:
# get_ipython().system('nvidia-smi')

If running on Colab, you will need to install the cudnn python interface.

In [None]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128')

## Overview

In the following, we are going to perform the RMS norm forward pass. First we define the batch size, number of channels, spatial dimensions, and some other parameters:

In [None]:
import cudnn
import torch

torch.manual_seed(1)
print("Running with cudnn backend version:", cudnn.backend_version())

handle = cudnn.create_handle()
assert torch.cuda.is_available()

N, C, H, W = 4, 16, 56, 56
epsilon_value = 1e-3
momentum = 0.1
dtype = torch.float16

## Using Wrapper

Below is how you can use the `Graph` wrapper to perform batch norm with the input tensors in PyTorch format.

In [None]:
x_gpu = torch.randn(N, C, H, W, device="cuda", dtype=dtype).to(
    memory_format=torch.channels_last
)
scale_gpu = torch.randn(1, C, 1, 1, device="cuda")
bias_gpu = torch.randn(1, C, 1, 1, device="cuda")
running_mean_gpu = torch.randn(1, C, 1, 1, device="cuda")
running_var_gpu = torch.randn(1, C, 1, 1, device="cuda")
comparison_gpu = torch.zeros(N, C, H, W, device="cuda", dtype=dtype).to(
    memory_format=torch.channels_last
)

eps_cpu = torch.full((1, 1, 1, 1), epsilon_value, dtype=torch.float32, device="cpu")
momentum_cpu = torch.full((1, 1, 1, 1), momentum, dtype=torch.float32, device="cpu")

# forward pass of batchnorm using cuDNN graph
with cudnn.Graph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
    inputs=[
        "bn_fwd::input",
        "bn_fwd::scale",
        "bn_fwd::bias",
        "bn_fwd::in_running_mean",
        "bn_fwd::in_running_var",
        "bn_fwd::epsilon",
        "bn_fwd::momentum",
        "cmp_gt::comparison",
    ],
    outputs=["y", "mean", "inv_var", "run_mean", "run_var", "mask"],
) as fwd_graph:
    out, mean, inv_var, run_mean, run_var = fwd_graph.batchnorm(
        name="bn_fwd",
        input=x_gpu,
        scale=scale_gpu,
        bias=bias_gpu,
        in_running_mean=running_mean_gpu,
        in_running_var=running_var_gpu,
        epsilon=eps_cpu,
        momentum=momentum_cpu,
    )
    y = fwd_graph.relu(name="relu", input=out)
    mask = fwd_graph.cmp_gt(
        name="cmp_gt",
        input=y,
        comparison=comparison_gpu,
    )
    y.set_output(True).set_name("y")
    mean.set_output(True).set_data_type(cudnn.data_type.FLOAT).set_name("mean")
    inv_var.set_output(True).set_data_type(cudnn.data_type.FLOAT).set_name("inv_var")
    run_mean.set_output(True).set_data_type(cudnn.data_type.FLOAT).set_name("run_mean")
    run_var.set_output(True).set_data_type(cudnn.data_type.FLOAT).set_name("run_var")
    mask.set_output(True).set_data_type(cudnn.data_type.BOOLEAN).set_name("mask")

y_gpu, out_mean_gpu, out_inv_var_gpu, running_mean_gpu, running_var_gpu, mask_gpu = (
    fwd_graph(
        x_gpu,
        scale_gpu,
        bias_gpu,
        running_mean_gpu,
        running_var_gpu,
        eps_cpu,
        momentum_cpu,
        comparison_gpu,
        handle=handle,
    )
)

This is a more sophisticated graph in that the output of the batch norm has been processed to create two other tensors, `y` and `mask`. There are multiple nodes in this graph. Let's compare the output with PyTorch:

In [None]:
# PyTorch equivalent of forward pass, and the stats of this batch
out_ref = torch.nn.functional.batch_norm(
    x_gpu,
    running_mean_gpu,
    running_var_gpu,
    weight=scale_gpu,
    bias=bias_gpu,
    training=True,
    momentum=momentum_cpu.item(),
    eps=eps_cpu.item(),
)
mean_ref = torch.mean(x_gpu.float(), dim=(0, 2, 3), keepdim=True)
inv_var_ref = torch.rsqrt(
    torch.var(x_gpu.float(), dim=(0, 2, 3), keepdim=True) + epsilon_value
)
y_ref = torch.relu(out_ref)
mask_ref = y_ref > 0

# Compare the output
torch.testing.assert_close(y_gpu, y_ref, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(mean_ref, out_mean_gpu, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(inv_var_ref, out_inv_var_gpu, atol=5e-3, rtol=3e-3)

## Using Python Binding APIs

#### Batchnorm Training Forward

Create input and output tensor buffers in PyTorch.

In [None]:
# input tensors
x_gpu = torch.randn(N, C, H, W, device="cuda", dtype=dtype).to(
    memory_format=torch.channels_last
)
scale_gpu = torch.randn(1, C, 1, 1, device="cuda")
bias_gpu = torch.randn(1, C, 1, 1, device="cuda")
running_mean_gpu = torch.randn(1, C, 1, 1, device="cuda")
running_var_gpu = torch.randn(1, C, 1, 1, device="cuda")
comparison_gpu = torch.zeros(N, C, H, W, device="cuda", dtype=dtype).to(
    memory_format=torch.channels_last
)

eps_cpu = torch.full((1, 1, 1, 1), epsilon_value, dtype=torch.float32, device="cpu")
momentum_cpu = torch.full((1, 1, 1, 1), momentum, dtype=torch.float32, device="cpu")

# output tensors
saved_mean_gpu = torch.empty_like(running_mean_gpu, device="cuda")
saved_inv_var_gpu = torch.empty_like(running_var_gpu, device="cuda")
y_gpu = torch.empty_like(x_gpu, dtype=dtype, device="cuda")
mask_gpu = torch.empty_like(x_gpu, dtype=torch.bool, device="cuda")

Create cuDNN graph

In [None]:
graph = cudnn.pygraph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
    handle=handle,
)

x = graph.tensor_like(x_gpu)
scale = graph.tensor_like(scale_gpu)
bias = graph.tensor_like(bias_gpu)

in_running_mean = graph.tensor_like(running_mean_gpu)
in_running_var = graph.tensor_like(running_var_gpu)
epsilon = graph.tensor_like(eps_cpu)
momentum = graph.tensor_like(momentum_cpu)
comparison = graph.tensor_like(x_gpu)

y_before_relu, saved_mean, saved_inv_var, out_running_mean, out_running_var = (
    graph.batchnorm(
        name="BN",
        input=x,
        scale=scale,
        bias=bias,
        in_running_mean=in_running_mean,
        in_running_var=in_running_var,
        epsilon=epsilon,
        momentum=momentum,
    )
)
y = graph.relu(name="relu", input=y_before_relu)
mask = graph.cmp_gt(name="cmp", input=y, comparison=comparison)

y.set_output(True)
saved_mean.set_output(True).set_data_type(cudnn.data_type.FLOAT)
saved_inv_var.set_output(True).set_data_type(cudnn.data_type.FLOAT)
out_running_mean.set_output(True).set_data_type(cudnn.data_type.FLOAT)
out_running_var.set_output(True).set_data_type(cudnn.data_type.FLOAT)
mask.set_output(True).set_data_type(cudnn.data_type.BOOLEAN)
pass

Build the graph

In [None]:
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph.check_support()
graph.build_plans()

Execute the graph

In [None]:
variant_pack = {
    x: x_gpu,
    scale: scale_gpu,
    bias: bias_gpu,
    in_running_mean: running_mean_gpu,
    in_running_var: running_var_gpu,
    epsilon: eps_cpu,
    momentum: momentum_cpu,
    out_running_mean: running_mean_gpu,
    out_running_var: running_var_gpu,
    saved_mean: saved_mean_gpu,
    saved_inv_var: saved_inv_var_gpu,
    y: y_gpu,
    comparison: comparison_gpu,
    mask: mask_gpu,
}
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
graph.execute(
    variant_pack,
    workspace,
    handle=handle,
)
torch.cuda.synchronize()

Test cuDNN's output against PyTorch's and check correctness

In [None]:
x_ref = x_gpu.clone().float()
running_mean_ref = running_mean_gpu.clone().float()
running_var_ref = running_var_gpu.clone().float()

y_before_relu_ref = torch.nn.functional.batch_norm(
    x_ref,
    running_mean_ref,  # running_mean is both input and output
    running_var_ref,  # running_var is both input and output
    weight=scale_gpu,
    bias=bias_gpu,
    training=True,
    momentum=momentum_cpu.item(),
    eps=eps_cpu.item(),
)

mean_ref = torch.mean(x_ref, dim=(0, 2, 3), keepdim=True)
inv_var_ref = torch.var(x_ref, dim=(0, 2, 3), keepdim=True)
inv_var_ref = torch.rsqrt(inv_var_ref + epsilon_value)
y_ref = torch.relu(y_before_relu_ref)
mask_ref = y_ref > 0

torch.testing.assert_close(y_ref, y_gpu.float(), atol=1e-3, rtol=1e-3)
torch.testing.assert_close(mean_ref, saved_mean_gpu.float(), atol=1e-3, rtol=1e-3)
torch.testing.assert_close(inv_var_ref, saved_inv_var_gpu.float(), atol=1e-3, rtol=1e-3)