In [1]:
# 1968

import asyncio
import os
import time
import random
from typing import Dict, List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache, StaticCache

from tqdm import tqdm

from shared import (
    MessageChannel,
    PrefillRequest,
    PrefillResponse,
    PrefillBatchRequest,
    PrefillBatchResponse,
    ResetRequest,
    VerifyRequest,
    VerifyResponse,
    VerifyBatchRequest,
    VerifyBatchResponse,
    VerifyResponseItem,
)

import torch.nn as nn
import numpy as np

from dotenv import load_dotenv
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
# BASE_MODEL = os.environ.get("HF_BASE_MODEL", "meta-llama/Llama-3.2-3B-Instruct")
# BASE_MODEL = os.environ.get("HF_BASE_MODEL", "meta-llama/Llama-3.2-1B-Instruct")
BASE_MODEL = os.environ.get("HF_BASE_MODEL", "meta-llama/Llama-3.1-8B")
TOP_K = int(os.environ.get("HF_TOP_K", "20"))
ATTN_IMPL_ENV = os.environ.get("HF_ATTN_IMPL", "").strip()  # e.g., "flash_attention_2" if available

In [3]:
DEVICE = torch.device('mps')
DTYPE = torch.float16

In [4]:
hf_token = os.environ.get("HF_TOKEN", None)
from_kwargs = {
    "dtype": DTYPE,
    "device_map": None,           # keep single process; move to one device below
    "low_cpu_mem_usage": True,
    "token": hf_token,
}
if ATTN_IMPL_ENV:
    from_kwargs["attn_implementation"] = ATTN_IMPL_ENV

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    **from_kwargs,
).to(DEVICE) # type: ignore

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True, token=hf_token)

PAD_ID = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else (
    tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
)

'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: c274c583-c99a-4e1a-b8d8-dcafa04377c1)')' thrown while requesting HEAD https://huggingface.co/meta-llama/Llama-3.1-8B/resolve/main/config.json
Retrying in 1s [Retry 1/5].
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.94it/s]


## Prefill Stage

In [5]:
from typing import Any


prompts_str = [
    "Explanation of speculative decoding in simple terms",
    "This is a terse haiku about Apple MLX",
    "Why is the sky blue",
]

prompts: list[list[int]] = [tokenizer.encode(prompt) for prompt in prompts_str]

In [14]:
# def mask_and_pos_ids(L: list[int]): # TODO: Should this be a list or a tensor?
#     max_len = max(L)

#     attention_mask = torch.zeros((len(L), max_len), dtype=torch.long, device=DEVICE)
#     for i, l in enumerate(L):
#         attention_mask[i, max_len - l:] = 1

#     # position_ids / cache_position: 0..L_i-1 for non-pad tokens, 0 for pads
#     # This works for both absolute and RoPE-style position handling.
#     position_ids = (attention_mask.cumsum(dim=-1) - 1).clamp_min(0)
#     position_ids = position_ids.masked_fill(attention_mask == 0, 0)

#     return attention_mask, position_ids 

def mask_and_pos_ids(L: list[int]): # TODO: Should this be a list or a tensor?
    max_len = max(L)

    attention_mask = torch.zeros((len(L), max_len), dtype=torch.long, device=DEVICE)
    for i, l in enumerate(L):
        attention_mask[i, max_len - l:] = 1

    # position_ids / cache_position: 0..L_i-1 for non-pad tokens, 0 for pads
    # This works for both absolute and RoPE-style position handling.
    position_ids = (attention_mask.cumsum(dim=-1) - 1).clamp_min(0)
    position_ids = position_ids.masked_fill(attention_mask == 0, 0)

    return attention_mask, position_ids 

@torch.inference_mode()
def prefill(model: nn.Module, prompts: list[list[int]]):
    max_len = max(len(x) for x in prompts)
    padded = [[PAD_ID] * (max_len - len(prompt)) + prompt for prompt in prompts]  # is pad id 0 correct?
    x = torch.tensor(padded, dtype=torch.long, device=DEVICE)

    cache = DynamicCache(
        config=model.config, 
    )

    attention_mask, position_ids = mask_and_pos_ids([len(x) for x in prompts])

    outputs = model(
        input_ids=x, 
        # attention_mask=attention_mask,
        past_key_values=cache, 
        use_cache=True,
        # position_ids=position_ids,
    )

    return outputs.past_key_values

