In [1]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Load GPT-2 and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [5]:
# Sample input
prompt = "The future of your life is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# Run baseline inference
with torch.no_grad():
    output = model.generate(input_ids, max_length=30)
    print(tokenizer.decode(output[0]))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


The future of your life is in your hands.

I'm not saying that you should be ashamed of your past, but I'm saying that


In [6]:
with torch.no_grad():
    outputs = model(input_ids, use_cache=True)
    past_key_values = outputs.past_key_values

# Inspect shape of past key/value for layer 0
print(f"# of layers: {len(past_key_values)}")
print(f"Key shape (layer 0): {past_key_values[0][0].shape}")  # [B, n_heads, seq_len, head_dim]
print(f"Value shape (layer 0): {past_key_values[0][1].shape}")

# of layers: 12
Key shape (layer 0): torch.Size([1, 12, 6, 64])
Value shape (layer 0): torch.Size([1, 12, 6, 64])


In [7]:
import torch
config = model.config
print(f"Layers: {config.n_layer}, Heads: {config.n_head}, Hidden: {config.n_embd}, Head dim: {config.n_embd // config.n_head}")

Layers: 12, Heads: 12, Hidden: 768, Head dim: 64


In [2]:
class LogicalBlock:
    def __init__(self, physical_block_id=None):
        self.physical_block_id = physical_block_id  # None if on CPU
        self.status = "gpu"  # or "cpu"
        self.token_count = 0

In [3]:
class EvictionPolicy:
    def register(self, block): pass
    def unregister(self, block): pass
    def notify_use(self, block): pass
    def evict(self): raise NotImplementedError

In [4]:
from collections import deque

class LRUEvictionPolicy(EvictionPolicy):
    def __init__(self):
        self.lru_queue = deque()

    def register(self, block):
        self.lru_queue.appendleft(block)

    def unregister(self, block):
        try:
            self.lru_queue.remove(block)
        except ValueError:
            pass

    def notify_use(self, block):
        try:
            self.lru_queue.remove(block)
            self.lru_queue.appendleft(block)
        except ValueError:
            pass  # block not tracked (e.g., on CPU)

    def evict(self):
        while self.lru_queue:
            victim = self.lru_queue.pop()
            if victim.status == "gpu":
                return victim
        raise RuntimeError("No GPU-resident blocks to evict.")

In [5]:
class PageTable:
    def __init__(self, num_blocks, block_size, num_heads, head_dim, device="cuda", eviction_policy=None):
        self.num_blocks = num_blocks
        self.block_size = block_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.device = device

        # Flat GPU memory pool
        self.gpu_k = torch.zeros((num_blocks, block_size, num_heads, head_dim), device=device)
        self.gpu_v = torch.zeros_like(self.gpu_k)

        # CPU-stored evicted blocks: keyed by LogicalBlock object
        self.cpu_k = {} # logical_block -> CPU tensor
        self.cpu_v = {}

        self.free_list = list(range(num_blocks))
        self.eviction_policy = eviction_policy or LRUEvictionPolicy()

    def allocate_block(self, logical_block):
        """Assign a physical block to a logical block."""
        if not self.free_list:
            self._evict_one_block()

        pid = self.free_list.pop()
        logical_block.physical_block_id = pid
        logical_block.status = "gpu"
        self.eviction_policy.register(logical_block)
        return pid

    def _evict_one_block(self):
        """Evict a block found on the GPU based on eviction policy."""
        victim = self.eviction_policy.evict()
        self.swap_to_cpu(victim)      

    def swap_to_cpu(self, logical_block):
        """Move a block from GPU to CPU and free its physical block ID."""
        pid = logical_block.physical_block_id
        if pid is None or logical_block.status != "gpu":
            raise RuntimeError("Block is not on GPU or already swapped.")

        self.cpu_k[logical_block] = self.gpu_k[pid].clone().cpu()
        # print(self.cpu_k[logical_block])
        self.cpu_v[logical_block] = self.gpu_v[pid].clone().cpu()

        logical_block.physical_block_id = None
        logical_block.status = "cpu"
        self.eviction_policy.unregister(logical_block)
        self.free_list.append(pid)

    def swap_to_gpu(self, logical_block):
        """Move a block from CPU to GPU, assigning a new physical block ID."""
        if logical_block not in self.cpu_k:
            raise RuntimeError("Block not found in CPU cache.")

        if not self.free_list:
            self._evict_one_block()

        pid = self.free_list.pop()
        self.gpu_k[pid] = self.cpu_k.pop(logical_block).to(self.device)
        self.gpu_v[pid] = self.cpu_v.pop(logical_block).to(self.device)

        logical_block.physical_block_id = pid
        logical_block.status = "gpu"
        self.eviction_policy.register(logical_block)

    def resolve_block(self, logical_block):
        """Return the (K, V) tensors, swapping in if needed."""
        if logical_block.status == "cpu":
            self.swap_to_gpu(logical_block)

        self.eviction_policy.notify_use(logical_block)
        pid = logical_block.physical_block_id
        return self.gpu_k[pid], self.gpu_v[pid]

    def free_block(self, logical_block):
        """Free both GPU and CPU copies of the block."""
        if logical_block.status == "gpu":
            pid = logical_block.physical_block_id
            self.gpu_k[pid].zero_()
            self.gpu_v[pid].zero_()
            self.free_list.append(pid)
            self.eviction_policy.unregister(logical_block)
        elif logical_block.status == "cpu":
            del self.cpu_k[logical_block]
            del self.cpu_v[logical_block]

        logical_block.status = "freed"
        logical_block.physical_block_id = None

In [None]:
# Create a logical block and allocate it
block = LogicalBlock()
page_table = PageTable(num_blocks=4, block_size=2, num_heads=2, head_dim=4, device="cuda")
pid = page_table.allocate_block(block)

# Set some known values
page_table.gpu_k[pid][0][0] = torch.tensor([1., 1., 1., 1.])
page_table.gpu_v[pid][1][1] = torch.tensor([2., 2., 2., 2.])

# Swap to CPU and back
page_table.swap_to_cpu(block)
page_table.swap_to_gpu(block)

# Resolve and print
restored_k, restored_v = page_table.resolve_block(block)
print("K_block[0,0]:", restored_k[0, 0])  # should be [1, 1, 1, 1]
print("V_block[1,1]:", restored_v[1, 1])  # should be [2, 2, 2, 2]

In [None]:
num_blocks = 2
page_table = PageTable(num_blocks=num_blocks, block_size=1, num_heads=1, head_dim=2, device="cpu")

# Allocate 3 blocks to force eviction
blocks = [LogicalBlock() for _ in range(3)]

for i, blk in enumerate(blocks):
    page_table.allocate_block(blk)
    # print(blk)
    k, v = page_table.resolve_block(blk)
    k[0, 0] = torch.tensor([i + 1.0, i + 1.0])    # Set K
    v[0, 0] = torch.tensor([(i + 1) * 10.0] * 2)  # Set V

# for i, blk in enumerate(blocks):
#     print(blk.status)

# print(page_table.cpu_k)

# Read them back to verify swap-in works
for i, blk in enumerate(blocks):
    k, v = page_table.resolve_block(blk)
    print(f"Block {i} K:", k[0, 0].tolist(), " V:", v[0, 0].tolist())

In [6]:
from collections import defaultdict

class KVCacheManager:
    def __init__(self, page_table):
        self.page_table = page_table
        self.block_size = page_table.block_size
        self.sequence_table = defaultdict(lambda: defaultdict(list))  # seq_id → layer_id → [LogicalBlock]

    def _get_active_block(self, seq_id, layer_id):
        blocks = self.sequence_table[seq_id][layer_id]
        if blocks and blocks[-1].token_count < self.block_size:
            return blocks[-1]

        # Allocate a new logical block
        new_block = LogicalBlock()
        self.page_table.allocate_block(new_block)
        self.sequence_table[seq_id][layer_id].append(new_block)
        return new_block

    def write_token(self, seq_id, layer_id, key_vec, value_vec):
        block = self._get_active_block(seq_id, layer_id)
        if block.token_count >= self.block_size:
            raise RuntimeError("Attempted to write to full block.")

        # Ensure it's on GPU before writing
        k_buf, v_buf = self.page_table.resolve_block(block) # will swap in if needed

        idx = block.token_count
        k_buf[idx] = key_vec
        v_buf[idx] = value_vec
        block.token_count += 1

    def prefill(self, seq_id, layer_id, k_list, v_list):
        for k, v in zip(k_list, v_list):
            self.write_token(seq_id, layer_id, k, v)

    def yield_k_blocks(self, seq_id, layer_id):
        for block in self.sequence_table[seq_id][layer_id]:
            k_buf, _ = self.page_table.resolve_block(block)
            yield k_buf[:block.token_count]

    def yield_v_blocks(self, seq_id, layer_id):
        for block in self.sequence_table[seq_id][layer_id]:
            _, v_buf = self.page_table.resolve_block(block)
            yield v_buf[:block.token_count]

    def yield_kv_blocks(self, seq_id, layer_id):
        for block in self.sequence_table[seq_id][layer_id]:
            k_buf, v_buf = self.page_table.resolve_block(block)
            yield k_buf[:block.token_count], v_buf[:block.token_count]

    def free(self, seq_id):
        for layer_blocks in self.sequence_table[seq_id].values():
            for block in layer_blocks:
                self.page_table.free_block(block)
        del self.sequence_table[seq_id]

In [None]:
# Setup
num_blocks = 2
block_size = 2
num_heads = 2
head_dim = 4
device = "cuda"

page_table = PageTable(num_blocks, block_size, num_heads, head_dim, device)
manager = KVCacheManager(page_table)

# Simulate a prompt for 1 sequence, 1 layer
seq_id = 0
layer_id = 0

# Write 4 tokens (fills 2 blocks)
for i in range(4):
    k = torch.full((num_heads, head_dim), i + 1.0, device=device)
    v = torch.full((num_heads, head_dim), (i + 1) * 10.0, device=device)
    manager.write_token(seq_id, layer_id, k, v)

# Print all resolved KV blocks
print("\n--- Resolved Blocks ---")
for i, (k, v) in enumerate(manager.yield_kv_blocks(seq_id, layer_id)):
    print(f"Block {i}:")
    print("  K:\n", k)
    print("  V:\n", v)


In [7]:
from typing import Optional, Union

import torch
from torch.nn.attention.flex_attention import (
    _identity,
    _mask_mod_signature,
    _score_mod_signature,
    BlockMask,
    noop_mask,
)


def _cdiv(x: Union[int, float, torch.Tensor], multiple: Union[int, float, torch.Tensor]):
    return (x + multiple - 1) // multiple


