a decode loop

In [5]:
import torch

In [20]:
class LogicalBlock:
    def __init__(self, physical_block_id=None):
        self.physical_block_id = physical_block_id  # None if on CPU
        self.status = "gpu"  # or "cpu"
        self.token_count = 0

In [13]:
class EvictionPolicy:
    def register(self, block): pass
    def unregister(self, block): pass
    def notify_use(self, block): pass
    def evict(self): raise NotImplementedError

In [14]:
from collections import deque

class LRUEvictionPolicy(EvictionPolicy):
    def __init__(self):
        self.lru_queue = deque()

    def register(self, block):
        self.lru_queue.appendleft(block)

    def unregister(self, block):
        try:
            self.lru_queue.remove(block)
        except ValueError:
            pass

    def notify_use(self, block):
        try:
            self.lru_queue.remove(block)
            self.lru_queue.appendleft(block)
        except ValueError:
            pass  # block not tracked (e.g., on CPU)

    def evict(self):
        while self.lru_queue:
            victim = self.lru_queue.pop()
            if victim.status == "gpu":
                return victim
        raise RuntimeError("No GPU-resident blocks to evict.")

In [10]:
class PageTable:
    def __init__(self, num_blocks, block_size, num_heads, head_dim, device="cuda", eviction_policy=None):
        self.num_blocks = num_blocks
        self.block_size = block_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.device = device

        # Flat GPU memory pool
        self.gpu_k = torch.zeros((num_blocks, block_size, num_heads, head_dim), device=device)
        self.gpu_v = torch.zeros_like(self.gpu_k)

        # CPU-stored evicted blocks: keyed by LogicalBlock object
        self.cpu_k = {} # logical_block -> CPU tensor
        self.cpu_v = {}

        self.free_list = list(range(num_blocks))
        self.eviction_policy = eviction_policy or LRUEvictionPolicy()

    def allocate_block(self, logical_block):
        """Assign a physical block to a logical block."""
        if not self.free_list:
            self._evict_one_block()

        pid = self.free_list.pop()
        logical_block.physical_block_id = pid
        logical_block.status = "gpu"
        self.eviction_policy.register(logical_block)
        return pid

    def _evict_one_block(self):
        """Evict a block found on the GPU based on eviction policy."""
        victim = self.eviction_policy.evict()
        self.swap_to_cpu(victim)      

    def swap_to_cpu(self, logical_block):
        """Move a block from GPU to CPU and free its physical block ID."""
        pid = logical_block.physical_block_id
        if pid is None or logical_block.status != "gpu":
            raise RuntimeError("Block is not on GPU or already swapped.")

        self.cpu_k[logical_block] = self.gpu_k[pid].clone().cpu()
        # print(self.cpu_k[logical_block])
        self.cpu_v[logical_block] = self.gpu_v[pid].clone().cpu()

        logical_block.physical_block_id = None
        logical_block.status = "cpu"
        self.eviction_policy.unregister(logical_block)
        self.free_list.append(pid)

    def swap_to_gpu(self, logical_block):
        """Move a block from CPU to GPU, assigning a new physical block ID."""
        if logical_block not in self.cpu_k:
            raise RuntimeError("Block not found in CPU cache.")

        if not self.free_list:
            self._evict_one_block()

        pid = self.free_list.pop()
        self.gpu_k[pid] = self.cpu_k.pop(logical_block).to(self.device)
        self.gpu_v[pid] = self.cpu_v.pop(logical_block).to(self.device)

        logical_block.physical_block_id = pid
        logical_block.status = "gpu"
        self.eviction_policy.register(logical_block)

    def resolve_block(self, logical_block):
        """Return the (K, V) tensors, swapping in if needed."""
        if logical_block.status == "cpu":
            self.swap_to_gpu(logical_block)

        self.eviction_policy.notify_use(logical_block)
        pid = logical_block.physical_block_id
        return self.gpu_k[pid], self.gpu_v[pid]

    def free_block(self, logical_block):
        """Free both GPU and CPU copies of the block."""
        if logical_block.status == "gpu":
            pid = logical_block.physical_block_id
            self.gpu_k[pid].zero_()
            self.gpu_v[pid].zero_()
            self.free_list.append(pid)
            self.eviction_policy.unregister(logical_block)
        elif logical_block.status == "cpu":
            del self.cpu_k[logical_block]
            del self.cpu_v[logical_block]

        logical_block.status = "freed"
        logical_block.physical_block_id = None

