In [1]:
import math
import time
import torch
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import math
from tqdm.auto import tqdm
import random

MODEL_NAME = "gpt2"
DATASET_NAME = "wikitext"
DATASET_CONFIG = "wikitext-2-raw-v1"
DEVICE = "cuda"
SEQ_LEN = 1024
BATCH_SIZE = 16

In [2]:
def load_wikitext(tokenizer, ratio=0.5):
    dataset = load_dataset(DATASET_NAME, DATASET_CONFIG)
    text_texts = dataset['test']['text']
    full_text = "\n\n".join(text_texts[:int(len(text_texts)*ratio)])
    
    enc = tokenizer(full_text, return_tensors='pt')
    return enc["input_ids"][0]

In [3]:
def make_windows(ids, stride):
    trunc_size = ids.size(0) - (ids.size(0) % SEQ_LEN)
    ids = ids[:trunc_size]
    windows = []
    for i in range(0, ids.size(0) - SEQ_LEN + 1, stride):
        windows.append(ids[i:i+SEQ_LEN])
    return torch.stack(windows)
    

In [4]:
def compute_ppl_cache(model, windows, prefill_len, scored_len, forward_kwargs, window_batch_size=8):
    B, L = windows.shape
    latencies = []
    total_nll = 0.0
    total_tokens = 0
    num_batches = (B + window_batch_size - 1) // window_batch_size
    batch = 0
    for start in range(0, B, window_batch_size):
        print(f"Processing batch {batch+1}/{num_batches}")
        batch += 1
        end = min(start + window_batch_size, B)
        win_batch = windows[start:end].to(DEVICE)
        b = win_batch.size(0)
        prefill = win_batch[:, :prefill_len]
        out = model(
            input_ids=prefill,
            use_cache=True,
            **forward_kwargs
        )
        pkvs = out.past_key_values
    
        for offset in range(scored_len):
            tok = prefill_len + offset
            
            input_token = win_batch[:, tok-1:tok]
            target = win_batch[:, tok]
            t0 = time.perf_counter()
            out = model(
                input_ids=input_token,
                past_key_values=pkvs,
                use_cache=True,
                **forward_kwargs
            )
            t1 = time.perf_counter()
            pkvs = out.past_key_values
            
            logits = out.logits[:, -1, :]
            log_probs = torch.log_softmax(logits, dim=-1)
            nll_step = -log_probs.gather(1, target.unsqueeze(1)).squeeze(1)
            
            total_nll += nll_step.sum().item()
            total_tokens += b
            latencies.append(t1 - t0)
        
    return total_nll, total_tokens, latencies
        

In [5]:
def compute_ppl_no_cache(model, windows, prefill_len, scored_len, forward_kwargs=None):
    B, _ = windows.shape
    latencies = []
    total_nll = 0.0
    total_tokens = 0
    
    for offset in range(scored_len):
        tok = prefill_len + offset
        
        context = windows[:, :tok]
        target = windows[:, tok]
        
        t0 = time.perf_counter()
        out = model(
            input_ids=context,
            use_cache=False,
            **forward_kwargs
        )
        t1 = time.perf_counter()
        logits = out.logits[:, -1, :]
        log_probs = torch.log_softmax(logits, dim=-1)
        nll_step = -log_probs.gather(1, target.unsqueeze(1)).squeeze(1)
        
        total_nll += nll_step.sum().item()
        total_tokens += B
        latencies.append(t1 - t0)

In [6]:
def compute_ppl(model, windows, device, prefill_len, use_cache=True, forward_kwargs=None):
    if forward_kwargs is None:
        forward_kwargs = {}
    model.eval()
    windows = windows.to(device)
    B, L = windows.shape
    scored_len = L - prefill_len

    with torch.inference_mode():
        if use_cache:
            total_nll, total_tokens, latencies = compute_ppl_cache(
                model, windows, prefill_len, scored_len, forward_kwargs
            )
        else:
            total_nll, total_tokens, latencies = compute_ppl_no_cache(
                model, windows, prefill_len, scored_len, forward_kwargs
            )
    avg_nll = total_nll / total_tokens
    ppl = math.exp(avg_nll)
    avg_latency_ms = 1000.0 * sum(latencies) / len(latencies)
    return ppl, avg_latency_ms

In [7]:
print(f"Loading wikitext")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.float16).to(DEVICE)

print("Tokenizing dataset")
ids = load_wikitext(tokenizer, ratio=0.3)
print(f"Total tokens: {ids.size(0)}")

print("Creating windows")
windows = make_windows(ids, stride=SEQ_LEN//2)

Loading wikitext


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Tokenizing dataset


Token indices sequence length is longer than the specified maximum sequence length for this model (90482 > 1024). Running this sequence through the model will result in indexing errors


Total tokens: 90482
Creating windows


In [8]:
ppl, avg_latency_ms = compute_ppl(
    model, windows, DEVICE, prefill_len=SEQ_LEN//2, use_cache=True
)
print(f"PPL with cache: {ppl:.2f}, Avg latency per token: {avg_latency_ms:.2f} ms")

Processing batch 1/22
Processing batch 2/22
Processing batch 3/22
Processing batch 4/22
Processing batch 5/22
Processing batch 6/22
Processing batch 7/22
Processing batch 8/22
Processing batch 9/22
Processing batch 10/22
Processing batch 11/22
Processing batch 12/22
Processing batch 13/22
Processing batch 14/22
Processing batch 15/22
Processing batch 16/22
Processing batch 17/22
Processing batch 18/22
Processing batch 19/22
Processing batch 20/22
Processing batch 21/22
Processing batch 22/22
PPL with cache: 25.75, Avg latency per token: 8.01 ms
