In [1]:
import math
import time
import torch
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantizedCache
import math
from tqdm.auto import tqdm
import random
import pandas as pd

!pip install "optimum[quanto]"

MODEL_NAME = "gpt2"
DATASET_NAME = "wikitext"
DATASET_CONFIG = "wikitext-2-raw-v1"
DEVICE = "cuda"
SEQ_LEN = 1024
BATCH_SIZE = 16



In [2]:
def load_wikitext(tokenizer, ratio=0.5):
    dataset = load_dataset(DATASET_NAME, DATASET_CONFIG)
    text_texts = dataset['test']['text']
    full_text = "\n\n".join(text_texts[:int(len(text_texts)*ratio)])
    
    enc = tokenizer(full_text, return_tensors='pt')
    return enc["input_ids"][0]

In [3]:
def make_windows(ids, stride):
    trunc_size = ids.size(0) - (ids.size(0) % SEQ_LEN)
    ids = ids[:trunc_size]
    windows = []
    for i in range(0, ids.size(0) - SEQ_LEN + 1, stride):
        windows.append(ids[i:i+SEQ_LEN])
    return torch.stack(windows)
    

In [4]:
def compute_ppl_cache(model, windows, prefill_len, scored_len, forward_kwargs, window_batch_size=16):
    B, L = windows.shape
    latencies = []
    total_nll = 0.0
    total_tokens = 0
    num_batches = (B + window_batch_size - 1) // window_batch_size
    
    batch = 0
    for start in range(0, B, window_batch_size):
        print(f"Processing batch {batch+1}/{num_batches}")
        batch += 1
        
        end = min(start + window_batch_size, B)
        win_batch = windows[start:end].to(DEVICE)
        b = win_batch.size(0)
        
        prefill = win_batch[:, :prefill_len]
        
        out = model(
            input_ids=prefill,
            use_cache=True,
            **forward_kwargs
        )
        pkvs = out.past_key_values
    
        for offset in range(scored_len):
            tok = prefill_len + offset
            
            input_token = win_batch[:, tok-1:tok]
            target = win_batch[:, tok]
            t0 = time.perf_counter()
            out = model(
                input_ids=input_token,
                past_key_values=pkvs,
                use_cache=True,
                **forward_kwargs
            )
            t1 = time.perf_counter()
            pkvs = out.past_key_values
            
            logits = out.logits[:, -1, :]
            log_probs = torch.log_softmax(logits, dim=-1)
            nll_step = -log_probs.gather(1, target.unsqueeze(1)).squeeze(1)
            
            total_nll += nll_step.sum().item()
            total_tokens += b
            latencies.append(t1 - t0)
        
    return total_nll, total_tokens, latencies
        

In [5]:
def compute_ppl_no_cache(model, windows, prefill_len, scored_len, forward_kwargs=None, window_batch_size=16):
    B, L = windows.shape
    latencies = []
    total_nll = 0.0
    total_tokens = 0
    num_batches = (B + window_batch_size - 1) // window_batch_size
    
    batch = 0
    for start in range(0, B, window_batch_size):
        print(f"Processing batch {batch+1}/{num_batches}")
        batch += 1
        
        end = min(start + window_batch_size, B)
        win_batch = windows[start:end].to(DEVICE)
        b = win_batch.size(0)
        context = win_batch[:, :prefill_len]
        
        for offset in range(scored_len):
            tok = prefill_len + offset
            input_context = win_batch[:, :tok]
            target = win_batch[:, tok]
            
            t0 = time.perf_counter()
            out = model(
                input_ids=input_context,
                use_cache=False,
                **forward_kwargs
            )
            t1 = time.perf_counter()
            
            logits = out.logits[:, -1, :]
            log_probs = torch.log_softmax(logits, dim=-1)
            nll_step = -log_probs.gather(1, target.unsqueeze(1)).squeeze(1)
            
            total_nll += nll_step.sum().item()
            total_tokens += b
            latencies.append(t1 - t0)
            
    return total_nll, total_tokens, latencies
        
        
        

