# Scaled Dot Product Attention (SDPA) Backward in cuDNN Frontend

This operation computes gradient tensors for scaled dot product attention using the FlashAttention-2 algorithm as described in the paper FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. The user is required to pass the stats tensor from the forward operation to the backward operation as input.

The full documentation can be found in: [docs/operations/Attention.md#scaled-dot-product-attention-backward](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-backward)

The python test code for the full set of features can be found in: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/51_sdpa_backward.ipynb)

## Prerequisites and Setup
This notebook requires an NVIDIA GPU A100 or newer. If running on Colab, go to Runtime → Change runtime type → Hardware accelerator and select a GPU.

In [None]:
# get_ipython().system('nvidia-smi')

In [None]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128')

## Overview

For this example, we will use the problem size from the original GPT-2 paper where:
 - maximum sequence length = 1024
 - hidden dim = number of heads $\times$ embedding dimension per head = 12 $\times$ 64 = 768

In [None]:
import cudnn
import torch
import math

torch.manual_seed(42)

handle = cudnn.create_handle()
assert torch.cuda.is_available()
assert (
    torch.cuda.get_device_capability()[0] >= 8
), "SDPA operation is only supported on SM80 architecture (Ampere) or above"
assert (
    cudnn.backend_version() >= 8903
), "SDPA operation is only supported cuDNN version 8.9.3 or above"

In [None]:
B = 4  # batch size
S = 1024  # maximum sequence length
H = 12  # query number of heads
D = 64  # embedding dimension per head
dtype = torch.half

attn_scale = 1.0 / math.sqrt(D)

## Using Wrapper

#### Forward pass graph

In [None]:
# allocate random input tensors: BSHD physical layout and BHSD logical layout
q_gpu = torch.randn(
    B, S, H, D, device="cuda", dtype=dtype, requires_grad=True
).transpose(1, 2)
k_gpu = torch.randn(
    B, S, H, D, device="cuda", dtype=dtype, requires_grad=True
).transpose(1, 2)
v_gpu = torch.randn(
    B, S, H, D, device="cuda", dtype=dtype, requires_grad=True
).transpose(1, 2)

# Forward graph
with cudnn.Graph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
    workspace_alloc=False,
    inputs=["SDPA::q", "SDPA::k", "SDPA::v"],
    outputs=["output", "stats"],
) as fwd_graph:
    o, stats = fwd_graph.sdpa(
        name="SDPA",
        q=q_gpu,
        k=k_gpu,
        v=v_gpu,
        attn_scale=attn_scale,
        is_inference=False,
        use_causal_mask=True,
    )
    o.set_output(True).set_dim(q_gpu.shape).set_stride(q_gpu.stride()).set_name(
        "output"
    )
    stats.set_output(True).set_data_type(cudnn.data_type.FLOAT).set_name("stats")

#### Backward pass graph

In [None]:
# allocate random tensor in place of the gradients
dO_gpu = torch.randn_like(q_gpu)
# allocate random tensor in place of the output tensors from forward graph to
# help creating the backward graph
o_gpu = torch.randn_like(q_gpu)
stats_gpu = torch.randn(B, H, S, 1, device="cuda", dtype=torch.float32)

# define the backward graph
with cudnn.Graph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
    workspace_alloc=False,
    inputs=[
        "d_sdpa::q",
        "d_sdpa::k",
        "d_sdpa::v",
        "d_sdpa::o",
        "d_sdpa::stats",
        "d_sdpa::dO",
    ],
    outputs=["dQ", "dK", "dV"],
) as bwd_graph:
    dQ, dK, dV = bwd_graph.sdpa_backward(
        name="d_sdpa",
        q=q_gpu,
        k=k_gpu,
        v=v_gpu,
        o=o_gpu,
        dO=dO_gpu,
        stats=stats_gpu,
        attn_scale=attn_scale,
        use_causal_mask=True,
    )
    dQ.set_output(True).set_dim(q_gpu.shape).set_stride(q_gpu.stride()).set_name("dQ")
    dK.set_output(True).set_dim(k_gpu.shape).set_stride(k_gpu.stride()).set_name("dK")
    dV.set_output(True).set_dim(v_gpu.shape).set_stride(v_gpu.stride()).set_name("dV")

#### Execute the graphs

Here we reuse the same workspace for both forward and backward graphs to save memory.

In [None]:
# Create workspace as maximum size between the forward and backward graphs
workspace_size = max(fwd_graph.get_workspace_size(), bwd_graph.get_workspace_size())
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)

# execute the forward graph
o_gpu, stats_gpu = fwd_graph(q_gpu, k_gpu, v_gpu, workspace=workspace, handle=handle)

# verify the forward result with PyTorch operations
o_ref = torch.nn.functional.scaled_dot_product_attention(
    q_gpu, k_gpu, v_gpu, is_causal=True, scale=attn_scale
)
torch.testing.assert_close(o_ref, o_gpu, atol=5e-3, rtol=3e-3)

# execute the backward graph
dQ_gpu, dK_gpu, dV_gpu = bwd_graph(
    q_gpu, k_gpu, v_gpu, o_gpu, stats_gpu, dO_gpu, workspace=workspace, handle=handle
)

