# Lab: Train Your Own Small Language Model (SLM)
## Purpose:
- Apply the pre-processing steps from 1_4
- Prepare data
- Train SLM
- Observe results
### Topics:
- Tokenization
- Vocabulary creation
- Embeddings
- Keras
### Steps
- Load & tokenize the dataset.
- Map tokens to token IDs and vice-versa.
- Create same-length sequences by padding them.
- Shuffle examples & group into batches
- Transform the data into model inputs and model targets.
- Train the transformer model.

Date: 2026-02-19

Source: https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_1/gdm_lab_1_5_train_your_own_small_language_model.ipynb

References: https://github.com/google-deepmind/ai-foundations
- GDM GH repo used in AI training courses at the university & college level.

In [None]:
%%capture
!pip install "git+https://github.com/google-deepmind/ai-foundations.git@main"

# Packages used.
import os # Used for setting Keras configuration variables.
os.environ["KERAS_BACKEND"] = "jax" # Set a parameter for Keras.
import re # Used for splitting text on whitespace.

import keras # Used for defining an training the model.
import pandas as pd # Used for loading the dataset.
import tensorflow as tf # Used for shuffling the dataset.

# Used for displaying nicer error messages.
from IPython.display import display, HTML
from ai_foundations import training # For training your model.
from ai_foundations import generation # For prompting your model.
from ai_foundations import visualizations # For visualizing probabilities.
from ai_foundations.feedback.course_1 import slm # For providing feedback.

# The following line provides configuration for Keras.
keras.utils.set_random_seed(812)  # For Keras layers.

### Load dataset

In [None]:
africa_galore = pd.read_json(
    "https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore.json"
)
dataset = africa_galore["description"].values
print("Loaded dataset with", dataset.shape[0], "paragraphs.")

### Tokenize dataset

In [None]:
class SimpleWordTokenizer:
    """A simple word tokenizer.

    Splits text sequence based on whitespace, using
    encode() converts text to indices
    decode() converts indices to text

    Can be initialized w/ a corpus or a vocabulary list.

    Typical usage example:
        corpus = "Hello there!"
        tokenizer = SimpleWordTokenizer(text)
        print(tokenizer.encode('Hello'))

    """

    # Define constants.
    UNKNOWN_TOKEN = "<UNK>"
    PAD_TOKEN = "<PAD>"

    def __init__(self, corpus: list[str], vocabulary: list[str] | None = None):
        """Initializes the tokenizer with texts in corpus or with a vocabulary.

        Args:
          corpus: Input text dataset.
          vocabulary: A pre-defined vocabulary. If None,
              the vocabulary is automatically inferred from the texts.
        """

        if vocabulary is None:
            # Build the vocabulary from scratch.
            if isinstance(corpus, str):
                corpus = [corpus]

            # Convert text sequence to tokens.
            tokens = []
            for text in corpus:
                for token in self.space_tokenize(text):
                    tokens.append(token)

            # Create a vocabulary comprising of unique tokens.
            vocabulary = self.build_vocabulary(tokens)

            # Add special unknown and pad tokens to the vocabulary list.
            self.vocabulary = (
                [self.PAD_TOKEN] + vocabulary + [self.UNKNOWN_TOKEN]
            )

        else:
            self.vocabulary = vocabulary

        # Size of vocabulary.
        self.vocabulary_size = len(self.vocabulary)

        # Create token-to-index and index-to-token mappings.
        self.token_to_index = {}
        self.index_to_token = {}
        # Loop through all tokens in the vocabulary. enumerate automatically
        # assigns a unique index to each token.
        for index, token in enumerate(self.vocabulary):
            self.token_to_index[token] = index
            self.index_to_token[index] = token

        # Map the special tokens to their IDs.
        self.pad_token_id = self.token_to_index[self.PAD_TOKEN]
        self.unknown_token_id = self.token_to_index[self.UNKNOWN_TOKEN]

    def space_tokenize(self, text: str) -> list[str]:
        """Splits a given text on whitespace into tokens.
        Args:
            text: Text to split on whitespace.
        Returns:
            List of tokens after splitting `text`.
        """

        # Use re.split such that multiple spaces are treated as a single
        # separator.
        return re.split(" +", text)

    def join_text(self, text_list: list[str]) -> str:
        """Combines a list of tokens into a single string.
        The combined tokens, as a single string, are separated by spaces in the
        string.
        Args:
            text_list: List of tokens to be joined.
        Returns:
            String with all tokens joined with a whitespace.
        """
        return " ".join(text_list)

    def build_vocabulary(self, tokens: list[str]) -> list[str]:
        """Create a vocabulary list from the list of tokens.
        Args:
            tokens: The list of tokens in the dataset.
        Returns:
            List of unique tokens (vocabulary) in the dataset.
        """
        return sorted(list(set(tokens)))

    def encode(self, text: str) -> list[int]:
        """Encodes a text sequence into a list of indices.
        Args:
            text: The input text to be encoded.
        Returns:
            A list of indices corresponding to the tokens in the input text.
        """

        # Convert tokens into indices.
        indices = []
        unk_index = self.token_to_index[self.UNKNOWN_TOKEN]
        for token in self.space_tokenize(text):
            token_index = self.token_to_index.get(token, unk_index)
            indices.append(token_index)

        return indices

    def decode(self, indices: int | list[int]) -> str:
        """Decodes a list (or single index) of integers back into tokens.
        Args:
            indices: A single index or a list of indices to be
                decoded into tokens.
        Returns:
            A string of decoded tokens corresponding to the input indices.
        """

        # If a single integer is passed, convert it into a list.
        if isinstance(indices, int):
            indices = [indices]

        # Map indices to tokens.
        tokens = []
        for index in indices:
            token = self.index_to_token.get(index, self.unknown_token_id)
            tokens.append(token)

        # Join the decoded tokens into a single string.
        return self.join_text(tokens)


