# 01 - Byte Pair Encoding from Scratch

## Why Tokenization Is the Hidden Foundation of LLMs

Every large language model -- GPT, LLaMA, Claude -- operates on **tokens**, not
characters or words. Before a single neuron fires, a tokenizer converts raw text
into a sequence of integers drawn from a fixed vocabulary. The model never sees
the original string; it sees only these integer IDs.

This means the tokenizer quietly shapes everything downstream:

- **Context window** -- a 4,096-token limit is not 4,096 words. Depending on
  the tokenizer, it might be 3,000 words of English prose or only 2,500 words
  of dense legal text.
- **Model comprehension** -- if a legal citation like `42 U.S.C. ยง 1983` is
  split into 8 tokens, the model must learn to reassemble those fragments into
  a meaningful concept. A single-token representation would be far easier to
  learn from.
- **API costs** -- commercial APIs charge per token. Legal text that tokenizes
  20% less efficiently costs 20% more to process.

In this notebook, we build the most widely used tokenization algorithm --
**Byte Pair Encoding (BPE)** -- from scratch in about 50 lines of Python.

## Theory: Character vs Word vs Subword Tokenization

There are three broad families of tokenization. Each makes a different tradeoff
between vocabulary size and sequence length.

### Character-Level

Split text into individual characters. The vocabulary is tiny (a few hundred
entries for English + punctuation + digits), but sequences become very long.
The word "jurisprudence" becomes 13 tokens. A 10-page court opinion might
exceed 50,000 tokens -- far beyond most models' context windows.

| Pros | Cons |
|------|------|
| Tiny vocabulary | Very long sequences |
| Handles any text | Model must learn spelling from scratch |
| No out-of-vocabulary tokens | Poor compression ratio |

### Word-Level

Split on whitespace and punctuation. Common words become single tokens, but
the vocabulary must be enormous to cover all observed words. Rare words
("certiorari", "PFAS", "QD-5000") are either unknown or require a special
`[UNK]` token that destroys information.

| Pros | Cons |
|------|------|
| Intuitive token boundaries | Huge vocabulary (100K+) |
| Short sequences | Cannot handle unseen words |
| Good compression | Morphology is ignored |

### Subword (BPE)

The sweet spot. Start with characters, then iteratively merge the most frequent
adjacent pairs. Common words like "the" and "court" stay whole. Rare words
like "certiorari" are split into learned subword pieces (e.g., `cert` + `ior` +
`ari`). The vocabulary size is a tunable hyperparameter, typically 32K-100K.

| Pros | Cons |
|------|------|
| Balanced vocabulary size | Tokenization is not linguistically motivated |
| Handles unseen words gracefully | Domain-specific text may tokenize poorly |
| Good compression ratio | Merges are corpus-dependent |

**BPE is the algorithm behind GPT-2, GPT-3, GPT-4, LLaMA, and most modern
LLMs.** Let's build it.

## BPE Algorithm Explained

The BPE training procedure is surprisingly simple:

1. **Initialize** -- Start with a vocabulary of individual characters (or bytes).
   Represent each word in the corpus as a sequence of characters separated by
   spaces, with a special end-of-word marker `</w>`.

2. **Count pairs** -- For every adjacent pair of symbols in the vocabulary,
   count how many times it appears across the entire corpus.

3. **Merge the top pair** -- Find the most frequent pair and merge it into a
   single new symbol. Add this new symbol to the vocabulary.

4. **Repeat** -- Go back to step 2. Each iteration adds one merge rule and
   one new vocabulary entry.

### Visual Example

Corpus: `the court held that the court found`

```
Initial vocabulary:
  t h e </w> c o u r d l a f n

Initial word representations:
  "the"   -> t h e </w>        (freq: 2)
  "court" -> c o u r t </w>    (freq: 2)
  "held"  -> h e l d </w>      (freq: 1)
  "that"  -> t h a t </w>      (freq: 1)
  "found" -> f o u n d </w>    (freq: 1)

Step 1: Most frequent pair is (t, h) with count 3
  Merge: t h -> th
  "the"   -> th e </w>
  "court" -> c o u r t </w>    (no change -- t not adjacent to h)
  "held"  -> h e l d </w>      (no change -- h not adjacent to t here)
  "that"  -> th a t </w>

Step 2: Most frequent pair is (th, e) with count 2
  Merge: th e -> the
  "the"   -> the </w>
  "that"  -> th a t </w>       (no change -- th is followed by a, not e)

Step 3: Most frequent pair is (the, </w>) with count 2
  Merge: the </w> -> the</w>
  "the"   -> the</w>           (single token!)
```