# verify the backward result with PyTorch operations
dQ_ref, dK_ref, dV_ref = torch.autograd.grad(
    outputs=[o_ref], inputs=[q_gpu, k_gpu, v_gpu], grad_outputs=[dO_gpu]
)
torch.testing.assert_close(dQ_ref, dQ_gpu, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(dK_ref, dK_gpu, atol=5e-3, rtol=3e-3)
torch.testing.assert_close(dV_ref, dV_gpu, atol=5e-3, rtol=3e-3)

## Using Python Binding APIs

Create the query, key, value, and output GPU tensors using PyTorch.

**However for backwards computation, we also need to pass the stats tensor between the forward graph and the backward graph.**

The stats tensor should have dims $(B, H, S, 1)$ and float32 datatype.

In [None]:
# The tensors will have non-interleaved
# BSHD (batch, sequence_length, num_head, dims_per_head) physical tensor layout
# BHSD (batch, num_head, sequence_length, dims_per_head) logical tensor layout
dims = (B, H, S, D)
strides = (S * H * D, D, H * D, 1)

# input tensors for the forward pass
q_gpu = torch.randn(B * S * H * D).half().cuda().as_strided(dims, strides)
k_gpu = torch.randn(B * S * H * D).half().cuda().as_strided(dims, strides)
v_gpu = torch.randn(B * S * H * D).half().cuda().as_strided(dims, strides)
# preallocated output tensors for the forward pass
o_gpu = torch.empty(B * S * H * D).half().cuda().as_strided(dims, strides)
stats_gpu = torch.empty(B, H, S, 1).float().cuda()

Create the forward graph and build

In [None]:
graph_forward = cudnn.pygraph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

q_forward = graph_forward.tensor_like(q_gpu)
k_forward = graph_forward.tensor_like(k_gpu)
v_forward = graph_forward.tensor_like(v_gpu)

# training mode is enabled with generate_stats=True
# causal mask is enabled
o_forward, stats_forward = graph_forward.sdpa(
    name="sdpa",
    q=q_forward,
    k=k_forward,
    v=v_forward,
    generate_stats=True,
    attn_scale=attn_scale,
    use_causal_mask=True,
)

o_forward.set_output(True).set_dim(o_gpu.size()).set_stride(o_gpu.stride())
stats_forward.set_output(True).set_dim(stats_gpu.size()).set_stride(stats_gpu.stride())
stats_forward.set_data_type(cudnn.data_type.FLOAT)

graph_forward.validate()
graph_forward.build_operation_graph()
graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph_forward.check_support()
graph_forward.build_plans()

Also create the query, key, value, and output gradient tensors to be used for backwards computation.

In [None]:
# note: torch 'like' preserves the strided layout
dO_gpu = torch.randn_like(o_gpu)
dQ_gpu = torch.empty_like(q_gpu)
dK_gpu = torch.empty_like(k_gpu)
dV_gpu = torch.empty_like(v_gpu)

Create the backward graph and build

In [None]:
graph_backward = cudnn.pygraph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

q_backward = graph_backward.tensor_like(q_gpu)
k_backward = graph_backward.tensor_like(k_gpu)
v_backward = graph_backward.tensor_like(v_gpu)
o_backward = graph_backward.tensor_like(o_gpu)
dO_backward = graph_backward.tensor_like(dO_gpu)
stats_backward = graph_backward.tensor_like(stats_gpu)

dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
    name="sdpa_backward",
    q=q_backward,
    k=k_backward,
    v=v_backward,
    o=o_backward,
    dO=dO_backward,
    stats=stats_backward,
    attn_scale=attn_scale,
    use_causal_mask=True,
)

dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())

graph_backward.validate()
graph_backward.build_operation_graph()
graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph_backward.check_support()
graph_backward.build_plans()

Allocate workspace required to execute. We take the maximum since forward and backward are executed sequentially.

In [None]:
workspace_size = max(
    graph_forward.get_workspace_size(),
    graph_backward.get_workspace_size(),
)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)

Execute forward graph

In [None]:
variant_pack_forward = {
    q_forward: q_gpu,
    k_forward: k_gpu,
    v_forward: v_gpu,
    o_forward: o_gpu,
    stats_forward: stats_gpu,
}

graph_forward.execute(variant_pack_forward, workspace, handle=handle)
torch.cuda.synchronize()

Execute backward graph

In [None]:
variant_pack_backward = {
    q_backward: q_gpu,
    k_backward: k_gpu,
    v_backward: v_gpu,
    o_backward: o_gpu,
    dO_backward: dO_gpu,
    stats_backward: stats_gpu,
    dQ_backward: dQ_gpu,
    dK_backward: dK_gpu,
    dV_backward: dV_gpu,
}

graph_backward.execute(variant_pack_backward, workspace, handle=handle)
torch.cuda.synchronize()

Test cuDNN's output against PyTorch's and check correctness

In [None]:
q_ref = q_gpu.detach().float().requires_grad_()
k_ref = k_gpu.detach().float().requires_grad_()
v_ref = v_gpu.detach().float().requires_grad_()
dO_ref = dO_gpu.detach().float()

o_ref = torch.nn.functional.scaled_dot_product_attention(
    q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale
)
torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)

dQ_ref, dK_ref, dV_ref = torch.autograd.grad(
    outputs=[o_ref], inputs=[q_ref, k_ref, v_ref], grad_outputs=[dO_ref]
)
torch.testing.assert_close(dQ_ref, dQ_gpu.float(), atol=5e-3, rtol=3e-3)
torch.testing.assert_close(dK_ref, dK_gpu.float(), atol=5e-3, rtol=3e-3)
torch.testing.assert_close(dV_ref, dV_gpu.float(), atol=5e-3, rtol=3e-3)


cudnn.destroy_handle(handle)