# Matrix multiplication operation with fused bias using cudnn FE
This notebook shows how a matmul operation with fused bias can be done using cudnn.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/01_matmul_bias.ipynb)

## Prerequisites for running on Colab
This notebook requires an NVIDIA GPU H100 or newer. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected.

In [10]:
# get_ipython().system('nvidia-smi')

If running on Colab, you will need to install the cudnn python interface.

In [11]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128')

#### General Setup
We are going to call the cudnn through torch in this example. In general any dlpack tensor should work.
cudnn handle is a per device handle used to initialize cudnn context.


In [12]:
import cudnn
import torch
import sys

handle = cudnn.create_handle()

#### Create input tensors and calculate reference

In [13]:
batch, m, n, k = 16, 128, 128, 512

input_type = torch.float16

# input tensors
a = torch.randn(batch, m, k, dtype=input_type, device="cuda")
b = torch.randn(batch, k, n, dtype=input_type, device="cuda")
B = torch.randn(1, m, n, dtype=torch.float16, device="cuda")

# reference output
c_ref = torch.matmul(a, b) + B

# place holder for cudnn output
c = torch.randn_like(c_ref, device="cuda")

#### Define the hash of the given graph in terms of tensors

In [14]:
def matmul_cache_key(handle, a, b, bias):
    """Custom key function for matmul + bias"""
    return (
        tuple(a.shape),
        tuple(b.shape),
        tuple(a.stride()),
        tuple(b.stride()),
        a.dtype,
        b.dtype,
    )

#### Create cudnn matmul + bias fused graph.

In [15]:
@cudnn.jit(heur_modes=[cudnn.heur_mode.A, cudnn.heur_mode.B])
@cudnn.graph_cache(key_fn=matmul_cache_key)
def create_matmul_bias_graph(handle, a, b, bias):
    with cudnn.graph(handle) as (g, _):
        a_cudnn = g.tensor_like(a)
        b_cudnn = g.tensor_like(b)
        bias_cudnn = g.tensor_like(bias)
        c_cudnn = g.matmul(name="matmul", A=a_cudnn, B=b_cudnn)
        out = g.bias(name="bias", input=c_cudnn, bias=bias_cudnn)
        out.set_output(True).set_data_type(cudnn.data_type.HALF)

    return g, [a_cudnn, b_cudnn, bias_cudnn, out]  # Return raw graph and tensors

#### Build the graph

In [16]:
g, uids = create_matmul_bias_graph(handle, a, b, B)

a_uid, b_uid, bias_uid, out_uid = uids

#### Execute the code

In [18]:
variant_pack = {
    a_uid: a,
    b_uid: b,
    bias_uid: B,
    out_uid: c,
}

workspace = torch.empty(g.get_workspace_size(), device="cuda", dtype=torch.uint8)
g.execute(variant_pack, workspace)
torch.cuda.synchronize()

In [19]:
torch.testing.assert_close(c, c_ref, rtol=5e-3, atol=5e-3)