# SDPA Prefill with Paged Attention

This notebook illustrates how the cuDNN's frontend scaled dot product attention operator can be used with paged K/V caches, specifically for prefill. For a simpler introduction to the scaled dot product attention operator, please refer to [samples/python/50_scaled_dot_product_attention.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb)

The full documentation of cuDNN's scaled dot production attention operator can be found in: [docs/operations/Attention.md#scaled-dot-product-attention](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention). The python test code for the full set of features can be found in: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py)

More details on paged attention can be found in the [PagedAttention paper](https://arxiv.org/abs/2309.06180).

This notebook specifically illustrates the following:
- SDPA Prefill
- Variable sequence lengths
- Q-tensor in a dense format
- Paged Attention
- Running the same graph with variable sequence lengths

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/52_scaled_dot_product_attention.ipynb)

#### Prerequisites and Setup
This notebook requires an NVIDIA GPU A100 or newer. If running on Colab, go to Runtime → Change runtime type → Hardware accelerator and select a GPU.

In [1]:
# get_ipython().system('nvidia-smi')

In [2]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128')

In [3]:
import cudnn
import torch
import math

torch.manual_seed(42)
handle = cudnn.create_handle()

assert torch.cuda.is_available()
assert (
    torch.cuda.get_device_capability()[0] >= 8
), "SDPA operation is only supported on SM80 architecture (Ampere) or above"

assert (
    cudnn.backend_version() >= 90500
), "SDPA operation is only supported cuDNN version 9.5.0 or above"

# An issue is preventing packed Q-tensors in versions prior to 9.10.0
packed_Q = cudnn.backend_version() >= 91000

#### Problem sizes and Q/K/V setup

Create the query, key, value, and output GPU tensors using PyTorch. However, the user may use any DLPack compatible tensor instead.

In [5]:
b = 4  # batch size
h_q = 12  # query number of heads
h_kv = 12  # key and value number of heads
s_q = 128  # maximum sequence length for Q
s_kv = 128  # maximum sequence length for K/V
d = 64  # embedding dimension per head

block_size_k = block_size_v = (
    8  # block size to be used by the non contiguous K/V containers
)
attn_scale = 1.0 / math.sqrt(d)

# BHSD (batch, sequence_length, num_head, dims_per_head) logcial tensor layout
dims_q = (b, h_q, s_q, d)
dims_kv = (b, h_kv, s_kv, d)
# BSHD physical tensor layout (this is required for packed Q-tensors)
strides_q = (s_q * h_q * d, d, h_q * d, 1)
strides_kv = (s_kv * h_kv * d, d, h_kv * d, 1)

# Randomly initialize the query, key, and value tensors.
q_gpu = torch.randn(b * s_q * h_q * d).half().cuda().as_strided(dims_q, strides_q)
k_gpu = torch.randn(b * s_kv * h_kv * d).half().cuda().as_strided(dims_kv, strides_kv)
v_gpu = torch.randn(b * s_kv * h_kv * d).half().cuda().as_strided(dims_kv, strides_kv)
o_gpu = torch.empty(b * s_q * h_q * d).half().cuda().as_strided(dims_q, strides_q)

#### Setup actual sequence lengths and ragged offsets
While we defined `s_q` as the maximum sequence length for Q and `s_kv` as the maximum sequence length for K/V, not all sequences have the same length. Therefore we specify actual sequence lengths in this section.

While optional for Q, actual sequence lengths are required when using paged K/V caches. To keep things simple, we set initialize the actual sequence lengths for KV to the maximum sequence length `s` in this example, but we will specify random sequence lengths for Q.

Lastly, when Q is in a packed format, we also need to create a ragged offset tensor. This is a tensor that indicates the start of each sample.

In [None]:
# @brief Helper function to return variable sequence lengths, along with a ragged offset tensor inidicating the start of each sequence
def create_variable_seq_lens(b, s):
    seq_len_q_gpu = torch.randint(1, s, (b, 1, 1, 1), device="cuda")

    # Create a [b+1, 1, 1, 1] ragged offset tensor
    q_ragged_offset_gpu = (
        torch.cat(
            (
                torch.zeros(1, 1, 1, 1, dtype=torch.int32, device="cuda"),
                torch.cumsum(seq_len_q_gpu, dim=0),
            )
        )
        * h_q
        * d
    )
    return seq_len_q_gpu, q_ragged_offset_gpu


