In [1]:
# 1968

import asyncio
import os
import time
import random
from typing import Dict, List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache, StaticCache

from tqdm import tqdm

from shared import (
    MessageChannel,
    PrefillRequest,
    PrefillResponse,
    PrefillBatchRequest,
    PrefillBatchResponse,
    ResetRequest,
    VerifyRequest,
    VerifyResponse,
    VerifyBatchRequest,
    VerifyBatchResponse,
    VerifyResponseItem,
)

import torch.nn as nn
import numpy as np

from dotenv import load_dotenv
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
# BASE_MODEL = os.environ.get("HF_BASE_MODEL", "meta-llama/Llama-3.2-3B-Instruct")
# BASE_MODEL = os.environ.get("HF_BASE_MODEL", "meta-llama/Llama-3.2-1B-Instruct")
BASE_MODEL = os.environ.get("HF_BASE_MODEL", "meta-llama/Llama-3.1-8B")
TOP_K = int(os.environ.get("HF_TOP_K", "20"))
ATTN_IMPL_ENV = os.environ.get("HF_ATTN_IMPL", "").strip()  # e.g., "flash_attention_2" if available

In [3]:
DEVICE = torch.device('mps')
DTYPE = torch.float16

In [4]:
hf_token = os.environ.get("HF_TOKEN", None)
from_kwargs = {
    "dtype": DTYPE,
    "device_map": None,           # keep single process; move to one device below
    "low_cpu_mem_usage": True,
    "token": hf_token,
}
if ATTN_IMPL_ENV:
    from_kwargs["attn_implementation"] = ATTN_IMPL_ENV

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    **from_kwargs,
).to(DEVICE) # type: ignore

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True, token=hf_token)

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.31it/s]


## Prefill Stage

In [5]:
prompts_str = [
    "Why is the sky blue",
    "Explanation of speculative decoding in simple terms",
    "This is a terse haiku about Apple MLX"
]

prompts = [tokenizer.encode(prompt) for prompt in prompts_str]

In [6]:
@torch.inference_mode()
def prefill(model: nn.Module, prompts: list[list[int]]):
    max_len = max(len(x) for x in prompts)
    print(f'prefilling to max len {max_len}')
    padded = [[0] * (max_len - len(prompt)) + prompt for prompt in prompts]  # is pad id 0 correct?
    x = torch.tensor(padded, dtype=torch.long, device=DEVICE)

    # cache = StaticCache(
    #     config=model.config, 
    #     max_batch_size=len(prompts),
    #     max_cache_len=1024,
    #     device=model.device,
    # )

    cache = DynamicCache(
        config=model.config, 
    )

    outputs = model(
        x, 
        past_key_values=cache, 
        use_cache=True)
    return outputs.past_key_values

cache = prefill(model, prompts)

prefilling to max len 11


In [7]:
for i, layer in enumerate(cache.layers):
    print(f"Layer {i}:")
    print(f"  Keys shape: {layer.keys.shape}")
    print(f"  Values shape: {layer.values.shape}")

Layer 0:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 1:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 2:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 3:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 4:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 5:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 6:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 7:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 8:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 9:
  Keys shape: torch.Size([3, 8, 11, 128])
  Values shape: torch.Size([3, 8, 11, 128])