cache = prefill(model, prompts)

In [None]:
# def zero_cache(cache: DynamicCache, lengths: list[int]):
#     assert cache.layers[0].keys is not None and cache.layers[0].values is not None
#     B = cache.layers[0].keys.shape[0]

#     # Prepare destination cache
#     dst = DynamicCache()

#     for layer in range(len(cache)):
#         K = cache.layers[layer].keys
#         V = cache.layers[layer].values
#         assert K is not None and V is not None

#         _, H, S_old, D = K.shape

#         K_new = K.new_zeros((B, H, S_old, D))
#         V_new = V.new_zeros((B, H, S_old, D))

#         # Copy per row
#         for i in range(B):
#             keep = int(S_keep[i].item())
#             if keep == 0:
#                 continue
#             # surviving tokens are the first 'keep' positions (earliest..latest-rollback)
#             K_src = K[i, :, :keep, :]
#             V_src = V[i, :, :keep, :]

#             # right-aligned → write to the right, pad on the left implicitly
#             start = S_new - keep
#             K_new[i, :, start:, :] = K_src
#             V_new[i, :, start:, :] = V_src
#             # print(K_new[i, 0, :, 0])

#         # print(dst.layers[layer].keys[i, 0, :, 0])
#         dst.update(K_new, V_new, layer)
#         print(dst.layers[layer].keys[0, 0, :, 0])

#     return dst

# rollback_values = torch.tensor([3, 0, 1], dtype=torch.long, device=model.device)
# # print(cache.layers[0].keys[:, 0, :, 0])
# new_cache = rollback_dynamic_per_row_simple(cache, rollback_values)
# print(new_cache.layers[0].keys[:, 0, :, 0])

In [15]:
print(cache.layers[0].keys[:, 0, :, 0])

tensor([[ 5.5508,  0.7910, -4.6953, -0.3940, -2.1562,  3.0645,  7.5547,  3.3496,
         -2.5117, -6.6289, -4.2461],
        [ 0.4468,  1.0215, -3.5273, -5.5859, -2.6172,  3.6797,  9.8750,  3.3047,
         -4.4180, -8.2578, -4.9453],
        [ 5.5508,  0.7910, -4.6953, -5.8672, -1.6426, -0.2014,  7.8828,  2.9688,
         -1.5879, -8.2578, -5.5938]], device='mps:0', dtype=torch.float16)


In [16]:
for i, layer in enumerate(cache.layers):
    print(f"Layer {i}:")
    print(f"  Keys shape: {layer.keys.shape}")
    print(f"  Values shape: {layer.values.shape}")

