<img src="https://storage.googleapis.com/dm-educational/assets/ai_foundations/GDM-Labs-banner-image-C6-white-bg.png">

# Lab: Fine-Tune a Model with bfloat16

<a href='https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_7/gdm_lab_7_5_fine_tune_a_model_with_bfloat16.ipynb' target='_parent'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/></a>

Learn how to successfully train a large model on a single GPU using the bfloat16 data type.

25 minutes

##Overview

In the previous fine-tuning lab "Hitting a Wall," you encountered an out-of-memory error when attempting to load Gemma-4B on a standard 16GB GPU.

In this lab, you will learn how to load the model parameters using the bfloat16 data type. Combining this technique with LoRA, you can then fine-tune a 4B model on the GPU that is available to you.

### What you will learn:

By the end of this lab, you will be able to:

* Implement the loading of parameters in bfloat16 in Keras.
* Explain how bfloat16 significantly reduces the memory footprint of a model.
* Fine-tune a large model on a resource-constrained GPU.

### Tasks

**In this lab, you will**:

* Load the Gemma-4B model using bfloat16 precision.
* Fine-tune the model to produce flashcards as in course 05 Fine-tune Your Model.
* Reflect on the performance of the fine-tuned model.

**This lab needs to be run on a GPU. Choose a T4 GPU.** See the section "How to use Google Colaboratory (Colab)" below for instructions on how to do this.

**You also need a Kaggle account** to download the weights of the Gemma 3 model. See the section "Set up a Kaggle account" for instructions on how to do this.


## How to use Google Colaboratory (Colab)

Google Colaboratory (also known as Google Colab) is a platform that allows you to run Python code in your browser. The code is written in **cells** that are excuted on a remote server.

To run a cell, hover over a cell and click on the `run` button to its left. The run button is the circle with the triangle (â–¶). Alternatively, you can also click on a cell and use the keyboard combination Ctrl+Return (or âŒ˜+Return if you are using a Mac).

To try this out, run the following cell. This should print today's day of the week below it.

In [None]:
from datetime import datetime
print(f"Today is {datetime.today():%A}.")

Note that the **order in which you run the cells matters**. When you are working through a lab, make sure to always run all cells in order, otherwise the code might not work. If you take a break while working on a lab, Colab may disconnect you and in that case, you have to execute all cells again before  continuing your work. To make this easier, you can select the cell you are currently working on and then choose _Runtime â†’ Run before_  from the menu above (or use the keyboard combination Ctrl/âŒ˜ + F8). This will re-execute all cells before the current one.

### Using Colab with a GPU

Follow these steps to run the activities in this lab on a GPU:

1.  In the top menu bar, click on **Runtime**.
2.  Select **Change runtime type** from the dropdown menu.
3.  In the pop-up window under **Hardware accelerator**, select **T4 GPU**.
4.  Click **Save**.

Your Colab session will now restart with GPU access.

Note that access to GPUs is limited and at times, you may not be able to run this lab on a GPU. All activities will still work but they will run slower and you will have to wait longer for some of the cells to finish running.

## Set up a Kaggle account



