In [1]:
import asyncio
import os
import time
import random
from typing import Dict, List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache, StaticCache

from tqdm import tqdm

from shared import (
    MessageChannel,
    PrefillRequest,
    PrefillResponse,
    PrefillBatchRequest,
    PrefillBatchResponse,
    ResetRequest,
    VerifyRequest,
    VerifyResponse,
    VerifyBatchRequest,
    VerifyBatchResponse,
    VerifyResponseItem,
)

import torch.nn as nn
import numpy as np

from const import DEVICE, BASE_MODEL

from dotenv import load_dotenv
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
%run utils_hf

In [None]:
model, tokenizer = load_model(BASE_MODEL)

PAD_ID = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else (
    tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
)

Loading checkpoint shards:  25%|██▌       | 1/4 [00:00<00:02,  1.02it/s]

## Prefill Stage

In [None]:
from typing import Any


prompts_str = [
    "Explanation of speculative decoding in simple terms",
    "This is a terse haiku about Apple MLX",
    "def bubble_sort(x: list[int])",
    "Why is the sky blue",
]

tokens: list[list[int]] = [tokenizer.encode(prompt) for prompt in prompts_str]

In [None]:
cache = prefill(model, tokens)

print_cache(cache, 2, 10)

tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], device='mps:0',
       dtype=torch.bfloat16)


In [None]:
# def zero_cache(cache: DynamicCache, lengths: list[int]):
#     assert cache.layers[0].keys is not None and cache.layers[0].values is not None
#     B = cache.layers[0].keys.shape[0]

#     # Prepare destination cache
#     dst = DynamicCache()

#     for layer in range(len(cache)):
#         K = cache.layers[layer].keys
#         V = cache.layers[layer].values
#         assert K is not None and V is not None

#         _, H, S, D = K.shape

#         K_new = K.new_zeros((B, H, S, D))
#         V_new = V.new_zeros((B, H, S, D))

#         # Copy per row
#         for i in range(B):
#             length = lengths[i]
#             if length == 0:
#                 continue
#             # surviving tokens are the first 'keep' positions (earliest..latest-rollback)
#             K_src = K[i, :, S-length:, :]
#             V_src = V[i, :, S-length:, :]

#             # right-aligned → write to the right, pad on the left implicitly
#             K_new[i, :, S-length:, :] = K_src
#             V_new[i, :, S-length:, :] = V_src

#         # print(dst.layers[layer].keys[i, 0, :, 0])
#         dst.update(K_new, V_new, layer)

#     return dst

# cache = zero_cache(cache, [len(x) for x in tokens])

In [None]:
print_cache(cache, 2)

tensor([ 3.1094,  3.1094,  0.4473,  0.9531, -5.1562, -8.1250, -2.8125,  1.5000,
         6.7812,  3.0156, -1.3984], device='mps:0', dtype=torch.bfloat16)


## Pure Decode (Just A Sanity Check)

In [None]:
# suffix_text = ':'
# generated = [[x] for x in tokenizer.encode(suffix_text)[1:]] * 4

# lengths = torch.tensor([len(x) for x in tokens], dtype=torch.long, device=model.device)
# print(lengths)

# for _ in tqdm(range(20)):
#     tokens: list[list[int]] = [x + y for x, y in zip(tokens, generated)]
#     # print(tokens)

#     generated, lengths = generate_step(model, cache, generated, lengths)
#     # print(generated)

# for i in range(4):
#     print(tokenizer.decode(tokens[i]))


## Experimental Verify Step

In [None]:
# @torch.inference_mode()
# def verify(model: nn.Module, cache: DynamicCache, tokens: list[list[int]], draft_logits: np.array):
#     assert all([len(x) == len(tokens[0]) for x in tokens])
#     x = torch.tensor(tokens, dtype=torch.long, device=DEVICE)

#     print(cache.layers[0].keys.shape)

#     outputs = model(
#         x, 
#         use_cache=True, 
#         past_key_values=cache
#     )