Layer 0:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 1:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 2:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 3:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 4:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 5:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 6:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 7:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 8:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 9:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 10:
  Keys shape: torch.Size([3, 8, 11, 128]

## Pure Decode (Just A Sanity Check)

In [17]:
# @torch.inference_mode()
# def generate_step(model: nn.Module, cache: StaticCache, tokens: list[int], lengths: torch.LongTensor):
#     x = torch.tensor(tokens, dtype=torch.long, device=DEVICE).view(-1, 1)

#     B = lengths.size(0)
#     S_prev = cache.layers[0].keys.shape[2]   # current cache time dim (same for all rows)
#     S_tot  = S_prev + 1                      # we add one token

#     # Mask: keep only (real prompt + generated so far) + the new token, right-aligned
#     attn_mask = torch.zeros((B, S_tot), dtype=torch.long, device=DEVICE)
#     for i, l in enumerate(lengths.tolist()):
#         attn_mask[i, S_prev - l - 1:] = 1   # existing real tokens
#         attn_mask[i, -1] = 1            # the new token itself

#     # print(attn_mask)

#     # Rope positions (per row) for the NEW token are exactly L
#     pos_ids = lengths.view(B, 1)

#     # Where to write in the cache (index along the time dimension) for this step
#     # cache_pos = torch.full((B, 1), S_prev, dtype=torch.long, device=DEVICE)
    
#     outputs = model(
#         x, 
#         use_cache=True, 
#         past_key_values=cache,
#         attention_mask=attn_mask,
#         position_ids=pos_ids,      # keep rope continuity
#         # cache_position=cache_pos,  # write K/V at the next slot
#     )

#     tokens = outputs.logits.argmax(axis=-1).tolist()
#     lengths = lengths + 1
#     return tokens, lengths


@torch.inference_mode()
def generate_step(
    model: nn.Module,
    cache: DynamicCache,
    tokens: list[list[int]],
    lengths: torch.LongTensor,
):
    x = torch.tensor(tokens, dtype=torch.long, device=DEVICE).view(-1, 1)

    B = lengths.size(0)
    S_prev = cache.layers[0].keys.shape[2]

    # Use int64 (not bool) on MPS; build past + current token mask
    attn_mask = torch.zeros((B, S_prev + 1), dtype=torch.long, device=DEVICE)

    starts = (S_prev - lengths).clamp_min(0)                # (B,)
    idx = torch.arange(S_prev, device=DEVICE)[None, :]      # (1, S_prev)
    attn_mask[:, :-1] = (idx >= starts[:, None]).to(torch.long)
    attn_mask[:, -1] = 1

    print(attn_mask.shape)
    print(attn_mask)

    pos_ids = lengths.view(B, 1)  # new token's RoPE position (0..L_i)

    out = model(
        input_ids=x,
        past_key_values=cache,
        use_cache=True,
        attention_mask=attn_mask,  # int mask
        position_ids=pos_ids,
    )

    next_tok = out.logits[:, -1].argmax(dim=-1).tolist()
    lengths = lengths + 1
    return [[t] for t in next_tok], lengths




suffix_text = ':'
tokens = [[x] for x in tokenizer.encode(suffix_text)[1:]] * 3
full_tokens = prompts

lengths = torch.tensor([len(x) for x in prompts], dtype=torch.long, device=model.device)

for _ in tqdm(range(20)):
    full_tokens: list[list[int]] = [x + y for x, y in zip(full_tokens, tokens)]

    tokens, lengths = generate_step(model, cache, tokens, lengths)

for i in range(3):
    print(tokenizer.decode(full_tokens[i]))

  0%|          | 0/20 [00:00<?, ?it/s]

torch.Size([3, 12])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]], device='mps:0')


 10%|█         | 2/20 [00:00<00:04,  4.22it/s]

torch.Size([3, 13])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], device='mps:0')
torch.Size([3, 14])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='mps:0')


 20%|██        | 4/20 [00:00<00:02,  5.81it/s]

torch.Size([3, 15])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='mps:0')
torch.Size([3, 16])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='mps:0')


 30%|███       | 6/20 [00:01<00:02,  6.56it/s]

torch.Size([3, 17])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='mps:0')
torch.Size([3, 18])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='mps:0')


 40%|████      | 8/20 [00:01<00:01,  6.85it/s]

torch.Size([3, 19])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='mps:0')
torch.Size([3, 20])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='mps:0')


 50%|█████     | 10/20 [00:01<00:01,  6.95it/s]

torch.Size([3, 21])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='mps:0')
torch.Size([3, 22])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='mps:0')


 60%|██████    | 12/20 [00:01<00:01,  6.98it/s]

torch.Size([3, 23])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='mps:0')
torch.Size([3, 24])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='mps:0')


 70%|███████   | 14/20 [00:02<00:00,  6.95it/s]

torch.Size([3, 25])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1]], device='mps:0')
torch.Size([3, 26])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1]], device='mps:0')


 80%|████████  | 16/20 [00:02<00:00,  6.87it/s]

torch.Size([3, 27])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1]], device='mps:0')
torch.Size([3, 28])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]], device='mps:0')


 90%|█████████ | 18/20 [00:02<00:00,  6.84it/s]

torch.Size([3, 29])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1]], device='mps:0')
torch.Size([3, 30])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1]], device='mps:0')


100%|██████████| 20/20 [00:03<00:00,  6.46it/s]