To run this notebook, you will have to sign up for [Kaggle](https://www.kaggle.com), a platform that hosts datasets and models for machine learning. You will also need to sign the agreement for using the Gemma 3 model. This is required so that you can download the weights of the Gemma model for fine-tuning.

### Step 1: Create your Kaggle account

* Go to the Kaggle website: https://www.kaggle.com

* Click the "Register" button in the top-right corner.

* You can sign up using your Google account (recommended for easy Colab integration) or by entering an email and password.

* Follow the on-screen prompts to complete your registration and verify your email.

### Step 2: Sign the Gemma 3 model agreement

* Make sure you are logged into your new Kaggle account.

* Go directly to the Gemma 3 model card page: https://www.kaggle.com/models/keras/gemma3/keras/

* You should see a "Request Access" button.

* Click the button, read through the license agreement, and if you are happy to click "Accept" to gain access to the model. You must do this before the API will let you download the model.

### Step 3: Generate your Kaggle API key

* From any Kaggle page, click on your profile picture or icon in the top-right corner.

* Select "Account" from the drop-down menu.

* Scroll down to the "API" section.

* Click the "Create New API Token" button.

* This will immediately download a file named `kaggle.json` to your computer. This file contains your username and your secret API key. Keep it safe.

### Step 4: Set your API Key in  Colab

* Click the "key" icon ðŸ”‘ in the left-hand sidebar.

* You will see the "Secrets" panel.

* Now, open the kaggle.json file you downloaded on your computer. It's a simple text file and will look like this:

   ```json
   {"username":"YOUR_KAGGLE_USERNAME","key":"YOUR_KAGGLE_API_KEY"}
   ```
* In the Colab Secrets panel, create two new secrets:

   1. Name: `KAGGLE_USERNAME`

      Value: Copy and paste `YOUR_KAGGLE_USERNAME` from your `kaggle.json` file.

   2. Name: `KAGGLE_KEY`

      Value: Copy and paste `YOUR_KAGGLE_API_KEY` from your `kaggle.json` file.

* For both secrets, make sure the "Notebook access" toggle is switched on.


## Imports

In this lab, you will use the Keras package for loading and fine-tuning a Gemma model and the Pandas package, for loading the dataset.

Run the following cell to import the required packages.


In [None]:
%%capture
# Install the custom package for this course.
!pip install "git+https://github.com/google-deepmind/ai-foundations.git@main"

import os # For setting system variables.

from google.colab import userdata # For using Colab secrets.

os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")

os.environ["KERAS_BACKEND"] = "jax"  # Set the Keras backend to JAX.

# Disable the command buffer pre-allocation to free up memory.
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_command_buffer="
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.9"

import keras # For defining and training models.
import keras_nlp # For loading the Keras implementation of Gemma.
import pandas as pd # For loading the dataset.
from textwrap import fill # For formatting long paragraphs.
from ai_foundations import formatting # For formatting the training data.

keras.utils.set_random_seed(812)  # For Keras layers.

### Load and process data

As in previous labs, the following cell defines a function `format_question` that formats a prompt. It also loads the Africa Galore QA dataset and processes the individual questions so that the data can be used to fine-tune a model to generate answers for flashcards, as in previous courses.

In [None]:
def format_question(
    question: str,
    sot = "<start_of_turn>",
    eot = "<end_of_turn>"
) -> str:
    """
    Formats a question for prompting the model and adds special delimiters at
    the start and end of the question.

    Args:
      text: The question to be formatted.
      sot: The token to mark the start of a turn.
      eot: The token to mark the end of a turn.

    Returns:
      Formatted string of the question.
    """

    formatted_q = f"{sot}user\n{question}{eot}\n"

    return formatted_q

# Load the question-answer dataset.
africa_galore_qa = pd.read_json(
    "https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore_qa_v2.json"
)

questions = []  # List of formatted questions.
answers = []  # List of formatted answers.

for idx, row in africa_galore_qa.iterrows():
    # Run the format_qa function from the previous lab to format the question
    # and the answer.
    question, answer = formatting.format_qa(row)
    questions.append(question)
    answers.append(answer)

# Show the first set of input and output.
print(questions[0])
print(fill(answers[0], replace_whitespace=False))

# Prepare the data dictionary for fine-tuning Gemma.
data = {
    "prompts": questions,
    "responses": answers
}

## Coding Activity 1: Load Gemma-4B parameters

In the previous fine-tuning lab "Hitting a Wall," the attempt to load Gemma-4B failed because the model's parameters (stored in the default 32-bit format) were too large to fit in the GPU's memory.

The solution is to load the model using a more memory-efficient number format. You will use bfloat16, which uses 16 bits instead of 32, effectively halving the memory required for the model's weights.

<br>

------
> ðŸ’» **Your task**:
>
> In the following cell, load the Gemma-4B model parameters in the `bfloat16` format.
>
> In Keras, you can load the parameters as `bfloat16` as follows:
>
> ```
> model = keras_nlp.models.Gemma3CausalLM.from_preset(preset=<MODEL_NAME>, dtype="bfloat16")
> ```
>
> For the Gemma-4B model, the `<MODEL_NAME>` value is `"gemma3_4b_text"`.
------

<br>

It should take 2-3 minutes to load the model.

In [None]:
# Load the Gemma3-4B Keras model with bfloat16 precision.
model = ... # Add your code here.
model.summary()

Now that the model has been loaded in with bfloat16, inspect some of the model weights to verify that its parameters are represented as bfloat16 by running the following cell.

In [None]:
# Access the first transformer block.
first_transformer_block = model.backbone.get_layer("decoder_block_0")

# Access the attention layer.
attention_layer = first_transformer_block.attention

# Get the weight matrix for the query projections, and check its dtype.
query_weights = attention_layer.query_dense.kernel

print(f"The model's weight precision is {query_weights.dtype}.")

### Activate LoRA

Even when representing model parameters as bfloat16, you will not be able to perform full-parameter fine-tuning on the available GPU. You will therefore again have to train the model using LoRA. Recall that in course 05 Fine-Tune Your Model, when fine-tuning with LoRA, only the parameters of the low-rank matrices were updated during training. These low-rank matrices constitute a small fraction of the total parameters. This means that the GPU memory needs to store gradients, activations, and optimizer states for only a small number of parameters. This drastically reduces the overall memory requirement. The lower memory requirement is why you can fine-tune a 4B model on a T4 GPU.

Run the following code to enable LoRA for Gemma-4B with a rank of 4.

In [None]:
model.backbone.enable_lora(rank=4)
model.summary()

## Fine-tune Gemma-4B with bfloat16

Now that LoRA is enabled, you need to configure the training process. This involves setting the hyperparameters that will guide the learning.

Training with a low-precision format like bfloat16 often requires more careful hyperparameter selection than standard 32-bit training. As the numbers are less precise, the learning process can be more sensitive, so it is important to use settings that result in a stable training process.

The following cell below sets up the AdamW optimizer with a set of values that have been found to work well for stable bfloat16 fine-tuning. Run the following cell to configure your model for training.

In [None]:
# Set hyperparameters.
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
    beta_1=0.9,
    beta_2=0.95,
    epsilon=1e-6,
    clipnorm=0.5,
)

