# Paged Attention in cuDNN Frontend

This notebook illustrates how the cuDNN's frontend scaled dot product attention operator can be used with paged K/V caches, specifically for decode. For a simpler introduction to the scaled dot product attention operator, please refer to [samples/python/50_scaled_dot_product_attention.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb)

The full documentation of cuDNN's scaled dot production attention operator can be found in: [docs/operations/Attention.md#scaled-dot-product-attention](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention). The python test code for the full set of features can be found in: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py)

More details on paged attention can be found in the [PagedAttention paper](https://arxiv.org/abs/2309.06180).


This notebook specifically illustrates the following:
- SDPA Decode (s_q=1)
- Paged Attention
- Variable sequence lengths for KV
- Packed Block Tables

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/52_scaled_dot_product_attention.ipynb)

#### Prerequisites and Setup
This notebook requires an NVIDIA GPU A100 or newer. If running on Colab, go to Runtime → Change runtime type → Hardware accelerator and select a GPU.

In [1]:
# get_ipython().system('nvidia-smi')

In [2]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128')

In [3]:
import cudnn
import torch
import math

torch.manual_seed(42)
handle = cudnn.create_handle()

assert torch.cuda.is_available()
assert (
    torch.cuda.get_device_capability()[0] >= 8
), "SDPA operation is only supported on SM80 architecture (Ampere) or above"

assert (
    cudnn.backend_version() >= 90500
), "SDPA operation is only supported cuDNN version 9.5.0 or above"

#### Problem sizes and Q/K/V setup

Create the query, key, value, and output GPU tensors using PyTorch. However, the user may use any DLPack compatible tensor instead.

In [5]:
b = 2  # batch size
h = 12  # query number of heads
s_q = 1  # For decode, we only have one query token
s_kv = 1024  # maximum sequence length
d = 64  # embedding dimension per head

block_size_k = block_size_v = (
    64  # block size to be used by the non contiguous K/V containers
)

attn_scale = 1.0 / math.sqrt(d)

# BSHD (batch, sequence_length, num_head, dims_per_head) logcial and physical tensor layouts
dims_qo = (b, h, s_q, d)
strides_qo = (s_q * h * d, s_q * d, d, 1)

dims_kv = (b, h, s_kv, d)
strides_kv = (s_kv * h * d, s_kv * d, d, 1)

q_gpu = torch.randn(b * s_q * h * d).half().cuda().as_strided(dims_qo, strides_qo)
k_gpu = torch.randn(b * s_kv * h * d).half().cuda().as_strided(dims_kv, strides_kv)
v_gpu = torch.randn(b * s_kv * h * d).half().cuda().as_strided(dims_kv, strides_kv)
o_gpu = torch.empty(b * s_q * h * d).half().cuda().as_strided(dims_qo, strides_qo)

####  Generate containers and block tables for K and V

In a real world scenario, container and block table tensors are generated by other parts of the model. For illustration purposes in this example, we provide a helper function to generate a trivial container from contiguous K and V caches. 
The helper function basically takes e.g., the K-cache and splits up the sequence (`S`) dimension in different blocks of length `block_size`. The resulting block table then helps identify which block belongs to which sequence ID.

In [None]:
# @brief Helper function to create a non contiguous container in blocks of block_size from a contiguous tensor
def create_container_and_block_table(tensor, block_size):
    B, H, S, D = tensor.shape
    # num_blocks = math.ceil(S/block_size) * B
    blocks_per_batch = math.ceil(S / block_size)

    # Only needed if S is not a multiple of block_size
    padding_seq = (blocks_per_batch * block_size) - S
    if padding_seq > 0:
        zeros = torch.zeros(B, H, padding_seq, D, device="cuda", dtype=tensor.dtype)
        cat_tensor = torch.cat((tensor, zeros), axis=2)
    else:
        cat_tensor = tensor

    # Create a container by splitting on the S dimension and concatenating at the block dimension
    # Its dimensions are [num_blocks, H, block_size, D] with num_blocks = B * blocks_per_batch
    container = torch.cat((cat_tensor.clone()).chunk(blocks_per_batch, dim=2), dim=0)

    # Create the block table
    table_size = math.ceil(S / block_size)
    block_table_temp = torch.linspace(
        0, B * table_size - 1, B * table_size, device="cuda", dtype=torch.int32
    ).reshape(table_size, 1, B, 1)
    block_table_temp = torch.transpose(block_table_temp, 0, 2)

    # Make batch size outer dimension (cuDNN backend preference)
    block_table = (
        torch.zeros(blocks_per_batch * B)
        .int()
        .cuda()
        .as_strided(
            (B, 1, blocks_per_batch, 1), (blocks_per_batch, blocks_per_batch, 1, 1)
        )
    )
    block_table.copy_(block_table_temp)

    return (container, block_table)


