In [1]:
import asyncio
import os
import time
import random
from typing import Dict, List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache, StaticCache

from tqdm import tqdm

from shared import (
    MessageChannel,
    PrefillRequest,
    PrefillResponse,
    PrefillBatchRequest,
    PrefillBatchResponse,
    ResetRequest,
    VerifyRequest,
    VerifyResponse,
    VerifyBatchRequest,
    VerifyBatchResponse,
    VerifyResponseItem,
)

import torch.nn as nn
import numpy as np

from const import DEVICE, BASE_MODEL

from dotenv import load_dotenv
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
%run utils_hf

In [3]:
model, tokenizer = load_model(BASE_MODEL)

PAD_ID = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else (
    tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


## Prefill Stage

In [None]:
from typing import Any

prompts_str = [
    "def bubble_sort(x: list[int])",
    # "What is the meaning of life? Well, ",
]

tokens: list[list[int]] = [PAD_ID] * 2 + [tokenizer.encode(prompt) for prompt in prompts_str]

In [5]:
cache = prefill(model, tokens)

print_cache(cache, 0, 10)

tensor([ 5.1558e-05,  6.2207e-01, -1.9043e-02, -5.0879e-01, -1.1699e+00,
         6.2695e-01,  6.3184e-01,  3.3008e-01, -3.7744e-01], device='mps:0',
       dtype=torch.float16)


In [None]:
print_cache(cache, 0, 10)
# cache, tokens = rollback_dynamic_per_row_simple(cache, tokens, [4])
# print_cache(cache, 0)

tensor([ 0.4468,  0.9570, -5.1562, -8.1172, -2.8301,  1.5059,  6.7773,  3.0020,
        -1.3984], device='mps:0', dtype=torch.float16)


In [10]:
print(cache.layers[2].keys[0, 0, :, 0])

raise Exception('stop')

tensor([ 2.0943e-03, -3.3496e-01,  4.4385e-01,  2.3926e+00,  2.7383e+00,
         4.6411e-01, -1.5254e+00, -2.8418e+00, -1.1318e+00], device='mps:0',
       dtype=torch.float16)


Exception: stop

In [None]:
suffix_text = [
    ':',
]
generated = [tokenizer.encode(x)[1:] for x in suffix_text]

lengths = torch.tensor([len(x) for x in tokens], dtype=torch.long, device=model.device)

for _ in tqdm(range(20)):
    tokens: list[list[int]] = [x + y for x, y in zip(tokens, generated)]
    # print(tokens)

    generated, lengths = generate_step(model, cache, generated, lengths)
    # print([tokenizer.decode(x) for x in generated])

print(tokenizer.decode(tokens[0]))

100%|██████████| 20/20 [00:01<00:00, 14.67it/s]

<|begin_of_text|>def bubble_sort(x: list) -> list:
    """Sorts a list in ascending order using the bubble sort algorithm