After just 3 merges, "the" is a single token. The algorithm naturally discovers
that common words should be atomic.

## Setup

In [None]:
import json
import re
from collections import Counter, defaultdict
from pathlib import Path

import matplotlib.pyplot as plt

## Load Legal Text Corpus

We use court opinions from our sample dataset as the training corpus for BPE.

In [None]:
DATA_PATH = Path("../../datasets/sample/court_opinions.jsonl")

corpus_texts = []
with open(DATA_PATH) as f:
    for line in f:
        record = json.loads(line)
        corpus_texts.append(record["text"])

# Combine all opinions into one training corpus
training_corpus = " ".join(corpus_texts)
print(f"Loaded {len(corpus_texts)} court opinions")
print(f"Total corpus size: {len(training_corpus):,} characters")
print(f"\nFirst 300 characters:")
print(training_corpus[:300])

## BPE from Scratch

Here is the complete implementation in five functions.

In [None]:
def build_vocab(text: str) -> dict[tuple[str, ...], int]:
    """Convert raw text into a word-frequency vocabulary.

    Each word is represented as a tuple of characters with a special
    end-of-word marker '</w>'. The dictionary maps each word-tuple to
    its frequency in the corpus.
    """
    words = text.split()
    vocab: dict[tuple[str, ...], int] = Counter()
    for word in words:
        # Represent each word as character tuple + end-of-word marker
        symbols = tuple(word) + ("</w>",)
        vocab[symbols] += 1
    return dict(vocab)

In [None]:
def get_stats(vocab: dict[tuple[str, ...], int]) -> dict[tuple[str, str], int]:
    """Count the frequency of every adjacent symbol pair in the vocabulary."""
    pairs: dict[tuple[str, str], int] = defaultdict(int)
    for symbols, freq in vocab.items():
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i + 1])] += freq
    return dict(pairs)

In [None]:
def merge_vocab(
    pair: tuple[str, str], vocab: dict[tuple[str, ...], int]
) -> dict[tuple[str, ...], int]:
    """Merge all occurrences of `pair` in every word of the vocabulary."""
    new_vocab = {}
    bigram = pair  # e.g., ('t', 'h')
    replacement = "".join(pair)  # e.g., 'th'
    for symbols, freq in vocab.items():
        new_symbols: list[str] = []
        i = 0
        while i < len(symbols):
            # Look for the bigram at position i
            if i < len(symbols) - 1 and (symbols[i], symbols[i + 1]) == bigram:
                new_symbols.append(replacement)
                i += 2
            else:
                new_symbols.append(symbols[i])
                i += 1
        new_vocab[tuple(new_symbols)] = freq
    return new_vocab

In [None]:
def train_bpe(
    text: str, num_merges: int
) -> tuple[list[tuple[str, str]], dict[tuple[str, ...], int]]:
    """Train BPE on the given text for a specified number of merges.

    Returns:
        merges: Ordered list of merge operations (pair that was merged).
        vocab: Final vocabulary after all merges.
    """
    vocab = build_vocab(text)
    merges: list[tuple[str, str]] = []

    for i in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            print(f"No more pairs to merge at step {i}. Stopping early.")
            break
        # Find the most frequent pair
        best_pair = max(pairs, key=pairs.get)
        vocab = merge_vocab(best_pair, vocab)
        merges.append(best_pair)

    return merges, vocab

In [None]:
def encode(text: str, merges: list[tuple[str, str]]) -> list[str]:
    """Encode text into BPE tokens using a learned merge list."""
    tokens: list[str] = []
    for word in text.split():
        # Start with character-level representation
        symbols = list(word) + ["</w>"]
        # Apply each merge rule in order
        for pair in merges:
            merged = "".join(pair)
            i = 0
            new_symbols: list[str] = []
            while i < len(symbols):
                if (
                    i < len(symbols) - 1
                    and (symbols[i], symbols[i + 1]) == pair
                ):
                    new_symbols.append(merged)
                    i += 2
                else:
                    new_symbols.append(symbols[i])
                    i += 1
            symbols = new_symbols
        tokens.extend(symbols)
    return tokens