In [16]:
from collections import defaultdict

class KVCacheManager:
    def __init__(self, page_table):
        self.page_table = page_table
        self.block_size = page_table.block_size
        self.sequence_table = defaultdict(lambda: defaultdict(list))  # seq_id → layer_id → [LogicalBlock]

    def _get_active_block(self, seq_id, layer_id):
        blocks = self.sequence_table[seq_id][layer_id]
        if blocks and blocks[-1].token_count < self.block_size:
            return blocks[-1]

        # Allocate a new logical block
        new_block = LogicalBlock()
        self.page_table.allocate_block(new_block)
        self.sequence_table[seq_id][layer_id].append(new_block)
        return new_block

    def write_token(self, seq_id, layer_id, key_vec, value_vec):
        block = self._get_active_block(seq_id, layer_id)
        if block.token_count >= self.block_size:
            raise RuntimeError("Attempted to write to full block.")

        # Ensure it's on GPU before writing
        k_buf, v_buf = self.page_table.resolve_block(block) # will swap in if needed

        idx = block.token_count
        k_buf[idx] = key_vec
        v_buf[idx] = value_vec
        block.token_count += 1

    def prefill(self, seq_id, layer_id, k_list, v_list):
        for k, v in zip(k_list, v_list):
            self.write_token(seq_id, layer_id, k, v)

    def yield_k_blocks(self, seq_id, layer_id):
        for block in self.sequence_table[seq_id][layer_id]:
            k_buf, _ = self.page_table.resolve_block(block)
            yield k_buf[:block.token_count]

    def yield_v_blocks(self, seq_id, layer_id):
        for block in self.sequence_table[seq_id][layer_id]:
            _, v_buf = self.page_table.resolve_block(block)
            yield v_buf[:block.token_count]

    def yield_kv_blocks(self, seq_id, layer_id):
        for block in self.sequence_table[seq_id][layer_id]:
            k_buf, v_buf = self.page_table.resolve_block(block)
            yield k_buf[:block.token_count], v_buf[:block.token_count]

    def free(self, seq_id):
        for layer_blocks in self.sequence_table[seq_id].values():
            for block in layer_blocks:
                self.page_table.free_block(block)
        del self.sequence_table[seq_id]

In [11]:
class PageTableWithSwap(PageTable):
    def __init__(self, num_blocks, block_size, num_heads, head_dim, device="cuda", k_cache=None, v_cache=None):
        super().__init__(num_blocks, block_size, num_heads, head_dim, device=device)
        self.k_cache = k_cache  # flat buffer: (1, H, max_tokens, D)
        self.v_cache = v_cache
        self.cpu_k = {}  # LogicalBlock → CPU tensor
        self.cpu_v = {}

    def swap_to_cpu(self, block):
        pid = block.physical_block_id
        if pid is None or block.status != "gpu":
            raise RuntimeError("Block is not on GPU")

        bs = self.block_size
        self.cpu_k[block] = self.k_cache[0, :, pid * bs : (pid + 1) * bs, :].clone().cpu()
        self.cpu_v[block] = self.v_cache[0, :, pid * bs : (pid + 1) * bs, :].clone().cpu()

        block.physical_block_id = None
        block.status = "cpu"
        self.eviction_policy.unregister(block)
        self.free_list.append(pid)
        print("sent")

    def swap_to_gpu(self, block):
        if block.status != "cpu":
            return

        if not self.free_list:
            victim = self.eviction_policy.evict()
            self.swap_to_cpu(victim)

        pid = self.free_list.pop()
        bs = self.block_size

        self.k_cache[0, :, pid * bs : (pid + 1) * bs, :] = self.cpu_k.pop(block).to(self.k_cache.device)
        self.v_cache[0, :, pid * bs : (pid + 1) * bs, :] = self.cpu_v.pop(block).to(self.v_cache.device)

        block.physical_block_id = pid
        block.status = "gpu"
        self.eviction_policy.register(block)
        print(f"retrieved to pid={pid}")


