This notebook shows how to compute a batchnorm forward operation using cuDNN.

$$\text{BatchNorm}(x) = \frac{x-\mu}{\sqrt{\sigma^2 + \epsilon}}\cdot\gamma+\beta$$

Where $\mu = E[x]$ and $\sigma^2 = Var[x]$ are taken over all inputs in a channel.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/01_matmul_bias.ipynb)

## Prerequisites and Setup
This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected.

In [1]:
# get_ipython().system('nvidia-smi')

If running on Colab, you will need to install the cudnn python interface.

In [2]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')

#### General Setup
Create a cudnn handle, which is a per device handle used to initialize cudnn context.

In [3]:
import cudnn
import torch
import sys

torch.manual_seed(1)
handle = cudnn.create_handle()

print("Running with cudnn backend version:", cudnn.backend_version())

assert torch.cuda.is_available()

Running with cudnn backend version: 90400


### Batchnorm Training Forward

In [4]:
# batch size, channel size, height, width
n, c, h, w = 4, 16, 56, 56
input_type = torch.float16

# Epsilon is a small number to prevent division by 0.
epsilon_value = 1e-3
# Momentum value is used in computing running stats during training where
# running_mean_next = (1 - momentum) * running_mean + momentum * local_mean
momentum_value = 1e-1

Create input and output tensor buffers in PyTorch.

In [5]:
# input tensors
x_gpu = torch.randn(n, c, h, w, dtype=input_type, device="cuda")
x_gpu = x_gpu.to(memory_format=torch.channels_last)
scale_gpu = torch.randn(1, c, 1, 1, device="cuda")
bias_gpu = torch.randn_like(scale_gpu)
running_mean_gpu = torch.randn_like(scale_gpu)
running_var_gpu = torch.randn_like(scale_gpu)

comparison_gpu = torch.zeros_like(x_gpu, dtype=input_type, device="cuda")

epsilon_cpu = torch.full((1, 1, 1, 1), epsilon_value)
momentum_cpu = torch.full((1, 1, 1, 1), momentum_value)

# output tensors
saved_mean_gpu = torch.empty_like(running_mean_gpu, device="cuda")
saved_inv_var_gpu = torch.empty_like(running_var_gpu, device="cuda")
y_gpu = torch.empty_like(x_gpu, dtype=input_type, device="cuda")
mask_gpu = torch.empty_like(x_gpu, dtype=torch.bool, device="cuda")

Create cuDNN graph

In [6]:
graph = cudnn.pygraph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
    handle=handle,
)

x = graph.tensor_like(x_gpu)
scale = graph.tensor_like(scale_gpu)
bias = graph.tensor_like(bias_gpu)

in_running_mean = graph.tensor_like(running_mean_gpu)
in_running_var = graph.tensor_like(running_var_gpu)
epsilon = graph.tensor_like(epsilon_cpu)
momentum = graph.tensor_like(momentum_cpu)
comparison = graph.tensor_like(x_gpu)

y_before_relu, saved_mean, saved_inv_var, out_running_mean, out_running_var = (
    graph.batchnorm(
        name="BN",
        input=x,
        scale=scale,
        bias=bias,
        in_running_mean=in_running_mean,
        in_running_var=in_running_var,
        epsilon=epsilon,
        momentum=momentum,
    )
)
y = graph.relu(name="relu", input=y_before_relu)
mask = graph.cmp_gt(name="cmp", input=y, comparison=comparison)

y.set_output(True)
saved_mean.set_output(True).set_data_type(cudnn.data_type.FLOAT)
saved_inv_var.set_output(True).set_data_type(cudnn.data_type.FLOAT)
out_running_mean.set_output(True).set_data_type(cudnn.data_type.FLOAT)
out_running_var.set_output(True).set_data_type(cudnn.data_type.FLOAT)
mask.set_output(True).set_data_type(cudnn.data_type.BOOLEAN)
pass

Build the graph

In [7]:
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph.check_support()
graph.build_plans()

Execute the graph

In [8]:
variant_pack = {
    x: x_gpu,
    scale: scale_gpu,
    bias: bias_gpu,
    in_running_mean: running_mean_gpu,
    in_running_var: running_var_gpu,
    epsilon: epsilon_cpu,
    momentum: momentum_cpu,
    out_running_mean: running_mean_gpu,
    out_running_var: running_var_gpu,
    saved_mean: saved_mean_gpu,
    saved_inv_var: saved_inv_var_gpu,
    y: y_gpu,
    comparison: comparison_gpu,
    mask: mask_gpu,
}
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
graph.execute(
    variant_pack,
    workspace,
    handle=handle,
)
torch.cuda.synchronize()

Test cuDNN's output against PyTorch's and check correctness

In [9]:
x_ref = x_gpu.clone().float()
running_mean_ref = running_mean_gpu.clone().float()
running_var_ref = running_var_gpu.clone().float()

y_before_relu_ref = torch.nn.functional.batch_norm(
    x_ref,
    running_mean_ref,  # running_mean is both input and output
    running_var_ref,  # running_var is both input and output
    weight=scale_gpu,
    bias=bias_gpu,
    training=True,
    momentum=momentum_cpu.item(),
    eps=epsilon_cpu.item(),
)

mean_ref = torch.mean(x_ref, dim=(0, 2, 3), keepdim=True)
inv_var_ref = torch.var(x_ref, dim=(0, 2, 3), keepdim=True)
inv_var_ref = torch.rsqrt(inv_var_ref + epsilon_value)
y_ref = torch.relu(y_before_relu_ref)
mask_ref = y_ref > 0

torch.testing.assert_close(y_ref, y_gpu.float(), atol=1e-3, rtol=1e-3)
torch.testing.assert_close(mean_ref, saved_mean_gpu.float(), atol=1e-3, rtol=1e-3)
torch.testing.assert_close(inv_var_ref, saved_inv_var_gpu.float(), atol=1e-3, rtol=1e-3)
# torch.testing.assert_close(mask_ref, mask_gpu.float(), atol=1e-3, rtol=1e-3)