def decode(tokens: list[str]) -> str:
    """Decode a list of BPE tokens back into text."""
    text = "".join(tokens)
    # Replace end-of-word markers with spaces
    text = text.replace("</w>", " ")
    return text.strip()

## Training on the Legal Corpus

Let's train BPE with 500 merges and inspect the merge sequence.

In [None]:
NUM_MERGES = 500
merges, final_vocab = train_bpe(training_corpus, NUM_MERGES)

print(f"Completed {len(merges)} merges.")
print(f"Final vocabulary size: {len(set(sym for word in final_vocab for sym in word))}")
print(f"\nFirst 30 merges (in order):")
for i, (a, b) in enumerate(merges[:30]):
    print(f"  {i + 1:3d}. '{a}' + '{b}' -> '{a}{b}'")

### Examining the Vocabulary at Different Merge Counts

Let's see how the vocabulary evolves as we increase the number of merges.
Tokens that appear as whole words in the vocabulary represent concepts the
tokenizer has learned to treat as atomic.

In [None]:
def get_token_vocab(vocab: dict[tuple[str, ...], int]) -> set[str]:
    """Extract the set of unique symbols from a BPE vocabulary."""
    return set(sym for word in vocab for sym in word)


# Train at multiple merge counts to compare
for n_merges in [50, 100, 200, 500]:
    m, v = train_bpe(training_corpus, n_merges)
    tokens = get_token_vocab(v)
    # Show the longest tokens (these are the most "complete" words)
    longest = sorted(tokens, key=len, reverse=True)[:10]
    print(f"\n--- {n_merges} merges (vocab size: {len(tokens)}) ---")
    print(f"Longest tokens: {longest}")

Notice how common legal terms gradually become single tokens as the number of
merges increases. Words like "the", "court", "that", and "of" are among the
first to be merged into whole tokens because they appear so frequently in
judicial opinions.

## Visualization

### Vocabulary Size vs Number of Merges

Each merge adds one new symbol to the vocabulary. Let's plot the growth and
track when specific legal terms become single tokens.

In [None]:
# Track vocabulary growth during training
vocab_for_plot = build_vocab(training_corpus)
initial_vocab_size = len(get_token_vocab(vocab_for_plot))

merge_counts = [0]
vocab_sizes = [initial_vocab_size]

# Track when specific terms become single tokens
target_words = ["the</w>", "court</w>", "that</w>", "plaintiff</w>", "defendant</w>"]
word_merge_points: dict[str, int | None] = {w: None for w in target_words}

vocab_tracking = build_vocab(training_corpus)
for i in range(NUM_MERGES):
    pairs = get_stats(vocab_tracking)
    if not pairs:
        break
    best_pair = max(pairs, key=pairs.get)
    vocab_tracking = merge_vocab(best_pair, vocab_tracking)

    current_tokens = get_token_vocab(vocab_tracking)
    merge_counts.append(i + 1)
    vocab_sizes.append(len(current_tokens))

    # Check if any target words just became single tokens
    for target in target_words:
        if word_merge_points[target] is None and target in current_tokens:
            word_merge_points[target] = i + 1

print("Merge at which each word becomes a single token:")
for word, step in sorted(word_merge_points.items(), key=lambda x: x[1] or 9999):
    label = word.replace("</w>", "")
    if step is not None:
        print(f"  '{label}' -> merge #{step}")
    else:
        print(f"  '{label}' -> not reached in {NUM_MERGES} merges")

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(merge_counts, vocab_sizes, linewidth=2, color="#2563eb")

# Annotate when target words become single tokens
for word, step in word_merge_points.items():
    if step is not None:
        label = word.replace("</w>", "")
        size_at_step = vocab_sizes[step]
        ax.annotate(
            f'"{label}"',
            xy=(step, size_at_step),
            xytext=(step + 20, size_at_step + 5),
            arrowprops=dict(arrowstyle="->", color="#dc2626"),
            fontsize=9,
            color="#dc2626",
        )
        ax.plot(step, size_at_step, "o", color="#dc2626", markersize=6)