In [1]:
import torch.nn as nn
# === ATTENTION LAYER TO DROP-IN EDITED FOR DECODE===
class SwappablePagedAttentionLayerBatched(nn.Module):
    def __init__(self, hidden_dim, n_heads, head_dim, dtype, spa, max_pos=512):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.spa = spa
        self.block_size = spa.page_size

        self.wqkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False, device="cuda", dtype=dtype)
        self.wo = nn.Linear(hidden_dim, hidden_dim, bias=False, device="cuda", dtype=dtype)
        # Inside __init__ of SwappablePagedAttentionLayerBatched
        self.freqs_cis = precompute_freqs_cis(max_pos, head_dim, dtype=dtype).to("cuda")


    def forward(self, x, seq_ids, token_idxs):
        B, S, _ = x.shape
        is_decode = (x.shape[1] == 1)  # True if S == 1
        x = x.to(dtype=self.wqkv.weight.dtype)
        qkv = self.wqkv(x)
        # q, k, v = qkv.split(self.hidden_dim, dim=-1)

        # q = q.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        # k = k.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        # v = v.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        print("qkv.shape before reshape:", qkv.shape)  # [B, S, 3 * hidden_dim]
        print("⛏ RAW qkv.shape:", qkv.shape)
        qkv = qkv.view(B, S, 3, self.n_heads, self.head_dim)
        print("🔧 After view →", qkv.shape)
        qkv = qkv.permute(0, 2, 3, 1, 4).contiguous()
        print("🔄 After permute →", qkv.shape)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
        print("✅ Final k.shape:", k.shape)

        print("q.shape after reshape:", q.shape)  # [B, H, S, D]
        print("v.shape after reshape:", v.shape)  # [B, H, S, D]
        print("k.shape after reshape:", k.shape)  # [B, H, S, D]

        q, k = q.clone(), k.clone()
        # for b, t in enumerate(token_idxs):
        #     freqs = self.freqs_cis[t]
        #     q[b:b+1] = apply_rotary_emb(q[b:b+1], freqs)
        #     k[b:b+1] = apply_rotary_emb(k[b:b+1], freqs)
        # for b, t in enumerate(token_idxs):
        #     if t >= self.freqs_cis.shape[0]:
        #         raise ValueError(f"Token index {t} exceeds precomputed freqs_cis range")
        #     freqs = self.freqs_cis[t]
        #     q[b:b+1] = apply_rotary_emb(q[b:b+1], freqs)
        #     k[b:b+1] = apply_rotary_emb(k[b:b+1], freqs)
        if S == 1:
            # Decode step: rotate with single frequency (1 per token)
            for b, t in enumerate(token_idxs):
                freqs = self.freqs_cis[t]
                # q[b] = apply_rotary_emb(q[b], freqs)  # q[b]: [H, 1, D]
                # k[b] = apply_rotary_emb(k[b], freqs)  # k[b]: [H, 1, D]
                q[b] = apply_rotary_emb(q[b].unsqueeze(0), freqs)[0]  # ✅ [1, H, S, D] → [H, S, D]
                k[b] = apply_rotary_emb(k[b].unsqueeze(0), freqs)[0]

                print("k[b].shape after applying rotary:", k[b].shape)  # [B, H, S, D]

        else:
            # Prefill step: rotate with freqs for [0:S]
            freqs = self.freqs_cis[:S].to(x.device)
            q = apply_rotary_emb(q, freqs)
            k = apply_rotary_emb(k, freqs)
            print("k shape after else in S block: ", k.shape)
        
        for b in range(B):
            assert k[b].ndim == 3 and k[b].shape[2] == self.head_dim, f"BAD k[b] shape: {k[b].shape}"
            assert q[b].ndim == 3 and q[b].shape[2] == self.head_dim, f"BAD q[b] shape: {q[b].shape}"


        for b, (sid, tidx) in enumerate(zip(seq_ids, token_idxs)):
            print("assign_token → k[b].shape:", k[b].shape)  # Should be [H, S, D]

            self.spa.assign_token(sid, tidx, k[b], v[b])

        block_mask = build_blockmask_batched(
            page_tables=self.spa.page_table,
            seq_ids=seq_ids,
            token_idxs=token_idxs,
            num_q=S,
            total_kv=self.spa.max_flat_idx_written,
            page_size=self.block_size,
            k_cache=self.spa.k_cache
        )

        print("\n🔍 DEBUG BLOCK_MASK BEFORE FLEX_ATTENTION")
        flat_kv_indices = block_mask.kv_indices.view(-1)
        invalid = flat_kv_indices[(flat_kv_indices < 0) | (flat_kv_indices >= self.spa.k_cache.shape[2])]
        print("🟥 Found invalid indices:", invalid)
        print("✅ Unique indices:", torch.unique(flat_kv_indices))
        print("✅ Max valid:", self.spa.k_cache.shape[2] - 1)

        out = self.spa.query(q, block_mask)
        return self.wo(out.transpose(1, 2).contiguous().view(B, S, -1))

