# Paged Attention in cuDNN Frontend

This notebook illustrates how the cuDNN's frontend scaled dot product attention operator can be used to supported paged attention. For a simpler introduction to the scaled dot product attention operator, please refer to [samples/python/50_scaled_dot_product_attention.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb)

The full documentation of cuDNN's scaled dot production attention operator can be found in: [docs/operations/Attention.md#scaled-dot-product-attention](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention). The python test code for the full set of features can be found in: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py)

More details on paged attention can be found in the [PagedAttention paper](https://arxiv.org/abs/2309.06180).

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/52_scaled_dot_product_attention.ipynb)

#### Prerequisites and Setup
This notebook requires an NVIDIA GPU A100 or newer. If running on Colab, go to Runtime → Change runtime type → Hardware accelerator and select a GPU.

In [1]:
# get_ipython().system('nvidia-smi')

In [2]:
# get_ipython().system('pip install nvidia-cudnn-cu12')
# get_ipython().system('pip install nvidia-cudnn-frontend')
# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')

In [3]:
import cudnn
import torch
import math

torch.manual_seed(42)
handle = cudnn.create_handle()

assert torch.cuda.is_available()
assert (
    torch.cuda.get_device_capability()[0] >= 8
), "SDPA operation is only supported on SM80 architecture (Ampere) or above"

assert (
    cudnn.backend_version() >= 90500
), "SDPA operation is only supported cuDNN version 9.5.0 or above"

#### Problem sizes and tensor setup

For this example, we will use the same problem size as in  [samples/python/50_scaled_dot_product_attention.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb).
In addition we are setting the block_size for both K and V to 64

Create the query, key, value, and output GPU tensors using PyTorch. However, the user may use any DLPack compatible tensor instead.

In [5]:
b = 4  # batch size
h = 12  # query number of heads
s = 1024  # maximum sequence length
d = 64  # embedding dimension per head

block_size_k = block_size_v = (
    64  # block size to be used by the non contiguous K/V containers
)

attn_scale = 1.0 / math.sqrt(d)

# The tensors will have non-interleaved
# BSHD (batch, sequence_length, num_head, dims_per_head) physical tensor layout
# BHSD (batch, num_head, sequence_length, dims_per_head) logical tensor layout
dims = (b, h, s, d)
strides = (s * h * d, d, h * d, 1)

q_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)
k_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)
v_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)
o_gpu = torch.empty(b * s * h * d).half().cuda().as_strided(dims, strides)

Create variable sequence length tensors. These are required when using paged K/V caches. To keep things simple, we set these to the maximum sequence length `s` in this example.

In [None]:
# Set to s for all batches, just for the notebook sample
seq_len_q_gpu = torch.full((b, 1, 1, 1), s, device="cuda")
seq_len_kv_gpu = torch.full((b, 1, 1, 1), s, device="cuda")

####  Generate containers and page tables for K and V

In a real world scenario, container and page table tensors are generated by other parts of the model. For illustration purposes in this example, we provide a helper function to generate a trivial container from contiguous K and V caches. 
The helper function basically takes e.g., the K-cache and splits up the sequence (`S`) dimension in different blocks of length `block_size`. The resulting page table then helps identify which block belongs to which sequence ID.

In [None]:
# Helper function to create a non contiguous container in blocks of block_size from a contiguous tensor
def create_container_and_page_table(tensor, block_size):
    B, H, S, D = tensor.shape
    blocks_per_batch = math.ceil(S / block_size)

    # This assertion keeps the helper function of this example simple, but is not a requirement for paged attention.
    assert (blocks_per_batch * block_size) == S

    # Create a container by splitting on the S dimension and concatenating at the block dimension
    # Its dimensions are [num_blocks, H, block_size, D] with num_blocks = B * blocks_per_batch
    container = torch.cat((tensor.clone()).chunk(blocks_per_batch, dim=2), dim=0)

    # Create the page table
    page_table = torch.linspace(
        0,
        B * blocks_per_batch - 1,
        B * blocks_per_batch,
        device="cuda",
        dtype=torch.int32,
    ).reshape(blocks_per_batch, 1, B, 1)
    page_table = torch.transpose(page_table, 0, 2)

    return (container, page_table)


# Create non contiguous containers with page tables for K and V from the contiguous k_gpu and v_gpu
container_k_gpu, page_table_k_gpu = create_container_and_page_table(k_gpu, block_size_k)
container_v_gpu, page_table_v_gpu = create_container_and_page_table(v_gpu, block_size_v)

#### Graph creation and execution

Create the graph

In [6]:
graph = cudnn.pygraph(
    io_data_type=cudnn.data_type.HALF,
    intermediate_data_type=cudnn.data_type.FLOAT,
    compute_data_type=cudnn.data_type.FLOAT,
)

q = graph.tensor_like(q_gpu)

container_k = graph.tensor_like(container_k_gpu)
container_v = graph.tensor_like(container_v_gpu)
page_table_k = graph.tensor_like(page_table_k_gpu)
page_table_v = graph.tensor_like(page_table_v_gpu)

seq_len_q = graph.tensor_like(seq_len_q_gpu)
seq_len_kv = graph.tensor_like(seq_len_kv_gpu)

o, _ = graph.sdpa(
    name="sdpa",
    q=q,
    k=container_k,  # Container K: non contiguous container with K blocks
    v=container_v,  # Container V: non contiguous container with V blocks
    is_inference=True,
    attn_scale=attn_scale,
    use_causal_mask=True,
    use_padding_mask=True,
    seq_len_q=seq_len_q,
    seq_len_kv=seq_len_kv,
    paged_attention_k_table=page_table_k,  # Page Table K: Tensor containing offsets to the container with K blocks
    paged_attention_v_table=page_table_v,  # Page Table V: Tensor containing offsets to the container with V blocks
    paged_attention_max_seq_len_kv=s,  # The maximum sequence length for K caches (this is optional, but recommended)
)

o.set_output(True).set_dim(dims).set_stride(strides)

Build the graph

In [7]:
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph.check_support()
graph.build_plans()

Execute the graph

In [8]:
variant_pack = {
    q: q_gpu,
    container_k: container_k_gpu,
    container_v: container_v_gpu,
    page_table_k: page_table_k_gpu,
    page_table_v: page_table_v_gpu,
    seq_len_q: seq_len_q_gpu,
    seq_len_kv: seq_len_kv_gpu,
    o: o_gpu,
}

workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
graph.execute(variant_pack, workspace)
torch.cuda.synchronize()

cudnn.destroy_handle(handle)

#### Run the PyTorch reference and compare against cuDNN's output

In [9]:
q_ref = q_gpu.detach().float().requires_grad_()
k_ref = k_gpu.detach().float().requires_grad_()
v_ref = v_gpu.detach().float().requires_grad_()

o_ref = torch.nn.functional.scaled_dot_product_attention(
    q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale
)

torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)