# Lab: Implement a BPE Tokenizer
## Purpose:
- Understanding the BPE algorithm
- Train a subword tokenizer
- Implement and run the BPE algorithm over multiple merge steps

### Topics:
- Byte-pair encoding

### Steps
* Write code to convert each word into a list of characters with </w> appended.
* Complete the function get_pair_frequencies to count adjacent symbol pairs.
* Fill in logic in merge_pair_in_word to replace the most frequent pair.
* Apply merges iteratively to update the corpus and vocabulary.
* Implement the loop in learn_bpe to perform a fixed number of merge steps and inspect results.

* Date: 2026-02-20

Source: https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_2/gdm_lab_2_4_implement_a_bpe_tokenizer.ipynb

References: https://github.com/google-deepmind/ai-foundations
- GDM GH repo used in AI training courses at the university & college level.

In [None]:
from collections import Counter, defaultdict # For counting token pairs.
import pandas as pd # For loading the dataset.

In [None]:
africa_galore = pd.read_json(
    "https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore.json"
)
dataset = africa_galore["description"].values
print(f"Dataset contains {dataset.shape[0]:,} paragraphs.")

## The BPE Algorithm

1. **Initialization**: Split the dataset into individual characters. Initialize the vocabulary with the set of unique characters. Replaces spaces with `</w>`).
2. **Counting**: Count how often each pair of tokens appears in the corpus.
3. **Merging**: Choose the most frequently appearing adjacent pair of tokens `(p, q)`. Add a new merged token `pq` to the vocabulary.
4. **Replacing**: Replace all adjacent pairs of tokens `(p, q)` with the new token `pq` in the corpus.
5. **Repeating**: Repeat steps 2-4 until you reach the target vocabulary size.

### 1. Initialization

- Convert every space-separated word into a list of characters.
- Append "</w>" at the end of each word's list of characters.
- Add the set of unique characters to the vocabulary.

Expected output
```
First 10 words:
  ['T', 'h', 'e', '</w>']
  ['L', 'a', 'g', 'o', 's', '</w>']
  ['a', 'i', 'r', '</w>']
  ['w', 'a', 's', '</w>']
  ['t', 'h', 'i', 'c', 'k', '</w>']
  ['w', 'i', 't', 'h', '</w>']
  ['h', 'u', 'm', 'i', 'd', 'i', 't', 'y', ',', '</w>']
  ['b', 'u', 't', '</w>']
  ['t', 'h', 'e', '</w>']
  ['e', 'n', 'e', 'r', 'g', 'y', '</w>']
```

In [None]:
EOW_SYMBOL = "</w>"

def bpe_initialize(dataset: list[str]) -> tuple[list[list[str]], set[str]]:
    """The initialization step.
    Args:
      dataset: The corpus consisting of a list of raw paragraphs.
    Returns:
      corpus: A list containing every word in the corpus represented as a list
        of characters and the special end-of-word-symbol "</w>".
      vocabulary: The set of unique tokens in the dataset that serves as the
        vocabulary before the first merge.
    """

    corpus = []
    vocabulary = {EOW_SYMBOL}
    for paragraph in dataset:
        for word in paragraph.split(" "):
            # Convert word into chars and add end-of-word symbol ('</w>').
            chars =  list(word)
            chars.append(EOW_SYMBOL)
            corpus.append(chars)

            # Append characters to the vocabulary. Since vocabulary is a set,
            # it will ignore characters that have already been added.
            vocabulary.update(chars)

    return corpus, vocabulary

corpus, vocabulary = bpe_initialize(dataset)

# Print the first 10 "words" in the corpus.
print("First 10 words: ")
for word in corpus[:10]:
    print(f"  {word}")

print("\n\nVocabulary:")
print(f"  {vocabulary}")

### 2: Counting adjacent token pairs

For every word in the corpus, extract every pair of tokens that appear next to each other and increase their count.

Ex. "rice" becomes `(r,i)`, `(i,c)`, `(c,e)`, `(e, </w>)`.

In [None]:
def get_pair_frequencies(
    corpus: list[list[str]]
) -> Counter[tuple[str, str], int]:
    """
    Calculates the frequency of adjacent character pairs in a corpus.
    Args:
      corpus: A list of tokenized words, where each word is represented as a
        list of subword tokens. Before the first merge these are all individual
        characters.
    Returns:
      A Counter object mapping each adjacent character pair (a tuple) to its
        frequency in the corpus.
    """
    pairs = Counter()
    for word in corpus:
        # Loop through the position of every token but the last one.
        for i in range(len(word) - 1):
            # Create a tuple representing the adjacent pair consisting of the
            # i-th and (i+1)-th token in `word`.
            pair =  word[i] + " " + word[i+1]
            pairs[pair] += 1
    return pairs


pair_freqs = get_pair_frequencies(corpus)
print("The 10 most common pair frequencies are:")
for freq in pair_freqs.most_common(10):
    print(f" {freq}")

