# Fusion Operation with Instance Norm

This notebook shows how to compute an instance Norm (+ add + relu) using the cuDNN python frontend.

$$y = \text{ReLU}\big(\text{InstanceNorm(x)}+A\big) = \max\Big(0, \big(\frac{x - \mathbb{E}(x)}{\sqrt{Var(x)+\epsilon}}\cdot \gamma + \beta\big) + A\Big)$$

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/31_instancenorm_fusion.ipynb)

## Prerequisites and Setup
This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected.

In [None]:
# get_ipython().system('nvidia-smi')

If running on Colab, you will need to install the cudnn python interface.

In [None]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128')

## Overview

In the following, we are going to perform the instance norm with the following batch size, number of channels, spatial dimensions, and some other parameters:

In [None]:
import cudnn
import torch

torch.manual_seed(1)
assert torch.cuda.is_available()

handle = cudnn.create_handle()

N, C, H, W = 16, 32, 64, 64
dtype = torch.float16
# Epsilon is a small number to prevent division by 0.
epsilon_value = 1e-5

## Using Wrapper

This notebook shows how to use the `Graph` wrapper to compute instance norm, add, and ReLU:

In [None]:
# input tensors
x_gpu = torch.randn(N, C, H, W, device="cuda", dtype=dtype).to(
    memory_format=torch.channels_last
)
scale_gpu = torch.randn(1, C, 1, 1, device="cuda").to(memory_format=torch.channels_last)
bias_gpu = torch.randn(1, C, 1, 1, device="cuda").to(memory_format=torch.channels_last)
eps_cpu = torch.full((1, 1, 1, 1), epsilon_value, dtype=torch.float32, device="cpu")
a_gpu = torch.randn(N, C, H, W, device="cuda", dtype=dtype).to(
    memory_format=torch.channels_last
)

# forward pass of instance norm + add + relu using cuDNN graph
with cudnn.Graph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
    inputs=[
        "in_fwd::input",
        "in_fwd::scale",
        "in_fwd::bias",
        "in_fwd::epsilon",
        "add::b",
    ],
    outputs=["relu"],
) as fwd_graph:
    out, mean, inv_var = fwd_graph.instancenorm(
        name="in_fwd",
        norm_forward_phase=cudnn.norm_forward_phase.INFERENCE,
        input=x_gpu.detach(),
        scale=scale_gpu.detach(),
        bias=bias_gpu.detach(),
        epsilon=eps_cpu,
    )
    assert mean is None, "Instance norm in inference mode should return mean=None"
    assert inv_var is None, "Instance norm in inference mode should return inv_var=None"
    sum_out = fwd_graph.add(a=out, b=a_gpu, name="add")
    relu_out = fwd_graph.relu(sum_out, name="relu")
    relu_out.set_output(True).set_data_type(dtype).set_name("relu")

y_gpu = fwd_graph(x_gpu, scale_gpu, bias_gpu, eps_cpu, a_gpu, handle=handle)

Then, we can verify the correctness with PyTorch:

In [None]:
# PyTorch forward pass, and the stats of this batch
out_ref = torch.nn.functional.instance_norm(
    x_gpu,
    weight=scale_gpu.view(C),
    bias=bias_gpu.view(C),
)
y_ref = torch.relu(out_ref + a_gpu)
mask_ref = y_ref > 0.0

torch.testing.assert_close(y_gpu, y_ref, atol=5e-3, rtol=3e-3)

## Using Python Binding APIs

Next, we need to create GPU buffers as input. We use PyTorch tensors here so we can reuse them easily when we calculate reference outputs.

In [None]:
# input tensor memory, initialize them to random numbers
x_gpu = torch.randn(N, C, H, W, device="cuda", dtype=dtype, requires_grad=True).to(
    memory_format=torch.channels_last
)
scale_gpu = torch.randn(1, C, 1, 1, device="cuda", requires_grad=True).to(
    memory_format=torch.channels_last
)
bias_gpu = torch.randn(1, C, 1, 1, device="cuda", requires_grad=True).to(
    memory_format=torch.channels_last
)

# set epsilon to epsilon_value, allocate on cpu.
eps_cpu = torch.full((1, 1, 1, 1), epsilon_value, dtype=torch.float32, device="cpu")

# Create tensors for fusion and intermediate ops
a_gpu = torch.randn(N, C, H, W, device="cuda", dtype=dtype).to(
    memory_format=torch.channels_last
)
sum_gpu = torch.empty_like(a_gpu)

Create the graph

In [None]:
# Create the graph
graph = cudnn.pygraph(
    handle=handle,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

# create tensor handles with the graph API
x = graph.tensor_like(x_gpu.detach()).set_name("X")
scale = graph.tensor_like(scale_gpu.detach()).set_name("scale")
bias = graph.tensor_like(bias_gpu.detach()).set_name("bias")
epsilon = graph.tensor_like(eps_cpu).set_name("epsilon")
a = graph.tensor_like(x).set_name("A")

# instance norm + add + relu
y, mean, inv_var = graph.instancenorm(
    name="in_fwd",
    input=x,
    norm_forward_phase=cudnn.norm_forward_phase.INFERENCE,
    scale=scale,
    bias=bias,
    epsilon=epsilon,
)
sum_out = graph.add(y, a, name="add")
sum_out.set_name("sum")
relu = graph.relu(sum_out)
relu.set_name("sum").set_output(True).set_data_type(cudnn.data_type.HALF)

# Build the graph
graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])

Execute the graph

In [None]:
# prepare buffer for output
relu_gpu = torch.empty(N, C, H, W, device="cuda", dtype=dtype).to(
    memory_format=torch.channels_last
)
mask_gpu = torch.empty(N, C, H, W, device="cuda", dtype=torch.bool).to(
    memory_format=torch.channels_last
)

# Mapping of (handles -> memory)
variant_pack = {
    # input tensors
    x: x_gpu,
    scale: scale_gpu,
    bias: bias_gpu,
    epsilon: eps_cpu,
    a: a_gpu,
    # output tensor
    relu: relu_gpu,
}
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
graph.execute(variant_pack, workspace)
torch.cuda.synchronize()

Compute reference ouputs and verify the results

In [None]:
y_ref = torch.nn.functional.instance_norm(
    x_gpu,
    weight=scale_gpu.view(C),
    bias=bias_gpu.view(C),
)
relu_ref = torch.nn.functional.relu(y_ref + a_gpu)

torch.testing.assert_close(relu_gpu, relu_ref, atol=5e-3, rtol=3e-3)

Cleanup

In [None]:
cudnn.destroy_handle(handle)