class PagedAttention:
    """
    PagedAttention supports flex attention inference with a large batch size.
    With PagedAttention, a batch of key/value tensors with varying kv length
    is splitted into tensor blocks of fixed length and cached in a compact way.
    Thus we can avoid redundant memory consumption due to varying kv length and
    support a larger batch size.
    """

    def __init__(
        self,
        n_pages: int,
        page_size: int,
        max_batch_size: int,
        device: str = "cuda",
    ):
        # number of pages
        self.n_pages = n_pages

        # number of tokens per page
        self.page_size = page_size

        # page table: [batch, logical_block_idx] -> physical_page_idx
        self.page_table = -torch.ones(
            (max_batch_size, self.n_pages), dtype=torch.int64, device=device
        )

        # capacity: batch_idx -> allocated sequence length
        self.capacity = torch.zeros(max_batch_size, dtype=torch.int64, device=device)

        # index of empty pages that is available for allocation
        self.empty_pages = list(range(n_pages - 1, -1, -1))

        # mapping from physical page index to logical page index
        self.physical_to_logical = -torch.ones(
            (max_batch_size, n_pages), dtype=torch.int64, device=device
        )

    def reserve(self, batch_idx: torch.Tensor, seq_len: torch.Tensor) -> None:
        """
        Requests the capacity of a given batch to be at least enough to
        hold `seq_len` elements.

        Args:
            batch_idx (Tensor): batch index to be reserved; shape :math:`(1)`.
            seq_len (Tensor): minimum capacity for the given batch; shape :math:`(1)`.
        """

        if seq_len <= self.capacity[batch_idx]:
            return

        num_pages_to_allocate = _cdiv(seq_len - self.capacity[batch_idx], self.page_size)

        assert len(self.empty_pages) >= num_pages_to_allocate, (
            f"requested {num_pages_to_allocate.item()} pages "
            f"but there are only {len(self.empty_pages)} empty pages"
        )

        start_page_idx = self.capacity[batch_idx] // self.page_size
        end_page_idx = start_page_idx + num_pages_to_allocate

        # find empty physical pages
        allocated_pages = torch.tensor(
            self.empty_pages[-num_pages_to_allocate:],
            device=num_pages_to_allocate.device,
        )
        self.empty_pages = self.empty_pages[:-num_pages_to_allocate]

        # update page table
        self.page_table[
            batch_idx,
            start_page_idx:end_page_idx,
        ] = allocated_pages

        # update metadata
        self.physical_to_logical[batch_idx, allocated_pages] = torch.arange(
            start_page_idx.item(),
            end_page_idx.item(),
            device=num_pages_to_allocate.device,
        )
        self.capacity[batch_idx] += num_pages_to_allocate * self.page_size

    def erase(self, batch_idx: torch.Tensor) -> None:
        """
        Removes a single batch from paged attention.

        Args:
            batch_idx (Tensor): batch index to be removed; shape :math:`(1)`.
        """

        # find allocated pages
        allocated_page_idx = self.page_table[batch_idx] != -1
        allocated_pages = self.page_table[batch_idx][allocated_page_idx]

        # clean metadata
        self.capacity[batch_idx] = 0
        self.empty_pages += allocated_pages.tolist()
        self.physical_to_logical[batch_idx][:, allocated_pages] = -1
        self.page_table[batch_idx] = -1

    def assign(
        self,
        batch_idx: torch.Tensor,
        input_pos: torch.Tensor,
        k_val: torch.Tensor,
        v_val: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
    ) -> None:
        """
        Assigns new contents `val` to the storage `cache` at the location
        `batch_idx` and `input_pos`.

        Args:
            batch_idx (Tensor): batch index; shape :math:`(B)`.
            input_pos (Tensor): input positions to be assigned for the given batch; shape :math:`(B, S)`.
            val (Tensor): value to be assigned; shape :math:`(B, H, S, D)`
            cache (Tensor): the cache to store the values; shape:`(1, H, MAX_S, D)`
        """
        if k_val.requires_grad:
            raise RuntimeError("val must not require gradient")

        B, H, S, K_D = k_val.shape
        V_D = v_val.shape[3]
        if B != batch_idx.shape[0]:
            raise RuntimeError(
                f"Expect val and batch_idx have the same batch size "
                f"but got B={B} and B={batch_idx.shape[0]}."
            )
        if H != k_cache.shape[1]:
            raise RuntimeError(
                f"Expect val and cache has the same number of heads "
                f"but got H={H} and H={k_cache.shape[1]}."
            )
        if S != input_pos.shape[1]:
            raise RuntimeError(
                f"Expect val and input_pos has the same length "
                f"but got S={S} and S={input_pos.shape[0]}."
            )
        if K_D != k_cache.shape[3]:
            raise RuntimeError(
                f"Expect k_val and k_cache has the same hidden dim "
                f"but got D={K_D} and D={k_cache.shape[3]}."
            )
        if V_D != v_cache.shape[3]:
            raise RuntimeError(
                f"Expect v_val and v_cache has the same hidden dim "
                f"but got D={V_D} and D={v_cache.shape[3]}."
            )

        # find address
        logical_block_idx = input_pos // self.page_size  # [B, S]
        logical_block_offset = input_pos % self.page_size  # [B, S]
        physical_block_idx = torch.gather(
            self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64)
        ).to(torch.int32)  # [B, S]

        addr = (physical_block_idx * self.page_size + logical_block_offset).view(-1)  # [B*S]

        k_val = k_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, K_D)
        v_val = v_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, V_D)

        k_cache[:, :, addr, :] = k_val
        v_cache[:, :, addr, :] = v_val

    def convert_logical_block_mask(
        self,
        block_mask: BlockMask,
        batch_idx: Optional[torch.Tensor] = None,
    ) -> BlockMask:
        """
        Converts a logical block mask by mapping its logical kv indices to the corresponding
        physical kv indices.

        Args:
            block_mask (BlockMask): logical block mask;
                kv_indices shape :math:`(B, H, ROWS, MAX_BLOCKS_IN_COL)`.
            batch_idx (Tensor): batch index corresponding to the block_mask
                batch dimension. This provides flexibility to convert a
                block mask with smaller batch size than the page table;
                shape :math:`(B)`.
        """
        B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape

        if block_mask.BLOCK_SIZE[1] != self.page_size:
            raise RuntimeError(
                f"Expect block_mask has the same column block size as page_size"
                f"but got size={block_mask.BLOCK_SIZE[1]} and size={self.page_size}"
            )

        # Increase the num columns of converted block mask from logical block mask's
        # num columns to n_pages, since a) the converted block mask
        # may have larger indices values; and b) `_ordered_to_dense` realizes
        # a dense tensor with these converted indices. There would be an IndexError
        # if using the logical block mask's num columns.

        device = block_mask.kv_num_blocks.device

        if batch_idx is None:
            batch_idx = torch.arange(B, device=device)
        page_table = self.page_table[batch_idx]

        new_kv_num_blocks = block_mask.kv_num_blocks.clone()

        new_kv_indices = torch.zeros((B, H, ROWS, self.n_pages), dtype=torch.int32, device=device)
        new_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = (
            torch.gather(page_table, 1, block_mask.kv_indices.view(B, -1).to(torch.int64))
            .view(block_mask.kv_indices.shape)
            .to(torch.int32)
        )

        new_full_kv_indices, new_full_kv_num_blocks = None, None
        if block_mask.full_kv_num_blocks is not None:
            assert block_mask.full_kv_indices is not None
            new_full_kv_num_blocks = block_mask.full_kv_num_blocks.clone()
            new_full_kv_indices = torch.zeros(
                (B, H, ROWS, self.n_pages), dtype=torch.int32, device=device
            )
            new_full_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = (
                torch.gather(
                    page_table,
                    1,
                    block_mask.full_kv_indices.view(B, -1).to(torch.int64),
                )
                .view(block_mask.full_kv_indices.shape)
                .to(torch.int32)
            )

        new_mask_mod = self.get_mask_mod(block_mask.mask_mod)

        seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size)
        return BlockMask.from_kv_blocks(
            new_kv_num_blocks,
            new_kv_indices,
            new_full_kv_num_blocks,
            new_full_kv_indices,
            block_mask.BLOCK_SIZE,
            new_mask_mod,
            seq_lengths=seq_lengths,
        )

    def get_mask_mod(self, mask_mod: Optional[_mask_mod_signature]) -> _mask_mod_signature:
        """
        Converts a mask_mod based on mapping from the physical block index to the logical
        block index.

        Args:
            mask_mod (_mask_mod_signature): mask_mod based on the logical block index.
        """
        if mask_mod is None:
            mask_mod = noop_mask

        def new_mask_mod(
            b: torch.Tensor,
            h: torch.Tensor,
            q_idx: torch.Tensor,
            physical_kv_idx: torch.Tensor,
        ):
            physical_kv_block = physical_kv_idx // self.page_size
            physical_kv_offset = physical_kv_idx % self.page_size
            logical_block_idx = self.physical_to_logical[b, physical_kv_block]
            logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
            return torch.where(
                logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
            )

        return new_mask_mod

    def get_score_mod(self, score_mod: Optional[_score_mod_signature]) -> _score_mod_signature:
        """
        Converts a score_mod based on mapping from the physical block index to the logical
        block index.

        Args:
            score_mod (_score_mod_signature): score_mod based on the logical block index.
        """
        if score_mod is None:
            score_mod = _identity

        def new_score_mod(
            score: torch.Tensor,
            b: torch.Tensor,
            h: torch.Tensor,
            q_idx: torch.Tensor,
            physical_kv_idx: torch.Tensor,
        ):
            physical_kv_block = physical_kv_idx // self.page_size
            physical_kv_offset = physical_kv_idx % self.page_size
            logical_block_idx = self.physical_to_logical[b, physical_kv_block]
            logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
            return torch.where(
                logical_block_idx >= 0,
                score_mod(score, b, h, q_idx, logical_kv_idx),
                float("-inf"),
            )

        return new_score_mod

In [8]:
class LRUEvictionPolicy:
    def __init__(self):
        self.lru_queue = deque()
    def register(self, block): self.lru_queue.appendleft(block)
    def evict(self): return self.lru_queue.pop()
    def unregister(self, block): self.lru_queue.remove(block)

    def unregister(self, block):
        if block in self.lru_queue:
            self.lru_queue.remove(block)


In [None]:
import torch
from collections import defaultdict, deque

# ----------------------- Setup -----------------------

B = 1
H = 2
D = 4
BLOCK_SIZE = 2
NUM_BLOCKS = 4
TOTAL_TOKENS = 6
device = "cuda" if torch.cuda.is_available() else "cpu"

# Simulated token-level key and value tensors
k_data = torch.randn(TOTAL_TOKENS, H, D, device=device)
v_data = torch.randn_like(k_data)

# Flat KV cache buffers for FlexAttention
k_cache = torch.zeros(1, H, NUM_BLOCKS * BLOCK_SIZE, D, device=device)
v_cache = torch.zeros_like(k_cache)

# Page table tensor as expected by FlexAttention
page_table_tensor = -torch.ones((B, 8), dtype=torch.long, device=device)

# ----------------------- Logical Components -----------------------

class LogicalBlock:
    def __init__(self, physical_block_id=None):
        self.physical_block_id = physical_block_id
        self.status = "gpu"
        self.token_count = 0


class PageTable:
    def __init__(self):
        self.free_list = list(range(NUM_BLOCKS))
        self.eviction_policy = LRUEvictionPolicy()
    def allocate_block(self, logical_block):
        if not self.free_list:
            victim = self.eviction_policy.evict()
            self.free_list.append(victim.physical_block_id)
        pid = self.free_list.pop()
        logical_block.physical_block_id = pid
        logical_block.status = "gpu"
        return pid
    def swap_to_gpu(self, block): pass  # not needed for this test

class KVCacheManager:
    def __init__(self, page_table):
        self.page_table = page_table
        self.sequence_table = defaultdict(lambda: defaultdict(list))

# ----------------------- Test Class -----------------------

class SwappablePagedAttention:
    def __init__(self, kv_cache_manager, page_size, page_table_tensor, layer_id):
        self.kv_cache_manager = kv_cache_manager
        self.page_size = page_size
        self.page_table = page_table_tensor
        self.layer_id = layer_id

    def assign(self, batch_idx, input_pos, k_val, v_val, k_cache, v_cache):
        B, H, S, D = k_val.shape
        for b in range(B):
            seq_id = int(batch_idx[b])
            for s in range(S):
                token_idx = int(input_pos[b, s])
                logical_block_id = token_idx // self.page_size
                offset = token_idx % self.page_size

                block_list = self.kv_cache_manager.sequence_table[seq_id][self.layer_id]
                if len(block_list) <= logical_block_id:
                    for _ in range(logical_block_id - len(block_list) + 1):
                        new_block = LogicalBlock()
                        pid = self.kv_cache_manager.page_table.allocate_block(new_block)
                        self.kv_cache_manager.page_table.eviction_policy.register(new_block)
                        block_list.append(new_block)

                block = block_list[logical_block_id]
                pid = block.physical_block_id

                flat_idx = pid * self.page_size + offset
                k_cache[0, :, flat_idx, :] = k_val[b, :, s, :]
                v_cache[0, :, flat_idx, :] = v_val[b, :, s, :]

                block.token_count += 1
                self.page_table[seq_id, logical_block_id] = pid

# ----------------------- Run the Test -----------------------

pt = PageTable()
kvm = KVCacheManager(pt)
spa = SwappablePagedAttention(kvm, BLOCK_SIZE, page_table_tensor, layer_id=0)

for t in range(TOTAL_TOKENS):
    k = k_data[t].unsqueeze(0).unsqueeze(2)  # (1, H, 1, D)
    v = v_data[t].unsqueeze(0).unsqueeze(2)
    input_pos = torch.tensor([[t]], device=device)
    spa.assign(torch.tensor([0], device=device), input_pos, k, v, k_cache, v_cache)

# ----------------------- Check Result -----------------------

diff = (k_cache[0, :, :TOTAL_TOKENS] - k_data.transpose(0, 1)).abs().max()
print("Max difference:", diff.item())  # Expect: 0.0


In [None]:
# Build token_index_to_flat_idx mapping
token_to_flat_idx = []
for t in range(TOTAL_TOKENS):
    token_idx = t
    logical_block_id = token_idx // BLOCK_SIZE
    offset = token_idx % BLOCK_SIZE
    pid = int(page_table_tensor[0, logical_block_id].item())
    flat_idx = pid * BLOCK_SIZE + offset
    token_to_flat_idx.append(flat_idx)

# Gather actual written vectors from k_cache
actual_k = torch.stack([k_cache[0, :, i, :] for i in token_to_flat_idx], dim=0)  # (T, H, D)
expected_k = k_data  # (T, H, D)

diff = (actual_k - expected_k).abs().max()
print("Corrected Max difference:", diff.item())  # Should be ~0.0


