In [1]:
import asyncio
import os
import time
import random
from typing import Dict, List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache, StaticCache

from tqdm import tqdm

from shared import (
    MessageChannel,
    PrefillRequest,
    PrefillResponse,
    PrefillBatchRequest,
    PrefillBatchResponse,
    ResetRequest,
    VerifyRequest,
    VerifyResponse,
    VerifyBatchRequest,
    VerifyBatchResponse,
    VerifyResponseItem,
)

import torch.nn as nn
import numpy as np

from dotenv import load_dotenv
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
# BASE_MODEL = os.environ.get("HF_BASE_MODEL", "meta-llama/Llama-3.2-3B-Instruct")
# BASE_MODEL = os.environ.get("HF_BASE_MODEL", "meta-llama/Llama-3.2-1B-Instruct")
BASE_MODEL = os.environ.get("HF_BASE_MODEL", "meta-llama/Llama-3.1-8B")
TOP_K = int(os.environ.get("HF_TOP_K", "20"))
ATTN_IMPL_ENV = os.environ.get("HF_ATTN_IMPL", "").strip()  # e.g., "flash_attention_2" if available

In [3]:
DEVICE = torch.device('mps')
DTYPE = torch.float16

In [4]:
hf_token = os.environ.get("HF_TOKEN", None)
from_kwargs = {
    "dtype": DTYPE,
    "device_map": None,           # keep single process; move to one device below
    "low_cpu_mem_usage": True,
    "token": hf_token,
    "local_files_only": True,
}
if ATTN_IMPL_ENV:
    from_kwargs["attn_implementation"] = ATTN_IMPL_ENV

import os
os.environ["HF_HUB_OFFLINE"] = "1"

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    **from_kwargs,
).to(DEVICE) # type: ignore

tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL, 
    use_fast=True, 
    token=hf_token, 
    local_files_only=True,
)

PAD_ID = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else (
    tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.05it/s]


## Prefill Stage

In [5]:
from typing import Any


prompts_str = [
    "Explanation of speculative decoding in simple terms",
    "This is a terse haiku about Apple MLX",
    "def bubble_sort(x: list[int])",
    "Why is the sky blue",
]

tokens: list[list[int]] = [tokenizer.encode(prompt) for prompt in prompts_str]

In [6]:
# def mask_and_pos_ids(L: list[int]): # TODO: Should this be a list or a tensor?
#     max_len = max(L)

#     attention_mask = torch.zeros((len(L), max_len), dtype=torch.long, device=DEVICE)
#     for i, l in enumerate(L):
#         attention_mask[i, max_len - l:] = 1

#     # position_ids / cache_position: 0..L_i-1 for non-pad tokens, 0 for pads
#     # This works for both absolute and RoPE-style position handling.
#     position_ids = (attention_mask.cumsum(dim=-1) - 1).clamp_min(0)
#     position_ids = position_ids.masked_fill(attention_mask == 0, 0)

#     return attention_mask, position_ids 

def mask_and_pos_ids(L: list[int]): # TODO: Should this be a list or a tensor?
    max_len = max(L)

    attention_mask = torch.zeros((len(L), max_len), dtype=torch.long, device=DEVICE)
    for i, l in enumerate(L):
        attention_mask[i, max_len - l:] = 1

    # position_ids / cache_position: 0..L_i-1 for non-pad tokens, 0 for pads
    # This works for both absolute and RoPE-style position handling.
    position_ids = (attention_mask.cumsum(dim=-1) - 1).clamp_min(0)
    position_ids = position_ids.masked_fill(attention_mask == 0, 0)

    return attention_mask, position_ids 

@torch.inference_mode()
def prefill(model: nn.Module, tokens: list[list[int]]):
    max_len = max(len(x) for x in tokens)
    padded = [[PAD_ID] * (max_len - len(prompt)) + prompt for prompt in tokens]  # is pad id 0 correct?
    x = torch.tensor(padded, dtype=torch.long, device=DEVICE)

    cache = DynamicCache(
        config=model.config, 
    )

    attention_mask, position_ids = mask_and_pos_ids([len(x) for x in tokens])

    outputs = model(
        input_ids=x, 
        # attention_mask=attention_mask,
        past_key_values=cache, 
        use_cache=True,
        # position_ids=position_ids,
    )

    return outputs.past_key_values

cache = prefill(model, tokens)

print(tokens)

