# Scaled Dot Product Attention (SDPA) in cuDNN Frontend

This notebook is an example for the scaled dot product attention operator in cuDNN frontend. This operation computes scaled dot product attention as

$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V$

using the FlashAttention-2 algorithm described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). It is applicable for both training and inference phases, with an option to generate a stats tensor to be used for backwards training computation.

The full documentation can be found in: [docs/operations/Attention.md#scaled-dot-product-attention](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention)

The python test code for the full set of features can be found in: [test/python_fe/test.mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb)

#### Prerequisites and Setup
This notebook requires an NVIDIA GPU A100 or newer. If running on Colab, go to Runtime → Change runtime type → Hardware accelerator and selct a GPU.

In [1]:
# get_ipython().system('nvidia-smi')

In [2]:
# get_ipython().system('export CUDA_VERSION="12.3"')
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12  | grep Location | cut -d":" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')

In [3]:
import cudnn
import torch
import math

torch.manual_seed(42)
handle = cudnn.create_handle()

assert torch.cuda.is_available()
assert torch.cuda.get_device_capability()[0] >= 8, "SDPA operation is only supported on SM80 architecture (Ampere) or above"
assert cudnn.backend_version() >= 8903, "SDPA operation is only supported cuDNN version 8.9.3 or above"

#### Problem sizes

For this example, we will use the problem size from the original GPT-2 paper where:
 - maximum sequence length = 1024
 - hidden dim = number of heads $\times$ embedding dimension per head = 12 $\times$ 64 = 768

In [4]:
b = 4    # batch size
h = 12   # query number of heads
s = 1024 # maximum sequence length
d = 64   # embedding dimension per head

attn_scale = 1.0 / math.sqrt(d)

Create the query, key, value, and output GPU tensors using PyTorch. However, the user may use any DLPack compatible tensor instead.

In [5]:
# The tensors will have non-interleaved
# BSHD (batch, sequence_length, num_head, dims_per_head) physical tensor layout
# BHSD (batch, num_head, sequence_length, dims_per_head) logical tensor layout
dims = (b, h, s, d)
strides = (s * h * d, d, h * d, 1)

q_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)
k_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)
v_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)
o_gpu = torch.empty(b * s * h * d).half().cuda().as_strided(dims, strides)

Create the graph

In [6]:
graph = cudnn.pygraph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

q = graph.tensor_like(q_gpu)
k = graph.tensor_like(k_gpu)
v = graph.tensor_like(v_gpu)

# the second return for the stats tensor is used for training only.
# causal mask is enabled
o, _ = graph.sdpa(
    name="sdpa",
    q=q, k=k, v=v,
    is_inference=True,
    attn_scale=attn_scale,
    use_causal_mask=True,
)

o.set_output(True).set_dim(dims).set_stride(strides)
pass

Build the graph

In [7]:
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph.check_support()
graph.build_plans()

Execute the graph

In [8]:
variant_pack = {
    q: q_gpu,
    k: k_gpu,
    v: v_gpu,
    o: o_gpu,
}

workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
graph.execute(variant_pack, workspace)
torch.cuda.synchronize()

Test cuDNN's output against PyTorch's and check correctness

In [9]:
q_ref = q_gpu.detach().float().requires_grad_()
k_ref = k_gpu.detach().float().requires_grad_()
v_ref = v_gpu.detach().float().requires_grad_()

o_ref = torch.nn.functional.scaled_dot_product_attention(q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale)
torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)