torch.Size([3, 31])
tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1]], device='mps:0')
<|begin_of_text|>Explanation of speculative decoding in simple terms: Speculative decoding is a technique used in computer architecture to optimize the execution of instructions. It involves
<|begin_of_text|>This is a terse haiku about Apple MLX: a machine learning framework for iOS and macOS. It’s a powerful tool for developers to create machine
<|begin_of_text|>Why is the sky blue: The sky blue?
The sky appears blue due to a phenomenon called Rayleigh scattering. This occurs





In [10]:
print(cache.layers[0].keys[:, 0, :, 0])
cache.crop(-2)
print(cache.layers[0].keys[:, 0, :, 0])

tensor([[ 5.5508,  5.5508,  5.5508,  0.4468,  1.1914, -3.6875, -7.4219, -2.4609,
          3.2773,  6.6562,  3.4648, -0.9922, -3.0820, -2.2852,  0.6128,  2.9492,
          2.5723, -0.1682, -2.7539, -2.8086, -0.2803,  2.5039,  2.9883,  0.7236,
         -2.2070, -3.1074, -1.1514,  1.8633,  3.1660,  1.5566, -1.4824],
        [ 0.4468,  1.0215, -3.5273, -5.5859, -2.6172,  3.6797,  9.8750,  3.3047,
         -4.4180, -8.2578, -4.9453,  0.4590,  5.4766,  5.6875, -1.4619, -6.6797,
         -4.9961,  0.1621,  4.3750,  7.0586,  0.8062, -5.2578, -5.4805, -1.4014,
          4.8320,  7.1250,  2.4062, -4.3164, -5.2734, -3.2305,  3.5684],
        [ 5.5508,  5.5508,  5.5508,  5.5508,  5.5508,  0.4468,  1.3496, -3.5273,
         -4.6562, -3.6094,  4.7266,  3.7832,  1.9541, -1.0459, -3.0820, -2.2852,
          0.6128,  2.9492,  2.5723, -0.1682, -2.7539, -2.8086, -0.2803,  2.5039,
          2.9883,  0.7236, -2.2070, -3.1074, -1.1514,  1.8633,  3.1660]],
       device='mps:0', dtype=torch.float16)
tensor(

## Experimental Verify Step

In [11]:
@torch.inference_mode()
def verify(model: nn.Module, cache: DynamicCache, tokens: list[list[int]], draft_logits: np.array):
    assert all([len(x) == len(tokens[0]) for x in tokens])
    print(tokens)
    x = torch.tensor(tokens, dtype=torch.long, device=DEVICE)

    outputs = model(
        x, 
        use_cache=True, 
        past_key_values=cache
    )

    print(cache.layers[0].keys.shape)


suffix_text = ' but this is a really useful test.'
suffix_tokens = tokenizer.encode(suffix_text)[1:]

verify(model, cache, [suffix_tokens for _ in range(3)], None)

[[719, 420, 374, 264, 2216, 5505, 1296, 13], [719, 420, 374, 264, 2216, 5505, 1296, 13], [719, 420, 374, 264, 2216, 5505, 1296, 13]]
torch.Size([3, 8, 37, 128])


In [12]:
def rollback_dynamic_per_row_simple(cache: DynamicCache, r: torch.LongTensor):
    """
    Roll back r[i] tokens for each batch row i in a DynamicCache.
    Produce a new DynamicCache with time dim S_new = max_i (S_old - r[i]).
    """
    assert cache.layers[0].keys is not None and cache.layers[0].values is not None
    assert (r >= 0).all().item()
    B = cache.layers[0].keys.shape[0]
    device = cache.layers[0].keys.device
    dtype = cache.layers[0].keys.dtype

    # Prepare destination cache
    dst = DynamicCache()

    for layer in range(len(cache)):
        K = cache.layers[layer].keys
        V = cache.layers[layer].values
        assert K is not None and V is not None

        _, H, S_old, D = K.shape

        # Per-row keep lengths
        S_keep = S_old - r  # shape (B,)
        S_new = int(S_keep.max().item())

        K_new = K.new_zeros((B, H, S_new, D))
        V_new = V.new_zeros((B, H, S_new, D))

        # Copy per row
        for i in range(B):
            keep = int(S_keep[i].item())
            if keep == 0:
                continue
            # surviving tokens are the first 'keep' positions (earliest..latest-rollback)
            K_src = K[i, :, :keep, :]
            V_src = V[i, :, :keep, :]

            # right-aligned → write to the right, pad on the left implicitly
            start = S_new - keep
            K_new[i, :, start:, :] = K_src
            V_new[i, :, start:, :] = V_src
            # print(K_new[i, 0, :, 0])

        # print(dst.layers[layer].keys[i, 0, :, 0])
        dst.update(K_new, V_new, layer)
        print(dst.layers[layer].keys[0, 0, :, 0])

    return dst

rollback_values = torch.tensor([3, 0, 1], dtype=torch.long, device=model.device)
# print(cache.layers[0].keys[:, 0, :, 0])
new_cache = rollback_dynamic_per_row_simple(cache, rollback_values)
print(new_cache.layers[0].keys[:, 0, :, 0])

tensor([ 0.0000,  0.0000,  0.0000,  5.5508,  5.5508,  5.5508,  0.4468,  1.1914,
        -3.6875, -7.4219, -2.4609,  3.2773,  6.6562,  3.4648, -0.9922, -3.0820,
        -2.2852,  0.6128,  2.9492,  2.5723, -0.1682, -2.7539, -2.8086, -0.2803,
         2.5039,  2.9883,  0.7236, -2.2070, -3.1074, -1.1514,  1.8633,  3.1660,
        -3.1875,  2.7129,  5.3828,  3.5391, -2.0801], device='mps:0',
       dtype=torch.float16)
tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,         nan,         nan,
                nan, -3.2978e-03,  4.0664e+00,  1.1885e+00, -5.6953e+00,
        -4.9805e+00, -1.0137e+00,  6.3086e+00,  4.4688e+00,  2.4062e+00,
        -3.2754e+00, -4.4336e+00, -2.0840e+00,  2.2461e+00,  4.5742e+00,
         2.6562e+00, -1.6953e+00, -4.4375e+00, -3.0879e+00,  1.0508e+00,
         4.1836e+00,  3.4336e+00, -4.9414e-01, -3.9199e+00, -3.6934e+00,
        -8.9844e-02,  3.5410e+00, -5.7617e+00, -2.2109e+00,  3.4258e+00,
         3.7461e+00,  6.2744e-01], device='mps:0', dtype=torch.float16)