In [6]:
def compute_ppl_quant_cache(model, windows, prefill_len, scored_len, forward_kwargs, window_batch_size=16, nbits=2):
    B, L = windows.shape
    latencies = []
    total_nll = 0.0
    total_tokens = 0
    num_batches = (B + window_batch_size - 1) // window_batch_size
    
    batch = 0
    for start in range(0, B, window_batch_size):
        print(f"Processing batch {batch+1}/{num_batches}")
        batch += 1
        
        end = min(start + window_batch_size, B)
        win_batch = windows[start:end].to(DEVICE)
        b = win_batch.size(0)
        
        prefill = win_batch[:, :prefill_len]
        
        cache_config = {
            "backend": "quanto",
            "nbits": nbits
        }
        
        pkvs = QuantizedCache(config=model.config, **cache_config)
        
        out = model(
            input_ids=prefill,
            past_key_values=pkvs,
            use_cache=True,
            **forward_kwargs
        )
        pkvs = out.past_key_values
        # print(f"Cache Type: {type(pkvs)}")
    
        for offset in range(scored_len):
            tok = prefill_len + offset
            
            input_token = win_batch[:, tok-1:tok]
            target = win_batch[:, tok]
            t0 = time.perf_counter()
            out = model(
                input_ids=input_token,
                past_key_values=pkvs,
                use_cache=True,
                **forward_kwargs
            )
            t1 = time.perf_counter()
            pkvs = out.past_key_values
            
            logits = out.logits[:, -1, :]
            log_probs = torch.log_softmax(logits, dim=-1)
            nll_step = -log_probs.gather(1, target.unsqueeze(1)).squeeze(1)
            
            total_nll += nll_step.sum().item()
            total_tokens += b
            latencies.append(t1 - t0)
        
    return total_nll, total_tokens, latencies
        

In [7]:
def compute_ppl(model, windows, device, prefill_len, use_cache="normal", forward_kwargs=None, quant_bits=2):
    if forward_kwargs is None:
        forward_kwargs = {}
    model.eval()
    windows = windows.to(device)
    B, L = windows.shape
    scored_len = L - prefill_len

    with torch.inference_mode():
        if use_cache == "normal":
            total_nll, total_tokens, latencies = compute_ppl_cache(
                model, windows, prefill_len, scored_len, forward_kwargs
            )
        elif use_cache == "quantized":
            total_nll, total_tokens, latencies = compute_ppl_quant_cache(
                model, windows, prefill_len, scored_len, forward_kwargs, nbits=quant_bits
            )
        else:
            total_nll, total_tokens, latencies = compute_ppl_no_cache(
                model, windows, prefill_len, scored_len, forward_kwargs
            )
    avg_nll = total_nll / total_tokens
    ppl = math.exp(avg_nll)
    avg_latency_ms = 1000.0 * sum(latencies) / len(latencies)
    return ppl, avg_latency_ms

In [8]:
def run_ppl(model, windows, device, prefill_len, use_cache="normal", forward_kwargs=None, quant_bits=2):
    return compute_ppl(model, windows, device, prefill_len, use_cache, forward_kwargs, quant_bits=quant_bits)

In [9]:
def track_memory(fn, *args, **kwargs):
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    start_alloc = torch.cuda.memory_allocated()
    t0 = time.time()
    
    result = fn(*args, **kwargs)
    
    torch.cuda.synchronize()
    t1 = time.time()
    peak_alloc = torch.cuda.max_memory_allocated()
    peak_reserved = torch.cuda.max_memory_reserved()
    
    stats = {
        "start_alloc_bytes": start_alloc,
        "peak_alloc_bytes": peak_alloc,
        "peak_reserved_bytes": peak_reserved,
        "delta_alloc_bytes": peak_alloc - start_alloc,
        "elapsed_sec": t1 - t0,
    }
    
    return result, stats

In [10]:
def warmup_gpu(model, windows, device, prefill_len, forward_kwargs=None):
    if forward_kwargs is None:
        forward_kwargs = {}

    model.eval()
    # use a tiny subset â€“ just enough to trigger kernels + quant machinery
    warmup_windows = windows[:2].to(device)

    # modes you actually benchmark; include quant settings you care about
    modes = [
        {"use_cache": "normal"},
        {"use_cache": "quantized", "quant_bits": 4},
        {"use_cache": "quantized", "quant_bits": 2},
        {"use_cache": "no_cache"},
    ]

    with torch.inference_mode():
        for cfg in modes:
            print(f"Warming up mode: {cfg}")
            _ = compute_ppl(
                model,
                warmup_windows,
                device,
                prefill_len=prefill_len,
                use_cache=cfg["use_cache"],
                forward_kwargs=forward_kwargs,
                quant_bits=cfg.get("quant_bits", 2),
            )


In [11]:
print(f"Loading wikitext")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.float16).to(DEVICE)

print("Tokenizing dataset")
ids = load_wikitext(tokenizer, ratio=0.5)
print(f"Total tokens: {ids.size(0)}")