# For Q, randomly generate sequence lengths between [1,s)
seq_len_q_gpu, q_ragged_offset_gpu = create_variable_seq_lens(b, s_q)

# For KV, set to s for all batches, just to keep this notebook sample simple
seq_len_kv_gpu = torch.full((b, 1, 1, 1), s_kv, device="cuda")

print(seq_len_q_gpu)

####  Generate containers and block tables for K and V

In a real world scenario, container and block table tensors are generated by other parts of the model. For illustration purposes in this example, we provide a helper function to generate a trivial container from contiguous K and V caches. 
The helper function basically takes e.g., the K-cache and splits up the sequence (`S`) dimension in different blocks of length `block_size`. The resulting block table then helps identify which block belongs to which sequence ID.

In [None]:
# @brief Helper function to create a non contiguous container in blocks of block_size from a contiguous tensor
def create_container_and_block_table(tensor, block_size):
    B, H, S, D = tensor.shape
    # num_blocks = math.ceil(S/block_size) * B
    blocks_per_batch = math.ceil(S / block_size)

    # Only needed if S is not a multiple of block_size
    padding_seq = (blocks_per_batch * block_size) - S
    if padding_seq > 0:
        zeros = torch.zeros(B, H, padding_seq, D, device="cuda", dtype=tensor.dtype)
        cat_tensor = torch.cat((tensor, zeros), axis=2)
    else:
        cat_tensor = tensor

    # Create a container by splitting on the S dimension and concatenating at the block dimension
    # Its dimensions are [num_blocks, H, block_size, D] with num_blocks = B * blocks_per_batch
    container = torch.cat((cat_tensor.clone()).chunk(blocks_per_batch, dim=2), dim=0)

    # Create the block table
    table_size = math.ceil(S / block_size)
    block_table_temp = torch.linspace(
        0, B * table_size - 1, B * table_size, device="cuda", dtype=torch.int32
    ).reshape(table_size, 1, B, 1)
    block_table_temp = torch.transpose(block_table_temp, 0, 2)

    # Make batch size outer dimension (cuDNN backend preference)
    block_table = (
        torch.zeros(blocks_per_batch * B)
        .int()
        .cuda()
        .as_strided(
            (B, 1, blocks_per_batch, 1), (blocks_per_batch, blocks_per_batch, 1, 1)
        )
    )
    block_table.copy_(block_table_temp)

    return (container, block_table)


# Create non contiguous containers with block tables for K and V from the contiguous k_gpu and v_gpu
container_k_gpu, block_table_k_gpu = create_container_and_block_table(
    k_gpu, block_size_k
)
container_v_gpu, block_table_v_gpu = create_container_and_block_table(
    v_gpu, block_size_v
)

#### Graph creation

Create the graph

In [6]:
graph = cudnn.pygraph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

q = graph.tensor_like(q_gpu)

if packed_Q:
    q_ragged_offset = graph.tensor_like(q_ragged_offset_gpu)
    q.set_ragged_offset(
        q_ragged_offset
    )  # With Q in a packed layout, we need to indicate the ragged offset

container_k = graph.tensor_like(container_k_gpu)
container_v = graph.tensor_like(container_v_gpu)
block_table_k = graph.tensor_like(block_table_k_gpu)
block_table_v = graph.tensor_like(block_table_v_gpu)

seq_len_q = graph.tensor_like(seq_len_q_gpu)
seq_len_kv = graph.tensor_like(seq_len_kv_gpu)

o, _ = graph.sdpa(
    name="sdpa",
    q=q,
    k=container_k,  # Container K: non contiguous container with K blocks
    v=container_v,  # Container V: non contiguous container with V blocks
    is_inference=False,
    attn_scale=attn_scale,
    use_causal_mask=True,
    use_padding_mask=True,
    seq_len_q=seq_len_q,
    seq_len_kv=seq_len_kv,
    paged_attention_k_table=block_table_k,  # Block Table K: Tensor containing offsets to the container with K blocks
    paged_attention_v_table=block_table_v,  # Block Table V: Tensor containing offsets to the container with V blocks
    paged_attention_max_seq_len_kv=s_kv,  # The maximum sequence length for K caches (this is optional, but recommended)
)

o.set_output(True).set_dim(dims_q).set_stride(strides_q)

Build the graph

In [7]:
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph.check_support()
graph.build_plans()

