# Multi-Word Tokenizer Lab

This notebook implements a tokenizer that treats specific multi-word phrases as single tokens and compares its performance against the standard GPT-2 BPE tokenizer using the OpenWebText dataset.

In [68]:
# Install necessary libraries if not present
import sys
import subprocess

packages = ["datasets", "tiktoken", "regex", "requests", "torch", "tqdm"]
subprocess.check_call([sys.executable, "-m", "pip", "install"] + packages)



0

In [69]:
import re
import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, IterableDataset
from datasets import load_dataset
from collections import Counter, defaultdict
import tqdm

## 1. Load Dataset

We use `Skylion007/openwebtext` (WebText replication) via streaming.

In [70]:
print("Loading dataset (streaming)...")
dataset = load_dataset("Skylion007/openwebtext", split="train", streaming=True)

Loading dataset (streaming)...


## 3. Multi-Word Tokenizer Implementation

This tokenizer will:
1.  Identify specific multi-word phrases in the text.
2.  Tokenize them as single units.
3.  Delegate the rest to the base tokenizer.

In [71]:
class MultiWordTokenizer:
    def __init__(self, phrases):
        """
        phrases: list of strings (e.g., ["New York", "machine learning"])
        """
        self.base_tokenizer = tiktoken.get_encoding("gpt2")
        self.base_vocab_size = self.base_tokenizer.n_vocab

        # Mappings for custom phrase tokens
        self.phrase_to_id = {}
        self.id_to_phrase = {}

        next_id = self.base_vocab_size

        # Flattened list for regex
        all_phrases = []

        for phrase in phrases:
            # Register original phrase
            self.phrase_to_id[phrase] = next_id
            self.id_to_phrase[next_id] = phrase
            all_phrases.append(phrase)
            next_id += 1

            # Register phrase with leading space (CRITICAL for BPE efficiency)
            # This prevents breaking " word" tokens into " " + "word"
            space_phrase = " " + phrase
            self.phrase_to_id[space_phrase] = next_id
            self.id_to_phrase[next_id] = space_phrase
            all_phrases.append(space_phrase)
            next_id += 1

        self.vocab_size = next_id

        # Create a regex pattern to match phrases first
        # Sort by length (descending) to match longest phrases first (e.g. " New York" before "New York")
        sorted_phrases = sorted(all_phrases, key=len, reverse=True)
        self.phrase_pattern = re.compile("|".join(map(re.escape, sorted_phrases)))

    def encode(self, text, allowed_special=None):
        # Split text by phrases
        tokens = []
        last_end = 0

        # tiktoken encode needs allowed_special={'<|endoftext|>'} if present
        special = {'<|endoftext|>'} if allowed_special is None else allowed_special

        for match in self.phrase_pattern.finditer(text):
            start, end = match.span()

            # Process text before the match with base tokenizer
            if start > last_end:
                pre_text = text[last_end:start]
                # If pre_text ends with a space but we matched a non-space phrase, 
                # we might still be splitting "word " + "phrase". 
                # But since we prioritized " phrase", this is less likely to happen for the phrase itself.
                tokens.extend(self.base_tokenizer.encode(pre_text, allowed_special=special))

            # Add the phrase token
            phrase = match.group()
            tokens.append(self.phrase_to_id[phrase])

            last_end = end

        # Process remaining text
        if last_end < len(text):
            tokens.extend(self.base_tokenizer.encode(text[last_end:], allowed_special=special))

        return tokens

    def decode(self, ids):
        # Split ids into chunks of (base_tokens) and (phrase_tokens)
        decoded_parts = []
        current_chunk = []

        for id in ids:
            if id < self.base_vocab_size:
                current_chunk.append(id)
            else:
                # Decode accumulated base tokens
                if current_chunk:
                    decoded_parts.append(self.base_tokenizer.decode(current_chunk))
                    current_chunk = []
                # Decode phrase token
                decoded_parts.append(self.id_to_phrase.get(id, ""))

        # Decode final chunk
        if current_chunk:
             decoded_parts.append(self.base_tokenizer.decode(current_chunk))

        return "".join(decoded_parts)

    @property
    def vocab(self):
        # Fake vocab dict for compatibility if needed
        return {**{str(i):i for i in range(self.base_vocab_size)}, **self.phrase_to_id}

## 4. Streaming Vocab Building

We use `tqdm` for progress tracking and implement periodic pruning to prevent Out-Of-Memory errors when processing a massive dataset.