In [9]:
class PageTableWithSwap(PageTable):
    def __init__(self, num_blocks, block_size, num_heads, head_dim, device="cuda", k_cache=None, v_cache=None):
        super().__init__(num_blocks, block_size, num_heads, head_dim, device=device)
        self.k_cache = k_cache  # flat buffer: (1, H, max_tokens, D)
        self.v_cache = v_cache
        self.cpu_k = {}  # LogicalBlock → CPU tensor
        self.cpu_v = {}

    def swap_to_cpu(self, block):
        pid = block.physical_block_id
        if pid is None or block.status != "gpu":
            raise RuntimeError("Block is not on GPU")

        bs = self.block_size
        self.cpu_k[block] = self.k_cache[0, :, pid * bs : (pid + 1) * bs, :].clone().cpu()
        self.cpu_v[block] = self.v_cache[0, :, pid * bs : (pid + 1) * bs, :].clone().cpu()

        block.physical_block_id = None
        block.status = "cpu"
        self.eviction_policy.unregister(block)
        self.free_list.append(pid)
        print("sent")

    def swap_to_gpu(self, block):
        if block.status != "cpu":
            return

        if not self.free_list:
            victim = self.eviction_policy.evict()
            self.swap_to_cpu(victim)

        pid = self.free_list.pop()
        bs = self.block_size

        self.k_cache[0, :, pid * bs : (pid + 1) * bs, :] = self.cpu_k.pop(block).to(self.k_cache.device)
        self.v_cache[0, :, pid * bs : (pid + 1) * bs, :] = self.cpu_v.pop(block).to(self.v_cache.device)

        block.physical_block_id = pid
        block.status = "gpu"
        self.eviction_policy.register(block)
        print(f"retrieved to pid={pid}")


In [None]:
# ✅ Set up environment
pt = PageTableWithSwap(
    num_blocks=NUM_BLOCKS,
    block_size=BLOCK_SIZE,
    num_heads=H,
    head_dim=D,
    device=device,
    k_cache=k_cache,
    v_cache=v_cache
)
kvm = KVCacheManager(pt)
spa = SwappablePagedAttention(kvm, BLOCK_SIZE, page_table_tensor, layer_id=0)
k_cache.zero_()
v_cache.zero_()

# ✅ Assign 6 tokens (should fill 3 logical blocks)
for t in range(TOTAL_TOKENS):
    k = k_data[t].unsqueeze(0).unsqueeze(2)
    v = v_data[t].unsqueeze(0).unsqueeze(2)
    input_pos = torch.tensor([[t]], device=device)
    spa.assign(torch.tensor([0], device=device), input_pos, k, v, k_cache, v_cache)

# ✅ Swap out logical block 1 (tokens 2 & 3)
evicted_block = kvm.sequence_table[0][0][1]
pt.swap_to_cpu(evicted_block, k_cache, v_cache)

# ✅ Zero out its region in the flat cache
pid = pt.free_list[-1]  # last freed
k_cache[0, :, pid * BLOCK_SIZE : (pid + 1) * BLOCK_SIZE, :].zero_()

# ✅ Swap it back in
pt.swap_to_gpu(evicted_block, k_cache, v_cache)

# ✅ Validate restored values match original token 2 and 3
flat_idx_2 = evicted_block.physical_block_id * BLOCK_SIZE + 0
flat_idx_3 = evicted_block.physical_block_id * BLOCK_SIZE + 1

actual_k = torch.stack([
    k_cache[0, :, flat_idx_2, :],
    k_cache[0, :, flat_idx_3, :]
])
expected_k = k_data[2:4]

diff = (actual_k - expected_k).abs().max()
print("Restoration diff:", diff.item())


In [None]:
# Verify flex_attention and BlockMask are available from torch
try:
    from torch.nn.attention.flex_attention import flex_attention, BlockMask
    flex_ready = True
except ImportError:
    flex_ready = False

flex_ready


In [None]:
from torch.nn.attention.flex_attention import flex_attention, BlockMask

# --- Setup ---
B = 1
H = 2
D = 4
SEQ_LEN = 1       # query length
BLOCK_SIZE = 2
NUM_BLOCKS = 4    # total blocks allocated in k_cache
KV_LEN = BLOCK_SIZE * NUM_BLOCKS

# Dummy query vector
q = torch.randn(B, H, SEQ_LEN, D, device=device)

# --- Construct BlockMask to cover all 8 kv slots ---
# We must provide all NUM_BLOCKS to match k_cache shape
kv_indices = -torch.ones((B, H, 1, NUM_BLOCKS), dtype=torch.int32, device=device)
kv_num_blocks = torch.full((B, H, 1), NUM_BLOCKS, dtype=torch.int32, device=device)

# Map physical page ids (0 to NUM_BLOCKS-1) as-is
for h in range(H):
    for i in range(NUM_BLOCKS):
        kv_indices[0, h, 0, i] = i

# FlexAttention requires BLOCK_SIZE and seq_lengths
block_mask = BlockMask.from_kv_blocks(
    kv_num_blocks=kv_num_blocks,
    kv_indices=kv_indices,
    full_kv_num_blocks=None,
    full_kv_indices=None,
    BLOCK_SIZE=(SEQ_LEN, BLOCK_SIZE),  # (Q, KV) block sizes
    mask_mod=None,
    seq_lengths=(SEQ_LEN, KV_LEN)
)

# --- Run FlexAttention ---
out = flex_attention(
    q,
    k_cache,
    v_cache,
    block_mask=block_mask,
    score_mod=None
)

print("Output shape:", out.shape)       # Expect (1, H, 1, D)
print("Output vector:\n", out[0, :, 0])  # One vector per head


In [10]:
from torch.nn.attention.flex_attention import BlockMask
class SwappablePagedAttention:
    def __init__(self, kv_cache_manager, k_cache, v_cache, page_table_tensor, page_size, layer_id):
        self.kv_cache_manager = kv_cache_manager
        self.k_cache = k_cache
        self.v_cache = v_cache
        self.page_table = page_table_tensor
        self.page_size = page_size
        self.layer_id = layer_id
        self.max_flat_idx_written = 0

    # def assign_token(self, seq_id, token_idx, key_vec, value_vec):
    #     logical_block_id = token_idx // self.page_size
    #     offset = token_idx % self.page_size

    #     block_list = self.kv_cache_manager.sequence_table[seq_id][self.layer_id]
    #     if len(block_list) <= logical_block_id:
    #         for _ in range(logical_block_id - len(block_list) + 1):
    #             new_block = LogicalBlock()
    #             pid = self.kv_cache_manager.page_table.allocate_block(new_block)
    #             self.kv_cache_manager.page_table.eviction_policy.register(new_block)
    #             block_list.append(new_block)

    #     block = block_list[logical_block_id]
    #     pid = block.physical_block_id

    #     # flat_idx = pid * self.page_size + offset
    #     # self.k_cache[0, :, flat_idx, :] = key_vec
    #     # self.v_cache[0, :, flat_idx, :] = value_vec

    #     flat_idx = pid * self.page_size + offset
    #     self.max_flat_idx_written = max(self.max_flat_idx_written, flat_idx + 1)
    #     self.k_cache[0, :, flat_idx, :] = key_vec.squeeze(1)  # (H, D)
    #     self.v_cache[0, :, flat_idx, :] = value_vec.squeeze(1)


    #     block.token_count += 1
    #     self.page_table[seq_id, logical_block_id] = pid

    def assign_token(self, seq_id, token_idx, key_vec, value_vec):
        logical_block_id = token_idx // self.page_size
        offset = token_idx % self.page_size

        # # ⬇️ Expand page_table if seq_id exceeds current rows
        # if seq_id >= self.page_table.shape[0]:
        #     pad = seq_id - self.page_table.shape[0] + 1
        #     new_table = torch.full(
        #         (self.page_table.shape[0] + pad, self.page_table.shape[1]),
        #         fill_value=-1,
        #         dtype=self.page_table.dtype,
        #         device=self.page_table.device
        #     )
        #     new_table[:self.page_table.shape[0], :] = self.page_table
        #     self.page_table = new_table


        # ⬇️ Expand page_table if the logical block id exceeds current columns
        if logical_block_id >= self.page_table.shape[1]:
            pad = logical_block_id - self.page_table.shape[1] + 1
            new_table = torch.full(
                (self.page_table.shape[0], self.page_table.shape[1] + pad),
                fill_value=-1,
                dtype=self.page_table.dtype,
                device=self.page_table.device
            )
            new_table[:, :self.page_table.shape[1]] = self.page_table
            self.page_table = new_table

        # Get or grow the block list for this sequence/layer
        block_list = self.kv_cache_manager.sequence_table[seq_id][self.layer_id]
        if len(block_list) <= logical_block_id:
            for _ in range(logical_block_id - len(block_list) + 1):
                new_block = LogicalBlock()
                pid = self.kv_cache_manager.page_table.allocate_block(new_block)
                self.kv_cache_manager.page_table.eviction_policy.register(new_block)
                block_list.append(new_block)

        block = block_list[logical_block_id]

        # If it's not on GPU, swap it in
        if block.status == "cpu":
            self.kv_cache_manager.page_table.swap_to_gpu(block)

        pid = block.physical_block_id
        # Sanity check:
        if block.physical_block_id is None:
            raise RuntimeError(f"Token {token_idx}: block still has no physical_block_id after swap.")

        flat_idx = pid * self.page_size + offset

        self.k_cache[0, :, flat_idx, :] = key_vec.squeeze(1)
        self.v_cache[0, :, flat_idx, :] = value_vec.squeeze(1)

        block.token_count += 1
        if seq_id >= self.page_table.shape[0] or logical_block_id >= self.page_table.shape[1]:
            print(f"🚨 Resize needed: seq_id={seq_id}, logical_block_id={logical_block_id}, current shape={self.page_table.shape}")

        self.page_table[seq_id, logical_block_id] = pid

        self.max_flat_idx_written = max(self.max_flat_idx_written, flat_idx + 1)



    # def build_blockmask(self, num_query_tokens, total_blocks):
    #     from torch.nn.attention.flex_attention import BlockMask
    #     B, H = 1, self.k_cache.shape[1]

    #     kv_indices = -torch.ones((B, H, 1, total_blocks), dtype=torch.int32, device=self.k_cache.device)
    #     kv_num_blocks = torch.full((B, H, 1), total_blocks, dtype=torch.int32, device=self.k_cache.device)

    #     for h in range(H):
    #         for i in range(total_blocks):
    #             kv_indices[0, h, 0, i] = i

    #     block_mask = BlockMask.from_kv_blocks(
    #         kv_num_blocks=kv_num_blocks,
    #         kv_indices=kv_indices,
    #         full_kv_num_blocks=None,
    #         full_kv_indices=None,
    #         BLOCK_SIZE=(num_query_tokens, self.page_size),
    #         mask_mod=None,
    #         seq_lengths=(num_query_tokens, total_blocks * self.page_size)
    #     )
    #     return block_mask
    
    # def build_blockmask(self, num_query_tokens, total_tokens_written):
    #     from torch.nn.attention.flex_attention import BlockMask
        
    #     # assert total_tokens_written <= self.k_cache.shape[2], "KV length exceeds cache capacity"
    #     B, H = 1, self.k_cache.shape[1]
    #     device = self.k_cache.device

    #     logical_to_physical = self.page_table[0]
    #     valid = logical_to_physical != -1
    #     pids = logical_to_physical[valid]
    #     total_blocks = pids.numel()

    #     kv_indices = -torch.ones((B, H, 1, total_blocks), dtype=torch.int32, device=device)
    #     kv_num_blocks = torch.full((B, H, 1), total_blocks, dtype=torch.int32, device=device)

    #     for h in range(H):
    #         for i in range(total_blocks):
    #             kv_indices[0, h, 0, i] = pids[i].item()


    #     # print("max pid:", max(pids))
    #     # print("kv_num_blocks:", kv_num_blocks)
    #     # print("kv_indices:", kv_indices)
    #     # assert (kv_indices >= 0).all(), "Found -1 in kv_indices"
    #     # assert (kv_indices < kv_num_blocks.max()).all(), "kv_indices out of bounds"


    #     block_mask = BlockMask.from_kv_blocks(
    #         kv_num_blocks=kv_num_blocks,
    #         kv_indices=kv_indices,
    #         full_kv_num_blocks=None,
    #         full_kv_indices=None,
    #         BLOCK_SIZE=(num_query_tokens, self.page_size),
    #         mask_mod=None,
    #         seq_lengths=(num_query_tokens, self.max_flat_idx_written) # seq_lengths=(num_query_tokens, total_tokens_written)
    #     )

    #     # ✅ Crop to match actual (query, kv) shapes
    #     return block_mask._adjust(num_query_tokens, total_tokens_written)
    
    def build_blockmask(self, num_query_tokens, total_tokens_written):
        from torch.nn.attention.flex_attention import BlockMask

        assert total_tokens_written <= self.k_cache.shape[2], "KV length exceeds cache capacity"
        B, H = 1, self.k_cache.shape[1]
        device = self.k_cache.device

        logical_to_physical = self.page_table[0]  # single seq_id only
        valid = logical_to_physical != -1
        pids = logical_to_physical[valid]
        total_blocks = pids.numel()

        kv_indices = torch.full((B, H, 1, total_blocks), -1, dtype=torch.int32, device=device)
        kv_num_blocks = torch.full((B, H, 1), total_blocks, dtype=torch.int32, device=device)

        for h in range(H):
            for i in range(total_blocks):
                kv_indices[0, h, 0, i] = pids[i].item()
        print(">> max_flat_idx_written:", self.max_flat_idx_written)
        print(">> kv_num_blocks.shape:", kv_num_blocks.shape)
        print(">> kv_num_blocks:", kv_num_blocks)
        print(">> kv_indices.shape:", kv_indices.shape)
        print(">> kv_indices:", kv_indices)

        block_mask = BlockMask.from_kv_blocks(
            kv_num_blocks=kv_num_blocks,
            kv_indices=kv_indices,
            full_kv_num_blocks=None,
            full_kv_indices=None,
            BLOCK_SIZE=(num_query_tokens, self.page_size),
            mask_mod=None,
            seq_lengths=(num_query_tokens, total_tokens_written)  # <- exact KV length
        )

        # Crop block mask to match token counts
        return block_mask._adjust(num_query_tokens, total_tokens_written)




    # def query(self, q, block_mask):
    #     from torch.nn.attention.flex_attention import flex_attention
    #     return flex_attention(
    #         q, self.k_cache, self.v_cache,
    #         block_mask=block_mask,
    #         score_mod=None
    #     )
    def query(self, q, block_mask):
        from torch.nn.attention.flex_attention import flex_attention

        # 🔁 Ensure required blocks are on GPU
        for block_list in self.kv_cache_manager.sequence_table.values():
            for block in block_list[self.layer_id]:
                if block.status == "cpu":
                    self.kv_cache_manager.page_table.swap_to_gpu(block)
                    print(f"⏪ Swapping in block with pid={block.physical_block_id}")

        print("Calling flex_attention with:")
        print("  q.shape:", q.shape)
        print("  k.shape:", self.k_cache.shape)
        print("  v.shape:", self.v_cache.shape)


        return flex_attention(
            q, self.k_cache, self.v_cache,
            block_mask=block_mask,
            score_mod=None
        )



