In [1]:
import asyncio
import os
import time
import random
from typing import Dict, List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache, StaticCache

from tqdm import tqdm

from shared import (
    MessageChannel,
    PrefillRequest,
    PrefillResponse,
    PrefillBatchRequest,
    PrefillBatchResponse,
    ResetRequest,
    VerifyRequest,
    VerifyResponse,
    VerifyBatchRequest,
    VerifyBatchResponse,
    VerifyResponseItem,
)

import torch.nn as nn
import numpy as np

from const import DEVICE, BASE_MODEL

from dotenv import load_dotenv
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
%run utils_hf

In [3]:
model, tokenizer = load_model(BASE_MODEL)

Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.03s/it]
    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    


## Prefill Stage

In [53]:
from typing import Any


prompts_str = [
    "Explanation of speculative decoding in simple terms",
    "This is a terse haiku about Apple MLX",
    "def bubble_sort(x: list[int])",
    "Why is the sky blue? Well, it's a complicated question",
]

tokens: list[list[int]] = [tokenizer.encode(prompt) for prompt in prompts_str]

cache = prefill(model, tokens)

In [54]:
cache = zero_cache(cache, [len(x) for x in tokens])

In [55]:
print_cache(cache, 2, 10)

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0061,  0.7448, -0.1790,
        -0.6929, -1.2870,  0.5196,  0.4144,  0.1880, -0.0737], device='cuda:0')


## Experimental Verify Step

In [56]:
@torch.inference_mode()
def verify(model: nn.Module, cache: DynamicCache, tokens: list[list[int]], draft_logits: np.array):
    assert all([len(x) == len(tokens[0]) for x in tokens])
    x = torch.tensor(tokens, dtype=torch.long, device=DEVICE)

    outputs = model(
        x, 
        use_cache=True, 
        past_key_values=cache
    )

    accept_values = []
    for _ in range(len(tokens)):
        accept_values.append(random.randint(0, len(tokens[0])))

    return outputs.past_key_values, accept_values

suffix_text = ': a short story'
suffix_tokens = [tokenizer.encode(suffix_text)[1:] for _ in range(4)]

tokens = [x + y for x, y in zip(tokens, suffix_tokens)]

cache, accept_values = verify(model, cache, suffix_tokens, None)

print(accept_values)

[2, 2, 1, 1]


In [57]:
rollback_values = [4 - y for y in accept_values]

print([tokenizer.decode(x) for x in tokens])
print_cache(cache, 2)

cache, tokens = rollback_dynamic_per_row_simple(cache, tokens, rollback_values)

print('\n'.join([tokenizer.decode(x) for x in tokens]))
print_cache(cache, 2)

['<|begin_of_text|>Explanation of speculative decoding in simple terms: a short story', '<|begin_of_text|>This is a terse haiku about Apple MLX: a short story', '<|begin_of_text|>def bubble_sort(x: list[int]): a short story', "<|begin_of_text|>Why is the sky blue? Well, it's a complicated question: a short story"]
tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.6243,  0.9151, -5.3545,
        -7.8965, -2.5987,  2.0659,  6.6940,  2.9187, -1.7030, -0.2650, -5.2807,
        -5.4822,  0.1768], device='cuda:0')
<|begin_of_text|>Explanation of speculative decoding in simple terms: a
<|begin_of_text|>This is a terse haiku about Apple MLX: a
<|begin_of_text|>def bubble_sort(x: list[int]):
<|begin_of_text|>Why is the sky blue? Well, it's a complicated question:
tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.6243,  0.9151, -5.3545, -7.8965, -2.5987,  2.0659,  6.6940,  2.9187,
        -1.7030, -0.2650], device='cuda:0')


In [58]:
suffix_text = [
    ' method',
    ' long',
    ' list',
    ' firstly',
]
generated = [tokenizer.encode(x)[1:] for x in suffix_text]
print(generated)

lengths = torch.tensor([len(x) for x in tokens], dtype=torch.long, device=model.device)
print(lengths)

for _ in tqdm(range(20)):
    tokens: list[list[int]] = [x + y for x, y in zip(tokens, generated)]
    # print(tokens)

    generated, lengths = generate_step(model, cache, generated, lengths)
    # print([tokenizer.decode(x) for x in generated])

for i in range(4):
    print(tokenizer.decode(tokens[i]))

[[1749], [1317], [1160], [95052]]
tensor([10, 13, 10, 15], device='cuda:0')


100%|██████████| 20/20 [00:03<00:00,  6.46it/s]

<|begin_of_text|>Explanation of speculative decoding in simple terms: a method
I am trying to understand the speculative decoding method in simple terms. I have read the paper
<|begin_of_text|>This is a terse haiku about Apple MLX: a long ago.
This is a terse haiku about Apple MLX long ago.
This is a terse
<|begin_of_text|>def bubble_sort(x: list[int]): list[int]:
    """Sorts a list of integers using the bubble sort algorithm."""
    for i
<|begin_of_text|>Why is the sky blue? Well, it's a complicated question: firstly, the sky is not blue, it's a mixture of blue, green and yellow. Secondly