In [None]:
def build_vocab_stream(dataset, limit=None):
    """
    Builds phrase counts by streaming through the dataset.
    (Base vocab is handled by tiktoken)
    """
    phrase_counts_2 = Counter()
    phrase_counts_3 = Counter()

    MAX_PHRASE_ENTRIES = 500000

    print("Starting streaming phrase extraction...")
    if limit:
        pbar = tqdm.tqdm(total=limit)
    else:
        pbar = tqdm.tqdm(desc="Processing docs")

    i = 0
    try:
        for index, doc in enumerate(dataset):
            if limit is not None and i >= limit:
                break

            text = doc['text']

            # Phrases (n-grams)
            # Use regex to find alphanumeric tokens, avoiding punctuation issues
            words = re.findall(r'\b[a-zA-Z]+\b', text)
            words = [w for w in words if w.strip()]

            if len(words) >= 2:
                ngrams_2 = zip(*[words[i:] for i in range(2)])
                phrase_counts_2.update(" ".join(ngram) for ngram in ngrams_2)

            if len(words) >= 3:
                ngrams_3 = zip(*[words[i:] for i in range(3)])
                phrase_counts_3.update(" ".join(ngram) for ngram in ngrams_3)

            if i % 10000 == 0 and i > 0:
                 if len(phrase_counts_2) > MAX_PHRASE_ENTRIES:
                     phrase_counts_2 = Counter({k: c for k, c in phrase_counts_2.items() if c > 1})
                 if len(phrase_counts_3) > MAX_PHRASE_ENTRIES:
                     phrase_counts_3 = Counter({k: c for k, c in phrase_counts_3.items() if c > 1})

            i += 1
            pbar.update(1)

    except KeyboardInterrupt:
        print("\nInterrupted by user.")

    pbar.close()
    print(f"\nFinished processing. Documents processed: {i}")
    return phrase_counts_2, phrase_counts_3


# Set limit for testing.
processed_limit = 100000

phrases_2, phrases_3 = build_vocab_stream(dataset, limit=processed_limit)

# --- FILTERING LOGIC ---
# Expand stop words to catch more common glue phrases
stop_words = {
}

def is_interesting(phrase):
    parts = phrase.split()
    # Check if all parts are stopwords (e.g., "of the", "in a")
    if all(p.lower() in stop_words for p in parts):
        return False
    return True

print("Filtering phrases...")
filtered_2 = Counter({p: c for p, c in phrases_2.items() if is_interesting(p)})
filtered_3 = Counter({p: c for p, c in phrases_3.items() if is_interesting(p)})

# Select top phrases from the FILTERED lists
common_phrases = [p for p, c in filtered_2.most_common(5000)] + [p for p, c in filtered_3.most_common(2000)]

print(f"Top 10 INTERESTING phrases (Actual Vocab): {common_phrases[:10]}")


Starting streaming phrase extraction...


100%|██████████| 100000/100000 [01:39<00:00, 1000.73it/s]



Finished processing. Documents processed: 100000
Filtering phrases...
Top 10 INTERESTING phrases (Actual Vocab): ['of the', 'in the', 'to the', 'on the', 'for the', 'to be', 'and the', 'at the', 'with the', 'in a']


In [73]:
# Initialize Tokenizers
gpt2_tokenizer = tiktoken.get_encoding("gpt2")
mw_tokenizer = MultiWordTokenizer(common_phrases)

# Tiktoken base vocab is ~50257. Our tokens start after that.
print(f"Multi-Word Tokenizer Total Vocab Size: {mw_tokenizer.vocab_size}")


Multi-Word Tokenizer Total Vocab Size: 52257


## 5. Comparision Results

In [74]:
def evaluate_tokenizer(tokenizer, text, name):
    encoded = tokenizer.encode(text)
    token_count = len(encoded)
    chars = len(text)
    compression_ratio = chars / token_count if token_count > 0 else 0
    print(f"--- {name} ---")
    print(f"Token Count: {token_count}")
    print(f"Compression Ratio: {compression_ratio:.2f} chars/token")
    return token_count