In [None]:
# 1. Setup the full KVCacheManager + PageTableWithSwap
pt = PageTableWithSwap(
    num_blocks=NUM_BLOCKS,
    block_size=BLOCK_SIZE,
    num_heads=H,
    head_dim=D,
    device=device,
    k_cache=k_cache,
    v_cache=v_cache
)
kvm = KVCacheManager(pt)

# 2. Flat cache buffers
k_cache = torch.zeros(1, H, NUM_BLOCKS * BLOCK_SIZE, D, device=device)
v_cache = torch.zeros_like(k_cache)
page_table_tensor = -torch.ones((1, NUM_BLOCKS), dtype=torch.long, device=device)

# 3. Swappable attention wrapper
spa = SwappablePagedAttention(kvm, k_cache, v_cache, page_table_tensor, page_size=BLOCK_SIZE, layer_id=0)

# 4. Assign 6 tokens (token 0–5) using full KVCacheManager
for t in range(6):
    spa.assign_token(seq_id=0, token_idx=t, key_vec=k_data[t], value_vec=v_data[t])

# 5. Evict logical block 1 (tokens 2–3)
evicted_block = kvm.sequence_table[0][0][1]
pt.swap_to_cpu(evicted_block)

# Zero out that region in the flat cache
evicted_pid = pt.free_list[-1]
k_cache[0, :, evicted_pid * BLOCK_SIZE : (evicted_pid + 1) * BLOCK_SIZE, :] = 0

# 6. Restore the block
pt.swap_to_gpu(evicted_block)

# 7. Build a query vector
q = torch.randn(1, H, 1, D, device=device)

# 8. Build a block mask for all 4 physical blocks (0–3)
block_mask = spa.build_blockmask(num_query_tokens=1, total_blocks=NUM_BLOCKS)

# 9. Run flex attention
out = spa.query(q, block_mask)

# 10. Print
print("FlexAttention output after swap/restore:", out[0, :, 0])


In [11]:
import torch.nn as nn


class SwappablePagedAttentionLayer(nn.Module):
    def __init__(self, hidden_dim, n_heads, head_dim, dtype, spa: SwappablePagedAttention):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.spa = spa

        self.wqkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False, device="cuda", dtype=dtype)
        self.wo = nn.Linear(hidden_dim, hidden_dim, bias=False, device="cuda", dtype=dtype)

        self.block_size = spa.page_size
        self.freqs_cis = precompute_freqs_cis(
            seq_len=512,  # max positions
            n_elem=self.head_dim,
            dtype=dtype
        )


    def forward(self, x, seq_id, token_idx):
        B, S, _ = x.shape
        kv_size = self.n_heads * self.head_dim

        q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
        q = q.view(B, S, self.n_heads, self.head_dim)
        k = k.view(B, S, self.n_heads, self.head_dim)
        v = v.view(B, S, self.n_heads, self.head_dim)

        freqs = self.freqs_cis[token_idx].unsqueeze(0)  # (1, 1, D//2, 2)
        q = apply_rotary_emb(q, freqs)
        k = apply_rotary_emb(k, freqs)

        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)  # (B, H, S, D)

        # assign KV
        self.spa.assign_token(seq_id, token_idx, k[0], v[0])

        # build blockmask for this query
        block_list = self.spa.kv_cache_manager.sequence_table[seq_id][self.spa.layer_id]
        num_blocks = len(block_list)
        # block_mask = self.spa.build_blockmask(num_query_tokens=1, total_blocks=num_blocks)
        block_mask = self.spa.build_blockmask(num_query_tokens=1, total_tokens_written=self.spa.max_flat_idx_written) # token_idx + 1)

        # run attention
        out = self.spa.query(q, block_mask)  # (B, H, 1, D)
        out = out.transpose(1, 2).contiguous().view(B, S, -1)

        return self.wo(out)
        
def precompute_freqs_cis(seq_len, n_elem, base=10000, dtype=torch.float16):
    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1).to(dtype=dtype, device="cuda")

def apply_rotary_emb(x, freqs_cis):
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    freqs_cis = freqs_cis.view(xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2)
    x_out2 = torch.stack([
        xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
        xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ], dim=-1)
    return x_out2.flatten(3).type_as(x)

In [79]:
k_cache = torch.zeros(
    1, H, NUM_BLOCKS * BLOCK_SIZE, D,
    device=device, dtype=torch.float16  # ← match the model/layer dtype
)
v_cache = torch.zeros_like(k_cache)
pt = PageTableWithSwap(
    num_blocks=NUM_BLOCKS,
    block_size=BLOCK_SIZE,
    num_heads=H,
    head_dim=D,
    device=device,
    k_cache=k_cache,
    v_cache=v_cache
)
kvm = KVCacheManager(pt)
spa = SwappablePagedAttention(kvm, k_cache, v_cache, page_table_tensor, BLOCK_SIZE, layer_id=0)
layer = SwappablePagedAttentionLayer(hidden_dim=D*H, n_heads=H, head_dim=D, dtype=torch.float16, spa=spa)


In [82]:
x = torch.randn(1, 1, D * H, device=device, dtype=torch.float16)  # (B, 1, D_model)
out = layer(x, seq_id=0, token_idx=6)

print("Layer output:", out.shape)  # expect (1, 1, D_model)


sent
retrieved to pid=1
>> max_flat_idx_written: 8
>> kv_num_blocks.shape: torch.Size([1, 2, 1])
>> kv_num_blocks: tensor([[[10],
         [10]]], device='cuda:0', dtype=torch.int32)
>> kv_indices.shape: torch.Size([1, 2, 1, 10])
>> kv_indices: tensor([[[[1, 1, 0, 0, 0, 0, 0, 0, 0, 0]],

         [[1, 1, 0, 0, 0, 0, 0, 0, 0, 0]]]], device='cuda:0',
       dtype=torch.int32)