#     print(cache.layers[0].keys.shape)


# suffix_text = ': a short story'
# suffix_tokens = [tokenizer.encode(suffix_text)[1:] for _ in range(4)]

# tokens = [x + y for x, y in zip(tokens, suffix_tokens)]

# verify(model, cache, suffix_tokens, None)

In [None]:
rollback_values = data=[3, 0, 3, 8]
print([tokenizer.decode(x) for x in tokens])
print_cache(cache, 2)
new_cache, tokens = rollback_dynamic_per_row_simple(cache, tokens, rollback_values)
print([tokenizer.decode(x) for x in tokens])
print_cache(cache, 2)

['<|begin_of_text|>Explanation of speculative decoding in simple terms', '<|begin_of_text|>This is a terse haiku about Apple MLX', '<|begin_of_text|>def bubble_sort(x: list[int])', '<|begin_of_text|>Why is the sky blue']
tensor([ 3.1094,  3.1094,  0.4473,  0.9531, -5.1562, -8.1250, -2.8125,  1.5000,
         6.7812,  3.0156, -1.3984], device='mps:0', dtype=torch.bfloat16)
['<|begin_of_text|>Explanation of speculative decoding', '<|begin_of_text|>This is a terse haiku about Apple MLX', '<|begin_of_text|>def bubble_sort(x:', '<|begin_of_text|>Why is the']
tensor([ 3.1094,  3.1094,  0.4473,  0.9531, -5.1562, -8.1250, -2.8125,  1.5000,
         6.7812,  3.0156, -1.3984], device='mps:0', dtype=torch.bfloat16)


In [None]:
# def get_layer(cache: DynamicCache, b: int):
#     """
#     Roll back r[i] tokens for each batch row i in a DynamicCache.
#     The output cache maintains the same sequence length as the input, padding with zeros where needed.
#     """
#     assert cache.layers[0].keys is not None and cache.layers[0].values is not None
#     B = cache.layers[0].keys.shape[0]

#     # Prepare destination cache
#     dst = DynamicCache()

#     for layer in range(len(cache)):
#         K = cache.layers[layer].keys
#         V = cache.layers[layer].values
#         assert K is not None and V is not None

#         _, H, S, D = K.shape

#         K_new = K.new_zeros((1, H, S-2, D))
#         V_new = V.new_zeros((1, H, S-2, D))

#         K_new[0, :, :, :] = K[b, :, 2: :]
#         V_new[0, :, :, :] = V[b, :, 2:, :]

#         dst.update(K_new, V_new, layer)

#     return dst

# i = 2
# cache = get_layer(cache, i)
# tokens = [tokens[i]]

# print_cache(cache, 0)

In [None]:
suffix_text = [
    ' speculative',
    ' Apple',
    ':',
    ' is',
]
generated = [tokenizer.encode(x)[1:] for x in suffix_text]
print(generated)

lengths = torch.tensor([len(x) for x in tokens], dtype=torch.long, device=model.device)
print(lengths)

for _ in tqdm(range(20)):
    tokens: list[list[int]] = [x + y for x, y in zip(tokens, generated)]
    # print(tokens)

    generated, lengths = generate_step(model, cache, generated, lengths)
    print([tokenizer.decode(x) for x in generated])

for i in range(4):
    print(tokenizer.decode(tokens[i]))

[[66836], [8325], [25], [374]]
tensor([ 5, 11,  6,  4], device='mps:0')


 10%|█         | 2/20 [00:00<00:03,  4.57it/s]

['!', '’s', '!', '!']
['!', ' new', '!', '!']


 20%|██        | 4/20 [00:00<00:03,  5.24it/s]

['!', ' machine', '!', '!']
['!', ' learning', '!', '!']


 30%|███       | 6/20 [00:01<00:02,  5.43it/s]

['!', ' framework', '!', '!']
['!', ',', '!', '!']


 40%|████      | 8/20 [00:01<00:02,  5.53it/s]