# Initialize the tokenizer. This will build the tokenizer's vocabulary with
# all the tokens that appear in the dataset.
tokenizer = SimpleWordTokenizer(dataset)

# Translate all tokens to their corresponding IDs.
encoded_tokens = []
for text in dataset:
    # Split text into tokens and translate the tokens to token IDs.
    token_ids = tokenizer.encode(text)
    encoded_tokens.append(token_ids)

### Pad the dataset
Create a matrix containing the indices of each token in the dataset.
- Each paragraph in the dataset will constitute a matrix.
- All matrices must be of same size, so <PAD> will be used as a token to match length of longest paragraph.
    - Alternate methods include
        - truncating all paragraphs to length of shortest paragraph
        - selecting an arbitrary paragraph length, truncating long ones and padding short ones (most common method).

### Compute length statistics
Length of the shortest paragraph is: 26
Length of the longest paragraph is: 318

In [None]:
# print(f"Length of first paragraph: {len(encoded_tokens[0]):,}")

paragraph_lengths = []
for i in range(dataset.shape[0] - 1):
  # print(len(encoded_tokens[i]))
  paragraph_lengths.append(len(encoded_tokens[i]))

paragraph_lengths.sort()
# print(paragraph_lengths)

# Add your code to compute the length of the shortest paragraph here.
shortest_paragraph_length = paragraph_lengths[0]

# Add your code to compute the length of the longest paragraph here.
longest_paragraph_length = paragraph_lengths[-1]

print(f"Length of the shortest paragraph is:", shortest_paragraph_length)
print(f"Length of the longest paragraph is:", longest_paragraph_length)

### Set max_length for padding & truncating

In [None]:
max_length = 300  # @param {type: "number"}

if max_length <= 0:
    display(
        HTML(
            f"<h3>Error:</h3><p>Max length must be greater than 0. Please"
            f" increase <code>max_length</code>.</p><p></p>"
        )
    )

elif max_length > longest_paragraph_length:
    display(
        HTML(
            f"<h3>Error:</h3><p>The padding token <code>"
            f" {tokenizer.pad_token_id}</code> will be added to all"
            f" sequences - you probably don't want that. Please reduce"
            f" <code>max_length</code>.</p><p></p>"
        )
    )

else:
    if max_length < longest_paragraph_length:
        display(
            HTML(
                f"<p><strong>Note:</strong> The longest paragraph has"
                f" {longest_paragraph_length} tokens,"
                f" but <code>max_length</code> is set to {max_length}."
                f" Paragraphs longer than <code>max_length</code> will be"
                " truncated.</p><p></p>"
            )
        )

    # Keras includes a pad_sequences function
    padded_sequences = keras.preprocessing.sequence.pad_sequences(
        encoded_tokens,
        maxlen=max_length,
        padding="post",
        truncating="post",
        value=tokenizer.pad_token_id,
    )

    print("New length of first paragraph:", len(padded_sequences[0]), "\n")

    print(
        "Padding makes the length of all sequences the same as the specified"
        " `max_length`."
    )

    print(
        "Notice the padded token IDs {tokenizer.pad_token_id} appearing at the"
        f" end of the sequence.\n"
    )
    print("Padded tokens of first paragraph:\n", padded_sequences[0])

### Preparing the input and target datasets
- Input: contains the context of a sequence
- Target: contains the last token
ex. 'Table Mountain is beautiful
- Input: ["Table", "Mountain", "is"] (last token removed).
- Target: ["Mountain", "is", "beautiful"] (shifted by one token).