sent
retrieved to pid=0
⏪ Swapping in block with pid=0
sent
retrieved to pid=1
⏪ Swapping in block with pid=1
sent
retrieved to pid=0
⏪ Swapping in block with pid=0
sent
retrieved to pid=1
⏪ Swapping in block with pid=1
sent
retrieved to pid=0
⏪ Swapping in block with pid=0
sent
retrieved to pid=1
⏪ Swapping in block with pid=1
sent
retrieved to pid=0
⏪ Swapping in block with pid=0
sent
retrieved to pid=1
⏪ Swapping in block with pid=1
sent
retrieved to pid=0
⏪ Swapping in block with pid=0
Calling flex_attention with:
  q.shape: torch.Size([1, 2, 1, 4])
  k.shape: torch.Size([1, 2, 8, 4])
  v.shape: torch.Size([1, 2

In [None]:
x = torch.randn(1, 1, D * H, device=device, dtype=torch.float16)
out = layer(x, seq_id=0, token_idx=6)
print(out.shape)

In [80]:
BLOCK_SIZE = 4
NUM_BLOCKS = 2

In [81]:
# --- Config ---
seq_id = 0
T = 40  # number of tokens to decode
hidden_dim = D * H

# --- Token input embeddings (e.g., dummy input tokens) ---
token_inputs = torch.randn(T, 1, hidden_dim, device=device, dtype=torch.float16)  # (T, B=1, D_model)

# --- Create a fresh attention layer per run ---
layer = SwappablePagedAttentionLayer(
    hidden_dim=hidden_dim,
    n_heads=H,
    head_dim=D,
    dtype=torch.float16,
    spa=spa
)

# --- Run the decoding loop ---
outputs = []

for t in range(T):
    x_t = token_inputs[t].unsqueeze(0)  # (1, 1, D_model)
    out_t = layer(x_t, seq_id=seq_id, token_idx=t)  # (1, 1, D_model)
    outputs.append(out_t)

# --- Stack outputs ---
outputs = torch.cat(outputs, dim=1)  # (1, T, D_model)

print("Final output shape:", outputs.shape)  # Expect: (1, T, D_model)


>> max_flat_idx_written: 5
>> kv_num_blocks.shape: torch.Size([1, 2, 1])
>> kv_num_blocks: tensor([[[3],
         [3]]], device='cuda:0', dtype=torch.int32)
>> kv_indices.shape: torch.Size([1, 2, 1, 3])
>> kv_indices: tensor([[[[1, 0, 0]],

         [[1, 0, 0]]]], device='cuda:0', dtype=torch.int32)
Calling flex_attention with:
  q.shape: torch.Size([1, 2, 1, 4])
  k.shape: torch.Size([1, 2, 8, 4])
  v.shape: torch.Size([1, 2, 8, 4])
>> max_flat_idx_written: 6
>> kv_num_blocks.shape: torch.Size([1, 2, 1])
>> kv_num_blocks: tensor([[[3],
         [3]]], device='cuda:0', dtype=torch.int32)
>> kv_indices.shape: torch.Size([1, 2, 1, 3])
>> kv_indices: tensor([[[[1, 0, 0]],

         [[1, 0, 0]]]], device='cuda:0', dtype=torch.int32)
Calling flex_attention with:
  q.shape: torch.Size([1, 2, 1, 4])
  k.shape: torch.Size([1, 2, 8, 4])
  v.shape: torch.Size([1, 2, 8, 4])
>> max_flat_idx_written: 7
>> kv_num_blocks.shape: torch.Size([1, 2, 1])
>> kv_num_blocks: tensor([[[3],
         [3]]], dev

In [1]:
import torch
from torch.nn.attention.flex_attention import flex_attention, BlockMask

# --- Dummy config ---
B, H, S, D = 2, 1, 1, 4             # batch=2, heads=1, q_len=1, head_dim=4
num_kv_tokens = 8
block_size = 1                     # keep it simple: 1 token per block
num_kv_blocks = num_kv_tokens // block_size

# --- Dummy Q, K, V ---
q = torch.randn(B, H, S, D, device='cuda')                  # (2, 1, 1, 4)
k = torch.randn(B, H, num_kv_tokens, D, device='cuda')      # (2, 1, 8, 4)
v = torch.randn(B, H, num_kv_tokens, D, device='cuda')      # (2, 1, 8, 4)

# --- Simulate page tables ---
# For each sequence ID in batch, assign logical blocks to physical blocks
# e.g., both sequence 0 and 1 map to [0, 1, 2, 3, 4, 5, 6, 7] (identity)
logical_to_physical = torch.arange(num_kv_blocks, device='cuda')
dummy_page_table = [logical_to_physical.clone() for _ in range(B)]  # One per batch

# --- Build BlockMask ---
kv_num_blocks = torch.full((B, H, 1), num_kv_blocks, dtype=torch.int32, device='cuda')  # (2, 1, 1)
kv_indices = torch.full((B, H, 1, num_kv_blocks), -1, dtype=torch.int32, device='cuda') # (2, 1, 1, 8)

for b in range(B):
    for i in range(num_kv_blocks):
        kv_indices[b, 0, 0, i] = dummy_page_table[b][i].item()

block_mask = BlockMask.from_kv_blocks(
    kv_num_blocks=kv_num_blocks,
    kv_indices=kv_indices,
    full_kv_num_blocks=None,
    full_kv_indices=None,
    BLOCK_SIZE=(S, block_size),
    mask_mod=None,
    seq_lengths=(S, num_kv_tokens),
)._adjust(S, num_kv_tokens)

# --- Run FlexAttention ---
out = flex_attention(q, k, v, block_mask=block_mask)

# --- Verify ---
print("q.shape:", q.shape)
print("k.shape:", k.shape)
print("v.shape:", v.shape)
print("block_mask.shape:", block_mask.shape)
print("output.shape:", out.shape)  # should be (2, 1, 1, 4)


q.shape: torch.Size([2, 1, 1, 4])
k.shape: torch.Size([2, 1, 8, 4])
v.shape: torch.Size([2, 1, 8, 4])
block_mask.shape: (2, 1, 1, 8)
output.shape: torch.Size([2, 1, 1, 4])


In [2]:
import torch

# Dummy BlockMask that mimics FlexAttention behavior
class DummyBlockMask:
    def __init__(self, kv_num_blocks, kv_indices, shape):
        self.kv_num_blocks = kv_num_blocks
        self.kv_indices = kv_indices
        self._shape = shape

    def _adjust(self, q_len, kv_len):
        self._shape = (self.kv_num_blocks.shape[0], self.kv_num_blocks.shape[1], q_len, kv_len)
        return self

    def shape(self):
        return self._shape


def build_blockmask_batched(page_tables, seq_ids, num_query_tokens, total_tokens_written, page_size):
    B = len(seq_ids)
    H = 2  # number of heads
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    kv_num_blocks = []
    kv_indices = []

    for seq_id in seq_ids:
        page_table = page_tables[seq_id]  # 1D tensor: logical → physical block id
        valid = page_table != -1
        pids = page_table[valid]
        total_blocks = pids.numel()

        kv_num_blocks.append(torch.full((H, 1), total_blocks, dtype=torch.int32, device=device))
        pid_tensor = torch.full((H, 1, total_blocks), -1, dtype=torch.int32, device=device)
        for h in range(H):
            pid_tensor[h, 0, :total_blocks] = pids
        kv_indices.append(pid_tensor)

    kv_num_blocks = torch.stack(kv_num_blocks, dim=0)  # (B, H, 1)
    kv_indices = torch.stack(kv_indices, dim=0)        # (B, H, 1, T)

    block_mask = DummyBlockMask(kv_num_blocks, kv_indices, kv_indices.shape)
    return block_mask._adjust(num_query_tokens, total_tokens_written), kv_num_blocks, kv_indices


# --- Test the blockmask ---
B = 2
T = 8
page_size = 1
seq_ids = list(range(B))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dummy page tables
page_tables = {
    0: torch.tensor([1, 0, 2, -1, -1, -1, -1, -1], dtype=torch.int32, device=device),
    1: torch.tensor([3, 4, 5, -1, -1, -1, -1, -1], dtype=torch.int32, device=device),
}

# Build and test blockmask
block_mask, kv_num_blocks, kv_indices = build_blockmask_batched(
    page_tables=page_tables,
    seq_ids=seq_ids,
    num_query_tokens=1,
    total_tokens_written=T,
    page_size=page_size,
)

# Print results
print("BlockMask shape:", block_mask._shape)
print("KV Num Blocks:\n", kv_num_blocks)
print("KV Indices (seq 0, head 0):", kv_indices[0, 0, 0])
print("KV Indices (seq 1, head 0):", kv_indices[1, 0, 0])


BlockMask shape: (2, 2, 1, 8)
KV Num Blocks:
 tensor([[[3],
         [3]],

        [[3],
         [3]]], device='cuda:0', dtype=torch.int32)
KV Indices (seq 0, head 0): tensor([1, 0, 2], device='cuda:0', dtype=torch.int32)
KV Indices (seq 1, head 0): tensor([3, 4, 5], device='cuda:0', dtype=torch.int32)


In [None]:
import torch
from torch.nn.attention.flex_attention import flex_attention, BlockMask

# --- Config ---
B, H, S, D = 2, 2, 1, 4  # 2 sequences, 2 heads, 1 query token, head_dim=4
page_size = 1
KV = 16  # total kv tokens
device = torch.device('cuda')

# --- Dummy QKV ---
q = torch.randn(B, H, S, D, device=device)
k = torch.randn(B, H, KV, D, device=device)
v = torch.randn(B, H, KV, D, device=device)

# --- Dummy page tables (NO -1s here!) ---
page_tables = {
    0: torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.int32, device=device),
    1: torch.tensor([8, 9, 10, 11, 12, 13, 14, 15], dtype=torch.int32, device=device),
}

seq_ids = list(page_tables.keys())

# --- Blockmask builder ---
def build_blockmask_batched(page_tables, seq_ids, num_query_tokens, total_tokens_written, page_size):
    B = len(seq_ids)
    H = 2  # number of heads
    device = torch.device("cuda")

    # max_blocks = max((pt != -1).sum().item() for pt in page_tables.values())
    max_blocks = total_tokens_written // page_size  # = 16 // 1 = 16


    kv_num_blocks = torch.full((B, H, 1), 0, dtype=torch.int32, device=device)
    kv_indices = torch.full((B, H, 1, max_blocks), -1, dtype=torch.int32, device=device)

    for b, seq_id in enumerate(seq_ids):
        l2p = page_tables[seq_id]
        valid = l2p != -1
        pids = l2p[valid]
        T = pids.numel()

        kv_num_blocks[b] = torch.full((H, 1), T, dtype=torch.int32, device=device)
        for h in range(H):
            kv_indices[b, h, 0, :T] = pids 

    # Safety checks
    assert (kv_indices >= -1).all(), "kv_indices contain values < -1!"
    # assert (kv_indices < total_tokens_written).all(), "kv_indices out-of-bounds!"
    assert (kv_indices < total_tokens_written).all() | (kv_indices == -1).all()
    assert kv_indices.max().item() < total_tokens_written, "kv_indices contains out-of-bound index!"


    print("K.shape[-2]:", k.shape[-2])
    print("total_tokens_written:", KV)

    print("kv_indices:", kv_indices)
    print("kv_indices.max():", kv_indices.max().item())
    print("total_tokens_written:", total_tokens_written)

    block_mask = BlockMask.from_kv_blocks(
        kv_num_blocks=kv_num_blocks,
        kv_indices=kv_indices,
        full_kv_num_blocks=None,
        full_kv_indices=None,
        BLOCK_SIZE=(num_query_tokens, page_size),
        mask_mod=None,
        seq_lengths=(num_query_tokens, total_tokens_written),
    )

    return block_mask
    # return block_mask._adjust(num_query_tokens, total_tokens_written)


# --- Build mask ---
block_mask = build_blockmask_batched(
    page_tables=page_tables,
    seq_ids=seq_ids,
    num_query_tokens=q.shape[2],
    total_tokens_written=k.shape[2],
    page_size=page_size,
)

# --- Call flex_attention ---
out = flex_attention(q, k, v, block_mask=block_mask)

# --- Done ---
print("q.shape:", q.shape)
print("k.shape:", k.shape)
print("block_mask.shape:", block_mask.shape)
print("output.shape:", out.shape)


K.shape[-2]: 16
total_tokens_written: 16
kv_indices: tensor([[[[ 0,  1,  2,  3,  4,  5,  6,  7, -1, -1, -1, -1, -1, -1, -1, -1]],

         [[ 0,  1,  2,  3,  4,  5,  6,  7, -1, -1, -1, -1, -1, -1, -1, -1]]],


        [[[ 8,  9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1]],

         [[ 8,  9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1]]]],
       device='cuda:0', dtype=torch.int32)
kv_indices.max(): 15
total_tokens_written: 16
q.shape: torch.Size([2, 2, 1, 4])
k.shape: torch.Size([2, 2, 16, 4])
block_mask.shape: (2, 2, 1, 16)
output.shape: torch.Size([2, 2, 1, 4])


In [1]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"  # Helps surface actual error sources


In [3]:
import torch
from torch.nn.attention.flex_attention import flex_attention, BlockMask

# --- Config ---
B, H, S, D = 2, 2, 1, 4
page_size = 1
max_decode_len = 8  # decode 8 tokens per sequence
device = torch.device('cuda')

# --- Allocate flat K/V cache ---
KV = B * max_decode_len
k_cache = torch.zeros((1, H, KV, D), device=device)
v_cache = torch.zeros((1, H, KV, D), device=device)

# --- Initialize empty page_table and sequence_table ---
page_table = torch.full((B, max_decode_len), -1, dtype=torch.int32, device=device)
sequence_table = {i: [] for i in range(B)}

# --- LogicalBlock class for test ---
class LogicalBlock:
    def __init__(self, pid):
        self.physical_block_id = pid
        self.token_count = 0
        self.status = "gpu"

# --- Assign token ---
def assign_token(seq_id, token_idx, k_vec, v_vec):
    logical_block_id = token_idx // page_size
    offset = token_idx % page_size
    pid = seq_id * max_decode_len + token_idx  # flat block assignment (mock logic)

    # Record block if first time
    if len(sequence_table[seq_id]) <= logical_block_id:
        block = LogicalBlock(pid)
        sequence_table[seq_id].append(block)

    flat_idx = pid * page_size + offset
    k_cache[0, :, flat_idx, :] = k_vec.squeeze(1)
    v_cache[0, :, flat_idx, :] = v_vec.squeeze(1)

    page_table[seq_id, logical_block_id] = pid
    sequence_table[seq_id][logical_block_id].token_count += 1


# --- Build blockmask ---
def build_blockmask_batched(page_tables, seq_ids, num_query_tokens, total_tokens_written, page_size):
    B = len(seq_ids)
    H = k_cache.shape[1]
    max_blocks = total_tokens_written // page_size
    device = k_cache.device

    kv_num_blocks = torch.full((B, H, 1), 0, dtype=torch.int32, device=device)
    kv_indices = torch.full((B, H, 1, max_blocks), -1, dtype=torch.int32, device=device)

    for b, seq_id in enumerate(seq_ids):
        l2p = page_tables[seq_id]
        valid = l2p != -1
        pids = l2p[valid]
        T = pids.numel()

        kv_num_blocks[b] = torch.full((H, 1), T, dtype=torch.int32, device=device)
        for h in range(H):
            kv_indices[b, h, 0, :T] = pids

    block_mask = BlockMask.from_kv_blocks(
        kv_num_blocks=kv_num_blocks,
        kv_indices=kv_indices,
        full_kv_num_blocks=None,
        full_kv_indices=None,
        BLOCK_SIZE=(num_query_tokens, page_size),
        mask_mod=None,
        seq_lengths=(num_query_tokens, total_tokens_written),
    )

    return block_mask


# --- Decode loop ---
for t in range(max_decode_len):
    x = torch.randn(B, S, 3*H * D, device=device)
    q, k, v = x.split([H*D, H*D, H*D], dim=-1)
    q = q.view(B, S, H, D).transpose(1, 2)
    k = k.view(B, S, H, D).transpose(1, 2)
    v = v.view(B, S, H, D).transpose(1, 2)

    for b in range(B):
        assign_token(b, t, k[b], v[b])

    block_mask = build_blockmask_batched(
        page_tables={b: page_table[b] for b in range(B)},
        seq_ids=list(range(B)),
        num_query_tokens=S,
        total_tokens_written=KV,
        page_size=page_size,
    )

    out = flex_attention(q, k_cache, v_cache, block_mask=block_mask)
    print(f"Step {t} | out.shape: {out.shape}")


Step 0 | out.shape: torch.Size([2, 2, 1, 4])
Step 1 | out.shape: torch.Size([2, 2, 1, 4])
Step 2 | out.shape: torch.Size([2, 2, 1, 4])
Step 3 | out.shape: torch.Size([2, 2, 1, 4])
Step 4 | out.shape: torch.Size([2, 2, 1, 4])
Step 5 | out.shape: torch.Size([2, 2, 1, 4])
Step 6 | out.shape: torch.Size([2, 2, 1, 4])
Step 7 | out.shape: torch.Size([2, 2, 1, 4])


In [5]:
import torch
from torch.nn.attention.flex_attention import flex_attention, BlockMask

# Config
B, H, S, D = 2, 2, 1, 4
page_size = 1
max_decode_len = 8
KV_CAPACITY = 8  # triggers eviction after 8 blocks
device = torch.device('cuda')

KV = KV_CAPACITY * page_size
k_cache = torch.zeros((1, H, KV, D), device=device)
v_cache = torch.zeros((1, H, KV, D), device=device)
page_table = torch.full((B, max_decode_len), -1, dtype=torch.int32, device=device)
sequence_table = {i: [] for i in range(B)}

class LogicalBlock:
    def __init__(self, pid):
        self.physical_block_id = pid
        self.token_count = 0
        self.status = "gpu"

class EvictionPolicy:
    def __init__(self, capacity):
        self.capacity = capacity
        self.active = []

    def touch(self, block):
        if block in self.active:
            self.active.remove(block)
        self.active.append(block)

    def evict_if_needed(self):
        if len(self.active) > self.capacity:
            evicted = self.active.pop(0)
            evicted.status = "cpu"
            print(f"Evicted pid={evicted.physical_block_id}")
            return evicted
        return None

eviction_policy = EvictionPolicy(KV_CAPACITY)


next_pid = [0]
def assign_token(seq_id, token_idx, k_vec, v_vec):
    lid = token_idx // page_size
    offset = token_idx % page_size
    pid = seq_id * max_decode_len + token_idx

    if len(sequence_table[seq_id]) <= lid:
        block = LogicalBlock(pid)
        eviction_policy.evict_if_needed()
        pid = next_pid[0] % KV_CAPACITY
        next_pid[0] += 1
        eviction_policy.touch(block)
        sequence_table[seq_id].append(block)

    block = sequence_table[seq_id][lid]
    flat_idx = pid * page_size + offset
    k_cache[0, :, flat_idx, :] = k_vec.squeeze(1)
    v_cache[0, :, flat_idx, :] = v_vec.squeeze(1)
    page_table[seq_id, lid] = pid
    block.token_count += 1
    block.status = "gpu"
    eviction_policy.touch(block)

def build_blockmask_batched(page_tables, seq_ids, num_q, total_kv, page_size):
    B, H = len(seq_ids), k_cache.shape[1]
    max_blocks = total_kv // page_size
    kv_num_blocks = torch.zeros((B, H, 1), dtype=torch.int32, device=device)
    kv_indices = torch.full((B, H, 1, max_blocks), -1, dtype=torch.int32, device=device)

    for b, seq_id in enumerate(seq_ids):
        pids = page_tables[seq_id][page_tables[seq_id] != -1]
        kv_num_blocks[b] = torch.full((H, 1), len(pids), dtype=torch.int32, device=device)
        for h in range(H):
            kv_indices[b, h, 0, :len(pids)] = pids

    return BlockMask.from_kv_blocks(kv_num_blocks, kv_indices, None, None, (num_q, page_size), None, (num_q, total_kv))

# Decode loop
for t in range(max_decode_len):
    x = torch.randn(B, S, 3 * H * D, device=device)
    q, k, v = x.split([H * D] * 3, dim=-1)
    q, k, v = [z.view(B, S, H, D).transpose(1, 2) for z in (q, k, v)]

    for b in range(B):
        assign_token(b, t, k[b], v[b])

    block_mask = build_blockmask_batched({b: page_table[b] for b in range(B)}, list(range(B)), S, KV, page_size)
    out = flex_attention(q, k_cache, v_cache, block_mask=block_mask)
    print(f"Step {t} | out.shape: {out.shape}")


Step 0 | out.shape: torch.Size([2, 2, 1, 4])
Step 1 | out.shape: torch.Size([2, 2, 1, 4])
Step 2 | out.shape: torch.Size([2, 2, 1, 4])
Step 3 | out.shape: torch.Size([2, 2, 1, 4])
Evicted pid=0
Step 4 | out.shape: torch.Size([2, 2, 1, 4])
Evicted pid=8
Evicted pid=1
Step 5 | out.shape: torch.Size([2, 2, 1, 4])
Evicted pid=9
Evicted pid=2
Step 6 | out.shape: torch.Size([2, 2, 1, 4])
Evicted pid=10
Evicted pid=3
Step 7 | out.shape: torch.Size([2, 2, 1, 4])


In [None]:
# Re-import after code execution environment reset
import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import flex_attention, BlockMask

# --- Helper functions ---
def precompute_freqs_cis(seq_len, n_elem, base=10000, dtype=torch.float16):
    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1).to(dtype=dtype, device="cuda")

def apply_rotary_emb(x, freqs_cis):
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    freqs_cis = freqs_cis.view(xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2)
    x_out2 = torch.stack([
        xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
        xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ], dim=-1)
    return x_out2.flatten(3).type_as(x)

# --- Blockmask builder ---
def build_blockmask_batched(page_tables, seq_ids, num_q, total_kv, page_size, k_cache):
    B, H = len(seq_ids), k_cache.shape[1]
    # max_blocks = total_kv // page_size
    max_blocks = max((page_tables[seq_id] != -1).sum().item() for seq_id in seq_ids)
    max_blocks = min(max_blocks, k_cache.shape[2] // page_size)
    device = k_cache.device

    kv_num_blocks = torch.zeros((B, H, 1), dtype=torch.int32, device=device)
    kv_indices = torch.full((B, H, 1, max_blocks), -1, dtype=torch.int32, device=device)
    max_valid_pid = k_cache.shape[2] - 1  # e.g. 3 for shape (1, 2, 4, 4)
    for b, seq_id in enumerate(seq_ids):
        pids = page_tables[seq_id][page_tables[seq_id] != -1]
        kv_num_blocks[b] = torch.full((H, 1), len(pids), dtype=torch.int32, device=device)
        safe_pids = pids[pids <= max_valid_pid]
        for h in range(H):
            # kv_indices[b, h, 0, :len(pids)] = pids
            kv_indices[b, h, 0, :safe_pids.numel()] = safe_pids


    return BlockMask.from_kv_blocks(kv_num_blocks, kv_indices, None, None, (num_q, page_size), None, (num_q, total_kv))

# --- Batched Attention Layer ---
class SwappablePagedAttentionLayerBatched(nn.Module):
    def __init__(self, hidden_dim, n_heads, head_dim, dtype, spa, max_pos=512):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.spa = spa
        self.block_size = spa.page_size

        self.wqkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False, device="cuda", dtype=dtype)
        self.wo = nn.Linear(hidden_dim, hidden_dim, bias=False, device="cuda", dtype=dtype)

        self.freqs_cis = precompute_freqs_cis(seq_len=max_pos, n_elem=self.head_dim, dtype=dtype)

    def forward(self, x, seq_ids, token_idxs):
        B, S, _ = x.shape
        kv_size = self.n_heads * self.head_dim
        q, k, v = self.wqkv(x).split([kv_size] * 3, dim=-1)

        q = q.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)

        # for b, t in enumerate(token_idxs):
        #     freqs = self.freqs_cis[t].unsqueeze(0)
        #     q[b] = apply_rotary_emb(q[b].unsqueeze(0), freqs)
        #     k[b] = apply_rotary_emb(k[b].unsqueeze(0), freqs)

        q = q.clone()
        k = k.clone()
        for b, t in enumerate(token_idxs):
            freqs = self.freqs_cis[t]  # (H, D//2, 2)
            q_b = apply_rotary_emb(q[b:b+1], freqs)  # returns (1, H, S, D)
            k_b = apply_rotary_emb(k[b:b+1], freqs)
            q[b:b+1] = q_b
            k[b:b+1] = k_b



        for b, (sid, tidx) in enumerate(zip(seq_ids, token_idxs)):
            self.spa.assign_token(sid, tidx, k[b], v[b])

        block_mask = build_blockmask_batched(
            page_tables=self.spa.page_table,
            seq_ids=seq_ids,
            num_q=S,
            total_kv=self.spa.max_flat_idx_written,
            page_size=self.block_size,
            k_cache=self.spa.k_cache
        )
        print("== DEBUG INFO ==")
        print("max_flat_idx_written:", self.spa.max_flat_idx_written)
        print("q.shape:", q.shape)
        print("k_cache.shape:", self.spa.k_cache.shape)
        print("block_mask shape:", block_mask.shape)
        print("block_mask.kv_indices.max():", block_mask.kv_indices.max().item())
        print("k_cache max valid idx:", self.spa.k_cache.shape[2] - 1)

        out = flex_attention(q, self.spa.k_cache, self.spa.v_cache, block_mask=block_mask)
        return self.wo(out.transpose(1, 2).contiguous().view(B, S, -1))


