# Mixed precision matrix multiplication operation using cudnn FE
This notebook shows how a mixed precision matmul operation can be done using cudnn.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/03_mixed_precision_matmul.ipynb)

## Prerequisites for running on Colab
This notebook requires an NVIDIA GPU H100 or newer. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected.

In [None]:
# get_ipython().system('nvidia-smi')

If running on Colab, you will need to install the cudnn python interface.

In [None]:
# get_ipython().system('export CUDA_VERSION="12.3"')
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12  | grep Location | cut -d":" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')

#### General Setup
We are going to call the cudnn through torch in this example. In general any dlpack tensor should work.
cudnn handle is a per device handle used to initialize cudnn context.


In [None]:
import cudnn
import torch
import sys

handle = cudnn.create_handle()

#### Create input tensors and calculate reference

In [None]:
batch, m, n, k = 16, 128, 128, 512

# input data types can be different
input_type_a = torch.int8
input_type_b = torch.bfloat16
output_type  = torch.bfloat16

# direct input data type for the matmul operation
mma_data_type = torch.bfloat16

# input tensors
if input_type_a != torch.int8:
    a = 2 * torch.randn(batch, m, k, dtype=input_type_a, device='cuda') - 0.5
else:
    a = torch.randint(4, (batch, m, k), dtype=input_type_a, device='cuda') - 1

if input_type_b != torch.int8:
    b_row_major = 3 * torch.randn(batch, k, n, dtype=input_type_b, device='cuda') - 1.25
else:
    b_row_major = torch.randint(3, (batch, k, n), dtype=input_type_b, device='cuda').contiguous() - 2
b = torch.as_strided(b_row_major, (batch, k, n), (n * k, 1, n))

# reference output
c_ref = torch.matmul(a.to(mma_data_type), b.to(mma_data_type)).to(output_type)

# place holder for cudnn output
c = torch.randn_like(c_ref, device='cuda')

#### Create cudnn graph and tensors

In [None]:
graph = cudnn.pygraph()

a_cudnn_tensor = graph.tensor_like(a)
b_cudnn_tensor = graph.tensor_like(b)

# cudnn will do the following conversion path: input_data_type -> compute_data_type -> output_data_type
# compute_data_type can be int32 as well
a_cudnn_tensor_casted = graph.identity(input = a_cudnn_tensor, compute_data_type=cudnn.data_type.FLOAT)
a_cudnn_tensor_casted.set_data_type(mma_data_type)

# here we omit the code casting tensor b to the mma_data_type
# since both of them are in bf16 data type in this example
# user can also cast tensor b if it has a different input_type from the mma_data_type

# compute_data_type should be set to int32 if the mma_data_type is int8
c_cudnn_tensor = graph.matmul(name = "matmul", A = a_cudnn_tensor_casted, B = b_cudnn_tensor, compute_data_type = cudnn.data_type.FLOAT)
c_cudnn_tensor.set_name("c").set_output(True).set_data_type(output_type)

#### Build the graph

In [None]:
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph.check_support()
graph.build_plans()

#### Execute the code

In [None]:
variant_pack = {
    a_cudnn_tensor: a,
    b_cudnn_tensor: b,
    c_cudnn_tensor: c,
}

workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
graph.execute(variant_pack, workspace)
torch.cuda.synchronize()

In [None]:
torch.testing.assert_close(c, c_ref, rtol = 5e-3, atol = 5e-3)