[[128000, 70869, 315, 66836, 48216, 304, 4382, 3878], [128000, 2028, 374, 264, 51637, 6520, 39342, 922, 8325, 20187, 55], [128000, 755, 24529, 18942, 2120, 25, 1160, 19155, 2526], [128000, 10445, 374, 279, 13180, 6437]]


In [7]:
def zero_cache(cache: DynamicCache, lengths: list[int]):
    assert cache.layers[0].keys is not None and cache.layers[0].values is not None
    B = cache.layers[0].keys.shape[0]

    # Prepare destination cache
    dst = DynamicCache()

    for layer in range(len(cache)):
        K = cache.layers[layer].keys
        V = cache.layers[layer].values
        assert K is not None and V is not None

        _, H, S, D = K.shape

        K_new = K.new_zeros((B, H, S, D))
        V_new = V.new_zeros((B, H, S, D))

        # Copy per row
        for i in range(B):
            length = lengths[i]
            if length == 0:
                continue
            # surviving tokens are the first 'keep' positions (earliest..latest-rollback)
            K_src = K[i, :, S-length:, :]
            V_src = V[i, :, S-length:, :]

            # right-aligned → write to the right, pad on the left implicitly
            K_new[i, :, S-length:, :] = K_src
            V_new[i, :, S-length:, :] = V_src

        # print(dst.layers[layer].keys[i, 0, :, 0])
        dst.update(K_new, V_new, layer)

    return dst

cache = zero_cache(cache, [len(x) for x in tokens])

In [8]:
print(cache.layers[0].keys[:, 0, 2, 0])

tensor([ 0.0000, -3.5273,  0.1252,  0.0000], device='mps:0',
       dtype=torch.float16)


## Pure Decode (Just A Sanity Check)

In [9]:
@torch.inference_mode()
def generate_step(
    model: nn.Module,
    cache: DynamicCache,
    tokens: list[list[int]],
    lengths: torch.LongTensor,
):
    x = torch.tensor(tokens, dtype=torch.long, device=DEVICE).view(-1, 1)

    B = lengths.size(0)
    S_prev = cache.layers[0].keys.shape[2]

    # Use int64 (not bool) on MPS; build past + current token mask
    attn_mask = torch.zeros((B, S_prev + 1), dtype=torch.long, device=DEVICE)

    starts = (S_prev - lengths).clamp_min(0)                # (B,)
    idx = torch.arange(S_prev, device=DEVICE)[None, :]      # (1, S_prev)
    attn_mask[:, :-1] = (idx >= starts[:, None]).to(torch.long)
    attn_mask[:, -1] = 1

    pos_ids = lengths.view(B, 1)  # new token's RoPE position (0..L_i)

    out = model(
        input_ids=x,
        past_key_values=cache,
        use_cache=True,
        attention_mask=attn_mask,  # int mask
        position_ids=pos_ids,
    )

    next_tok = out.logits[:, -1].argmax(dim=-1).tolist()
    lengths = lengths + 1
    return [[t] for t in next_tok], lengths



# suffix_text = ':'
# generated = [[x] for x in tokenizer.encode(suffix_text)[1:]] * 4

# lengths = torch.tensor([len(x) for x in tokens], dtype=torch.long, device=model.device)
# print(lengths)

# for _ in tqdm(range(20)):
#     tokens: list[list[int]] = [x + y for x, y in zip(tokens, generated)]
#     # print(tokens)

#     generated, lengths = generate_step(model, cache, generated, lengths)
#     # print(generated)

# for i in range(4):
#     print(tokenizer.decode(tokens[i]))


## Experimental Verify Step

In [10]:
@torch.inference_mode()
def verify(model: nn.Module, cache: DynamicCache, tokens: list[list[int]], draft_logits: np.array):
    assert all([len(x) == len(tokens[0]) for x in tokens])
    x = torch.tensor(tokens, dtype=torch.long, device=DEVICE)

    print(cache.layers[0].keys.shape)

    outputs = model(
        x, 
        use_cache=True, 
        past_key_values=cache
    )

    print(cache.layers[0].keys.shape)


suffix_text = ': a short story'
suffix_tokens = [tokenizer.encode(suffix_text)[1:] for _ in range(4)]

tokens = [x + y for x, y in zip(tokens, suffix_tokens)]

verify(model, cache, suffix_tokens, None)

torch.Size([4, 8, 11, 128])
torch.Size([4, 8, 15, 128])