# Run comparison on a larger, random sample from the dataset
print("Fetching a random sample from dataset...")
try:
    # We use a new shuffled instance to get a random sample
    # Using a different seed or buffer ensures some randomness
    temp_ds = load_dataset("Skylion007/openwebtext", split="train", streaming=True).shuffle(seed=42, buffer_size=1000)
    sample_doc = next(iter(temp_ds))
    sample_text = sample_doc['text']

    # Ensure it's substantial
    if len(sample_text) < 1000:
        # If too short, grab another
        sample_doc = next(iter(temp_ds))
        sample_text += "\n\n" + sample_doc['text']

    # Truncate if too huge (e.g., limit to 5000 chars for clean output)
    if len(sample_text) > 5000:
        sample_text = sample_text[:5000]

    print(f"Sample Text Length: {len(sample_text)} characters")
    print(f"Sample snippet: {sample_text[:100]}...")
except Exception as e:
    print(f"Error fetching sample: {e}")
    sample_text = "The quick brown fox jumps over the lazy dog. New York is a big city. " * 100

mw_count = evaluate_tokenizer(mw_tokenizer, sample_text, "Multi-Word Tokenizer")
gpt2_count = evaluate_tokenizer(gpt2_tokenizer, sample_text, "GPT-2 BPE (tiktoken)")

diff = gpt2_count - mw_count
pct = (diff / gpt2_count) * 100
print(f"\nDifference: {diff} tokens ({pct:.2f}%)")


Fetching a random sample from dataset...
Sample Text Length: 5000 characters
Sample snippet: If you live in Minneapolis’ Powderhorn neighborhood, keep your eyes peeled on Aug. 14. Starting that...
--- Multi-Word Tokenizer ---
Token Count: 1100
Compression Ratio: 4.55 chars/token
--- GPT-2 BPE (tiktoken) ---
Token Count: 1175
Compression Ratio: 4.26 chars/token

Difference: 75 tokens (6.38%)


In [97]:
# Spot check
#sample_phrase = common_phrases[0] if common_phrases else "test phrase"
sample_phrase = "new york"
print(f"Sample Phrase: '{sample_phrase}'")
print(f"MW Encoded: {mw_tokenizer.encode(sample_phrase)}")
print(f"GPT-2 Encoded: {gpt2_tokenizer.encode(sample_phrase)}")

Sample Phrase: 'new york'
MW Encoded: [3605, 331, 967]
GPT-2 Encoded: [3605, 331, 967]


## 6. Model Implementation (GPT-2)

Copying the GPT-2 model architecture from `llm-from-scratch.ipynb`.

In [76]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))

class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec

class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut
        return x

class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])

        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])

        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(
            cfg["emb_dim"], cfg["vocab_size"], bias=False
        )

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

## 7. Streaming Data Loaders

We implement a `StreamingTextDataset` using `IterableDataset` to handle the streaming data.

In [77]:
class StreamingTextDataset(IterableDataset):
    def __init__(self, hf_dataset, tokenizer, max_length, stride):
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.stride = stride

    def __iter__(self):
        # Iterate over streaming HF dataset
        buffer_tokens = []

        for doc in self.hf_dataset:
            text = doc['text']

            # Tokenize
            if hasattr(self.tokenizer, "encode"):
                 try:
                     # Use allowed_special for both tokenizers to be safe
                     token_ids = self.tokenizer.encode(text, allowed_special={"<|endoftext|>"})
                 except TypeError:
                     token_ids = self.tokenizer.encode(text)
            else:
                 continue # skip invalid

            buffer_tokens.extend(token_ids)

            while len(buffer_tokens) >= self.max_length + 1:
                chunk = buffer_tokens[:self.max_length + 1]
                input_chunk = torch.tensor(chunk[:-1])
                target_chunk = torch.tensor(chunk[1:])
                yield input_chunk, target_chunk

                buffer_tokens = buffer_tokens[self.stride:]

def create_streaming_dataloader(hf_dataset, tokenizer, batch_size=4, max_length=256, stride=128):
    dataset = StreamingTextDataset(hf_dataset, tokenizer, max_length, stride)
    return DataLoader(dataset, batch_size=batch_size)


## 8. Training Loop

A training loop adapted for infinite streams (based on steps, not epochs).

In [78]:
def train_model(model, train_loader, optimizer, device, max_steps=None, eval_freq=1000):
    model.to(device)
    model.train()

    step = 0
    loss_val = 0.0

    # Iterate over stream
    for input_batch, target_batch in train_loader:
        if max_steps is not None and step >= max_steps:
            break

        input_batch, target_batch = input_batch.to(device), target_batch.to(device)

        optimizer.zero_grad()
        logits = model(input_batch)
        loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
        loss.backward()
        optimizer.step()

        loss_val = loss.item()

        if step % eval_freq == 0:
            print(f"Step {step}, Loss: {loss_val:.4f}")

        step += 1

    return model