### 3: Merge the most frequent pair
Identify the most frequent pair of adjacent tokens and merge them into a new subword token.

In [None]:
most_freq_pair, freq = pair_freqs.most_common(1)[0]
print(
    f"The most frequent pair is {most_freq_pair} with a frequency of {freq:,}."
)

new_token = most_freq_pair.replace(" ", "")
vocabulary.add(new_token)

In [None]:
def merge_pair_in_word(
    word: list[str], pair_to_merge: tuple[str, str]
) -> list[str]:
    """Merges adjacent occurrences of a specified pair of characters in a word.

    Given a word represented as a list of subword tokens and a pair of tokens to
    merge (represented as a tuple), return a new list where every
    instance of the specified pair appearing adjacently in the original word is
    replaced by a single string representing the merged pair.

    Args:
      tokens: A list of subword tokens representing one space separated word.
      pair_to_merge: A pair of two subword tokens that should be merged into one
        subword token.

    Returns:
      New list of subword tokens representing the word after applying the merge.
    """
    merged_symbol = pair_to_merge[0] + pair_to_merge[1]
    i = 0
    new_word = []
    while i < len(word):
        # If this position and the next match the pair, merge them.
        if i < len(word) - 1 and (word[i], word[i + 1]) == pair_to_merge:
            new_word.append(merged_symbol)
            i += 2
        else:
            new_word.append(word[i])
            i += 1
    return new_word

print(f"Before merging: {' '.join(corpus[0])}")
new_word = merge_pair_in_word(corpus[0], most_freq_pair)
print(f"After merging:  {' '.join(new_word)}")

### 4. Replace tokens in corpus

Use the `merge_pair_in_word` to replace every occurrence of the most frequent pair with the new merged subword token.

In [None]:
new_corpus = []
for word in corpus:
    new_word = merge_pair_in_word(word, most_freq_pair)
    new_corpus.append(new_word)

# Each iteration of BPE results in a new corpus with merged pairs.
print(
    "The first 10 \"words\" in the corpus, with \"e</w>\" being a single token"
    " now:"
)
for word in new_corpus[:10]:
    print(f"  {word}")

print("\n\nVocabulary:")
print(f"  {vocabulary}")

### Repeat

In [None]:
def learn_bpe(
    dataset: list[str], num_merges: int
) -> tuple[list[tuple[str, str]], list[list[str]]]:
    """
    Implements the basic BPE algorithm by iteratively merging the
    most frequent adjacent character pairs in a corpus to create subword units.
    It continues merging until the specified number of merges is reached or all
    words are represented by a single token.

    Args:
      dataset: A list of paragraphs, where each sentence is a string.
      num_merges: The desired number of BPE merge operations to perform.

    Returns:
      merges: A list of tuples, where each tuple represents a merged pair of
        tokens. (e.g., [('a', 'b'), ('ab', 'c')]).  The order of tuples reflects
        the order in which merges were performed.
      vocabulary: The final vocabulary of all subword tokens.
      corpus: The final corpus after BPE is applied.  Each sentence is
        represented as a list of words, and each word is represented as a list
        of subword tokens.
    """
    # Step 1: Initialize corpus as list of lists of characters + </w>.
    corpus, vocabulary = bpe_initialize(dataset)

    merges = []
    for _ in range(num_merges):
        # Step 2: Count all adjacent-pair frequencies.
        pair_freqs = get_pair_frequencies(corpus)
        if not pair_freqs:
            break

        # Step 3: Pick the most frequent pair.
        most_freq_pair, freq = pair_freqs.most_common(1)[0]
        if freq < 1:
            break  # No more pairs to merge.
        merges.append(most_freq_pair)
        new_token = most_freq_pair[0] + most_freq_pair[1]
        vocabulary.add(new_token)

        # Step 4: Merge that pair in every word of the corpus.
        new_corpus = []
        for word in corpus:
            new_word = merge_pair_in_word(word, most_freq_pair)
            new_corpus.append(new_word)
        corpus = new_corpus

    # Return the list of merges, the vocabulary, and the final tokenized corpus.
    return merges, vocabulary, corpus

In [None]:
sample_text = [
        "desert",
        "deserted",
        "deserts",
        "desert",
        "tested",
        "test",
        "deserted",
        "desert",
        "desertion",
        "desertion",
        "function",
]
# Learn BPE tokenizer with 10 merge operations.
merges, vocabulary, tokenized_corpus = learn_bpe(sample_text, num_merges=10)

print("Learned merges (in order):")
for i, pair in enumerate(merges, start=1):
    print(f" {i}. Merge {pair[0]!r} + {pair[1]!r}  →  {pair[0]+pair[1]!r}")

print("\nFinal tokenized corpus:")
for orig, tokenized in zip(sample_text, tokenized_corpus):
    print(f"  {orig:8s} → {' '.join(tokenized)}")