In [14]:
def apply_rotary_emb(x, freqs_cis):
    # x: (B, H, S, D)
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)  # (..., D/2, 2)
    freqs_cis = freqs_cis.unsqueeze(1).unsqueeze(2)    # (1, 1, 1, D/2, 2)

    # Broadcastable multiplication
    x_out2 = torch.stack([
        xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
        xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ], dim=-1)

    return x_out2.flatten(3).type_as(x)


In [15]:
import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import BlockMask, flex_attention

# --- Config ---
B, H, D = 2, 2, 4                   # batch, heads, head_dim
T = 6                              # tokens per sequence
page_size = 1
num_blocks = 4                     # total physical cache blocks
hidden_dim = H * D
dtype = torch.float16
device = "cuda"

# --- Flat KV buffers ---
k_cache = torch.zeros((1, H, num_blocks * page_size, D), device=device, dtype=dtype)
v_cache = torch.zeros_like(k_cache)

# --- Real PageTable and Manager ---
pt = PageTableWithSwap(num_blocks, page_size, H, D, device=device, k_cache=k_cache, v_cache=v_cache)
manager = KVCacheManager(page_table=pt)

# --- Real SwappablePagedAttention ---
spa = SwappablePagedAttention(
    kv_cache_manager=manager,
    k_cache=k_cache,
    v_cache=v_cache,
    page_table_tensor=torch.full((B, T), -1, dtype=torch.int32, device=device),
    page_size=page_size,
    layer_id=0
)

# --- The Layer ---
layer = SwappablePagedAttentionLayerBatched(
    hidden_dim=hidden_dim,
    n_heads=H,
    head_dim=D,
    dtype=dtype,
    spa=spa
)

# --- Decode loop ---
outputs = []
for t in range(T):
    x_t = torch.randn(B, 1, hidden_dim, device=device, dtype=dtype)
    out_t = layer(x_t, seq_ids=list(range(B)), token_idxs=[t] * B)  # (B, 1, D_model)
    outputs.append(out_t)

final_output = torch.cat(outputs, dim=1)  # (B, T, D_model)

# --- Done ---
print("✅ Final output shape:", final_output.shape)
print("✅ No NaNs:", not torch.isnan(final_output).any().item())


kv_indices max: 3
kv_indices unique: tensor([2, 3], device='cuda:0', dtype=torch.int32)
== DEBUG INFO ==
max_flat_idx_written: 4
q.shape: torch.Size([2, 2, 1, 4])
k_cache.shape: torch.Size([1, 2, 4, 4])
block_mask shape: (2, 2, 1, 4)


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [12]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [19]:
import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import flex_attention, BlockMask

# === CONFIG ===
B, H, D = 2, 2, 4                # batch size, num heads, head dim
T = 6                           # tokens per sequence
page_size = 1
num_blocks = 4                  # small capacity to force evictions
hidden_dim = H * D
dtype = torch.float16
device = "cuda"

# === FLAT CACHE ===
k_cache = torch.zeros((1, H, num_blocks * page_size, D), device=device, dtype=dtype)
v_cache = torch.zeros_like(k_cache)

# === PAGE TABLE + MANAGER ===
pt = PageTableWithSwap(num_blocks, page_size, H, D, device=device, k_cache=k_cache, v_cache=v_cache)
manager = KVCacheManager(page_table=pt)

# === SWAPPABLE PAGER ===
spa = SwappablePagedAttention(
    kv_cache_manager=manager,
    k_cache=k_cache,
    v_cache=v_cache,
    page_table_tensor=torch.full((B, T), -1, dtype=torch.int32, device=device),
    page_size=page_size,
    layer_id=0
)

# === ROTARY UTILS ===
def precompute_freqs_cis(seq_len, n_elem, base=10000, dtype=torch.float16):
    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1).to(dtype=dtype, device="cuda")

def apply_rotary_emb(x, freqs_cis):
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    freqs_cis = freqs_cis.unsqueeze(1).unsqueeze(2)
    x_out2 = torch.stack([
        xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
        xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ], dim=-1)
    return x_out2.flatten(3).type_as(x)

