# Lab: Train an SLM with Your BPE Tokenizer
## Purpose:
- Learn merges for a BPE subword tokenizer
- Use t-SNE dimensionality reduction to visualize embeddings

### Topics:
- BPE
- Unicode characters

### Steps
* Load and inspect the Africa Galore dataset.
* Train a BPE encoder on the Africa Galore dataset.
* Encode and decode example words and sentences (including made-up words) to see how this tokenizer handles out-of-vocabulary (OOV) cases.
* Convert the tokenized corpus into padded numerical index sequences required for model training.
* Train the transformer model from the previous course on the dataset and observe how.
* Visualize the learned embeddings of some of the tokens using the t-SNE algorithm.

Date: 2026-02-21

Source: https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_2/gdm_lab_2_6_train_an_slm_with_your_bpe_tokenizer.ipynb

References: https://github.com/google-deepmind/ai-foundations
- GDM GH repo used in AI training courses at the university & college level.

In [None]:
%%capture
# Install the custom package for this course.
!pip install "git+https://github.com/google-deepmind/ai-foundations.git@main"

from collections import Counter # For counting tokens in the BPE tokenizer.
import os # Used for setting Keras configuration variables.
import string # For accessing string constants.

# The following line provides configuration for Keras.
os.environ["KERAS_BACKEND"] = "jax"

import keras
import numpy as np # For working with vectors and matrices.
import pandas as pd # For loading the Africa Galore dataset.
import tqdm # For displaying progress bars.

from ai_foundations import training # For defining and training the SLM.
from ai_foundations import embeddings as emb # For visualizing embeddings.

## The full BPE Tokenizer class, all in one place

