# Lab: Compare N-Gram Models and Transformer Language Models
## Purpose:
- Intro to xformer models
- Compare n-gram models w/ xformer models

### Topics:
- Transformer model (Gemma-1B)
- Probability distribution
- Token prediction

Date: 2026-02-18

Source: https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_1/gdm_lab_1_3_compare_n_gram_models_and_transformer_language_models.ipynb

References: https://github.com/google-deepmind/ai-foundations
- GDM GH repo used in AI training courses at the university & college level.

Lab evaluation criteria:

**Fluency**: Does it read naturally? Grammar, punctuation, sentence length.
**Coherence**: Does it make logical sense and stay on topic? Does it ramble? Could it have been produced by a human? As language models are predicting one token at a time, the end of a generation may be about a different topic than its beginning.
**Relevance**: Does it fit the context or prompt?
**Bias**: Does the output promote inequalities? Language models are trained on human-written data that likely include biases and promote stereotypes. You may observe very stereotypical outputs that could promote inequalities in the generations of a model.

In [None]:
%%capture
!pip install orbax-checkpoint==0.11.21 jax[cuda12]==0.7.2
!pip install "git+https://github.com/google-deepmind/ai-foundations.git@main"

# Packages used.
import os # For setting a variable needed to load the model onto the GPU.
import pandas as pd # For loading the Africa Galore dataset.

# Functions for clearing outputs and formatting.
from IPython.display import clear_output, display, HTML

# Functions for generating texts with a language model, visualizing probability
# distributions, and loading an n-gram model.
from ai_foundations import generation
from ai_foundations import visualizations
from ai_foundations.ngram import model as ngram_model

# Set the full GPU memory usage for JAX.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

### Load the models
Use the Africa Galore dataset again.

In [None]:
# Load the Africa Galore dataset.
africa_galore = pd.read_json(
    "https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore.json"
)
dataset = africa_galore["description"]
print(f"Loaded Africa Galore dataset with {len(dataset)} paragraphs.\n")

# Load a trigram model
trigram_model = ngram_model.NGramModel(dataset, 3)
print("Loaded trigram model.\n")

print("Loading Gemma-1B model...")
gemma_model = generation.load_gemma()
print("Loaded Gemma-1B model.")

In [None]:
# @title Compute the next token for a prompt
# Test the models
prompt = "Jide was hungry so she went looking for"  # @param {type: "string"}

#Test Gemma-1B
output_text_transformer, _, _ = (
    generation.prompt_transformer_model(
        prompt, max_new_tokens=1, loaded_model=gemma_model
    )
)

clear_output()
print(f"Generation by Gemma-1B:\n{output_text_transformer}\n\n")

# Test trigram model
output_text_ngram = trigram_model.generate(1, prompt)
print(f"Generation by trigram model:\n{output_text_ngram}")

Expected output
```
>Generation by Gemma-1B:
>Jide was hungry so she went looking for a

>Generation by trigram model:
>Jide was hungry so she went looking for a

### Visualize the probability distribution
ai_foundations has a handy visualizations module based on matplotlib.pyplot, numpy, jax, and pandas

**Observations**
- Gemma 1-B computes probabilities for
    - a larger variety of words
    - with a higher range of probabilities
    - and more finely-tuned probabilities

In [None]:
# @title Visualize the probability distributions

prompt = "Jide was hungry so she went looking for"  # @param {type: "string"}

output_text_transformer, next_token_logits, tokenizer = (
    generation.prompt_transformer_model(
        prompt, max_new_tokens=1, loaded_model=gemma_model
    )
)

display(HTML("<h3>Gemma-1B</h3>"))

# Visualize the Gemma-1B probabilities.
visualizations.plot_next_token(
    next_token_logits,
    prompt=prompt,
    tokenizer=tokenizer
)

display(HTML("<h3>Trigram model</h3>"))

# Visualize the trigram probabilities.
context_ngram = tuple(prompt.split(" ")[-2:])
if context_ngram in trigram_model.probabilities:
    visualizations.plot_next_token(
        trigram_model.probabilities[context_ngram], prompt=prompt
    )
else:
    print(
        "The trigram model does not make any predictions for the prompt"
        f" \"{prompt}\" since the bigram \"{' '.join(context_ngram)}\""
        f" is not part of the dataset."
    )

### Context sensitivity
When the context changes?

**Observations**
- Gemma 1-B computes probabilities for
    - an even larger variety of words
    - with a higher range of probabilities
    - most probabilities were vanishingly small

- The trigram probability distribution did not change

In [None]:
prompt = "Jide was thirsty so she went looking for"  # @param {type: "string"}

output_text_transformer, next_token_logits, tokenizer = (
    generation.prompt_transformer_model(
        prompt, max_new_tokens=1, loaded_model=gemma_model
    )
)

output_text_ngram = trigram_model.generate(1, prompt)

clear_output()

print(f"Generation by Gemma-1B:\n{output_text_transformer}\n\n")
output_text_ngram = trigram_model.generate(1, prompt)

print(f"Generation by trigram model:\n{output_text_ngram}")

display(HTML("<h3>Gemma-1B</h3>"))

# Visualize the Gemma-1B probabilities.
visualizations.plot_next_token(next_token_logits, prompt=prompt, tokenizer=tokenizer)

display(HTML("<h3>Trigram model</h3>"))

# Visualize the trigram probabilities.
context_ngram = tuple(prompt.split(" ")[-2:])
if context_ngram in trigram_model.probabilities:
    visualizations.plot_next_token(
        trigram_model.probabilities[context_ngram], prompt=prompt
    )
else:
    print(
        "The trigram model does not make any predictions for the prompt"
        f" \"{prompt}\ since the bigram \"{' '.join(context_ngram)}\""
        f" is not part of the dataset."
    )

### Generating sequences

**Observations**
- Gemma 1-B is
    - fluent
    - coherent
    - maintains relevance
    - bias is noticeable over long responses because coherence is maintained

- Trigram model
    - loses fluency & coherency quickly
    - has a small context window

Sample output
```
>Generation by Gemma-1B:
>Jide was hungry so she went looking for something to eat, when she saw a bag she grabbed it and started eating it while eating the bag she was thirsty and started drinking the water in the bag, as she was drinking she noticed a big rat had entered the bag and started eating the water

>Generation by trigram model:
>Jide was hungry so she went looking for a host of uniquely adapted endemic species, such as meat, poultry, fish, beans, and nuts, supply essential amino acids for building and repairing tissues. Healthy fats, found in avocados, nuts, and olive oil, provide essential vitamins, minerals, and fiber. Grains, particularly whole grains, provide carbohydrates and fiber. Protein foods, such

In [None]:
# @title Generate sequences
# Hint: greedy mode always yields the same result.
prompt = "Jide was hungry so she went looking for"  # @param {type: "string"}

num_tokens_to_generate = 50  # @param {type: "number"}

(output_text_transformer, next_token_logits, tokenizer) = (
    generation.prompt_transformer_model(
        prompt, max_new_tokens=num_tokens_to_generate, loaded_model=gemma_model
    )
)

clear_output()

print(f"Generation by Gemma-1B:\n{output_text_transformer}\n\n")

output_text_ngram = trigram_model.generate(num_tokens_to_generate, prompt)
print(f"Generation by trigram model:\n{output_text_ngram}")