In [2]:
# === PATCHED CAUSAL BLOCKMASK BUILDER ===
def build_blockmask_batched(page_tables, seq_ids, token_idxs, num_q, total_kv, page_size, k_cache):
    B, H = len(seq_ids), k_cache.shape[1]
    device = k_cache.device

    max_blocks = k_cache.shape[2] // page_size
    kv_num_blocks = torch.zeros((B, H, 1), dtype=torch.int32, device=device)
    kv_indices = torch.full((B, H, 1, max_blocks), -1, dtype=torch.int32, device=device)

    for b, sid in enumerate(seq_ids):
        raw_pids = page_tables[sid]
        max_token_idx = token_idxs[b]
        max_block_idx = max_token_idx // page_size

        causal_pids = raw_pids[:max_block_idx + 1]
        valid_pids = causal_pids[(causal_pids >= 0) & (causal_pids < max_blocks)]
        num_valid = valid_pids.numel()

        kv_num_blocks[b] = torch.full((H, 1), num_valid, dtype=torch.int32, device=device)
        for h in range(H):
            kv_indices[b, h, 0, :num_valid] = valid_pids[:max_blocks]
            kv_indices[b, h, 0, num_valid:] = -1

    assert (kv_indices < max_blocks).all() | (kv_indices == -1).all(), "🟥 kv_indices out of range"

    return BlockMask.from_kv_blocks(
        kv_num_blocks=kv_num_blocks,
        kv_indices=kv_indices,
        full_kv_num_blocks=None,
        full_kv_indices=None,
        BLOCK_SIZE=(num_q, page_size),
        mask_mod=None,
        seq_lengths=(num_q, total_kv),
    )


In [3]:
## JUST FOR THE DECODE LOOP. DELETE FOR ABOVE TEST CASE RUN. JUST CHANGED ASSIGN TOKEN