In [None]:
class BPEWordTokenizer:
    """
    A Byte Pair Encoding (BPE) based subword tokenizer.

    Supports encoding and decoding text to subword tokens using BPE.
    Learns merge rules from a corpus or be initialized with a pre-built vocabulary.

    Attributes:
        vocabulary: List of subword tokens including special tokens.
        vocabulary_size : Total number of tokens in vocabulary.
        token_to_index: Mapping from tokens to indices.
        index_to_token: Mapping from indices to tokens.
        pad_token_id: Index of the padding token.
        unknown_token_id: Index of the unknown token.
        tokenized_corpus: Cached tokenized corpus after BPE training.
    """

    UNKNOWN_TOKEN = "<UNK>"
    PAD_TOKEN = "<PAD>"
    END_WORD = "</w>"

    def __init__(
        self,
        texts: list[str] | str,
        vocabulary: list[str] | None = None,
        num_merges: int = 100,
    ):
        """Initializes the BPEWordTokenizer.

        If no vocabulary is specified, it extracts the unique tokens from the
        text corpus and learns the BPE merges.

        Args:
          texts: A list of strings or a string representing the text corpus.
          vocabulary: Optional list of strings with unique tokens.
          num_merges: Defines how many rounds of merges should be performed
            when learning the BPE merges.
        """

        # Normalize to list of strings.
        if isinstance(texts, str):
            texts = [texts]

        if vocabulary is None:
            # Learn BPE merges and derive vocabulary from tokenized corpus.
            self.merges, tokenized, vocabulary_set = self._learn_bpe(
                texts, num_merges
            )
            self.tokenized_corpus = tokenized

            # Ensure that basic alphanumeric characters are always included in
            # the vocabulary.
            required_chars = set(
                string.ascii_lowercase + string.ascii_uppercase + string.digits
            )

            vocabulary_set.update(required_chars)

            # Add special tokens to the vocabulary.
            self.vocabulary = (
                [self.PAD_TOKEN] + sorted(vocabulary_set) + [self.UNKNOWN_TOKEN]
            )

        else:
            self.vocabulary = vocabulary
            self.merges = []  # Skip merge logic when a vocabulary is provided.

        # Build mappings and set IDs of special tokens.
        self.vocabulary_size = len(self.vocabulary)
        self.token_to_index = {tok: i for i, tok in enumerate(self.vocabulary)}
        self.index_to_token = {i: tok for i, tok in enumerate(self.vocabulary)}
        self.pad_token_id = self.token_to_index[self.PAD_TOKEN]
        self.unknown_token_id = self.token_to_index[self.UNKNOWN_TOKEN]

    def _split_text(self, text: str) -> list[str]:
        """Split a string into subword tokens using learned BPE merges.

        Args:
          text: String to split into subword tokens.

        Returns:
          List of subword tokens that together form the original text.
        """
        tokens = []
        for word in text.strip().split():
            # Split the string into characters and add special END_WORD token.
            chars = list(word) + [self.END_WORD]

            # Merge individual characters according to learned BPE merges.
            for pair in self.merges:
                chars = self._merge_pairs_in_word(chars, pair)
            tokens.extend(chars)
        return tokens

    def join_text(self, tokens: list[str]) -> str:
        """Join subword tokens into full string, preserving word boundaries.

        Args:
          tokens: List of subword tokens to be joined.

        Returns:
          String obtained from joining the subword tokens.
        """
        words = []
        current_word = []
        for token in tokens:
            # Check whether token ends with a word boundary marker.
            if token.endswith(self.END_WORD):
                current_word.append(token.replace(self.END_WORD, ""))
                words.append("".join(current_word))
                current_word = []
            else:
                current_word.append(token)
        if current_word:
            words.append("".join(current_word))
        return " ".join(words).strip()

    def encode(self, text: str) -> list[int]:
        """
        Encode a string into list of token indices.

        Args:
            text: Input text.

        Returns:
            List of integers corresponding to tokens.
        """
        token_ids = []
        for token in self._split_text(text):
            token_id = self.token_to_index.get(token, self.unknown_token_id)
            token_ids.append(token_id)
        return token_ids

    def decode(self, token_ids: int | list[int]) -> str:
        """
        Decode list of token IDs back to original text.

        Args:
          token_ids: Single index or list of token IDs.

        Returns:
          Decoded text string.
        """
        # Covert to list if a single token index is specified.
        if isinstance(token_ids, int):
            token_ids = [token_ids]

        tokens = []
        for token_id in token_ids:
            tokens.append(
                self.index_to_token.get(
                    token_id,
                    self.UNKNOWN_TOKEN + self.END_WORD,
                )
            )
        return self.join_text(tokens)

    def _get_pair_frequencies(self, corpus: list[list[str]]) -> Counter[str]:
        """Count all adjacent token pairs in corpus.

        Args:
          corpus: A list of lists of strings representing subword tokens.

        Returns:
          Counter mapping adjacent pairs of subword tokens to their frequencies.
        """
        pairs = Counter()
        for word in corpus:
            for i in range(len(word) - 1):
                pair = (word[i], word[i + 1])
                # Increase the count by 1.
                pairs[pair] += 1
        return pairs

    def _merge_pairs_in_word(
        self, word: list[str], pair_to_merge: tuple[str, str]
    ) -> list[str]:
        """Merge all occurrences of a token pair inside a word.

        Args:
          tokens: A list of subword tokens representing one space separated
            word.
          pair_to_merge: A pair of two subword tokens that should be merged into
            one subword token.

        Returns:
          New list of subword tokens representing the word after applying the
            merge.
        """

        merged_symbol = pair_to_merge[0] + pair_to_merge[1]
        if pair_to_merge[0] not in word or pair_to_merge[1] not in word:
            return word
        i = 0
        new_word = []
        while i < len(word):
            if i < len(word) - 1 and (word[i], word[i + 1]) == pair_to_merge:
                new_word.append(merged_symbol)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        return new_word

    def _learn_bpe(
        self, corpus: list[str], num_merges: int
    ) -> tuple[list[tuple[str, str]], list[list[list[str]]], set[str]]:
        """
        Learn BPE merges from a corpus of texts.

        Args:
          corpus: List of input texts.
          num_merges: Number of merge operations to perform.

        Returns:
            merges: List of merges in order they are learned to be performed.
            tokenized_corpus: List of list of list of subword tokens where each
              paragraph in the corpus is tokenized as a list of list of
              subword tokens.
            vocabulary_set: Set of subword tokens after performing all merges.
        """
        # List of lists of lists to store tokenized text corpus.
        tokenized_corpus = []
        vocabulary = set([self.END_WORD])
        for paragraph in corpus:
            sentence_raw_tokens = []
            for word in paragraph.strip().split():
                # Split the word into characters and add word boundary marker.
                sentence_raw_tokens.append(list(word) + [self.END_WORD])
                vocabulary.update(list(word))
            tokenized_corpus.append(sentence_raw_tokens)

        merges = []
        for _ in (pbar := tqdm.tqdm(range(num_merges), unit="merges")):
            # Build a one-dimensional list of all tokens in the corpus.
            flat_corpus = []
            for tokenized_paragraph in tokenized_corpus:
                flat_corpus.extend(tokenized_paragraph)

            # Find the most frequent pair of adjacent tokens.
            pair_freqs = self._get_pair_frequencies(flat_corpus)
            if not pair_freqs:
                break
            most_freq_pair, freq = pair_freqs.most_common(1)[0]
            if freq < 1:
                break
            merges.append(most_freq_pair)

            # Apply merge to each token in each paragraph.
            new_tokenized_corpus = []
            for para_tokens in tokenized_corpus:
                new_para_tokens = []
                for word_tokens in para_tokens:
                    new_para_tokens.append(
                        self._merge_pairs_in_word(word_tokens, most_freq_pair)
                    )
                new_tokenized_corpus.append(new_para_tokens)
            tokenized_corpus = new_tokenized_corpus
            vocabulary.add(most_freq_pair[0] + most_freq_pair[1])
            pbar.set_postfix(vocabulary_size=f"{len(vocabulary):,}")

        return merges, tokenized_corpus, vocabulary

