# Scaled Dot Product Attention (SDPA) in cuDNN Frontend

This notebook is an example for the scaled dot product attention operator in cuDNN frontend. This operation computes scaled dot product attention as

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V$$

using the FlashAttention-2 algorithm described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). It is applicable for both training and inference phases, with an option to generate a stats tensor to be used for backwards training computation.

The full documentation can be found in: [docs/operations/Attention.md#scaled-dot-product-attention](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention)

The python test code for the full set of features can be found in: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/50_sdpa_forward.ipynb)

## Prerequisites and Setup

This notebook requires an NVIDIA GPU A100 or newer. If running on Colab, go to Runtime → Change runtime type → Hardware accelerator and select a GPU.

In [None]:
# get_ipython().system('nvidia-smi')

In [None]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128')

## Overview

For this example, we will use the problem size from the original GPT-2 paper where:

 - maximum sequence length = 1024
 - hidden dim = number of heads $\times$ embedding dimension per head = 12 $\times$ 64 = 768

In [None]:
import math
import cudnn
import torch

torch.manual_seed(42)

handle = cudnn.create_handle()

assert torch.cuda.is_available()
assert (
    torch.cuda.get_device_capability()[0] >= 8
), "SDPA operation is only supported on SM80 architecture (Ampere) or above"
assert (
    cudnn.backend_version() >= 8903
), "SDPA operation is only supported cuDNN version 8.9.3 or above"

In [None]:
B = 4  # batch size
S = 1024  # maximum sequence length
H = 12  # query number of heads
D = 64  # embedding dimension per head
dtype = torch.half

attn_scale = 1.0 / math.sqrt(D)

## Using Wrapper

In [None]:
# allocate random input tensors: Should be in BSHD physical layout and BHSD logical layout
q_gpu = torch.randn(B, S, H, D, device="cuda", dtype=dtype).transpose(1, 2)
k_gpu = torch.randn(B, S, H, D, device="cuda", dtype=dtype).transpose(1, 2)
v_gpu = torch.randn(B, S, H, D, device="cuda", dtype=dtype).transpose(1, 2)

# create a graph
with cudnn.Graph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
    inputs=["SDPA::q", "SDPA::k", "SDPA::v"],
    outputs=["attn_output"],
) as graph:
    o, _ = graph.sdpa(
        name="SDPA",
        q=q_gpu,
        k=k_gpu,
        v=v_gpu,
        attn_scale=attn_scale,
        is_inference=True,
        use_causal_mask=True,
    )
    o.set_output(True).set_name("attn_output").set_dim(q_gpu.shape).set_stride(
        q_gpu.stride()
    )

# execute the graph
o_gpu = graph(q_gpu, k_gpu, v_gpu, handle=handle)

# verify the result with PyTorch operations
o_ref = torch.nn.functional.scaled_dot_product_attention(
    q_gpu, k_gpu, v_gpu, is_causal=True, scale=attn_scale
)
torch.testing.assert_close(o_ref, o_gpu, atol=5e-3, rtol=3e-3)

## Using Python Binding APIs

Create the query, key, value, and output GPU tensors using PyTorch. However, the user may use any DLPack compatible tensor instead.

In [None]:
# The tensors will have non-interleaved
# BHSD (batch, num_head, sequence_length, dims_per_head) logical tensor layout
dims = (B, H, S, D)
# BSHD (batch, sequence_length, num_head, dims_per_head) physical layout
strides = (S * H * D, D, H * D, 1)
# For BHSD (batch, num_head, sequence_length, dims_per_head) physical tensor layout, uncomment the following:
# strides = (S*H*D, S*D, D, 1)

q_gpu = torch.randn(B, S, H, D, device="cuda", dtype=dtype).transpose(1, 2)
k_gpu = torch.randn(B, S, H, D, device="cuda", dtype=dtype).transpose(1, 2)
v_gpu = torch.randn(B, S, H, D, device="cuda", dtype=dtype).transpose(1, 2)

Create the graph

In [None]:
graph = cudnn.pygraph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

q = graph.tensor_like(q_gpu)
k = graph.tensor_like(k_gpu)
v = graph.tensor_like(v_gpu)

# the second return for the stats tensor is used for training only.
# causal mask is enabled
o, _ = graph.sdpa(
    name="sdpa",
    q=q,
    k=k,
    v=v,
    generate_stats=False,
    attn_scale=attn_scale,
    use_causal_mask=True,
)

o.set_output(True).set_dim(dims).set_stride(strides)
pass

Build the graph

In [None]:
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph.check_support()
graph.build_plans()

Execute the graph

In [None]:
# allocate output tensor
o_gpu = torch.empty(B, H, S, D, device="cuda", dtype=dtype).as_strided(dims, strides)

# execute the graph
variant_pack = {
    q: q_gpu,
    k: k_gpu,
    v: v_gpu,
    o: o_gpu,
}
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
graph.execute(variant_pack, workspace, handle=handle)
torch.cuda.synchronize()

Test cuDNN's output against PyTorch's and check correctness

In [None]:
q_ref = q_gpu.detach().float().requires_grad_()
k_ref = k_gpu.detach().float().requires_grad_()
v_ref = v_gpu.detach().float().requires_grad_()

o_ref = torch.nn.functional.scaled_dot_product_attention(
    q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale
)
torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)


cudnn.destroy_handle(handle)