In [11]:
def rollback_dynamic_per_row_simple(cache: DynamicCache, tokens: list[list[int]], r: list[int]):
    """
    Roll back r[i] tokens for each batch row i in a DynamicCache.
    The output cache maintains the same sequence length as the input, padding with zeros where needed.
    """
    assert cache.layers[0].keys is not None and cache.layers[0].values is not None
    assert all([x >= 0 for x in r])
    B = cache.layers[0].keys.shape[0]
    device = cache.layers[0].keys.device
    dtype = cache.layers[0].keys.dtype

    # Prepare destination cache
    dst = DynamicCache()

    for layer in range(len(cache)):
        K = cache.layers[layer].keys
        V = cache.layers[layer].values
        assert K is not None and V is not None

        _, H, S, D = K.shape

        K_new = K.new_zeros((B, H, S, D))
        V_new = V.new_zeros((B, H, S, D))

        # Copy per row
        for i in range(B):
            keep = S - r[i]
            if keep <= 0:
                continue
            # surviving tokens are the first 'keep' positions (earliest..latest-rollback)
            K_src = K[i, :, :keep, :]
            V_src = V[i, :, :keep, :]

            # right-aligned → write to the right, pad on the left implicitly
            start = S - keep
            K_new[i, :, start:, :] = K_src
            V_new[i, :, start:, :] = V_src
            # print(K_new[i, 0, :, 0])

        # print(dst.layers[layer].keys[i, 0, :, 0])
        dst.update(K_new, V_new, layer)

    tokens = [x[:len(x) - trim] for x, trim in zip(tokens, r)]

    return dst, tokens

rollback_values = data=[3, 0, 3, 8]
print([tokenizer.decode(x) for x in tokens])
print(cache.layers[0].keys[2, 0, :, 0])
new_cache, tokens = rollback_dynamic_per_row_simple(cache, tokens, rollback_values)
print([tokenizer.decode(x) for x in tokens])
print(new_cache.layers[0].keys[2, 0, :, 0])

['<|begin_of_text|>Explanation of speculative decoding in simple terms: a short story', '<|begin_of_text|>This is a terse haiku about Apple MLX: a short story', '<|begin_of_text|>def bubble_sort(x: list[int]): a short story', '<|begin_of_text|>Why is the sky blue: a short story']
tensor([ 0.0000,  0.0000,  0.1252, -6.6953, -2.9492,  5.2969,  7.7773,  2.5840,
        -2.9531, -6.8164, -2.0684,  0.4590,  5.4766,  4.9961, -1.1309],
       device='mps:0', dtype=torch.float16)
['<|begin_of_text|>Explanation of speculative decoding in simple terms:', '<|begin_of_text|>This is a terse haiku about Apple MLX: a short story', '<|begin_of_text|>def bubble_sort(x: list[int]):', '<|begin_of_text|>Why']
tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.1252, -6.6953, -2.9492,
         5.2969,  7.7773,  2.5840, -2.9531, -6.8164, -2.0684,  0.4590],
       device='mps:0', dtype=torch.float16)


In [12]:
for i in range(4):
    print(tokenizer.decode(tokens[i]))

# raise Exception('stop')

<|begin_of_text|>Explanation of speculative decoding in simple terms:
<|begin_of_text|>This is a terse haiku about Apple MLX: a short story
<|begin_of_text|>def bubble_sort(x: list[int]):
<|begin_of_text|>Why


In [13]:
suffix_text = [
    ' speculative',
    ' Apple',
    ' \n',
    ' is',
]
# generated = [[x] for x in tokenizer.encode(suffix_text)[1:]] * 4
generated = [tokenizer.encode(x)[1:] for x in suffix_text]
print(generated)

lengths = torch.tensor([len(x) for x in tokens], dtype=torch.long, device=model.device)
print(lengths)

for _ in tqdm(range(20)):
    tokens: list[list[int]] = [x + y for x, y in zip(tokens, generated)]
    # print(tokens)

    generated, lengths = generate_step(model, cache, generated, lengths)
    print([tokenizer.decode(x) for x in generated])

for i in range(4):
    print(tokenizer.decode(tokens[i]))

[[66836], [8325], [720], [374]]
tensor([ 9, 15, 10,  2], device='mps:0')


  5%|▌         | 1/20 [00:00<00:04,  3.93it/s]

[' decoding', ' has', ':', '://']


 15%|█▌        | 3/20 [00:00<00:03,  4.69it/s]

[' in', ' been', ' a', ' story']
[' simple', ' telling', ' a', '\n']


 25%|██▌       | 5/20 [00:01<00:03,  4.87it/s]

[' decoding', ' about', ' a', 'The']
[' in', ' its', ' a', ' story']


 35%|███▌      | 7/20 [00:01<00:02,  5.02it/s]

