# Lab: Prepare The Dataset For Training an SLM
## Purpose:
- Learn preprocessing steps to train a xformer SLM
- Tokenize text
- Generate primitive embeddings
### Topics:
- Tokenization
- Vocabulary creation
- Embeddings
### Steps
Load & tokenize the dataset.
List tokens in the dataset.
List unique tokens in the dataset.
Map tokens to token IDs and vice-versa.
Translate between tokens and their corresponding IDs w/ functions.
Wrap it all in a class that encapsulates all methods necessary for preparing the data for a transformer model.

Date: 2026-02-19

Source: https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_1/gdm_lab_1_4_prepare_the_dataset_for_training_a_slm.ipynb

References: https://github.com/google-deepmind/ai-foundations
- GDM GH repo used in AI training courses at the university & college level.

In [None]:
%%capture
!pip install "git+https://github.com/google-deepmind/ai-foundations.git@main"

import re # Used for splitting strings on spaces.

# Packages used.
import pandas as pd # For reading the dataset.
import textwrap # For adding linebreaks to paragraphs.

# For providing feedback.
from ai_foundations.feedback.course_1 import slm

### Load & Tokenize
w/ Pandas

In [None]:
africa_galore = pd.read_json(
    "https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore.json"
)
dataset = africa_galore["description"]
print(f"Loaded Africa Galore dataset with {len(dataset)} paragraphs.")
print(f"\nFirst paragraph:")
print(textwrap.fill(dataset[0]))

In [None]:
# Tokenize
def tokenizer(text: str) -> list[str]:
    """Creates a list of tokens by splitting a string on spaces.
    Args:
        text: The input text.
    Returns:
        A list of tokens. Returns empty list if text is empty or all spaces.
    """
    # Use `re` package so that splitting on multiple spaces also works.
    tokens = re.split(r" +", text)
    return tokens

# Tokenize an example text with tokenizer().
tokenizer("Kanga, a colorful printed cloth is more than just a fabric.")

### List all tokens

In [None]:
# Create the list of tokens
tokens = []

# loop through each para in the dataset
for paragraph in dataset:
    # tokenizer returns a list. Add each item in list to tokens[]
    # I could loop through the lists or use extend()
    tokens.extend(tokenizer(paragraph))

print(f"Total number of tokens in the Africa Galore dataset: {len(tokens):,}")

# Print first 30 tokens sample
tokens[:30]

# test
slm.test_build_tokens_list(tokens, tokenizer, dataset)

### List unique tokens
(5260)

In [None]:
def build_vocabulary(tokens: list[str]) -> list[str]:
    # Build a vocabulary list from the set of tokens.
    vocabulary = list(set(tokens))
    return vocabulary

slm.test_build_vocabulary(build_vocabulary)

In [None]:
vocabulary = build_vocabulary(tokens)

vocabulary_size = len(vocabulary)

print(
    "Total number of unique tokens in the Africa Galore dataset:"
    f" {vocabulary_size:,}"
)

### Create an index for each token
- IDs should always be consecutive
- Use a dictionary for each mapping

In [None]:
# Build the `token_to_index` dictionary.
# enumerate() is faster than a for loop
token_to_index = {}

for index, token in enumerate(vocabulary):
    token_to_index[token] = index

In [None]:
# Build the `token_to_index` dictionary.
index_to_token = {}

for index, token in enumerate(vocabulary):
    index_to_token[index] = token

### Encoding and decoding
- encode() takes a string of text and returns the corresponding indices.
- decode() takes a list of indices and returns the corresponding text.

In [None]:
def encode(text: str) -> list[int]:
    """Encodes a text sequence into a list of indices based on the vocabulary.
    Args:
        text: The input text to be encoded.
    Returns:
        A list of indices corresponding to the tokens in the input text.
    """

    # Convert tokens into indices.
    indices = []
    for token in tokenizer(text):
        token_index = token_to_index.get(token)
        indices.append(token_index)

    return indices