## 9. Train Models & Compare

We will train two small GPT models. One with standard BPE, one with Multi-Word Tokenizer.

In [79]:
# Configuration for small model (to run fast in lab)
# Significantly smaller than GPT-2 Small to allow quick training
SMALL_GPT_CONFIG = {
    "vocab_size": 50257,    # Will be updated per tokenizer
    "context_length": 128,  # Context length
    "emb_dim": 256,         # Embedding dimension
    "n_heads": 4,           # Number of attention heads
    "n_layers": 4,          # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # QKV bias
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [81]:
# 1. Baseline Model (BPE)
print("--- Training Baseline BPE Model ---")
bpe_config = SMALL_GPT_CONFIG.copy()
bpe_config["vocab_size"] = 50257 # Standard GPT-2 vocab size

bpe_model = GPTModel(bpe_config)
bpe_optimizer = torch.optim.AdamW(bpe_model.parameters(), lr=5e-4, weight_decay=0.1)

# Create NEW streaming dataset iterator for training
dataset_bpe = load_dataset("Skylion007/openwebtext", split="train", streaming=True)
dataset_bpe = dataset_bpe.shuffle(seed=42, buffer_size=1000)
bpe_loader = create_streaming_dataloader(dataset_bpe, gpt2_tokenizer, batch_size=4, max_length=SMALL_GPT_CONFIG["context_length"])

bpe_model = train_model(bpe_model, bpe_loader, bpe_optimizer, device, max_steps=6000, eval_freq=1000)


--- Training Baseline BPE Model ---
Step 0, Loss: 10.9786
Step 1000, Loss: 6.8831
Step 2000, Loss: 6.7784
Step 3000, Loss: 5.7678
Step 4000, Loss: 6.4348
Step 5000, Loss: 6.5063


In [82]:
# 2. Multi-Word Model
print("\n--- Training Multi-Word Tokenizer Model ---")
mw_config = SMALL_GPT_CONFIG.copy()
mw_config["vocab_size"] = mw_tokenizer.vocab_size # Use explicit vocab_size

mw_model = GPTModel(mw_config)
mw_optimizer = torch.optim.AdamW(mw_model.parameters(), lr=5e-4, weight_decay=0.1)

# Create NEW streaming dataset iterator for training
dataset_mw = load_dataset("Skylion007/openwebtext", split="train", streaming=True)
dataset_mw = dataset_mw.shuffle(seed=42, buffer_size=1000)
mw_loader = create_streaming_dataloader(dataset_mw, mw_tokenizer, batch_size=4, max_length=SMALL_GPT_CONFIG["context_length"])

mw_model = train_model(mw_model, mw_loader, mw_optimizer, device, max_steps=6000, eval_freq=1000)



--- Training Multi-Word Tokenizer Model ---
Step 0, Loss: 10.9875
Step 1000, Loss: 7.6526
Step 2000, Loss: 7.0121
Step 3000, Loss: 6.8444
Step 4000, Loss: 6.5920
Step 5000, Loss: 5.8412


## 10. Qualitative Comparison

Generate text from both models to see if there's a difference in coherence (given limited training).

In [96]:
def generate_text(model, tokenizer, start_context, max_new_tokens, context_length, device, temperature=0.7, top_k=None):
    model.eval()

    # Handle encoding variations
    if hasattr(tokenizer, "encode"):
         try:
             encoded = tokenizer.encode(start_context, allowed_special={"<|endoftext|>"})
         except TypeError:
             encoded = tokenizer.encode(start_context)

    idx = torch.tensor(encoded).unsqueeze(0).to(device)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_length:]
        with torch.no_grad():
            logits = model(idx_cond)

        logits = logits[:, -1, :]
        
        # Apply temperature
        if temperature > 0.0:
            logits = logits / temperature
        
        # Apply Top-K filtering
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')

        if temperature > 0.0:
            probas = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probas, num_samples=1)
        else:
            # Greedy decoding if temperature is 0
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)
            
        idx = torch.cat((idx, idx_next), dim=1)

    return tokenizer.decode(idx.squeeze(0).tolist())

start_text = "I love"

print(f"Prompt: '{start_text}'")

print("\n--- Baseline BPE Generation (Temp=0.7, TopK=50) ---")
print(generate_text(bpe_model, gpt2_tokenizer, start_text, 200, SMALL_GPT_CONFIG["context_length"], device, temperature=0.7, top_k=50))