[' simple', ' machine', ' a', ' is']
[' decoding', ' learning', ' a', ' about']


 45%|████▌     | 9/20 [00:01<00:02,  5.11it/s]

[':', ' technology', ' a', ' a']
[' a', '.', ' a', ' young']


 55%|█████▌    | 11/20 [00:02<00:01,  5.03it/s]

[' short', ' The', ' a', ' man']
[' story', ' story', ' a', ' named']


 65%|██████▌   | 13/20 [00:02<00:01,  5.03it/s]

[':', ' is', ' a', ' Jack']
[' a', ' about', ' a', ' who']


 75%|███████▌  | 15/20 [00:03<00:00,  5.04it/s]

[' short', ' how', ' a', ' is']
[' story', ' Apple', ' a', ' a']


 85%|████████▌ | 17/20 [00:03<00:00,  5.00it/s]

[':', '’s', ' a', ' passionate']
[' a', ' machine', ' a', ' writer']


 90%|█████████ | 18/20 [00:03<00:00,  4.96it/s]

[' simple', ' learning', ' a', ' but']


100%|██████████| 20/20 [00:04<00:00,  4.95it/s]

[' decoding', ' technology', ' a', ' struggling']
[':', ' is', ' a', ' to']
<|begin_of_text|>Explanation of speculative decoding in simple terms: speculative decoding in simple decoding in simple decoding: a short story: a short story: a simple decoding
<|begin_of_text|>This is a terse haiku about Apple MLX: a short story Apple has been telling about its machine learning technology. The story is about how Apple’s machine learning technology
<|begin_of_text|>def bubble_sort(x: list[int]): 
: a a a a a a a a a a a a a a a a a a
<|begin_of_text|>Why is:// story
The story is about a young man named Jack who is a passionate writer but struggling





In [14]:
lengths

tensor([29, 35, 30, 22], device='mps:0')

In [15]:
# suffix_text = '.'
# generated = [[x] for x in tokenizer.encode(suffix_text)[1:]] * 3

# lengths = torch.LongTensor([len(x) for x in tokens], device=model.device)

# for _ in tqdm(range(20)):
    

#     generated, lengths = generate_step(model, cache, tokens, lengths)
#     tokens: list[list[int]] = [x + y for x, y in zip(full_tokens, tokens)]

# for i in range(3):
#     print(tokenizer.decode(full_tokens[i]))

# Bin

In [16]:
# def rollback_dynamic_per_row(cache: DynamicCache, r: torch.LongTensor):
#     """
#     Roll back r[i] tokens for each batch row i in a DynamicCache.
#     Returns a *new* DynamicCache with time dim equal to max(L_i - r_i).
#     """
#     assert cache.layers[0].keys is not None
#     B = cache.layers[0].keys.shape[0]
#     device = cache.layers[0].keys.device
#     uniq = torch.unique(r)

#     # Build empty destination (we'll fill layer by layer)
#     dst = DynamicCache()
#     for layer in range(len(cache)):
#         K = cache.layers[layer].keys
#         V = cache.layers[layer].keys
#         assert K is not None
#         assert V is not None

#         B, H, S_old, D = K.shape
#         # Compute new per-row lengths after rollback
#         L_after = torch.full((B,), S_old, dtype=torch.long, device=device) - r
#         S_new = int(L_after.max().item())

#         K_new = K.new_zeros(B, H, S_new, D)
#         V_new = V.new_zeros(B, H, S_new, D)

#         # For each rollback bucket, crop and scatter back
#         for rv in uniq.tolist():
#             idx = (r == rv).nonzero(as_tuple=False).squeeze(-1)
#             if idx.numel() == 0: 
#                 continue
#             # Select sub-batch, crop rv tokens from the right
#             K_sub = K.index_select(0, idx)
#             V_sub = V.index_select(0, idx)

#             # physical crop for this bucket
#             S_keep = S_old - rv
#             K_sub = K_sub[..., :S_keep, :]
#             V_sub = V_sub[..., :S_keep, :]

#             # place back into dst; zero-filling beyond S_keep keeps them "rolled back"
#             K_new.index_copy_(0, idx, torch.nn.functional.pad(K_sub, (0,0,0,0,0, S_new - S_keep)))
#             V_new.index_copy_(0, idx, torch.nn.functional.pad(V_sub, (0,0,0,0,0, S_new - S_keep)))

#         dst.update(K_new, V_new, layer)

#     return dst
