<img src="https://storage.googleapis.com/dm-educational/assets/ai_foundations/GDM-Labs-banner-image-C6-white-bg.png">

# Lab: Compare Models of Different Sizes

<a href='https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_7/gdm_lab_7_1_compare_models_of_different_sizes.ipynb' target='_parent'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/></a>

Explore generations for models of different sizes to understand trade-offs between model performance and efficiency.

15 minutes

##Overview
This activity explores how generations vary between models of different sizes, both in terms of performance and efficiency. It should help you to gain an intuition for the trade-offs you will need to consider when deciding which model to use.

### What you will learn:

By the end of this lab, you will be able to:

* Describe how generations differ between models of different sizes.
* Recall the factors to consider when choosing a model.

### Tasks

**In this lab, you will**:
* Write code to inspect generations for the Gemma-1B and Gemma-4B transformer models.
* Compare the generation times and outputs for each model to gain a better understanding of the trade-offs between smaller and larger models.

**This lab needs to be run on a GPU. Choose a T4 GPU.** See the section "How to use Google Colaboratory (Colab)" below for instructions on how to do this.

## How to use Google Colaboratory (Colab)

Google Colaboratory (also known as Google Colab) is a platform that allows you to run Python code in your browser. The code is written in **cells** that are excuted on a remote server.

To run a cell, hover over a cell and click on the `run` button to its left. The run button is the circle with the triangle (â–¶). Alternatively, you can also click on a cell and use the keyboard combination Ctrl+Return (or âŒ˜+Return if you are using a Mac).

To try this out, run the following cell. This should print today's day of the week below it.

In [None]:
from datetime import datetime
print(f"Today is {datetime.today():%A}.")

Note that the **order in which you run the cells matters**. When you are working through a lab, make sure to always run all cells in order, otherwise the code might not work. If you take a break while working on a lab, Colab may disconnect you and in that case, you have to execute all cells again before  continuing your work. To make this easier, you can select the cell you are currently working on and then choose _Runtime â†’ Run before_  from the menu above (or use the keyboard combination Ctrl/âŒ˜ + F8). This will re-execute all cells before the current one.

### Using Colab with a GPU

Follow these steps to run the activities in this lab on a GPU:

1.  In the top menu bar, click on **Runtime**.
2.  Select **Change runtime type** from the dropdown menu.
3.  In the pop-up window under **Hardware accelerator**, select **T4 GPU**.
4.  Click **Save**.

Your Colab session will now restart with GPU access.

Note that access to GPUs is limited and at times, you may not be able to run this lab on a GPU. All activities will still work but they will run slower and you will have to wait longer for some of the cells to finish running.

## Imports

