In [1]:
from model.modeling_llama import LlamaForCausalLM as ModifiedLlama
from transformers import AutoTokenizer, AutoModelForCausalLM
from kvcache.iterative import IterativeReduceKVBiasCache as ModifiedCache
from transformers import DynamicCache
from datasets import load_dataset

import torch
import numpy as np
import random

import torch.nn.functional as F
import time
from IPython.display import DisplayHandle

import os
import random
import pandas as pd

DEVICE = 'mps'
DTYPE = torch.float32
FIRST_N = 1000
SAMPLE_SIZE = 10

In [2]:
# Define the cache file name
CACHE_FILENAME = f"fineweb_sample{SAMPLE_SIZE}of{FIRST_N}.csv"

# Check if the cache file already exists
if os.path.exists(CACHE_FILENAME):
    print(f"Cache file already exists: {CACHE_FILENAME}")
    df = pd.read_csv(CACHE_FILENAME)
else:
    # Load streaming dataset
    dataset = load_dataset("HuggingFaceFW/fineweb-edu", split="train", name="sample-10BT", streaming=True)
    stream = iter(dataset)

    # Take 1000 streamed samples
    samples = [next(stream) for _ in range(1000)]

    # Randomly select 10 of them
    selected_samples = random.sample(samples, 10)

    # Convert to DataFrame
    df = pd.DataFrame(selected_samples)

    # Save to CSV
    df.to_csv(CACHE_FILENAME, index=False, encoding="utf-8")
    print(f"Saved CSV with {len(df)} samples to: {CACHE_FILENAME}")

texts = df["text"]

Cache file already exists: fineweb_sample10of1000.csv


In [3]:
def stepwise_perplexity(model, tokenizer, texts, cache_impl, update_every=10):
    total_loss = 0.0
    total_tokens = 0
    loss_fn = torch.nn.CrossEntropyLoss()

    global_start = time.time()
    total_texts = len(texts)

    display_handle = DisplayHandle()
    display_handle.display("Starting perplexity evaluation...")

    for text_idx, text in enumerate(texts):
        enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
        input_ids = enc["input_ids"].squeeze(0).to(model.device)
        cache = cache_impl()
        seq_len = input_ids.size(0)

        for i in range(1, seq_len):
            input_slice = input_ids[i - 1 : i].unsqueeze(0)  # [1, 1]
            label_token = input_ids[i].unsqueeze(0)          # [1]

            with torch.no_grad():
                output = model(
                    input_ids=input_slice,
                    use_cache=True,
                    past_key_values=cache,
                )
                cache = output.past_key_values
                logits = output.logits[:, -1, :]  # [1, vocab_size]
                loss = loss_fn(logits, label_token)
                total_loss += loss.item()
                total_tokens += 1

            # Only update output every N steps for smooth UX
            if total_tokens % update_every == 0 or (i == seq_len - 1 and text_idx == total_texts - 1):
                elapsed = time.time() - global_start
                avg_step_time = elapsed / max(total_tokens, 1)
                remaining_steps = sum(len(tokenizer(t, truncation=True, max_length=256)["input_ids"]) - 1 for t in texts) - total_tokens
                eta_minutes = (avg_step_time * remaining_steps) / 60
                current_ppl = np.exp(total_loss / total_tokens)

                status = (f"Text {text_idx + 1}/{total_texts} | "
                          f"Token {i + 1}/{seq_len} "
                          f"Global Steps: {total_tokens} | "
                          f"ETA: {eta_minutes:.1f} min "
                          f"Cumulative PPL: {current_ppl:.2f}")
                display_handle.update(status)

    if total_tokens == 0:
        return float("inf")

    return np.exp(total_loss / total_tokens)


In [4]:
# model_hf = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
# model_hf.eval().to(DEVICE).to(DTYPE)
# 
# tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")

In [5]:
# stepwise_perplexity(model_hf, tokenizer, texts, cache_impl=lambda: None)

In [6]:
model_mod = ModifiedLlama.from_pretrained("HuggingFaceTB/SmolLM2-135M")
model_mod.eval().to(DEVICE).to(DTYPE)

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")

In [7]:
stepwise_perplexity(model_mod, tokenizer, texts, cache_impl=lambda: ModifiedCache())

'Starting perplexity evaluation...'

RuntimeError: Backend doesn't support synchronizing streams.