print("\n--- Multi-Word Generation (Temp=0.7, TopK=50) ---")
print(generate_text(mw_model, mw_tokenizer, start_text, 200, SMALL_GPT_CONFIG["context_length"], device, temperature=0.7, top_k=50))

Prompt: 'I love'

--- Baseline BPE Generation (Temp=0.7, TopK=50) ---
I love for me up. A lot of my own experience with the best-like, and I was also in the same and it was the rest of the country, and the right of a long of the most important-game. He's no longer I was able to get his son of the most of his mother in the first and I was that I was a great of the first time.

She's a number of the time of the same week, when I was an opportunity to know, and I was no longer’s son in their way to know for me of the two of the fact that I’s name.


We’s mother in the next year after the “I am you a two-ball and was “This is really in his way to be in the rest of the other rights of a good and I was taken by both sides

“We were so I was a pretty good of a good-to-up

--- Multi-Word Generation (Temp=0.7, TopK=50) ---
I love, the anti-free, even and no time.

The former two games, the former first-right-right, he knew a left-free is where we're able to keep her.

It's really are to identif

## 11. TriviaQA Evaluation

We use the TriviaQA dataset (rc.nocontext) to measure perplexity. Lower perplexity indicates the model is less "surprised" by the text, suggesting better language modeling capability.

In [94]:
def calculate_perplexity(model, tokenizer, text, context_length, device):
    model.eval()

    if hasattr(tokenizer, "encode"):
       encodings = tokenizer.encode(text)

    if len(encodings) < 2:
        return float('nan')

    input_ids = torch.tensor(encodings).to(device)

    max_len = context_length
    stride = max_len

    seq_len = input_ids.size(0)
    prev_end_loc = 0

    nlls = []

    for begin_loc in range(0, seq_len - 1, stride):
        end_loc = min(begin_loc + max_len, seq_len)

        if end_loc - begin_loc < 2:
             break

        input_chunk = input_ids[begin_loc:end_loc]

        inputs = input_chunk[:-1].unsqueeze(0)
        targets = input_chunk[1:].unsqueeze(0)

        with torch.no_grad():
            logits = model(inputs)
            loss = torch.nn.functional.cross_entropy(logits.permute(0, 2, 1), targets, reduction='sum')
            nlls.append(loss.item())

    if not nlls:
        return float('nan')

    total_nll = sum(nlls)
    total_tokens_predicted = sum([min(i + max_len, seq_len) - i - 1 for i in range(0, seq_len - 1, stride)])

    if total_tokens_predicted <= 0: return float('nan')

    return torch.exp(torch.tensor(total_nll / total_tokens_predicted))

print("Loading TriviaQA (rc.nocontext) validation subset...")
trivia_data = load_dataset("trivia_qa", "rc.nocontext", split="validation", streaming=True)

# Evaluate on first 5000 samples
eval_limit = 5000
bpe_ppls = []
mw_ppls = []

print(f"Evaluating on {eval_limit} examples...")

for i, item in enumerate(trivia_data):
    if i >= eval_limit: break

    question = item['question']
    text = f"Question: {question}\nAnswer:"

    ppl_bpe = calculate_perplexity(bpe_model, gpt2_tokenizer, text, SMALL_GPT_CONFIG["context_length"], device)
    ppl_mw = calculate_perplexity(mw_model, mw_tokenizer, text, SMALL_GPT_CONFIG["context_length"], device)

    if not torch.isnan(ppl_bpe): bpe_ppls.append(ppl_bpe)
    if not torch.isnan(ppl_mw): mw_ppls.append(ppl_mw)

avg_bpe_ppl = torch.tensor(bpe_ppls).mean().item()
avg_mw_ppl = torch.tensor(mw_ppls).mean().item()

print(f"\n--- Results ---")
print(f"GPT-2 BPE Average Perplexity: {avg_bpe_ppl:.4f}")
print(f"Multi-Word Average Perplexity: {avg_mw_ppl:.4f}")

if avg_mw_ppl < avg_bpe_ppl:
    print("\nWinner: Multi-Word Tokenizer Model is less surprised by the data.")
else:
    print("\nWinner: Baseline BPE Model is less surprised by the data.")


Loading TriviaQA (rc.nocontext) validation subset...
Evaluating on 5000 examples...

--- Results ---
GPT-2 BPE Average Perplexity: 1872.2816
Multi-Word Average Perplexity: 2875.7637

Winner: Baseline BPE Model is less surprised by the data.