def decode(indices: int | list[int]) -> list[str]:
    """Decodes a list (or single index) of integers back into tokens.
    Args:
        indices: A single index or a list of indices to be decoded into tokens.
    Returns:
        str: A string of decoded tokens corresponding to the input indices.
    """

    # If a single integer is passed, convert it into a list.
    if isinstance(indices, int):
        indices = [indices]

    # Map indices to tokens.
    tokens = []
    for index in indices:
        token = index_to_token.get(index)
        tokens.append(token)

    # Join the decoded tokens with spaces.
    return " ".join(tokens)

In [None]:
text = dataset[0]
print(text)

### Wrap it all into a class

In [None]:
class SimpleWordTokenizer:
    """A simple word tokenizer that can be initialized with a corpus of texts
       or using a provided vocabulary list.

    The tokenizer splits the text on spaces,
    encode() converts the text to indices.
    decode() converts indices to text.

    Typical usage example:
        corpus = "Hello there!"
        tokenizer = SimpleWordTokenizer(corpus)
        print(tokenizer.encode('Hello'))
    """

    def __init__(self, corpus: list[str], vocabulary: list[str] | None = None):
        """Initialize tokenizer with texts in corpus or with a vocabulary.
        Args:
            corpus: Input text dataset.
            vocabulary: A pre-defined vocabulary. If None,
                the vocabulary is automatically inferred from the texts.
        """

        if vocabulary is None:
            # Build the vocabulary from scratch.
            if isinstance(corpus, str):
                corpus = [corpus]

            # Convert text sequence to tokens.
            tokens = []
            for text in corpus:
                for token in self.tokenizer(text):
                    tokens.append(token)

            # Create a vocabulary of unique tokens.
            self.vocabulary = self.build_vocabulary(tokens)

        else:
            self.vocabulary = vocabulary

        # Size of vocabulary.
        self.vocabulary_size = len(self.vocabulary)

        # Create token-to-index and index-to-token mappings.
        self.token_to_index = {}
        self.index_to_token = {}
        # Loop through all tokens in the vocabulary. enumerate automatically
        # assigns a unique index to each token.
        for index, token in enumerate(self.vocabulary):
            self.token_to_index[token] = index
            self.index_to_token[index] = token

    def tokenizer(self, text: str) -> list[str]:
        """Splits text on space into tokens.
        Args:
            text: Text to split on space.
        Returns:
            List of tokens after splitting `text`.
        """

        # Use re.split such that multiple spaces are treated as a single
        # separator.
        return re.split(" +", text)

    def join_text(self, text_list: list[str]) -> str:
        """Combines a list of tokens into a single string, with tokens separated
           by spaces.
        Args:
            text_list: List of tokens to be joined.
        Returns:
            String with all tokens joined with a space.
        """
        return " ".join(text_list)

    def build_vocabulary(self, tokens: list[str]) -> list[str]:
        """Create a vocabulary list from the list of tokens.
        Args:
            tokens: The list of tokens in the dataset.
        Returns:
            List of unique tokens (vocabulary) in the dataset.
        """
        return sorted(list(set(tokens)))

    def encode(self, text: str) -> list[int]:
        """Encodes a text sequence into a list of indices.
        Args:
            text: The input text to be encoded.
        Returns:
            A list of indices corresponding to the tokens in the input text.
        """

        # Convert tokens into indices.
        indices = []
        for token in self.tokenizer(text):
            token_index = self.token_to_index.get(token)
            indices.append(token_index)

        return indices

    def decode(self, indices: int | list[int]) -> str:
        """Decodes a list (or single index) of integers back into tokens.
        Args:
            indices: A single index or a list of indices to be decoded into
                tokens.
        Returns:
            str: A string of decoded tokens corresponding to the input indices.
        """

        # If a single integer is passed, convert it into a list.
        if isinstance(indices, int):
            indices = [indices]

        # Map indices to tokens.
        tokens = []
        for index in indices:
            token = self.index_to_token.get(index)
            tokens.append(token)

        # Join the decoded tokens into a single string.
        return self.join_text(tokens)