In [13]:
# def rollback_dynamic_per_row(cache: DynamicCache, r: torch.LongTensor):
#     """
#     Roll back r[i] tokens for each batch row i in a DynamicCache.
#     Returns a *new* DynamicCache with time dim equal to max(L_i - r_i).
#     """
#     assert cache.layers[0].keys is not None
#     B = cache.layers[0].keys.shape[0]
#     device = cache.layers[0].keys.device
#     uniq = torch.unique(r)

#     # Build empty destination (we'll fill layer by layer)
#     dst = DynamicCache()
#     for layer in range(len(cache)):
#         K = cache.layers[layer].keys
#         V = cache.layers[layer].keys
#         assert K is not None
#         assert V is not None

#         B, H, S_old, D = K.shape
#         # Compute new per-row lengths after rollback
#         L_after = torch.full((B,), S_old, dtype=torch.long, device=device) - r
#         S_new = int(L_after.max().item())

#         K_new = K.new_zeros(B, H, S_new, D)
#         V_new = V.new_zeros(B, H, S_new, D)

#         # For each rollback bucket, crop and scatter back
#         for rv in uniq.tolist():
#             idx = (r == rv).nonzero(as_tuple=False).squeeze(-1)
#             if idx.numel() == 0: 
#                 continue
#             # Select sub-batch, crop rv tokens from the right
#             K_sub = K.index_select(0, idx)
#             V_sub = V.index_select(0, idx)

#             # physical crop for this bucket
#             S_keep = S_old - rv
#             K_sub = K_sub[..., :S_keep, :]
#             V_sub = V_sub[..., :S_keep, :]

#             # place back into dst; zero-filling beyond S_keep keeps them "rolled back"
#             K_new.index_copy_(0, idx, torch.nn.functional.pad(K_sub, (0,0,0,0,0, S_new - S_keep)))
#             V_new.index_copy_(0, idx, torch.nn.functional.pad(V_sub, (0,0,0,0,0, S_new - S_keep)))

#         dst.update(K_new, V_new, layer)

#     return dst
