In [1]:
import asyncio
import os
import time
import random
from typing import Dict, List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache, StaticCache

from tqdm import tqdm

from shared import (
    MessageChannel,
    PrefillRequest,
    PrefillResponse,
    PrefillBatchRequest,
    PrefillBatchResponse,
    ResetRequest,
    VerifyRequest,
    VerifyResponse,
    VerifyBatchRequest,
    VerifyBatchResponse,
    VerifyResponseItem,
)

import torch.nn as nn
import numpy as np

from dotenv import load_dotenv
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
BASE_MODEL = os.environ.get("HF_BASE_MODEL", "meta-llama/Llama-3.1-8B")
TOP_K = int(os.environ.get("HF_TOP_K", "20"))
ATTN_IMPL_ENV = os.environ.get("HF_ATTN_IMPL", "").strip()  # e.g., "flash_attention_2" if available

In [3]:
DEVICE = torch.device('mps')
DTYPE = torch.float16

In [4]:
hf_token = os.environ.get("HF_TOKEN", None)
from_kwargs = {
    "dtype": DTYPE,
    "device_map": None,           # keep single process; move to one device below
    "low_cpu_mem_usage": True,
    "token": hf_token,
    "local_files_only": True,
}
if ATTN_IMPL_ENV:
    from_kwargs["attn_implementation"] = ATTN_IMPL_ENV

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    **from_kwargs,
).to(DEVICE) # type: ignore

tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL, 
    use_fast=True, 
    token=hf_token, 
    local_files_only=True,
)

PAD_ID = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else (
    tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.89s/it]


## Prefill Stage

In [11]:
from typing import Any

prompts_str = [
    # "def bubble_sort(x: list[int]):",
    "What is the meaning of life",
]

tokens: list[list[int]] = [tokenizer.encode(prompt) for prompt in prompts_str]

In [12]:
@torch.inference_mode()
def prefill(model: nn.Module, tokens: list[list[int]]):
    max_len = max(len(x) for x in tokens)
    padded = [[PAD_ID] * (max_len - len(prompt)) + prompt for prompt in tokens]  # is pad id 0 correct?
    x = torch.tensor(padded, dtype=torch.long, device=DEVICE)

    cache = DynamicCache(
        config=model.config, 
    )

    outputs = model(
        input_ids=x, 
        past_key_values=cache, 
        use_cache=True,
    )

    return outputs.past_key_values

cache = prefill(model, tokens)

print(tokens)

[[128000, 3923, 374, 279, 7438, 315, 2324]]


In [None]:
@torch.inference_mode()
def generate_step(
    model: nn.Module,
    cache: DynamicCache,
    tokens: list[list[int]],
    lengths: torch.LongTensor,
):
    x = torch.tensor(tokens, dtype=torch.long, device=DEVICE).view(-1, 1)

    B = lengths.size(0)
    S_prev = cache.layers[0].keys.shape[2]

    # Use int64 (not bool) on MPS; build past + current token mask
    attn_mask = torch.zeros((B, S_prev + 1), dtype=torch.long, device=DEVICE)

    starts = (S_prev - lengths).clamp_min(0)                # (B,)
    idx = torch.arange(S_prev, device=DEVICE)[None, :]      # (1, S_prev)
    attn_mask[:, :-1] = (idx >= starts[:, None]).to(torch.long)
    attn_mask[:, -1] = 1

    pos_ids = lengths.view(B, 1)  # new token's RoPE position (0..L_i)

    out = model(
        input_ids=x,
        past_key_values=cache,
        use_cache=True,
        attention_mask=attn_mask,  # int mask
        position_ids=pos_ids,
    )

    next_tok = out.logits[:, -1].argmax(dim=-1).tolist()
    lengths = lengths + 1
    return [[t] for t in next_tok], lengths




suffix_text = [
    '?',
]
generated = [tokenizer.encode(x)[1:] for x in suffix_text]
print(generated)

lengths = torch.tensor([len(x) for x in tokens], dtype=torch.long, device=model.device)
print(lengths)

for _ in tqdm(range(20)):
    tokens: list[list[int]] = [x + y for x, y in zip(tokens, generated)]
    print(tokens)

    generated, lengths = generate_step(model, cache, generated, lengths)
    print([tokenizer.decode(x) for x in generated])

print(tokenizer.decode(tokens[0]))

[[30]]
tensor([7], device='mps:0')


 10%|█         | 2/20 [00:00<00:01, 17.66it/s]

[[128000, 3923, 374, 279, 7438, 315, 2324, 30]]
[' This']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115]]
[' is']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374]]
[' a']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264]]


 30%|███       | 6/20 [00:00<00:00, 17.77it/s]

[' question']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488]]
[' that']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430]]
[' has']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706]]
[' puzzled']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420]]


 50%|█████     | 10/20 [00:00<00:00, 17.81it/s]

[' philosophers']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420, 61787]]
[',']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420, 61787, 11]]
[' theolog']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420, 61787, 11, 90602]]
['ians']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420, 61787, 11, 90602, 5493]]


 70%|███████   | 14/20 [00:00<00:00, 17.91it/s]

[',']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420, 61787, 11, 90602, 5493, 11]]
[' and']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420, 61787, 11, 90602, 5493, 11, 323]]
[' scientists']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420, 61787, 11, 90602, 5493, 11, 323, 14248]]
[' for']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420, 61787, 11, 90602, 5493, 11, 323, 14248, 369]]


 90%|█████████ | 18/20 [00:01<00:00, 17.84it/s]

[' centuries']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420, 61787, 11, 90602, 5493, 11, 323, 14248, 369, 24552]]
['.']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420, 61787, 11, 90602, 5493, 11, 323, 14248, 369, 24552, 13]]
[' There']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420, 61787, 11, 90602, 5493, 11, 323, 14248, 369, 24552, 13, 2684]]
[' are']
[[128000, 3923, 374, 279, 7438, 315, 2324, 30, 1115, 374, 264, 3488, 430, 706, 87420, 61787, 11, 90602, 5493, 11, 323, 14248, 369, 24552, 13, 2684, 527]]


100%|██████████| 20/20 [00:01<00:00, 17.83it/s]

[' many']
<|begin_of_text|>What is the meaning of life? This is a question that has puzzled philosophers, theologians, and scientists for centuries. There are





IndexError: list index out of range