from torch.nn.attention.flex_attention import BlockMask
class SwappablePagedAttention:
    def __init__(self, kv_cache_manager, k_cache, v_cache, page_table_tensor, page_size, layer_id):
        self.kv_cache_manager = kv_cache_manager
        self.k_cache = k_cache
        self.v_cache = v_cache
        self.page_table = page_table_tensor
        self.page_size = page_size
        self.layer_id = layer_id
        self.max_flat_idx_written = 0

    def assign_token(self, seq_id, token_idx, key_vec, value_vec):
        logical_block_id = token_idx // self.page_size
        offset = token_idx % self.page_size

        # ⬇️ Expand page_table if the logical block id exceeds current columns
        if logical_block_id >= self.page_table.shape[1]:
            pad = logical_block_id - self.page_table.shape[1] + 1
            new_table = torch.full(
                (self.page_table.shape[0], self.page_table.shape[1] + pad),
                fill_value=-1,
                dtype=self.page_table.dtype,
                device=self.page_table.device
            )
            new_table[:, :self.page_table.shape[1]] = self.page_table
            self.page_table = new_table

        # Get or grow the block list for this sequence/layer
        block_list = self.kv_cache_manager.sequence_table[seq_id][self.layer_id]
        if len(block_list) <= logical_block_id:
            for _ in range(logical_block_id - len(block_list) + 1):
                new_block = LogicalBlock()
                pid = self.kv_cache_manager.page_table.allocate_block(new_block)
                self.kv_cache_manager.page_table.eviction_policy.register(new_block)
                block_list.append(new_block)

        block = block_list[logical_block_id]

        # If it's not on GPU, swap it in
        if block.status == "cpu":
            self.kv_cache_manager.page_table.swap_to_gpu(block)

        pid = block.physical_block_id
        # Sanity check:
        if block.physical_block_id is None:
            raise RuntimeError(f"Token {token_idx}: block still has no physical_block_id after swap.")

        flat_idx = pid * self.page_size + offset

        # key_vec, value_vec: [H, S, D]
        # token_idx: scalar index into S

        # assert key_vec.ndim == 3, f"Expected key_vec to be [H, S, D], got {key_vec.shape}"
        # assert value_vec.ndim == 3, f"Expected value_vec to be [H, S, D], got {value_vec.shape}"
        # assert 0 <= token_idx < key_vec.shape[1], f"token_idx {token_idx} out of bounds for seq_len={key_vec.shape[1]}"

        # self.k_cache[0, :, flat_idx, :] = key_vec[:, token_idx, :]  # [H, D]
        # self.v_cache[0, :, flat_idx, :] = value_vec[:, token_idx, :]
        print("[DEBUG] key_vec.shape:", key_vec.shape)
        print("[DEBUG] token_idx:", token_idx)
        # print("[DEBUG] key_vec.shape:", key_vec.shape)
        # print("[DEBUG] k_cache slice shape:", self.k_cache[0, :, flat_idx, :].shape)
        # Note: key_vec is [H, S, D], where usually S = 1 (decode) or up to T (prefill)

        if key_vec.shape[1] == 1:
            print("[DEBUG] decode mode → key_vec[:, 0, :].shape:", key_vec[:, 0, :].shape)
            self.k_cache[0, :, flat_idx, :] = key_vec[:, 0, :]  # decode
            self.v_cache[0, :, flat_idx, :] = value_vec[:, 0, :]
        else:
            # token_idx is the global token index, but we only have local slice (0..S-1) — fix this:
            # local_idx = token_idx % key_vec.shape[1]
            local_idx = token_idx % key_vec.shape[1]
            print("[DEBUG] key_vec[:, local_idx, :].shape:", key_vec[:, local_idx, :].shape)
            self.k_cache[0, :, flat_idx, :] = key_vec[:, local_idx, :]
            self.v_cache[0, :, flat_idx, :] = value_vec[:, local_idx, :]
            


        print("hi")




        block.token_count += 1
        if seq_id >= self.page_table.shape[0] or logical_block_id >= self.page_table.shape[1]:
            print(f"🚨 Resize needed: seq_id={seq_id}, logical_block_id={logical_block_id}, current shape={self.page_table.shape}")

        self.page_table[seq_id, logical_block_id] = pid

        self.max_flat_idx_written = max(self.max_flat_idx_written, flat_idx + 1)

        print(f"[assign_token] token_idx={token_idx} | flat_idx={flat_idx}")
        if key_vec.shape[1] == 1:
            print(f"[assign_token] storing slice: {key_vec[:, 0, :].shape}")
        else:
            local_idx = token_idx % key_vec.shape[1]
            print(f"[assign_token] storing slice: {key_vec[:, local_idx, :].shape}")

        # print(f"[assign_token] storing slice: {key_vec[:, token_idx, :].shape}")


    def build_blockmask(self, num_query_tokens, total_tokens_written):
        from torch.nn.attention.flex_attention import BlockMask

        assert total_tokens_written <= self.k_cache.shape[2], "KV length exceeds cache capacity"
        B, H = 1, self.k_cache.shape[1]
        device = self.k_cache.device

        logical_to_physical = self.page_table[0]  # single seq_id only
        valid = logical_to_physical != -1
        pids = logical_to_physical[valid]
        total_blocks = pids.numel()

        kv_indices = torch.full((B, H, 1, total_blocks), -1, dtype=torch.int32, device=device)
        kv_num_blocks = torch.full((B, H, 1), total_blocks, dtype=torch.int32, device=device)

        for h in range(H):
            for i in range(total_blocks):
                kv_indices[0, h, 0, i] = pids[i].item()
        print(">> max_flat_idx_written:", self.max_flat_idx_written)
        print(">> kv_num_blocks.shape:", kv_num_blocks.shape)
        print(">> kv_num_blocks:", kv_num_blocks)
        print(">> kv_indices.shape:", kv_indices.shape)
        print(">> kv_indices:", kv_indices)

        block_mask = BlockMask.from_kv_blocks(
            kv_num_blocks=kv_num_blocks,
            kv_indices=kv_indices,
            full_kv_num_blocks=None,
            full_kv_indices=None,
            BLOCK_SIZE=(num_query_tokens, self.page_size),
            mask_mod=None,
            seq_lengths=(num_query_tokens, total_tokens_written)  # <- exact KV length
        )

        # Crop block mask to match token counts
        return block_mask._adjust(num_query_tokens, total_tokens_written)

    def query(self, q, block_mask):
        from torch.nn.attention.flex_attention import flex_attention

        # 🔁 Ensure required blocks are on GPU
        for block_list in self.kv_cache_manager.sequence_table.values():
            for block in block_list[self.layer_id]:
                if block.status == "cpu":
                    self.kv_cache_manager.page_table.swap_to_gpu(block)
                    print(f"⏪ Swapping in block with pid={block.physical_block_id}")

        print("Calling flex_attention with:")
        print("  q.shape:", q.shape)
        print("  k.shape:", self.k_cache.shape)
        print("  v.shape:", self.v_cache.shape)


        return flex_attention(
            q, self.k_cache, self.v_cache,
            block_mask=block_mask,
            score_mod=None
        )