#### Graph Execution
Execute the graph

In [8]:
variant_pack = {
    q: q_gpu,
    container_k: container_k_gpu,
    container_v: container_v_gpu,
    block_table_k: block_table_k_gpu,
    block_table_v: block_table_v_gpu,
    seq_len_q: seq_len_q_gpu,
    seq_len_kv: seq_len_kv_gpu,
    o: o_gpu,
}
if packed_Q:
    variant_pack[q_ragged_offset] = q_ragged_offset_gpu

print("First execution")
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
graph.execute(variant_pack, workspace)
torch.cuda.synchronize()

#### Run the PyTorch reference and compare against cuDNN's output

In [9]:
def compare_against_torch_ref(q_gpu, k_gpu, v_gpu, o_gpu, seq_len_q_gpu):
    q_gpu_packed = q_gpu.detach().float().requires_grad_()
    k_ref = k_gpu.detach().float().requires_grad_()
    v_ref = v_gpu.detach().float().requires_grad_()

    (b, h, s_q, d) = q_gpu_packed.shape
    s_kv = k_gpu.shape[2]

    mask = torch.ones(b, s_q, s_kv, dtype=torch.bool, device="cuda")

    # Create attention mask for variable lengths in Q
    mask = torch.ones(b, s_q, s_kv, dtype=torch.bool, device="cuda")
    for i, length in enumerate(seq_len_q_gpu):
        mask[i, length:, :] = False

    # Create attention mask for variable lengths in KV
    for i, length in enumerate(seq_len_kv_gpu):
        mask[i, :, length:] = False
    # Causal masking
    for i in range(s_q):
        mask[:, i, i + 1 :] = False

    # Expand mask to match attention shape
    mask = mask.unsqueeze(1)

    o_ref = None
    if packed_Q:
        # Create unpacked tensor with proper shape
        # Convert bhsd to bshd logical layout and flatten
        uniform_tensor = torch.zeros(b, s_q, h, d).to(
            dtype=q_gpu_packed.dtype, device=q_gpu_packed.device
        )
        q_gpu_packed_thd = torch.einsum("bhsd->bshd", q_gpu_packed).reshape(
            b * s_q, h, d
        )

        # Copy the data from the packed tensor to the unpacked tensor
        start_idx = 0
        for i in range(b):
            s = seq_len_q_gpu[i]
            uniform_tensor[i, 0:s, :, :] = q_gpu_packed_thd[
                start_idx : start_idx + s, :, :
            ]
            start_idx += s

        # Convert back to bhsd logical layout
        q_unpacked_ref = torch.einsum("bshd->bhsd", uniform_tensor)

        o_ref = torch.nn.functional.scaled_dot_product_attention(
            q_unpacked_ref,
            k_ref,
            v_ref,
            is_causal=False,
            scale=attn_scale,
            attn_mask=mask,
        )
    else:
        o_ref = torch.nn.functional.scaled_dot_product_attention(
            q_gpu_packed,
            k_ref,
            v_ref,
            is_causal=False,
            scale=attn_scale,
            attn_mask=mask,
        )

    torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)


compare_against_torch_ref(q_gpu, k_gpu, v_gpu, o_gpu, seq_len_q_gpu)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/52_scaled_dot_product_attention.ipynb)

#### Executing the same graph with different sequence lengths

Note that the graph construction we went through earlier is a one-time cost. We can reuse the same graph for different actual sequence lengths. We illustrate this below by creating new variable sequence lengths for Q.

In [None]:
seq_len_q_gpu, q_ragged_offset_gpu = create_variable_seq_lens(b, s_q)
print(seq_len_q_gpu)

In [None]:
variant_pack = {
    q: q_gpu,
    container_k: container_k_gpu,
    container_v: container_v_gpu,
    block_table_k: block_table_k_gpu,
    block_table_v: block_table_v_gpu,
    seq_len_q: seq_len_q_gpu,
    seq_len_kv: seq_len_kv_gpu,
    o: o_gpu,
}

if packed_Q:
    variant_pack[q_ragged_offset] = q_ragged_offset_gpu

print("Second execution")
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
graph.execute(variant_pack, workspace)
torch.cuda.synchronize()
compare_against_torch_ref(q_gpu, k_gpu, v_gpu, o_gpu, seq_len_q_gpu)

In [None]:
cudnn.destroy_handle(handle)