# === PATCHED BLOCKMASK BUILDER ===
def build_blockmask_batched(page_tables, seq_ids, num_q, total_kv, page_size, k_cache):
    B, H = len(seq_ids), k_cache.shape[1]
    device = k_cache.device

    max_blocks = k_cache.shape[2] // page_size  # 🧠 max physical blocks, not tokens
    kv_num_blocks = torch.zeros((B, H, 1), dtype=torch.int32, device=device)
    kv_indices = torch.full((B, H, 1, max_blocks), -1, dtype=torch.int32, device=device)

    for b, sid in enumerate(seq_ids):
        raw_pids = page_tables[sid]
        valid_pids = raw_pids[(raw_pids >= 0) & (raw_pids < max_blocks)]
        num_valid = valid_pids.numel()

        kv_num_blocks[b] = torch.full((H, 1), num_valid, dtype=torch.int32, device=device)
        for h in range(H):
            kv_indices[b, h, 0, :num_valid] = valid_pids[:max_blocks]  # clamp to capacity
            kv_indices[b, h, 0, num_valid:] = -1  # pad with -1

    assert (kv_indices < max_blocks).all() | (kv_indices == -1).all(), "🟥 kv_indices out of range"

    return BlockMask.from_kv_blocks(
        kv_num_blocks=kv_num_blocks,
        kv_indices=kv_indices,
        full_kv_num_blocks=None,
        full_kv_indices=None,
        BLOCK_SIZE=(num_q, page_size),
        mask_mod=None,
        seq_lengths=(num_q, total_kv),
    )


In [20]:
# === PATCHED ATTENTION LAYER USING spa.query ===
class SwappablePagedAttentionLayerBatched(nn.Module):
    def __init__(self, hidden_dim, n_heads, head_dim, dtype, spa, max_pos=512):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.spa = spa
        self.block_size = spa.page_size

        self.wqkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False, device="cuda", dtype=dtype)
        self.wo = nn.Linear(hidden_dim, hidden_dim, bias=False, device="cuda", dtype=dtype)
        self.freqs_cis = precompute_freqs_cis(max_pos, head_dim, dtype=dtype)

    def forward(self, x, seq_ids, token_idxs):
        B, S, _ = x.shape
        qkv = self.wqkv(x)
        q, k, v = qkv.split(self.hidden_dim, dim=-1)
        q = q.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)  # (B, H, S, D)
        k = k.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)

        q, k = q.clone(), k.clone()
        for b, t in enumerate(token_idxs):
            freqs = self.freqs_cis[t]
            q[b:b+1] = apply_rotary_emb(q[b:b+1], freqs)
            k[b:b+1] = apply_rotary_emb(k[b:b+1], freqs)

        for b, (sid, tidx) in enumerate(zip(seq_ids, token_idxs)):
            self.spa.assign_token(sid, tidx, k[b], v[b])

        block_mask = build_blockmask_batched(
            page_tables=self.spa.page_table,
            seq_ids=seq_ids,
            num_q=S,
            total_kv=self.spa.max_flat_idx_written,
            page_size=self.block_size,
            k_cache=self.spa.k_cache
        )

        print("\n🔍 DEBUG BLOCK_MASK BEFORE FLEX_ATTENTION")
        flat_kv_indices = block_mask.kv_indices.view(-1)
        invalid = flat_kv_indices[(flat_kv_indices < 0) | (flat_kv_indices >= self.spa.k_cache.shape[2])]
        print("🟥 Found invalid indices:", invalid)
        print("✅ Unique indices:", torch.unique(flat_kv_indices))
        print("✅ Max valid:", self.spa.k_cache.shape[2] - 1)

        out = self.spa.query(q, block_mask)  # ✅ Use SPA query for retrieval
        return self.wo(out.transpose(1, 2).contiguous().view(B, S, -1))


# === RUN TEST ===
layer = SwappablePagedAttentionLayerBatched(
    hidden_dim=hidden_dim,
    n_heads=H,
    head_dim=D,
    dtype=dtype,
    spa=spa
)

outputs = []
for t in range(T):
    x_t = torch.randn(B, 1, hidden_dim, device=device, dtype=dtype)
    out_t = layer(x_t, seq_ids=list(range(B)), token_idxs=[t] * B)
    outputs.append(out_t)

final_out = torch.cat(outputs, dim=1)
print("✅ Final output shape:", final_out.shape)



🔍 DEBUG BLOCK_MASK BEFORE FLEX_ATTENTION
🟥 Found invalid indices: tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], device='cuda:0',
       dtype=torch.int32)
✅ Unique indices: tensor([-1,  2,  3], device='cuda:0', dtype=torch.int32)
✅ Max valid: 3
Calling flex_attention with:
  q.shape: torch.Size([2, 2, 1, 4])
  k.shape: torch.Size([1, 2, 4, 4])
  v.shape: torch.Size([1, 2, 4, 4])

🔍 DEBUG BLOCK_MASK BEFORE FLEX_ATTENTION
🟥 Found invalid indices: tensor([-1, -1, -1, -1, -1, -1, -1, -1], device='cuda:0', dtype=torch.int32)
✅ Unique indices: tensor([-1,  0,  1,  2,  3], device='cuda:0', dtype=torch.int32)
✅ Max valid: 3
Calling flex_attention with:
  q.shape: torch.Size([2, 2, 1, 4])
  k.shape: torch.Size([1, 2, 4, 4])
  v.shape: torch.Size([1, 2, 4, 4])
sent
sent

🔍 DEBUG BLOCK_MASK BEFORE FLEX_ATTENTION
🟥 Found invalid indices: tensor([-1, -1, -1, -1], device='cuda:0', dtype=torch.int32)
✅ Unique indices: tensor([-1,  0,  1,  2,  3], device='cuda:0', dtype=torch.int32)
✅ Max 

In [26]:
import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import flex_attention, BlockMask

# === CONFIG ===
B, H, D = 2, 2, 4
T = 6
page_size = 1
num_blocks = 4
hidden_dim = H * D
dtype = torch.float16
device = "cuda"

# === FLAT CACHE ===
k_cache = torch.zeros((1, H, num_blocks * page_size, D), device=device, dtype=dtype)
v_cache = torch.zeros_like(k_cache)

# === PAGE TABLE + MANAGER ===
pt = PageTableWithSwap(num_blocks, page_size, H, D, device=device, k_cache=k_cache, v_cache=v_cache)
manager = KVCacheManager(page_table=pt)

# === SWAPPABLE PAGER ===
spa = SwappablePagedAttention(
    kv_cache_manager=manager,
    k_cache=k_cache,
    v_cache=v_cache,
    page_table_tensor=torch.full((B, T), -1, dtype=torch.int32, device=device),
    page_size=page_size,
    layer_id=0
)

# === ROTARY UTILS ===
def precompute_freqs_cis(seq_len, n_elem, base=10000, dtype=torch.float16):
    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1).to(dtype=dtype, device="cuda")

def apply_rotary_emb(x, freqs_cis):
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    freqs_cis = freqs_cis.unsqueeze(1).unsqueeze(2)
    x_out2 = torch.stack([
        xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
        xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ], dim=-1)
    return x_out2.flatten(3).type_as(x)

# === PATCHED CAUSAL BLOCKMASK BUILDER ===
def build_blockmask_batched(page_tables, seq_ids, token_idxs, num_q, total_kv, page_size, k_cache):
    B, H = len(seq_ids), k_cache.shape[1]
    device = k_cache.device

    max_blocks = k_cache.shape[2] // page_size
    kv_num_blocks = torch.zeros((B, H, 1), dtype=torch.int32, device=device)
    kv_indices = torch.full((B, H, 1, max_blocks), -1, dtype=torch.int32, device=device)

    for b, sid in enumerate(seq_ids):
        raw_pids = page_tables[sid]
        max_token_idx = token_idxs[b]
        max_block_idx = max_token_idx // page_size

        causal_pids = raw_pids[:max_block_idx + 1]
        valid_pids = causal_pids[(causal_pids >= 0) & (causal_pids < max_blocks)]
        num_valid = valid_pids.numel()

        kv_num_blocks[b] = torch.full((H, 1), num_valid, dtype=torch.int32, device=device)
        for h in range(H):
            kv_indices[b, h, 0, :num_valid] = valid_pids[:max_blocks]
            kv_indices[b, h, 0, num_valid:] = -1

    assert (kv_indices < max_blocks).all() | (kv_indices == -1).all(), "🟥 kv_indices out of range"

    return BlockMask.from_kv_blocks(
        kv_num_blocks=kv_num_blocks,
        kv_indices=kv_indices,
        full_kv_num_blocks=None,
        full_kv_indices=None,
        BLOCK_SIZE=(num_q, page_size),
        mask_mod=None,
        seq_lengths=(num_q, total_kv),
    )

# === ATTENTION LAYER ===
class SwappablePagedAttentionLayerBatched(nn.Module):
    def __init__(self, hidden_dim, n_heads, head_dim, dtype, spa, max_pos=512):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.spa = spa
        self.block_size = spa.page_size

        self.wqkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False, device="cuda", dtype=dtype)
        self.wo = nn.Linear(hidden_dim, hidden_dim, bias=False, device="cuda", dtype=dtype)
        self.freqs_cis = precompute_freqs_cis(max_pos, head_dim, dtype=dtype)

    def forward(self, x, seq_ids, token_idxs):
        B, S, _ = x.shape
        qkv = self.wqkv(x)
        q, k, v = qkv.split(self.hidden_dim, dim=-1)
        q = q.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)

        q, k = q.clone(), k.clone()
        for b, t in enumerate(token_idxs):
            freqs = self.freqs_cis[t]
            q[b:b+1] = apply_rotary_emb(q[b:b+1], freqs)
            k[b:b+1] = apply_rotary_emb(k[b:b+1], freqs)

        for b, (sid, tidx) in enumerate(zip(seq_ids, token_idxs)):
            self.spa.assign_token(sid, tidx, k[b], v[b])

        block_mask = build_blockmask_batched(
            page_tables=self.spa.page_table,
            seq_ids=seq_ids,
            token_idxs=token_idxs,
            num_q=S,
            total_kv=self.spa.max_flat_idx_written,
            page_size=self.block_size,
            k_cache=self.spa.k_cache
        )

        print("\n🔍 DEBUG BLOCK_MASK BEFORE FLEX_ATTENTION")
        flat_kv_indices = block_mask.kv_indices.view(-1)
        invalid = flat_kv_indices[(flat_kv_indices < 0) | (flat_kv_indices >= self.spa.k_cache.shape[2])]
        print("🟥 Found invalid indices:", invalid)
        print("✅ Unique indices:", torch.unique(flat_kv_indices))
        print("✅ Max valid:", self.spa.k_cache.shape[2] - 1)

        out = self.spa.query(q, block_mask)
        return self.wo(out.transpose(1, 2).contiguous().view(B, S, -1))

# === RUN TEST ===
layer = SwappablePagedAttentionLayerBatched(
    hidden_dim=hidden_dim,
    n_heads=H,
    head_dim=D,
    dtype=dtype,
    spa=spa
)

outputs = []
for t in range(T):
    x_t = torch.randn(B, 1, hidden_dim, device=device, dtype=dtype)
    out_t = layer(x_t, seq_ids=list(range(B)), token_idxs=[t] * B)
    outputs.append(out_t)

final_out = torch.cat(outputs, dim=1)
print("✅ Final output shape:", final_out.shape)



🔍 DEBUG BLOCK_MASK BEFORE FLEX_ATTENTION
🟥 Found invalid indices: tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], device='cuda:0',
       dtype=torch.int32)
✅ Unique indices: tensor([-1,  2,  3], device='cuda:0', dtype=torch.int32)
✅ Max valid: 3
Calling flex_attention with:
  q.shape: torch.Size([2, 2, 1, 4])
  k.shape: torch.Size([1, 2, 4, 4])
  v.shape: torch.Size([1, 2, 4, 4])

🔍 DEBUG BLOCK_MASK BEFORE FLEX_ATTENTION
🟥 Found invalid indices: tensor([-1, -1, -1, -1, -1, -1, -1, -1], device='cuda:0', dtype=torch.int32)
✅ Unique indices: tensor([-1,  0,  1,  2,  3], device='cuda:0', dtype=torch.int32)
✅ Max valid: 3
Calling flex_attention with:
  q.shape: torch.Size([2, 2, 1, 4])
  k.shape: torch.Size([1, 2, 4, 4])
  v.shape: torch.Size([1, 2, 4, 4])
sent
sent

🔍 DEBUG BLOCK_MASK BEFORE FLEX_ATTENTION
🟥 Found invalid indices: tensor([-1, -1, -1, -1], device='cuda:0', dtype=torch.int32)
✅ Unique indices: tensor([-1,  0,  1,  2,  3], device='cuda:0', dtype=torch.int32)
✅ Max 