print("Creating windows")
windows = make_windows(ids, stride=SEQ_LEN//2)
print("Success!")

Loading wikitext


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Tokenizing dataset


Token indices sequence length is longer than the specified maximum sequence length for this model (142657 > 1024). Running this sequence through the model will result in indexing errors


Total tokens: 142657
Creating windows
Success!


In [12]:
warmup_gpu(model, windows, DEVICE, prefill_len=SEQ_LEN // 2)

(result_noCache, mem_noCache) = track_memory(
    run_ppl,
    model,
    windows,
    DEVICE,
    SEQ_LEN // 2,
    False,
)

ppl_noCache, avg_latency_noCache = result_noCache

(result_normal, mem_normal) = track_memory(
    run_ppl,
    model,
    windows,
    DEVICE,
    SEQ_LEN // 2,
    "normal",
)

ppl_normal, avg_latency_normal = result_normal


(result_quant2, mem_quant2) = track_memory(
    run_ppl,
    model,
    windows,
    DEVICE,
    SEQ_LEN // 2,
    "quantized",
    quant_bits=2,
)

ppl_quant2, avg_latency_quant2 = result_quant2

(result_quant4, mem_quant4) = track_memory(
    run_ppl,
    model,
    windows,
    DEVICE,
    SEQ_LEN // 2,
    "quantized",
    quant_bits=4,
)
ppl_quant4, avg_latency_quant4 = result_quant4


Warming up mode: {'use_cache': 'normal'}
Processing batch 1/1
Warming up mode: {'use_cache': 'quantized', 'quant_bits': 4}
Processing batch 1/1


Multiple distributions found for package optimum. Picked distribution: optimum-quanto


Warming up mode: {'use_cache': 'quantized', 'quant_bits': 2}
Processing batch 1/1
Warming up mode: {'use_cache': 'no_cache'}
Processing batch 1/1
Processing batch 1/18
Processing batch 2/18
Processing batch 3/18
Processing batch 4/18
Processing batch 5/18
Processing batch 6/18
Processing batch 7/18
Processing batch 8/18
Processing batch 9/18
Processing batch 10/18
Processing batch 11/18
Processing batch 12/18
Processing batch 13/18
Processing batch 14/18
Processing batch 15/18
Processing batch 16/18
Processing batch 17/18
Processing batch 18/18
Processing batch 1/18
Processing batch 2/18
Processing batch 3/18
Processing batch 4/18
Processing batch 5/18
Processing batch 6/18
Processing batch 7/18
Processing batch 8/18
Processing batch 9/18
Processing batch 10/18
Processing batch 11/18
Processing batch 12/18
Processing batch 13/18
Processing batch 14/18
Processing batch 15/18
Processing batch 16/18
Processing batch 17/18
Processing batch 18/18
Processing batch 1/18
Processing batch 2/18


In [13]:
def model_size_bytes(model, include_buffers=True):
    param_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
    buffer_bytes = 0
    if include_buffers:
        buffer_bytes = sum(b.numel() * b.element_size() for b in model.buffers())
    return param_bytes + buffer_bytes

In [14]:
perplexities = {
    "No Cache": round(ppl_noCache, 4),
    "Normal Cache": round(ppl_normal, 4),
    "Quantized Cache (2-bit)": round(ppl_quant2, 4),
    "Quantized Cache (4-bit)": round(ppl_quant4, 4)
}

latencies = {
    "No Cache": round(avg_latency_noCache, 4),
    "Normal Cache": round(avg_latency_normal, 4),
    "Quantized Cache (2-bit)": round(avg_latency_quant2, 4),
    "Quantized Cache (4-bit)": round(avg_latency_quant4, 4)
}

peak_alloc = {
    "No Cache": round(mem_noCache["peak_alloc_bytes"] / (1024**2), 4),
    "Normal Cache": round(mem_normal["peak_alloc_bytes"] / (1024**2), 4),
    "Quantized Cache (2-bit)": round(mem_quant2["peak_alloc_bytes"] / (1024**2), 4),
    "Quantized Cache (4-bit)": round(mem_quant4["peak_alloc_bytes"] / (1024**2), 4)
}

total_elapsed = {
    "No Cache": round(mem_noCache["elapsed_sec"], 4),
    "Normal Cache": round(mem_normal["elapsed_sec"], 4),
    "Quantized Cache (2-bit)": round(mem_quant2["elapsed_sec"], 4),
    "Quantized Cache (4-bit)": round(mem_quant4["elapsed_sec"], 4)
}

results_df = pd.DataFrame({
    "Perplexity": perplexities,
    "Avg Latency (ms)": latencies,
    "Peak Alloc (MB)": peak_alloc,
    "Total Elapsed (s)": total_elapsed
})

model_size = round(model_size_bytes(model) / (1024**2), 4)
results_df["Model Size (MB)"] = model_size
results_df["Cache Size (MB)"] = (results_df["Peak Alloc (MB)"] - model_size).round(4)

results_df = results_df.round(4)

results_df

Unnamed: 0,Perplexity,Avg Latency (ms),Peak Alloc (MB),Total Elapsed (s),Model Size (MB),Cache Size (MB)
No Cache,24.5021,16.002,3430.2461,479.9001,249.3501,3180.896
Normal Cache,24.5416,7.6514,1932.2983,72.1479,249.3501,1682.9482
Quantized Cache (2-bit),34.0366,12.0098,1207.8755,112.4732,249.3501,958.5254
Quantized Cache (4-bit),24.6361,11.9507,1316.019,111.8905,249.3501,1066.6689


In [15]:
print("\a")


