# SDPA operation using cudnn FE and serialization
This notebook shows how a sdpa operation can be done using cudnn and how to serialize and deserialize the graph.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/02_sdpa_graph_serialization.ipynb)

If running on Colab, you will need to install the cudnn python interface.

In [1]:
# get_ipython().system('export CUDA_VERSION="12.3"')
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12  | grep Location | cut -d":" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')

#### General Setup
We are going to call the cudnn through torch in this example. In general any dlpack tensor should work.
cudnn handle is a per device handle used to initialize cudnn context.


In [2]:
import cudnn
import torch
from enum import Enum

handle = cudnn.create_handle()

#### Problem definition

In [3]:
b = 2 # batch size

s_q  = 1024 # query sequence length
s_kv = 1024 # key+value sequence length

h = 6 # Query heads

d = 64   # query+key embedding dimension per head

shape_q = (b, h, s_q, d)
shape_k = (b, h, s_kv, d)
shape_v = (b, h, s_kv, d)
shape_o = (b, h, s_q, d)

stride_q = (s_q  * h * d, d, h * d, 1)
stride_k = (s_kv * h * d, d, h * d, 1)
stride_v = (s_kv * h * d, d, h * d, 1)
stride_o = (s_q  * h * d, d, h * d, 1)

attn_scale = 0.125

q_gpu     = torch.randn(b * h * s_q * d, dtype=torch.bfloat16, device="cuda").as_strided(shape_q, stride_q)
k_gpu     = torch.randn(b * h * s_kv * d, dtype=torch.bfloat16, device="cuda").as_strided(shape_k, stride_k)
v_gpu     = torch.randn(b * h * s_kv * d, dtype=torch.bfloat16, device="cuda").as_strided(shape_v, stride_v)
o_gpu     = torch.empty(b * h * s_q * d, dtype=torch.bfloat16, device="cuda").as_strided(shape_o, stride_o)
stats_gpu = torch.empty(b, h, s_q, 1, dtype=torch.float32, device="cuda")

class UIDs(Enum):
    Q_UID     = 0
    K_UID     = 1
    V_UID     = 2
    O_UID     = 3
    STATS_UID = 4

#### Graph build helper
This will called by check_support and serialize function to build the sdpa graph

In [4]:
def build_and_validate_graph_helper():
    graph = cudnn.pygraph(
        io_data_type=cudnn.data_type.HALF,
        intermediate_data_type=cudnn.data_type.FLOAT,
        compute_data_type=cudnn.data_type.FLOAT,
        handle = handle)
    
    q = graph.tensor_like(q_gpu)
    k = graph.tensor_like(k_gpu)
    v = graph.tensor_like(v_gpu)
    
    o, stats = graph.sdpa(name="sdpa",
        q=q, k=k, v=v,
        is_inference=False,
        attn_scale=attn_scale,
        use_causal_mask=True)
    
    o.set_output(True).set_dim(shape_o).set_stride(stride_o)
    stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
    
    q.set_uid(UIDs.Q_UID.value)
    k.set_uid(UIDs.K_UID.value)
    v.set_uid(UIDs.V_UID.value)
    o.set_uid(UIDs.O_UID.value)
    stats.set_uid(UIDs.STATS_UID.value)
    
    graph.validate()
    
    return graph

#### Check support 

In [5]:
def check_support():
    
    graph = build_and_validate_graph_helper()
    
    graph.build_operation_graph()
    
    graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])

    graph.check_support()

#### Serialization function

In [6]:
def serialize():
    graph = build_and_validate_graph_helper()
    
    graph.build_operation_graph()
    
    graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])

    graph.check_support()
    
    graph.build_plans()
    
    return graph.serialize()

#### De-serialization function

In [7]:
def deserialize(payload):
    
    graph = cudnn.pygraph()
    
    graph.deserialize(payload)
    
    return graph

####  running the execution plan

In [8]:
check_support()

data = serialize()

deserialized_graph  = deserialize(data)

workspace = torch.empty(deserialized_graph.get_workspace_size(), device="cuda", dtype=torch.uint8)

variant_pack = {
    UIDs.Q_UID.value: q_gpu,
    UIDs.K_UID.value: k_gpu,
    UIDs.V_UID.value: v_gpu,
    UIDs.O_UID.value: o_gpu,
    UIDs.STATS_UID.value: stats_gpu,
}

deserialized_graph.execute(variant_pack, workspace)

torch.cuda.synchronize()