Layer 10:
  Keys shape: torch.Size([3, 8, 11, 128]

## Pure Decode (Just A Sanity Check)

In [8]:
@torch.inference_mode()
def generate_step(model: nn.Module, cache: StaticCache, tokens: list[int]):
    x = torch.tensor(tokens, dtype=torch.long, device=DEVICE)

    print(x.shape)
    outputs = model(x, use_cache=True, past_key_values=cache)

    suffix_cache = outputs.past_key_values

    return suffix_cache

# print(cache.layers[0].keys[0, 0, :32, 0])

# suffix_text = ' but'
# suffix_tokens = tokenizer.encode(suffix_text)[1:]

# cache = generate_step(model, cache, [suffix_tokens for _ in range(3)])

# print(cache.layers[0].keys[0, 0, :32, 0])

In [9]:
# @torch.inference_mode()
# def generate_step(model: nn.Module, cache: StaticCache, tokens: list[int]):
#     x = torch.tensor(tokens, dtype=torch.long, device=DEVICE).reshape(-1, 1)
    
#     outputs = model(x, use_cache=True, past_key_values=cache)

#     tokens = outputs.logits.argmax(axis=2).tolist()
#     return tokens

# suffix_text = ':'
# tokens = [[x] for x in tokenizer.encode(suffix_text)[1:]] * 3
# full_tokens = prompts

# for _ in tqdm(range(20)):
#     full_tokens = [x + y for x, y in zip(full_tokens, tokens)]

#     tokens = generate_step(model, cache, tokens)

# for i in range(3):
#     print(tokenizer.decode(full_tokens[i]))

In [10]:
print(cache.layers[0].keys[:, 0, :, 0])
cache.crop(-2)
print(cache.layers[0].keys[:, 0, :, 0])

tensor([[ 3.1133,  1.1787, -1.8408, -3.1660, -1.5820, -0.2014,  7.8828,  2.9688,
         -1.5879, -8.2578, -5.5938],
        [ 3.1133,  1.1787, -1.8408, -0.3940, -2.1562,  3.0645,  7.5547,  3.3496,
         -2.5117, -6.6289, -4.2461],
        [ 0.4468,  1.0215, -3.5273, -5.5859, -2.6172,  3.6797,  9.8750,  3.3047,
         -4.4180, -8.2578, -4.9453]], device='mps:0', dtype=torch.float16)
tensor([[ 3.1133,  1.1787, -1.8408, -3.1660, -1.5820, -0.2014,  7.8828,  2.9688,
         -1.5879],
        [ 3.1133,  1.1787, -1.8408, -0.3940, -2.1562,  3.0645,  7.5547,  3.3496,
         -2.5117],
        [ 0.4468,  1.0215, -3.5273, -5.5859, -2.6172,  3.6797,  9.8750,  3.3047,
         -4.4180]], device='mps:0', dtype=torch.float16)


In [11]:


@torch.inference_mode()
def generate_step(model: nn.Module, cache: DynamicCache, tokens: list[int]):
    x = torch.tensor(tokens, dtype=torch.long, device=DEVICE).reshape(-1, 1)
    
    outputs = model(x, use_cache=True, past_key_values=cache)

    tokens = outputs.logits.argmax(axis=2).tolist()
    return tokens

suffix_text = ':'
tokens = [[x] for x in tokenizer.encode(suffix_text)[1:]] * 3
full_tokens = prompts

initial_len = max([len(x) for x in prompts])
L = torch.tensor([initial_len] * 3, dtype=torch.long, device=model.device)

for _ in tqdm(range(20)):
    full_tokens = [x + y for x, y in zip(full_tokens, tokens)]

    tokens = generate_step(model, cache, tokens)

for i in range(3):
    print(tokenizer.decode(full_tokens[i]))

100%|██████████| 20/20 [00:02<00:00,  8.25it/s]

<|begin_of_text|>Why is the sky blue: "The Great Gatsby" by F. Scott Fitzgerald so popular?
The Great Gatsby is
<|begin_of_text|>Explanation of speculative decoding in simple terms: "The Secret of the Golden Flower" by Richard Wilhelm
The Secret of the Golden Flower is
<|begin_of_text|>This is a terse haiku about Apple MLX: the company, the fruit, the color, the computer, the music, the lifestyle, the





## Experimental Verify Step

In [None]:
@torch.inference_mode()
def verify(model: nn.Module, cache: DynamicCache, tokens: list[list[int]], draft_logits: np.array):
    assert all([len(x) == len(tokens[0]) for x in tokens])
    print(tokens)
    x = torch.tensor(tokens, dtype=torch.long, device=DEVICE)

    outputs = model(
        x, 
        use_cache=True, 
        past_key_values=cache
    )

    print(cache.layers[0].keys.shape)


suffix_text = ' but this is a really useful test.'
suffix_tokens = tokenizer.encode(suffix_text)[1:]

verify(model, cache, [suffix_tokens for _ in range(3)], None)

[[719, 420, 374, 264, 2216, 5505, 1296, 13], [719, 420, 374, 264, 2216, 5505, 1296, 13], [719, 420, 374, 264, 2216, 5505, 1296, 13]]
torch.Size([3, 8, 37, 128])


In [None]:
def rollback_dynamic_per_row_simple(cache: DynamicCache, r: torch.LongTensor):
    """
    Roll back r[i] tokens for each batch row i in a DynamicCache.
    Produce a new DynamicCache with time dim S_new = max_i (S_old - r[i]).
    """
    assert cache.layers[0].keys is not None and cache.layers[0].values is not None
    assert (r >= 0).all().item()
    B = cache.layers[0].keys.shape[0]
    device = cache.layers[0].keys.device
    dtype = cache.layers[0].keys.dtype

    # Prepare destination cache
    dst = DynamicCache()

    for layer in range(len(cache)):
        K = cache.layers[layer].keys
        V = cache.layers[layer].values
        assert K is not None and V is not None

        _, H, S_old, D = K.shape

        # Per-row keep lengths
        S_keep = S_old - r  # shape (B,)
        S_new = int(S_keep.max().item())

        K_new = K.new_zeros((B, H, S_new, D))
        V_new = V.new_zeros((B, H, S_new, D))

        # Copy per row
        for i in range(B):
            keep = int(S_keep[i].item())
            if keep == 0:
                continue
            # surviving tokens are the first 'keep' positions (earliest..latest-rollback)
            K_src = K[i, :, :keep, :]
            V_src = V[i, :, :keep, :]

            # right-aligned → write to the right, pad on the left implicitly
            start = S_new - keep
            K_new[i, :, start:, :] = K_src
            V_new[i, :, start:, :] = V_src
            print(K_new[i, 0, :, 0])

        dst.update(K_new, V_new, layer)

    return dst

rollback_values = torch.tensor([3, 0, 1], dtype=torch.long, device=model.device)
print(cache.layers[0].keys[:, 0, :, 0])
new_cache = rollback_dynamic_per_row_simple(cache, rollback_values)
print(cache.layers[0].keys[:, 0, :, 0])

tensor([[ 3.1133,  1.1787, -1.8408, -3.1660, -1.5820, -0.2014,  7.8828,  2.9688,
         -1.5879, -3.6562, -3.2715,  2.2988,  9.8281,  6.6172, -1.8887, -3.9707,
         -4.9141,  0.7852,  1.4707,  9.6172,  0.2461, -5.5391, -6.2148, -2.1406,
          4.5625,  9.9141,  3.2871, -5.2227, -5.4375, -3.1875,  2.7129,  5.3828,
          3.5391, -2.0801, -6.3438, -4.8164, -0.2000],
        [ 3.1133,  1.1787, -1.8408, -0.3940, -2.1562,  3.0645,  7.5547,  3.3496,
         -2.5117, -3.6562, -3.2715,  2.2988,  8.6719,  4.1367, -0.3003, -9.5156,
         -7.4766,  0.6074,  4.7383,  9.7188,  0.5723, -0.8623, -5.8867, -1.9434,
          3.7227,  4.6523,  3.5352, -5.6875, -5.4375, -3.1875,  2.7129,  5.3828,
          3.5391, -2.0801, -6.3438, -4.8164, -0.2000],
        [ 0.4468,  1.0215, -3.5273, -5.5859, -2.6172,  3.6797,  9.8750,  3.3047,
         -4.4180, -3.6562, -3.3223,  2.1973,  2.5449,  3.7520, -1.4756, -2.2715,
         -4.1016,  0.2227,  1.9541,  4.3750,  0.9512, -1.5957, -4.5586, -1.8594,

In [None]:
def rollback_dynamic_per_row(cache: DynamicCache, r: torch.LongTensor):
    """
    Roll back r[i] tokens for each batch row i in a DynamicCache.
    Returns a *new* DynamicCache with time dim equal to max(L_i - r_i).
    """
    assert cache.layers[0].keys is not None
    B = cache.layers[0].keys.shape[0]
    device = cache.layers[0].keys.device
    uniq = torch.unique(r)

    # Build empty destination (we'll fill layer by layer)
    dst = DynamicCache()
    for layer in range(len(cache)):
        K = cache.layers[layer].keys
        V = cache.layers[layer].keys
        assert K is not None
        assert V is not None

        B, H, S_old, D = K.shape
        # Compute new per-row lengths after rollback
        L_after = torch.full((B,), S_old, dtype=torch.long, device=device) - r
        S_new = int(L_after.max().item())

        K_new = K.new_zeros(B, H, S_new, D)
        V_new = V.new_zeros(B, H, S_new, D)

        # For each rollback bucket, crop and scatter back
        for rv in uniq.tolist():
            idx = (r == rv).nonzero(as_tuple=False).squeeze(-1)
            if idx.numel() == 0: 
                continue
            # Select sub-batch, crop rv tokens from the right
            K_sub = K.index_select(0, idx)
            V_sub = V.index_select(0, idx)

            # physical crop for this bucket
            S_keep = S_old - rv
            K_sub = K_sub[..., :S_keep, :]
            V_sub = V_sub[..., :S_keep, :]

            # place back into dst; zero-filling beyond S_keep keeps them "rolled back"
            K_new.index_copy_(0, idx, torch.nn.functional.pad(K_sub, (0,0,0,0,0, S_new - S_keep)))
            V_new.index_copy_(0, idx, torch.nn.functional.pad(V_sub, (0,0,0,0,0, S_new - S_keep)))

        dst.update(K_new, V_new, layer)

    return dst