['!', ' which', '!', '!']
['!', ' is', '!', '!']


 50%|█████     | 10/20 [00:01<00:01,  5.49it/s]

['!', ' designed', '!', '!']
['!', ' to', '!', '!']


 60%|██████    | 12/20 [00:02<00:01,  5.61it/s]

['!', ' make', '!', '!']
['!', ' it', '!', '!']


 70%|███████   | 14/20 [00:02<00:01,  5.56it/s]

['!', ' easier', '!', '!']
['!', ' for', '!', '!']


 80%|████████  | 16/20 [00:02<00:00,  5.57it/s]

['!', ' developers', '!', '!']
['!', ' to', '!', '!']


 90%|█████████ | 18/20 [00:03<00:00,  5.53it/s]

['!', ' build', '!', '!']
['!', ' machine', '!', '!']


100%|██████████| 20/20 [00:03<00:00,  5.42it/s]

['!', ' learning', '!', '!']
['!', ' models', '!', '!']
<|begin_of_text|>Explanation of speculative decoding speculative!!!!!!!!!!!!!!!!!!!
<|begin_of_text|>This is a terse haiku about Apple MLX Apple’s new machine learning framework, which is designed to make it easier for developers to build machine learning
<|begin_of_text|>def bubble_sort(x::!!!!!!!!!!!!!!!!!!!
<|begin_of_text|>Why is the is!!!!!!!!!!!!!!!!!!!





In [None]:
suffix_text = [
    ' speculative',
    ' Apple',
    ':',
    ' is',
]
# generated = [[x] for x in tokenizer.encode(suffix_text)[1:]] * 4
generated = [tokenizer.encode(x)[1:] for x in suffix_text]
print(generated)

lengths = torch.tensor([len(x) for x in tokens], dtype=torch.long, device=model.device)
print(lengths)

for _ in tqdm(range(20)):
    tokens: list[list[int]] = [x + y for x, y in zip(tokens, generated)]
    # print(tokens)

    generated, lengths = generate_step(model, cache, generated, lengths)
    print([tokenizer.decode(x) for x in generated])

for i in range(4):
    print(tokenizer.decode(tokens[i]))

[[66836], [8325], [25], [374]]
tensor([25, 31, 26, 24], device='mps:0')


  5%|▌         | 1/20 [00:00<00:03,  5.31it/s]

['!', ' ML', '!', '!']


 10%|█         | 2/20 [00:00<00:03,  5.54it/s]

['!', 'X', '!', '!']


 15%|█▌        | 3/20 [00:00<00:03,  5.54it/s]

['!', ' Apple', '!', '!']


 20%|██        | 4/20 [00:00<00:02,  5.55it/s]

['!', '’s', '!', '!']


 25%|██▌       | 5/20 [00:00<00:02,  5.51it/s]

['!', ' new', '!', '!']


 30%|███       | 6/20 [00:01<00:02,  5.49it/s]

['!', ' machine', '!', '!']


 35%|███▌      | 7/20 [00:01<00:02,  5.54it/s]

['!', ' learning', '!', '!']


 40%|████      | 8/20 [00:01<00:02,  5.50it/s]

['!', ' framework', '!', '!']


 45%|████▌     | 9/20 [00:01<00:02,  5.48it/s]

['!', ',', '!', '!']


 50%|█████     | 10/20 [00:01<00:01,  5.52it/s]

['!', ' which', '!', '!']


 55%|█████▌    | 11/20 [00:01<00:01,  5.59it/s]

['!', ' is', '!', '!']


 60%|██████    | 12/20 [00:02<00:01,  5.53it/s]

['!', ' designed', '!', '!']


 65%|██████▌   | 13/20 [00:02<00:01,  5.54it/s]

['!', ' to', '!', '!']


 70%|███████   | 14/20 [00:02<00:01,  5.54it/s]

['!', ' make', '!', '!']


 75%|███████▌  | 15/20 [00:02<00:00,  5.50it/s]

['!', ' it', '!', '!']


 80%|████████  | 16/20 [00:02<00:00,  5.52it/s]

['!', ' easier', '!', '!']


 85%|████████▌ | 17/20 [00:03<00:00,  5.49it/s]

['!', ' for', '!', '!']


 90%|█████████ | 18/20 [00:03<00:00,  5.48it/s]

['!', ' developers', '!', '!']


 95%|█████████▌| 19/20 [00:03<00:00,  5.48it/s]

['!', ' to', '!', '!']


100%|██████████| 20/20 [00:03<00:00,  5.50it/s]

['!', ' build', '!', '!']
<|begin_of_text|>Explanation of speculative decoding speculative!!!!!!!!!!!!!!!!!!! speculative!!!!!!!!!!!!!!!!!!!
<|begin_of_text|>This is a terse haiku about Apple MLX Apple’s new machine learning framework, which is designed to make it easier for developers to build machine learning Apple MLX Apple’s new machine learning framework, which is designed to make it easier for developers to
<|begin_of_text|>def bubble_sort(x::!!!!!!!!!!!!!!!!!!!:!!!!!!!!!!!!!!!!!!!
<|begin_of_text|>Why is the is!!!!!!!!!!!!!!!!!!! is!!!!!!!!!!!!!!!!!!!





In [None]:
lengths

tensor([45, 51, 46, 44], device='mps:0')

In [None]:
# suffix_text = '.'
# generated = [[x] for x in tokenizer.encode(suffix_text)[1:]] * 3

# lengths = torch.LongTensor([len(x) for x in tokens], device=model.device)

# for _ in tqdm(range(20)):
    

#     generated, lengths = generate_step(model, cache, tokens, lengths)
#     tokens: list[list[int]] = [x + y for x, y in zip(full_tokens, tokens)]

# for i in range(3):
#     print(tokenizer.decode(full_tokens[i]))

# Bin

In [None]:
# def rollback_dynamic_per_row(cache: DynamicCache, r: torch.LongTensor):
#     """
#     Roll back r[i] tokens for each batch row i in a DynamicCache.
#     Returns a *new* DynamicCache with time dim equal to max(L_i - r_i).
#     """
#     assert cache.layers[0].keys is not None
#     B = cache.layers[0].keys.shape[0]
#     device = cache.layers[0].keys.device
#     uniq = torch.unique(r)

#     # Build empty destination (we'll fill layer by layer)
#     dst = DynamicCache()
#     for layer in range(len(cache)):
#         K = cache.layers[layer].keys
#         V = cache.layers[layer].keys
#         assert K is not None
#         assert V is not None

#         B, H, S_old, D = K.shape
#         # Compute new per-row lengths after rollback
#         L_after = torch.full((B,), S_old, dtype=torch.long, device=device) - r
#         S_new = int(L_after.max().item())

#         K_new = K.new_zeros(B, H, S_new, D)
#         V_new = V.new_zeros(B, H, S_new, D)

#         # For each rollback bucket, crop and scatter back
#         for rv in uniq.tolist():
#             idx = (r == rv).nonzero(as_tuple=False).squeeze(-1)
#             if idx.numel() == 0: 
#                 continue
#             # Select sub-batch, crop rv tokens from the right
#             K_sub = K.index_select(0, idx)
#             V_sub = V.index_select(0, idx)

#             # physical crop for this bucket
#             S_keep = S_old - rv
#             K_sub = K_sub[..., :S_keep, :]
#             V_sub = V_sub[..., :S_keep, :]

#             # place back into dst; zero-filling beyond S_keep keeps them "rolled back"
#             K_new.index_copy_(0, idx, torch.nn.functional.pad(K_sub, (0,0,0,0,0, S_new - S_keep)))
#             V_new.index_copy_(0, idx, torch.nn.functional.pad(V_sub, (0,0,0,0,0, S_new - S_keep)))

#         dst.update(K_new, V_new, layer)

#     return dst