# Determine the number of epochs.
num_epochs = 4

# Set the maximum length.
model.preprocessor.sequence_length = 400

# Compile the optimizer.
model.compile(
    optimizer=optimizer,
)

### Monitor training progress

As for previous fine-tuning runs, you need a method to evaluate the training progress beyond the loss. As before, you will do this by sampling generations from the model after each epoch.

The following cell creates a Keras callback. This is a helper function that will automatically run code at the end of each epoch. In this case, it will generate and print outputs for three specific test prompts. This allows you to monitor the model's progress in real-time.

The three test prompts are designed to check different aspects of learning:

* "What is Kente cloth?": This tests if the model is learning the specific knowledge from the fine-tuning dataset.
* "What is Kilimanjaro?": This tests generalization, as "Mount Kilimanjaro" features in the fine-tuning dataset but the word "Mount" is omitted here to make it more difficult.
* "What is Tokyo?": Since there is nothing about Tokyo or Japan in the fine-tuning dataset, this checks if the model retains its pre-trained knowledge while learning the new task.

Run the following cell to define this callback.

In [None]:
# Define a list of prompts you want to check after each epoch.
test_prompts = [
    "What is Kente cloth?",
    "What is Kilimanjaro?",
    "What is Tokyo?"
]

class GenerationMonitor(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f"\n--- Generations after epoch {epoch + 1} ---")
        for prompt in test_prompts:
            # Format the prompt correctly for the model.
            formatted_prompt = format_question(prompt)

            # Generate and print the output.
            output = self.model.generate(formatted_prompt, max_length=150)
            print(output)
            print("-" * 20)

Now run the following cells to fine-tune the model using LoRA and the bfloat16 number format. Notice how this does not lead to an out-of-memory error.

Compare the outputs after each epoch. How do they look?

The training will take about five minutes per epoch on a T4 GPU. In total, it should take around 20 minutes for four epochs.

In [None]:
# Create an instance of the monitoring callback.
generation_callback = GenerationMonitor()

# Train the model.
history = model.fit(
    data,
    epochs=num_epochs,
    batch_size=1,
    callbacks=[generation_callback],
)

### What did you observe?

In the previous fine-tuning lab "Hitting a Wall," already attempting to load the model with full 32-bit precision resulted in an immediate out-of-memory error. Now, by making a single change - that is, loading the model with bfloat16 instead of 32-bit floating point numbers - you can not only load the model but also fine-tune it with LoRA.

The output of the training process also shows how the model learned to respond to the questions in the desired flashcard format after training for a few epochs. Notice the initial instability in Epoch 1. For the "Kente" prompt, the model generates a repeated, random token ("KMnO4"). For the "Tokyo" prompt, the model keeps generating "Category:" lists. However, after this, the generations improve dramatically. The model produces increasingly coherent text matching the flashcard format. The final generations in Epoch 4 are high-quality, factually accurate, and stylistically correct. This shows that using bfloat16 does not compromise the quality of the trained model.

You have now successfully trained a 4-billion parameter model on a single GPU, a task that was impossible with standard precision. You have also seen that the final generations are high-quality, indicating that bfloat16 did not significantly harm the model's performance on this task.

## Summary

Loading the model parameters with a lower precision number format like bfloat16 allows you to use and fine-tune bigger models with limited GPU resources. As you observed in this lab, by using bfloat16 representations, you were able to load all parameters of the Gemma-4B into memory and even fine-tune the model on a T4 GPU available through Colab. This resulted in a model that is able to generate high-quality responses thanks to the powerful 4-billion parameter foundation model that served as the basis for fine-tuning.

## Solutions

The following cells provide reference solutions to the coding activities in this notebook. If you really get stuck after trying to solve the activities yourself, you may want to consult these solutions.

It is recommended that you *only* look at the solutions after you have tried to solve the activities *multiple times*. The best way to learn challenging concepts in computer science and artificial intelligence is to debug your code piece-by-piece until it works, rather than copying existing solutions.

If you feel stuck, you may want to first try to debug your code. For example, by adding additional print statements to see what your code is doing at every step. This will provide you with a much deeper understanding of the code and the materials. It will also provide you with practice on how to solve challenging coding problems beyond this course.

To view the solutions for an activity, click on the arrow to the left of the activity name. If you consult the solutions, do not copy and paste them into the cells above. Instead, look at them, and type them manually into the cell. This will help you understand where you went wrong.


### Coding Activity 1

In [None]:
model = keras_nlp.models.Gemma3CausalLM.from_preset("gemma3_4b_text", dtype="bfloat16")