### Train the BPE tokenizer
Uses 3000 merges to completely tokenize the dataset.

In [None]:
# Load the Africa Galore dataset
africa_galore = pd.read_json(
    "https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore.json"
)
dataset = africa_galore["description"].values
print("Loaded dataset with", dataset.shape[0], "paragraphs.")

In [None]:
num_merges = 3000

tokenizer = BPEWordTokenizer(dataset, num_merges=num_merges)
print(f"\n\nFinal tokenizer vocabulary size: {tokenizer.vocabulary_size:,}\n")

### Check the behavior on the first 20 words in the dataset

Expected Output
```
The</w>
Lago s</w>
air</w>
was</w>
thick</w>
with</w>
humid ity,</w>
but</w>
the</w>
energy</w>
in</w>
the</w>
cl ub </w>
was</w>
electr ic.</w>
The</w>
band</w>
la un ched</w>
into</w>
a</w>
```

In [None]:
africa_galore_tokenized = tokenizer.tokenized_corpus
for tokens in africa_galore_tokenized[0][:20]:
    print(" ".join(tokens))

### Tokenize unknown words
Expected output
```
Decoded sentence from tokens: A Zimbabwian dish <UNK>.
Token 63:	A
Token 388:	Zimbab
Token 3003:	wi
Token 500:	an
Token 1026:	dish
Token 3080:	<UNK>
Token 35:	.
```

In [None]:
madeup_tokens = tokenizer.encode("A Zimbabwian dish ðŸ˜‹.")

print(f"Decoded sentence from tokens: {tokenizer.decode(madeup_tokens)}")
for token in madeup_tokens:
    decoded_token = tokenizer.decode(token)
    print(f"Token {token}:\t{decoded_token}")

# Convert data to token IDs
Prepare dataset to train model.

In [None]:
encoded_tokens = []
for paragraph in tqdm.tqdm(dataset, unit="paragraphs"):
    encoded_tokens.append(tokenizer.encode(paragraph))

### Truncate * pad the sequences using keras.preprocessing.sequence.pad_sequences
Truncate any paragraphs longer than 300 words & pad the rest.

In [None]:
max_length = 300
padded_sequences = keras.preprocessing.sequence.pad_sequences(
        encoded_tokens,
        maxlen=max_length,
        padding="post",
        truncating="post",
        value=tokenizer.pad_token_id,
    )

Split the dataset to create the inputs and targets.

In [None]:
# Prepare input and target for the transformer model.
# For each example, extract all tokens except the last one.
input_sequences = padded_sequences[:, :-1]
# For each example, extract all tokens except the first one.
target_sequences = padded_sequences[:, 1:]

max_length = input_sequences.shape[1]

## Train the model

Initialize the model and trains for 100 epochs.
Monitor progress by printing a statement every tenth epoch.

In [None]:
# Set a seed for reproducibility.
keras.utils.set_random_seed(3112)

model = training.create_model(
    max_length=max_length,
    vocabulary_size=tokenizer.vocabulary_size,
    learning_rate=8e-5
)

prompt = "Jide"
prompt_ids = tokenizer.encode(prompt)
text_gen_callback = training.TextGenerator(
    max_tokens=11, start_tokens=prompt_ids, tokenizer=tokenizer, print_every=10
)

num_epochs = 100
# verbose=2: Instructs the model.fit method to print one line per
# epoch so you can observe loss decreasing and the generated texts improving.
history = model.fit(
    x=input_sequences,
    y=target_sequences,
    verbose=2,
    epochs=num_epochs,
    batch_size=2,
    callbacks=[text_gen_callback]
)

### Visualize embeddings with t-SNE


In [None]:
# Extract all embeddings from your model as a matrix.
embeddings = model.trainable_weights[0].value

# Define token categories.
food_and_drink = [
    "water",
    "coffee",
    "onions",
    "peanut",
    "pepper",
    "pudding",
    "sauce",
    "stew",
    "carrots,"
]
prepositions = ["on", "in", "with", "for", "of"]
adjectives = ["aromatic", "hot"]
countries = ["Egypt", "Congo", "Ghana,", "Tanzania"]

# Define list of tokens and map them to categories for coloring them.
tokens = food_and_drink + prepositions + adjectives + countries
categories = (
    [0] * len(food_and_drink)
    + [1] * len(prepositions)
    + [2] * len(adjectives)
    + [3] * len(countries)
)

# Convert tokens into token IDs.
token_ids = []
for token in tokens:
    token_ids.extend(tokenizer.encode(token))

# Check that each token is represented as a single token in the tokenizer.
assert len(token_ids) == len(tokens)

# Extract embeddings for the set of tokens of interest.
embeddings_subset = embeddings[token_ids, :]

# Generate t-SNE plot with embeddings from your model.
emb.plot_embeddings_tsne(embeddings_subset, tokens, categories)