# Create non contiguous containers with block tables for K and V from the contiguous k_gpu and v_gpu
container_k_gpu, block_table_k = create_container_and_block_table(k_gpu, block_size_k)
container_v_gpu, block_table_v = create_container_and_block_table(v_gpu, block_size_v)

#### Variable KV sequence lengths and packed block tables
Note that we created block tables containing block offsets for every sequence ID up to s_kv-1. However, with variable sequence lengths, we don't need block offsets for sequence ID's beyond the actual sequence length per batch. Therefore, we can consider "packing" the block tables, by only storing the block offsets that are needed, similar to how ragged tensors work. It should be noted that due to the small size of block tables, the amount of memory transfer reduction is minimal, and performance is expected to slightly degrade with this technique (this is because packing block tables removes the ability to user vectorized loads). However, for compatibility reasons with many of the frameworks, this feature can still be useful.

Let's start with creating the actual sequence length tensors.

In [None]:
# In decode, s_q is set to 1 for all batches
seq_len_q_gpu = torch.ones((b, 1, 1, 1), device="cuda", dtype=torch.int32)

# Create an actual sequence length tensor for KV
seq_len_kv_gpu = torch.randint(1, s_kv, (b, 1, 1, 1), device="cuda", dtype=torch.int32)

Now let's pack the previously created block tables. We use the following helper function to do so:

In [None]:
# @brief Helper function to convert a padded block table into a packed block table
# @return packed_block_table: packed block table
# @return ragged_offset: offset into the packed block table
def convert_uniform_to_ragged_block_tables(uniform_tensor, seq_len, block_size):
    [B, H, S, D] = uniform_tensor.size()
    ragged_offset = torch.zeros(
        B + 1, 1, 1, 1, dtype=torch.int32, device=uniform_tensor.device
    )  # Initialize with first offset as 0
    for i in range(1, B + 1):
        prev_seq_len = seq_len[i - 1]
        num_pages_prev_batch = (prev_seq_len + block_size - 1) // block_size
        next_batch_offset = ragged_offset[i - 1] + num_pages_prev_batch
        ragged_offset[i, 0, 0, 0] = next_batch_offset

    ragged_offset.to(dtype=torch.int64)

    packed_block_table = torch.zeros(B * S, H, D).to(
        dtype=uniform_tensor.dtype, device=uniform_tensor.device
    )

    uniform_tensor_thd = torch.einsum("bhsd->bshd", uniform_tensor).reshape(B * S, H, D)

    t_0 = 0
    for b, t_1 in enumerate(ragged_offset.flatten()[1:]):
        packed_block_table[t_0:t_1, :, :] = uniform_tensor_thd[
            b * S : b * S + (t_1 - t_0), :, :
        ]
        t_0 = t_1

    packed_block_table = packed_block_table.reshape(B, S, H, D)
    packed_block_table = torch.einsum("bshd->bhsd", packed_block_table)

    return packed_block_table, ragged_offset


block_table_k_packed_gpu, block_table_k_ragged_offset_gpu = (
    convert_uniform_to_ragged_block_tables(block_table_k, seq_len_kv_gpu, block_size_k)
)
block_table_v_packed_gpu, block_table_v_ragged_offset_gpu = (
    convert_uniform_to_ragged_block_tables(block_table_v, seq_len_kv_gpu, block_size_v)
)

`block_table_{k,v}_packed_gpu` are now packed block tables, containing only the block offsets that are needed for the actual sequence lengths.
`block_table_{k,v}_ragged_offset_gpu` are the ragged offsets into the packed block tables. They indicate the start of the offsets for each sequence. 

To illustrate this, consider a scenario where `seq_len_kv_gpu = {250,300}`, and assume further that the container has blocks contiguously allocated per batch (so the block table offsets are just linear increments).