In [6]:
def precompute_freqs_cis(seq_len, n_elem, base=10000, dtype=torch.float16):
    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1).to(dtype=dtype, device="cuda")

In [18]:
class GPT2CompatibleAttention(nn.Module):
    def __init__(self, swappable_layer, layer_id):
        super().__init__()
        self.attn_core = swappable_layer
        self.layer_id = layer_id
        self.token_idx = 0  # Tracks where we are in decoding

    def forward(self, hidden_states, layer_past=None, attention_mask=None, **kwargs):
        B, S, D = hidden_states.shape
        seq_ids = list(range(B))
        token_idxs = [self.token_idx] * B
        self.token_idx += S

        return (self.attn_core(hidden_states, seq_ids, token_idxs), None)

In [7]:
def apply_rotary_emb(x, freqs_cis_t):
    B, H, S, D = x.shape
    assert D % 2 == 0, "Head dim must be even"

    orig_dtype = x.dtype  # Save dtype to restore later
    x = x.float().reshape(B, H, S, D // 2, 2)
    freqs_cis_t = freqs_cis_t.unsqueeze(0).unsqueeze(0).unsqueeze(0)  # (1,1,1,D//2,2)

    re = x[..., 0] * freqs_cis_t[..., 0] - x[..., 1] * freqs_cis_t[..., 1]
    im = x[..., 1] * freqs_cis_t[..., 0] + x[..., 0] * freqs_cis_t[..., 1]

    out = torch.stack((re, im), dim=-1).reshape(B, H, S, D)  # ✅ use reshape not flatten
    return out.to(dtype=orig_dtype)  # restore original dtype safely


In [21]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# --- Setup ---
device = "cuda"
batch_size = 2
max_new_tokens = 10
D = 64
H = 12
hidden_dim = H * D
num_blocks = 32
page_size = 1
dtype = torch.float16

k_cache = torch.zeros((1, H, num_blocks * page_size, D), device=device, dtype=torch.float16)
v_cache = torch.zeros_like(k_cache)
print("[CHECK] KV cache head_dim:", k_cache.shape[-1])  # Should be 64

# Load model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # Needed for padding
model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(device)
model.config.use_cache = False  # Important for our SPA patch

pt = PageTableWithSwap(num_blocks, page_size, H, D, device=device, k_cache=k_cache, v_cache=v_cache)
manager = KVCacheManager(page_table=pt)

# === SPA INSTANCE ===
spa = SwappablePagedAttention(
    kv_cache_manager=manager,
    k_cache=k_cache,
    v_cache=v_cache,
    page_table_tensor=torch.full((batch_size, 1024), -1, dtype=torch.int32, device=device),  # GPT-2 max length
    page_size=page_size,
    layer_id=0
)

# === PATCHED LAYER ===
patched_layer = SwappablePagedAttentionLayerBatched(
    hidden_dim=hidden_dim,
    n_heads=H,
    head_dim=D,
    dtype=dtype,
    spa=spa
)

# === PATCH FIRST LAYER WITH CUSTOM ATTENTION ===
model.transformer.h[0].attn = GPT2CompatibleAttention(patched_layer, layer_id=0)


# --- Input Prompts ---
prompts = [
    "The sky is",
    "Once upon a time"
]
input_ids = tokenizer(prompts, return_tensors="pt", padding=True).input_ids.to(device)
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)