In this lab, you will primarily interact with the `ai_foundations` package, which has been specifically developed for this course. In the background, this package uses the [`gemma`](https://github.com/google-deepmind/gemma) package to load and prompt the Gemma-1B and Gemma-4B models.


In [None]:
%%capture
# Install the custom package for this course.
!pip install orbax-checkpoint==0.11.21
!pip install "git+https://github.com/google-deepmind/ai-foundations.git@main"

import os # For setting system variables.
import time # For timing how long generations take.

# For formatting the model generations.
from IPython.display import display, HTML
from ai_foundations import generation # For generating texts with Gemma.

# Set the full GPU memory usage for JAX.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

## Generate outputs with Gemma-1B

As preparation for comparing the generations between the Gemma-1B and Gemma-4B models, run the following cell. This will load the Gemma-1B model and generate an output for the prompt "Jide was hungry so she went looking for", which you encountered in previous courses in the curriculum. Note that this is the original pre-trained variant of Gemma-1B, not a version of Gemma that has been instruction-tuned and optimized for dialog.

In [None]:
# Load the Gemma-1B model.
print("Loading Gemma-1B model...")
gemma_1b_model = generation.load_gemma(model_name="Gemma-1B")
print("Loaded Gemma-1B model.\n\n")

# Generate an output.
prompt = "Jide was hungry so she went looking for"
output_text_transformer, next_token_logits, tokenizer = (
    generation.prompt_transformer_model(
        prompt, max_new_tokens=10, loaded_model=gemma_1b_model
    )
)

print(f"Generation by Gemma-1B:\n{output_text_transformer}")

### Coding Activity 1: Prompt Gemma-1B

<br>

------
> ðŸ’» **Your task**:
>
> Complete the code in the following cell to:
> 1. Loop through the list of prompts.
> 2. For each prompt, generate an output with Gemma-1B.
> 3. Add the output to the list `generations_gemma_1b`.
>
> When generating outputs:
> - Set the `max_new_tokens` argument to `50`.
> - Set the `sampling_mode` argument to `"greedy"`.
>
> Once you have implemented the code, run the following cell to generate outputs. It will also print the time taken to generate outputs for all ten prompts.
------

In [None]:
list_of_prompts = [
    "The tallest mountain in Africa is",
    "The key ingredients and preparation method for making traditional Nigerian Jollof rice are",
    "Wild coffee plants originated in",
    "The national flower of South Africa is",
    "The significance of the city of Marrakesh is",
    "The most populous African nation is",
    "The world's deepest river is",
    "The Maasai ethnic group originate from ",
    "The African proverb 'A roaring lion kills no game' means",
    "Braai is",
]

generations_gemma_1b = []

before_loop = time.time_ns()

# Add your code here.

after_loop = time.time_ns()

# Calculate duration in nanoseconds.
duration_ns = after_loop - before_loop

# Convert to total seconds.
total_seconds = (after_loop - before_loop) / 1_000_000_000

# Calculate minutes and remaining seconds.
minutes = int(total_seconds // 60)
seconds = total_seconds % 60

if minutes > 0:
    if minutes == 1:
        minutes_str = "minute"
    else:
        minutes_str = "minutes"
    if seconds == 1:
        seconds_str = "second"
    else:
        seconds_str = "seconds"
    print(f"Duration: {minutes} {minutes_str} {seconds:.2f} {seconds_str}")
else:
    print(f"Duration: {seconds:.2f} seconds")

Now inspect the generations by running the following cell.

In [None]:
for i in range(len(list_of_prompts)):
    display(HTML(f"<h3>Prompt:</h3><p>{list_of_prompts[i]}</p>"))
    display(
        HTML(
            f"<blockquote><b>Gemma-1B generation:</b><br>{generations_gemma_1b[i]}</blockquote>"
        )
    )
    display(HTML("<hr>"))

## Generate outputs with Gemma-4B

Run the following cell to load the Gemma-4B model and generate an output for the prompt "Jide was hungry so she went looking for".

In [None]:
# Delete the Gemma-1B model to free up memory.
del gemma_1b_model

# Load the Gemma-4B model.
print("Loading Gemma-4B model...")
gemma_4b_model = generation.load_gemma(model_name="Gemma-4B")
print("Loaded Gemma-4B model.\n\n")

# Generate an output.
prompt = "Jide was hungry so she went looking for"
output_text_transformer, next_token_logits, tokenizer = (
    generation.prompt_transformer_model(
        prompt, max_new_tokens=10, loaded_model=gemma_4b_model
    )
)

print(f"Generation by Gemma-4B:\n{output_text_transformer}")

### Coding Activity 2: Prompt Gemma-4B

------
> ðŸ’» **Your task**:
>
> Complete the code in the following cell to generate outputs for each prompt for the Gemma-4B model, now adding outputs to the list `generations_gemma_4b`.
>
> When generating outputs:
> - Set the `max_new_tokens` argument to `50`.
> - Set the `sampling_mode` argument to `"greedy"`.
>
> Once you have implemented the code, run the following cell to generate outputs.
------

In [None]:
list_of_prompts = [
    "The tallest mountain in Africa is",
    "The key ingredients and preparation method for making traditional Nigerian Jollof rice are",
    "Wild coffee plants originated in",
    "The national flower of South Africa is",
    "The significance of the city of Marrakesh is",
    "The most populous African nation is",
    "The world's deepest river is",
    "The Maasai ethnic group originate from ",
    "The African proverb 'A roaring lion kills no game' means",
    "Braai is",
]

generations_gemma_4b = []

before_loop = time.time_ns()

# Add your code here.

after_loop = time.time_ns()

# Calculate duration in nanoseconds.
duration_ns = after_loop - before_loop

# Convert to total seconds.
total_seconds = (after_loop - before_loop) / 1_000_000_000

# Calculate minutes and remaining seconds.
minutes = int(total_seconds // 60)
seconds = total_seconds % 60

if minutes > 0:
    if minutes == 1:
        minutes_str = "minute"
    else:
        minutes_str = "minutes"
    if seconds == 1:
        seconds_str = "second"
    else:
        seconds_str = "seconds"
    print(f"Duration: {minutes} {minutes_str} {seconds:.2f} {seconds_str}")
else:
    print(f"Duration: {seconds:.2f} seconds")

Inspect the generations by running the following cell.

In [None]:
for i in range(len(list_of_prompts)):
    display(HTML(f"<h3>Prompt:</h3><p>{list_of_prompts[i]}</p>"))
    display(
        HTML(
            f"<blockquote><b>Gemma-4B generation:</b><br>{generations_gemma_4b[i]}</blockquote>"
        )
    )
    display(HTML("<hr>"))

## Compare the generations

Now that you have the outputs from both models, compare and evaluate the generations.

**Speed**: First, compare the duration it took each model to generate the outputs. Which model was faster?

**Quality**: Then, review the outputs for each prompt. Which model generally provides more accurate, detailed, and coherent answers?

Run the next cell to display the generations side by side.

In [None]:
# This cell will display the outputs from both models for each prompt.
for i in range(len(list_of_prompts)):
    display(HTML(f"<h3>Prompt:</h3><p>{list_of_prompts[i]}</p>"))
    display(
        HTML(
            f"<blockquote><b>Gemma-1B generation:</b>"
            f"<br>{generations_gemma_1b[i]}</blockquote>"
        )
    )
    display(
        HTML(
            f"<blockquote><b>Gemma-4B generation:</b>"
            f"<br>{generations_gemma_4b[i]}</blockquote>"
        )
    )
    display(HTML("<hr>"))

### What did you observe?
This exercise highlights the central trade-off in choosing a model. Take a moment to reflect on what you have observed.

While considerably slower, the Gemma-4B model tends to be better at factual recall and generating more nuanced text for some of the prompts. For example, you likely observed that the Gemma-4B model provided a better response for the prompt about the meaning of the proverb, and more factually accurate statements about the origin of coffee and the world's deepest river. Such a model, therefore, is generally well-suited for tasks where quality, accuracy, and detail are paramount.

Conversely, the smaller Gemma-1B model may be a better choice for applications where speed is critical and "good-enough" quality may be acceptable, such as in a real-time chatbot.

### Summary
This lab explored the practical trade-offs between language models of different sizes. The following highlights the key takeaways:

1. **Performance versus efficiency**: Larger models like Gemma-4B generally produce higher-quality, more factual, and more coherent text. However, this comes at the cost of being significantly slower.

2. **Task-dependent model choice**: Smaller models like Gemma-1B are much faster, making them suitable for applications where speed and efficiency are critical. The "best" model is always a balance between the performance you need and the resources you have.

3. **Intuition for scaling**: You have now experienced in practice how scaling up a model's parameter count affects both its capabilities and its performance.

## Solutions

The following cells provide reference solutions to the coding activities in this notebook. If you really get stuck after trying to solve the activities yourself, you may want to consult these solutions.


It is recommended that you *only* look at the solutions after you have tried to solve the activities *multiple times*. The best way to learn challenging concepts in computer science and artificial intelligence is to debug your code piece-by-piece until it works, rather than copying existing solutions.


If you feel stuck, you may want to first try to debug your code. For example, by adding additional print statements to see what your code is doing at every step. This will provide you with a much deeper understanding of the code and the materials. It will also provide you with practice on how to solve challenging coding problems beyond this course.


To view the solutions for an activity, click on the arrow to the left of the activity name. If you consult the solutions, do not copy and paste them into the cells above. Instead, look at them, and type them manually into the cell. This will help you understand where you went wrong.


### Coding Activity 1

In [None]:
# Add this code to the cell above.
for prompt in list_of_prompts:
    output_text_transformer, next_token_logits, tokenizer = (
        generation.prompt_transformer_model(
            prompt,
            max_new_tokens=50,
            loaded_model=gemma_1b_model,
            sampling_mode="greedy",
        )
    )
    generations_gemma_1b.append(output_text_transformer)

### Coding Activity 2

In [None]:
# Add this code to the cell above.
for prompt in list_of_prompts:
    output_text_transformer, next_token_logits, tokenizer = (
        generation.prompt_transformer_model(
            prompt,
            max_new_tokens=50,
            loaded_model=gemma_4b_model,
            sampling_mode="greedy",
        )
    )
    generations_gemma_4b.append(output_text_transformer)