# Scaled Dot Product Attention (SDPA) Backward in cuDNN Frontend

This operation computes gradient tensors for scaled dot product attention using the FlashAttention-2 algorithm as described in the paper FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. The user is required to pass the stats tensor from the forward operation to the backward operation as input.

The full documentation can be found in: [docs/operations/Attention.md#scaled-dot-product-attention-backward](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-backward)

The python test code for the full set of features can be found in: [test/python_fe/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/51_scaled_dot_product_attention_backward.ipynb)

#### Prerequisites and Setup
This notebook requires an NVIDIA GPU A100 or newer. If running on Colab, go to Runtime → Change runtime type → Hardware accelerator and selct a GPU.

In [39]:
# get_ipython().system('nvidia-smi')

In [40]:
# get_ipython().system('export CUDA_VERSION="12.3"')
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12  | grep Location | cut -d":" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')

In [41]:
import cudnn
import torch
import math

torch.manual_seed(42)
handle = cudnn.create_handle()

assert torch.cuda.is_available()
assert torch.cuda.get_device_capability()[0] >= 8, "SDPA operation is only supported on SM80 architecture (Ampere) or above"
assert cudnn.backend_version() >= 8903, "SDPA operation is only supported cuDNN version 8.9.3 or above"

#### Problem sizes

For this example, we will use the problem size from the original GPT-2 paper where:
 - maximum sequence length = 1024
 - hidden dim = number of heads $\times$ embedding dimension per head = 12 $\times$ 64 = 768

In [42]:
b = 4    # batch size
h = 12   # query number of heads
s = 1024 # maximum sequence length
d = 64   # embedding dimension per head

attn_scale = 1.0 / math.sqrt(d)

Create the query, key, value, and output GPU tensors using PyTorch.

**However for backwards computation, we also need to pass the stats tensor between the forward graph and the backward graph.**

The stats tensor should have dims $(B, H, S, 1)$ and float32 datatype.

In [43]:
# The tensors will have non-interleaved
# BSHD (batch, sequence_length, num_head, dims_per_head) physical tensor layout
# BHSD (batch, num_head, sequence_length, dims_per_head) logical tensor layout
dims = (b, h, s, d)
strides = (s * h * d, d, h * d, 1)

q_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)
k_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)
v_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)
o_gpu = torch.empty(b * s * h * d).half().cuda().as_strided(dims, strides)
stats_gpu = torch.empty(b, h, s, 1).float().cuda()

Also create the query, key, value, and output gradient tensors to be used for backwards computation.

In [44]:
# note: torch 'like' preserves the strided layout
dQ_gpu = torch.empty_like(q_gpu)
dK_gpu = torch.empty_like(k_gpu)
dV_gpu = torch.empty_like(v_gpu)
dO_gpu = torch.randn_like(o_gpu)

Create the forward graph and build

In [45]:
graph_forward = cudnn.pygraph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

q_forward = graph_forward.tensor_like(q_gpu)
k_forward = graph_forward.tensor_like(k_gpu)
v_forward = graph_forward.tensor_like(v_gpu)

# training mode in enabled with is_inference=False
# causal mask is enabled
o_forward, stats_forward = graph_forward.sdpa(
    name="sdpa",
    q=q_forward, k=k_forward, v=v_forward,
    is_inference=False,
    attn_scale=attn_scale,
    use_causal_mask=True,
)

o_forward.set_output(True).set_dim(o_gpu.size()).set_stride(o_gpu.stride())
stats_forward.set_output(True).set_dim(stats_gpu.size()).set_stride(stats_gpu.stride())
stats_forward.set_data_type(cudnn.data_type.FLOAT)

graph_forward.validate()
graph_forward.build_operation_graph()
graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph_forward.check_support()
graph_forward.build_plans()

Create the backward graph and build

In [46]:
graph_backward = cudnn.pygraph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

q_backward = graph_backward.tensor_like(q_gpu)
k_backward = graph_backward.tensor_like(k_gpu)
v_backward = graph_backward.tensor_like(v_gpu)
o_backward = graph_backward.tensor_like(o_gpu)
dO_backward = graph_backward.tensor_like(dO_gpu)
stats_backward = graph_backward.tensor_like(stats_gpu)

dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
    name="sdpa_backward",
    q=q_backward, k=k_backward, v=v_backward,
    o=o_backward, dO=dO_backward, stats=stats_backward,
    attn_scale=attn_scale,
    use_causal_mask=True,
)

dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())

graph_backward.validate()
graph_backward.build_operation_graph()
graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph_backward.check_support()
graph_backward.build_plans()

Allocate workspace required to execute. We take the maximum since forward and backward are executed sequentially.

In [47]:
workspace_size = max(
    graph_forward.get_workspace_size(),
    graph_backward.get_workspace_size(),
)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)

Execute forward graph

In [48]:
variant_pack_forward = {
    q_forward: q_gpu,
    k_forward: k_gpu,
    v_forward: v_gpu,
    o_forward: o_gpu,
    stats_forward: stats_gpu,
}

graph_forward.execute(variant_pack_forward, workspace)
torch.cuda.synchronize()

Execute backward graph

In [49]:
variant_pack_backward = {
    q_backward: q_gpu,
    k_backward: k_gpu,
    v_backward: v_gpu,
    o_backward: o_gpu,
    dO_backward: dO_gpu,
    stats_backward: stats_gpu,
    dQ_backward: dQ_gpu,
    dK_backward: dK_gpu,
    dV_backward: dV_gpu,
}

graph_backward.execute(variant_pack_backward, workspace)
torch.cuda.synchronize()

Test cuDNN's output against PyTorch's and check correctness

In [50]:
q_ref = q_gpu.detach().float().requires_grad_()
k_ref = k_gpu.detach().float().requires_grad_()
v_ref = v_gpu.detach().float().requires_grad_()
dO_ref = dO_gpu.detach().float()

o_ref = torch.nn.functional.scaled_dot_product_attention(q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale)
torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)

dQ_ref, dK_ref, dV_ref = torch.autograd.grad(outputs=[o_ref], inputs=[q_ref, k_ref, v_ref], grad_outputs=[dO_ref])
torch.testing.assert_close(dQ_ref, dQ_gpu.float(), atol=5e-3, rtol=3e-3)
torch.testing.assert_close(dK_ref, dK_gpu.float(), atol=5e-3, rtol=3e-3)
torch.testing.assert_close(dV_ref, dV_gpu.float(), atol=5e-3, rtol=3e-3)