# Track lengths of each prompt (to know where decode starts)
prompt_lengths = attention_mask.sum(dim=1)

# --- Prefill (token_idx = 0 → prompt_lengths-1) ---
with torch.no_grad():
    logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
    generated = [ids.clone() for ids in input_ids]  # List of token sequences

# --- Decode Loop ---
for step in range(max_new_tokens):
    # last_tokens = torch.stack([g[prompt_lengths[i] + step - 1] for i, g in enumerate(generated)])
    # last_tokens = last_tokens.unsqueeze(1).to(device)  # (B, 1)
    # Only pass the last generated token for each sequence
    last_tokens = torch.stack([g[-1] for g in generated])  # (B,)
    last_tokens = last_tokens.unsqueeze(1)                # (B, 1)

    # Now run the model
    with torch.no_grad():
        out = model(input_ids=last_tokens)                # Correct shape: (B, 1)
        with torch.no_grad():
            out = model(input_ids=last_tokens)  # Forward 1 token per sequence

        next_logits = out.logits[:, -1, :]  # (B, vocab_size)
        next_tokens = torch.argmax(next_logits, dim=-1)  # Greedy decode (B,)

        # Append to each sequence
        for i in range(batch_size):
            generated[i] = torch.cat([generated[i], next_tokens[i:i+1]], dim=0)

# --- Decode output ---
decoded = [tokenizer.decode(g, skip_special_tokens=True) for g in generated]
print("\n=== Outputs ===")
for i, text in enumerate(decoded):
    print(f"[{i}] {text}")


[CHECK] KV cache head_dim: 64
qkv.shape before reshape: torch.Size([2, 4, 2304])
⛏ RAW qkv.shape: torch.Size([2, 4, 2304])
🔧 After view → torch.Size([2, 4, 3, 12, 64])
🔄 After permute → torch.Size([2, 3, 12, 4, 64])
✅ Final k.shape: torch.Size([2, 12, 4, 64])
q.shape after reshape: torch.Size([2, 12, 4, 64])
v.shape after reshape: torch.Size([2, 12, 4, 64])
k.shape after reshape: torch.Size([2, 12, 4, 64])
k shape after else in S block:  torch.Size([2, 12, 4, 64])
assign_token → k[b].shape: torch.Size([12, 4, 64])
[DEBUG] key_vec.shape: torch.Size([12, 4, 64])
[DEBUG] token_idx: 0
[DEBUG] key_vec[:, local_idx, :].shape: torch.Size([12, 64])
hi
[assign_token] token_idx=0 | flat_idx=31
[assign_token] storing slice: torch.Size([12, 64])
assign_token → k[b].shape: torch.Size([12, 4, 64])
[DEBUG] key_vec.shape: torch.Size([12, 4, 64])
[DEBUG] token_idx: 0
[DEBUG] key_vec[:, local_idx, :].shape: torch.Size([12, 64])
hi
[assign_token] token_idx=0 | flat_idx=30
[assign_token] storing slice: to