ax.set_xlabel("Number of Merges")
ax.set_ylabel("Vocabulary Size")
ax.set_title("BPE Vocabulary Growth on Legal Corpus")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

### How Legal Terms Emerge

Let's look at how specific legal terms are tokenized at different merge counts.

In [None]:
legal_terms = [
    "court",
    "plaintiff",
    "defendant",
    "summary",
    "judgment",
    "injunction",
    "discrimination",
    "infringement",
]

print(f"{'Term':<20} {'50 merges':<25} {'200 merges':<25} {'500 merges'}")
print("-" * 95)

for n_merges in [50, 200, 500]:
    m, _ = train_bpe(training_corpus, n_merges)
    globals()[f"merges_{n_merges}"] = m

for term in legal_terms:
    tok_50 = encode(term, globals()["merges_50"])
    tok_200 = encode(term, globals()["merges_200"])
    tok_500 = encode(term, globals()["merges_500"])
    t50 = " | ".join(tok_50)
    t200 = " | ".join(tok_200)
    t500 = " | ".join(tok_500)
    print(f"{term:<20} {t50:<25} {t200:<25} {t500}")

## Round-Trip Test

A correct tokenizer must preserve text through an encode-decode round trip.
Let's verify this with a legal passage.

In [None]:
test_passage = (
    "The district court granted summary judgment in favor of the defendant "
    "Meridian Health Systems concluding that the plaintiff failed to establish "
    "a prima facie case of discrimination under the Americans with Disabilities Act"
)

# Encode with the 500-merge model
encoded = encode(test_passage, merges)
decoded = decode(encoded)

print(f"Original:  {test_passage}")
print(f"\nEncoded:   {encoded[:20]}... ({len(encoded)} tokens total)")
print(f"\nDecoded:   {decoded}")
print(f"\nRound-trip match: {test_passage == decoded}")
assert test_passage == decoded, "Round-trip failed!"
print("\nRound-trip test PASSED.")

In [None]:
# Test with a more complex passage including citations
complex_passage = (
    "We review the district court grant of summary judgment de novo "
    "construing all facts and drawing all reasonable inferences in "
    "favor of the nonmoving party"
)

encoded_complex = encode(complex_passage, merges)
decoded_complex = decode(encoded_complex)

print(f"Original:  {complex_passage}")
print(f"Decoded:   {decoded_complex}")
print(f"Match:     {complex_passage == decoded_complex}")
assert complex_passage == decoded_complex, "Round-trip failed!"
print(f"Tokens:    {len(encoded_complex)}")
print(f"\nToken breakdown:")
for token in encoded_complex:
    display = token.replace('</w>', '_')
    print(f"  [{display}]", end="")
print()

## Exercises

### Exercise (a): Unicode-Aware BPE

The implementation above splits on whitespace and treats each character as a
symbol. This works for ASCII text but breaks on Unicode characters that legal
text commonly uses:

- Section symbol: `\u00a7` ("\u00a7")
- Em dash: `\u2014` ("\u2014")
- Paragraph symbol: `\u00b6` ("\u00b6")
- Various quotation marks: `\u201c` `\u201d` `\u2018` `\u2019`

Modify the `build_vocab` function to handle Unicode properly. Two approaches:

1. **Byte-level BPE**: Convert text to UTF-8 bytes and run BPE on byte values
   (this is what GPT-2 does).
2. **Unicode-aware character BPE**: Use Python's built-in Unicode support to
   correctly iterate over characters (Python strings are already Unicode, so
   the main issue is ensuring your pre-tokenization regex handles punctuation
   categories correctly).

Test your modified version on: `"Pursuant to 42 U.S.C. \u00a7 1983, the plaintiff\u2014"`

### Exercise (b): Compression Ratio Comparison

Compare how well BPE compresses legal text versus general English:

1. Take a sample of general English text (e.g., a paragraph from Wikipedia).
2. Take a passage of comparable length from the court opinions dataset.
3. Train BPE on each corpus separately (with the same number of merges).
4. Compute the **compression ratio**: `original_characters / num_tokens`.

Which domain achieves better compression? Why? Think about vocabulary
diversity, word frequency distributions, and the presence of specialized
notation in legal text.