In [None]:
# Prepare input and target for the transformer model.
# For each example, extract all tokens except the last one.
input_sequences = padded_sequences[:, :-1]
# For each example, extract all tokens except the first one.
target_sequences = padded_sequences[:, 1:]

In [None]:
print("First 10 token IDs in first input sequence:", input_sequences[0, :10])
print(
    "First 10 tokens in first input sequence:",
    tokenizer.decode(input_sequences[0, :10]),
)

print("\n")

print("First 10 token IDs in first target sequence:", target_sequences[0, :10])
print(
    "First 10 tokens in target sequence:",
    tokenizer.decode(target_sequences[0, :10])
)

In [None]:
# since the first and last tokens of each paragraph have been removed, the max_length is now one shorter.
max_length = input_sequences.shape[1]

### Shuffle the dataset and specify the batch size
Batch: chunks of data, in this case, a few paragraphs per batch.
Shuffling ensures that a diverse group of each data lands in each batch. We don't want one batch to be all about coffee and another batch to be all about gorillas.
- The order of tokens w/in the paragraph must remain consistent.
Processing order:
1. Create the dataset (tokenize & encode)
2. Pad
3. Shuffle
4. Batch (select batch size)

In [None]:
# TensorFlow has libraries to shuffle & batch.
# Create TensorFlow dataset to prepare sequences.
tf_dataset = tf.data.Dataset.from_tensor_slices((input_sequences, target_sequences))

# Randomly shuffle the dataset.
# The buffer_size determines how many examples from the dataset
# are held in memory before shuffling.
# If you are working with a very large dataset,
# reduce the buffer_size as needed.
tf_dataset = tf_dataset.shuffle(buffer_size=len(input_sequences))

# Specify batch size.
batch_size = 32  # @param {type: "number"}

# Create batches.
batches = tf_dataset.batch(batch_size)

for batch in batches.take(1):
    print(batch)

Expected output
>```
>(<tf.Tensor: shape=(32, 299), dtype=int32, numpy=
>array([[ 719, 5092, 4815, ...,    0,    0,    0],
>       [ 797,  597,  912, ...,    0,    0,    0],
>       [ 470, 4084, 2932, ...,    0,    0,    0],
>       ...,
>       [ 814, 4079, 1171, ...,    0,    0,    0],
>       [ 814, 3085, 2932, ...,    0,    0,    0],
>       [ 358, 1605, 2935, ...,    0,    0,    0]], dtype=int32)>, <tf.Tensor: shape=(32, 299), dtype=int32, numpy=
>array([[5092, 4815, 4403, ...,    0,    0,    0],
>       [ 597,  912, 2364, ...,    0,    0,    0],
>       [4084, 2932,  912, ...,    0,    0,    0],
>       ...,
>       [4079, 1171, 3522, ...,    0,    0,    0],
>       [3085, 2932, 4792, ...,    0,    0,    0],
>       [1605, 2935, 2968, ...,    0,    0,    0]], dtype=int32)>)

In [None]:
# count the number of batches
total_batches = 0
for batch in batches:
    total_batches += 1
print("Total number of batches is:", total_batches)

## Train the SLM
Our SLM will have around 3.5B parameters
------
> **ℹ️ Info: Parameters of a transformer model**
>
> **Parameters** are a set of numbers that guide the model to perform whatever task it was trained to do. In the case of transformer models, the parameters are less interpretable. They are often a very large collection of numbers that determine the model behavior.
>
> The parameters of a transformer model are updated after processing each batch of paragraphs. At the start of the training, the parameters are intialized with random numbers.
>Models are then usually trained by processing the data multiple times. Going through the data once is known as an **iteration** or **epoch**. During each training iteration, the parameters are updated so that they lead to better predictions of the next token.
------

### Initialize the model
Use create_model() to build a transformer model.
- max_length: Length of a paragraph, same as above.
- vocabulary_size: number of unique tokens in the dataset.
- learning_rate: How quickly to update the parameters.
    - High values are faster, but not as effective
    - Low values are more accurate, but slow

In [None]:
model = training.create_model(
    max_length=max_length,
    vocabulary_size=tokenizer.vocabulary_size,
    learning_rate=1e-4
)

### Create a callback function
A callback function prints a sample output on a regular basis so loss can be measured.

In [None]:
prompt = "Abeni,"
prompt_ids = tokenizer.encode(prompt)
text_gen_callback = training.TextGenerator(
    max_tokens=10, start_tokens=prompt_ids, tokenizer=tokenizer, print_every=10
)

## Train the Model!
**Step**: The training process updates the model parameter after processing each batch. Processing & updating is a single step.
**Epoch**: Processing all batches in the dataset. Run multiple epochs to improve accuracy.
**Overfitting**: When you run too many epochs and the model starts picking up noise.