In [49]:
from torch.nn.attention.flex_attention import BlockMask
class SwappablePagedAttention:
    def __init__(self, kv_cache_manager, k_cache, v_cache, page_table_tensor, page_size, layer_id):
        self.kv_cache_manager = kv_cache_manager
        self.k_cache = k_cache
        self.v_cache = v_cache
        self.page_table = page_table_tensor
        self.page_size = page_size
        self.layer_id = layer_id
        self.max_flat_idx_written = 0

    def assign_token(self, seq_id, token_idx, key_vec, value_vec):
        logical_block_id = token_idx // self.page_size
        offset = token_idx % self.page_size

        # ⬇️ Expand page_table if the logical block id exceeds current columns
        if logical_block_id >= self.page_table.shape[1]:
            pad = logical_block_id - self.page_table.shape[1] + 1
            new_table = torch.full(
                (self.page_table.shape[0], self.page_table.shape[1] + pad),
                fill_value=-1,
                dtype=self.page_table.dtype,
                device=self.page_table.device
            )
            new_table[:, :self.page_table.shape[1]] = self.page_table
            self.page_table = new_table

        # Get or grow the block list for this sequence/layer
        block_list = self.kv_cache_manager.sequence_table[seq_id][self.layer_id]
        if len(block_list) <= logical_block_id:
            for _ in range(logical_block_id - len(block_list) + 1):
                new_block = LogicalBlock()
                pid = self.kv_cache_manager.page_table.allocate_block(new_block)
                self.kv_cache_manager.page_table.eviction_policy.register(new_block)
                block_list.append(new_block)

        block = block_list[logical_block_id]

        # If it's not on GPU, swap it in
        if block.status == "cpu":
            self.kv_cache_manager.page_table.swap_to_gpu(block)

        pid = block.physical_block_id
        # Sanity check:
        if block.physical_block_id is None:
            raise RuntimeError(f"Token {token_idx}: block still has no physical_block_id after swap.")

        flat_idx = pid * self.page_size + offset

        # key_vec, value_vec: [H, S, D]
        # token_idx: scalar index into S

        assert key_vec.ndim == 3, f"Expected key_vec to be [H, S, D], got {key_vec.shape}"
        assert value_vec.ndim == 3, f"Expected value_vec to be [H, S, D], got {value_vec.shape}"
        assert 0 <= token_idx < key_vec.shape[1], f"token_idx {token_idx} out of bounds for seq_len={key_vec.shape[1]}"

        self.k_cache[0, :, flat_idx, :] = key_vec[:, token_idx, :]  # [H, D]
        self.v_cache[0, :, flat_idx, :] = value_vec[:, token_idx, :]


        block.token_count += 1
        if seq_id >= self.page_table.shape[0] or logical_block_id >= self.page_table.shape[1]:
            print(f"🚨 Resize needed: seq_id={seq_id}, logical_block_id={logical_block_id}, current shape={self.page_table.shape}")

        self.page_table[seq_id, logical_block_id] = pid

        self.max_flat_idx_written = max(self.max_flat_idx_written, flat_idx + 1)

        print(f"[assign_token] token_idx={token_idx} | flat_idx={flat_idx}")
        print(f"[assign_token] key_vec.shape: {key_vec.shape}")
        print(f"[assign_token] storing slice: {key_vec[:, token_idx, :].shape}")


    def build_blockmask(self, num_query_tokens, total_tokens_written):
        from torch.nn.attention.flex_attention import BlockMask

        assert total_tokens_written <= self.k_cache.shape[2], "KV length exceeds cache capacity"
        B, H = 1, self.k_cache.shape[1]
        device = self.k_cache.device

        logical_to_physical = self.page_table[0]  # single seq_id only
        valid = logical_to_physical != -1
        pids = logical_to_physical[valid]
        total_blocks = pids.numel()

        kv_indices = torch.full((B, H, 1, total_blocks), -1, dtype=torch.int32, device=device)
        kv_num_blocks = torch.full((B, H, 1), total_blocks, dtype=torch.int32, device=device)

        for h in range(H):
            for i in range(total_blocks):
                kv_indices[0, h, 0, i] = pids[i].item()
        print(">> max_flat_idx_written:", self.max_flat_idx_written)
        print(">> kv_num_blocks.shape:", kv_num_blocks.shape)
        print(">> kv_num_blocks:", kv_num_blocks)
        print(">> kv_indices.shape:", kv_indices.shape)
        print(">> kv_indices:", kv_indices)

        block_mask = BlockMask.from_kv_blocks(
            kv_num_blocks=kv_num_blocks,
            kv_indices=kv_indices,
            full_kv_num_blocks=None,
            full_kv_indices=None,
            BLOCK_SIZE=(num_query_tokens, self.page_size),
            mask_mod=None,
            seq_lengths=(num_query_tokens, total_tokens_written)  # <- exact KV length
        )

        # Crop block mask to match token counts
        return block_mask._adjust(num_query_tokens, total_tokens_written)

    def query(self, q, block_mask):
        from torch.nn.attention.flex_attention import flex_attention

        # 🔁 Ensure required blocks are on GPU
        for block_list in self.kv_cache_manager.sequence_table.values():
            for block in block_list[self.layer_id]:
                if block.status == "cpu":
                    self.kv_cache_manager.page_table.swap_to_gpu(block)
                    print(f"⏪ Swapping in block with pid={block.physical_block_id}")

        print("Calling flex_attention with:")
        print("  q.shape:", q.shape)
        print("  k.shape:", self.k_cache.shape)
        print("  v.shape:", self.v_cache.shape)


        return flex_attention(
            q, self.k_cache, self.v_cache,
            block_mask=block_mask,
            score_mod=None
        )



In [50]:
# === ATTENTION LAYER TO DROP-IN ===
class SwappablePagedAttentionLayerBatched(nn.Module):
    def __init__(self, hidden_dim, n_heads, head_dim, dtype, spa, max_pos=512):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.spa = spa
        self.block_size = spa.page_size

        self.wqkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False, device="cuda", dtype=dtype)
        self.wo = nn.Linear(hidden_dim, hidden_dim, bias=False, device="cuda", dtype=dtype)
        # Inside __init__ of SwappablePagedAttentionLayerBatched
        self.freqs_cis = precompute_freqs_cis(max_pos, head_dim, dtype=dtype).to("cuda")


    def forward(self, x, seq_ids, token_idxs):
        B, S, _ = x.shape
        x = x.to(dtype=self.wqkv.weight.dtype)
        qkv = self.wqkv(x)
        q, k, v = qkv.split(self.hidden_dim, dim=-1)
        q = q.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)

        q, k = q.clone(), k.clone()
        # for b, t in enumerate(token_idxs):
        #     freqs = self.freqs_cis[t]
        #     q[b:b+1] = apply_rotary_emb(q[b:b+1], freqs)
        #     k[b:b+1] = apply_rotary_emb(k[b:b+1], freqs)
        for b, t in enumerate(token_idxs):
            if t >= self.freqs_cis.shape[0]:
                raise ValueError(f"Token index {t} exceeds precomputed freqs_cis range")
            freqs = self.freqs_cis[t]
            q[b:b+1] = apply_rotary_emb(q[b:b+1], freqs)
            k[b:b+1] = apply_rotary_emb(k[b:b+1], freqs)


        for b, (sid, tidx) in enumerate(zip(seq_ids, token_idxs)):
            self.spa.assign_token(sid, tidx, k[b], v[b])

        block_mask = build_blockmask_batched(
            page_tables=self.spa.page_table,
            seq_ids=seq_ids,
            token_idxs=token_idxs,
            num_q=S,
            total_kv=self.spa.max_flat_idx_written,
            page_size=self.block_size,
            k_cache=self.spa.k_cache
        )

        print("\n🔍 DEBUG BLOCK_MASK BEFORE FLEX_ATTENTION")
        flat_kv_indices = block_mask.kv_indices.view(-1)
        invalid = flat_kv_indices[(flat_kv_indices < 0) | (flat_kv_indices >= self.spa.k_cache.shape[2])]
        print("🟥 Found invalid indices:", invalid)
        print("✅ Unique indices:", torch.unique(flat_kv_indices))
        print("✅ Max valid:", self.spa.k_cache.shape[2] - 1)

        out = self.spa.query(q, block_mask)
        return self.wo(out.transpose(1, 2).contiguous().view(B, S, -1))

In [51]:
def apply_rotary_emb(x, freqs_cis_t):
    B, H, S, D = x.shape
    assert D % 2 == 0, "Head dim must be even"

    x = x.float().reshape(B, H, S, D // 2, 2)
    freqs_cis_t = freqs_cis_t.unsqueeze(0).unsqueeze(0).unsqueeze(0)  # (1,1,1,D//2,2)

    re = x[..., 0] * freqs_cis_t[..., 0] - x[..., 1] * freqs_cis_t[..., 1]
    im = x[..., 1] * freqs_cis_t[..., 0] + x[..., 0] * freqs_cis_t[..., 1]

    return torch.stack((re, im), dim=-1).flatten(3).type_as(x)


In [52]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch import nn
from torch.nn.attention.flex_attention import BlockMask

# === CONFIG ===
B, H, D = 1, 12, 64  # GPT-2 small: 12 heads, 64 dim each
page_size = 1
num_blocks = 32
hidden_dim = H * D
dtype = torch.float16
device = "cuda"

# === FLAT KV CACHE ===
k_cache = torch.zeros((1, H, num_blocks * page_size, D), device=device, dtype=dtype)
v_cache = torch.zeros_like(k_cache)

# === PAGE TABLE + MANAGER ===
pt = PageTableWithSwap(num_blocks, page_size, H, D, device=device, k_cache=k_cache, v_cache=v_cache)
manager = KVCacheManager(page_table=pt)

# === SPA INSTANCE ===
spa = SwappablePagedAttention(
    kv_cache_manager=manager,
    k_cache=k_cache,
    v_cache=v_cache,
    page_table_tensor=torch.full((1, 1024), -1, dtype=torch.int32, device=device),  # GPT-2 max length
    page_size=page_size,
    layer_id=0
)

# === PATCHED LAYER ===
patched_layer = SwappablePagedAttentionLayerBatched(
    hidden_dim=hidden_dim,
    n_heads=H,
    head_dim=D,
    dtype=dtype,
    spa=spa
)

class GPT2CompatibleAttention(nn.Module):
    def __init__(self, swappable_layer, layer_id):
        super().__init__()
        self.attn_core = swappable_layer
        self.layer_id = layer_id
        self.token_idx = 0  # Tracks where we are in decoding

    def forward(self, hidden_states, layer_past=None, attention_mask=None, **kwargs):
        B, S, D = hidden_states.shape
        seq_ids = list(range(B))
        token_idxs = [self.token_idx] * B
        self.token_idx += S

        return (self.attn_core(hidden_states, seq_ids, token_idxs), None)


# === LOAD GPT-2 ===
model = GPT2LMHeadModel.from_pretrained("gpt2").eval().cuda()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model.config.use_cache = False  # disable past_key_values

# === PATCH FIRST LAYER ===
model.transformer.h[0].attn = GPT2CompatibleAttention(patched_layer, layer_id=0)

# === TEST INFERENCE ===
prompt = "The sky is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

with torch.no_grad():
    out = model(input_ids).logits
    print("✅ Output shape:", out.shape)


[assign_token] token_idx=0 | flat_idx=31
[assign_token] key_vec.shape: torch.Size([12, 3, 64])
[assign_token] storing slice: torch.Size([12, 64])

🔍 DEBUG BLOCK_MASK BEFORE FLEX_ATTENTION
🟥 Found invalid indices: tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,

In [56]:
batch_size = 2  # or whatever batch size you're using
num_blocks = 128
page_size = 1  # assuming 1 token per block
num_heads = 12
head_dim = 64
device = "cuda"

spa = SwappablePagedAttention(
    num_seqs=batch_size,         # ✅ supports `seq_id` in [0, batch_size - 1]
    num_blocks=num_blocks,
    page_size=page_size,
    num_heads=num_heads,
    head_dim=head_dim,
    device=device,
)

TypeError: SwappablePagedAttention.__init__() got an unexpected keyword argument 'num_seqs'

In [54]:
prompt = "Once upon a time"
max_new_tokens = 10
batch_size = 2

# Set pad_token to eos_token (common workaround for GPT2)
tokenizer.pad_token = tokenizer.eos_token

# Tokenize input prompt (same prompt duplicated for simplicity)
input_ids = tokenizer([prompt] * batch_size, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

# Create attention mask if needed
attention_mask = (input_ids != tokenizer.pad_token_id).long()

# Decode state
generated = input_ids.clone()


In [55]:
with torch.no_grad():
    model(input_ids=input_ids, attention_mask=attention_mask)


[assign_token] token_idx=3 | flat_idx=28
[assign_token] key_vec.shape: torch.Size([12, 4, 64])
[assign_token] storing slice: torch.Size([12, 64])
🚨 Resize needed: seq_id=1, logical_block_id=3, current shape=torch.Size([1, 1024])


IndexError: index 1 is out of bounds for dimension 0 with size 1