A padded page table would be:
block_table = {B,1, max_s_kv/block_size, 1} = {2,1,16,1}
```
b = 0 : 
    block_table_k[0,0] = 0
    block_table_k[0,1] = 1
    block_table_k[0,2] = 2
    block_table_k[0,3] = 3
    block_table_k[0,4] = x
    block_table_k[0,5] = x
    ...
    block_table_k[0,15] = x

b = 1 : 
    block_table_k[1,0] = 16
    block_table_k[1,1] = 17
    block_table_k[1,2] = 18
    block_table_k[1,3] = 19
    block_table_k[1,4] = 20
    block_table_k[1,5] = x
    block_table_k[1,6] = x
    ...
    block_table_k[1,15] = x

```
Since only 4 and 5 elements contain meaningful block offsets, for batch 0 and 1 respectively (seq_len_kv_gpu/block_size=[4,5]), a packed page table would be:
``` 
block_table_k_packed_gpu = [0,1,2,3,16,17,18,19,20]
```
With ragged offests
```
block_table_k,_ragged_offset_gpu = [0, 4, 9]
```

#### Graph creation and execution

Create the graph

In [6]:
graph = cudnn.pygraph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

q = graph.tensor_like(q_gpu)

container_k = graph.tensor_like(container_k_gpu)
container_v = graph.tensor_like(container_v_gpu)
block_table_k_packed = graph.tensor_like(block_table_k_packed_gpu)
block_table_v_packed = graph.tensor_like(block_table_v_packed_gpu)

seq_len_q = graph.tensor_like(seq_len_q_gpu)
seq_len_kv = graph.tensor_like(seq_len_kv_gpu)

# Add ragged offset tensors to the block tables
block_table_k_ragged_offset = graph.tensor_like(block_table_k_ragged_offset_gpu)
block_table_k_packed.set_ragged_offset(block_table_k_ragged_offset)
block_table_v_ragged_offset = graph.tensor_like(block_table_v_ragged_offset_gpu)
block_table_v_packed.set_ragged_offset(block_table_v_ragged_offset)

o, _ = graph.sdpa(
    name="sdpa",
    q=q,
    k=container_k,  # Container K: non contiguous container with K blocks
    v=container_v,  # Container V: non contiguous container with V blocks
    is_inference=True,
    attn_scale=attn_scale,
    use_causal_mask=False,
    use_padding_mask=True,
    seq_len_q=seq_len_q,
    seq_len_kv=seq_len_kv,
    paged_attention_k_table=block_table_k_packed,  # Block Table K: Tensor containing offsets to the container with K blocks
    paged_attention_v_table=block_table_v_packed,  # Block Table V: Tensor containing offsets to the container with V blocks
    paged_attention_max_seq_len_kv=s_kv,  # The maximum sequence length for K caches (this is optional, but recommended)
)

o.set_output(True).set_dim(dims_qo).set_stride(strides_qo)

Build the graph

In [7]:
graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])

Execute the graph

In [8]:
variant_pack = {
    q: q_gpu,
    container_k: container_k_gpu,
    container_v: container_v_gpu,
    block_table_k_packed: block_table_k_packed_gpu,
    block_table_v_packed: block_table_v_packed_gpu,
    block_table_k_ragged_offset: block_table_k_ragged_offset_gpu,  # Ragged offset for K's block table
    block_table_v_ragged_offset: block_table_v_ragged_offset_gpu,  # Ragged offset for V's block table
    seq_len_q: seq_len_q_gpu,
    seq_len_kv: seq_len_kv_gpu,
    o: o_gpu,
}

workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
graph.execute(variant_pack, workspace)
torch.cuda.synchronize()

cudnn.destroy_handle(handle)

#### Run the PyTorch reference and compare against cuDNN's output

In [9]:
q_ref = q_gpu.detach().float().requires_grad_()
k_ref = k_gpu.detach().float().requires_grad_()
v_ref = v_gpu.detach().float().requires_grad_()

# Create attention mask for variable lengths in KV
mask = torch.ones(b, s_kv, dtype=torch.bool, device="cuda")

for i in range(b):
    seqlen = seq_len_kv_gpu[i, 0, 0, 0].item()
    mask[i, seqlen:] = False

# Expand mask (B,s_kv) -> (B,1,1,s_kv) to match attention shape
mask = mask.unsqueeze(1)
mask = mask.unsqueeze(1)

o_ref = torch.nn.functional.scaled_dot_product_attention(
    q_ref, k_ref, v_ref, is_causal=False, scale=attn_scale, attn_mask=mask
)

torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)