> <p><small><small>This Notebook is made available subject to the licence and terms set out in the <a href = "http://www.github.com/google-deepmind/ai-foundations">AI Research Foundations Github README file</a>.

<img src="https://storage.googleapis.com/dm-educational/assets/ai_foundations/GDM-Labs-banner-image-C5-white-bg.png">

# Lab: Fine-Tune Gemma with LoRA

<a href='https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_5/gdm_lab_5_6_fine_tune_gemma_with_lora.ipynb' target='_parent'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/></a>

30 minutes understanding + 20 minutes computation time

Explore how you can fine-tune larger pre-trained models by applying LoRA.


## Overview

Recall that in the lab "Full-Parameter Fine-Tuning of Gemma," your training process failed due to an out-of-memory error. In this lab, you will explore how this can be avoided by applying LoRA when fine-tuning a Gemma model. This will allow you to fine-tune Gemma for your flashcard generation task. You will also explore the advantages of fine-tuning a large pre-trained model, such as how it can answer questions on a range of topics due to the diverse data that it has been pre-trained on.




### What you will learn:

By the end of this lab, you will be able to:

* Apply LoRA to Gemma-1B.
* Fine-tune a model for your own task.
* Evaluate what the fine-tuned model learned from pre-training.
* Evaluate what the fine-tuned model learned from fine-tuning.
* Describe how combining knowledge from both training processes allows it to perform in ways that could not be achieved with either type of training alone.


### Tasks

**In this lab, you will**:
* Repeat the steps from lab "Large pre-trained models" to load and prepare Gemma-1B.
* Prepare Gemma-1B for fine-tuning with LoRA.
* Choose the hyperparameters for fine-tuning.
* Monitor the fine-tuning process.
* Evaluate the performance of the pre-trained and fine-tuned model.

**This lab needs to be run on a GPU. Choose a T4 GPU.** See the section "How to use Google Colaboratory (Colab)" below for instructions on how to do this.

**You also need a Kaggle account** to download the weights of the Gemma 3 model. See the section "Set up a Kaggle account" for instructions on how to do this.


## How to use Google Colaboratory (Colab)

Google Colaboratory (also known as Google Colab) is a platform that allows you to run Python code in your browser. The code is written in cells that are executed on a remote server.

To run a cell, hover over a cell, and click the `run` button to its left. The run button is the circle with the triangle (â–¶). Alternatively, you can also click a cell and use the keyboard combination Ctrl+Return (or âŒ˜+Return if you are using a Mac).

To try this out, run the following cell. This should print today's day of the week below it.

In [None]:
from datetime import datetime
print(f"Today is {datetime.today():%A}.")

Note that the order in which you run the cells matters. When you are working through a lab, make sure to always run all cells in order. Otherwise, the code might not work. If you take a break while working on a lab, Colab may disconnect you; in that case, you have to execute all cells again before  continuing your work. To make this easier, you can select the cell you are currently working on and then choose __Runtime â†’ Run before__  from the menu above (or use the keyboard combination Ctrl/âŒ˜ + F8). This will re-execute all cells before the current one.

### Using Colab with a GPU

Follow these steps to run the activities in this lab on a GPU:

1.  In the top menu bar, click **Runtime**.
2.  Select **Change runtime type** from the dropdown menu.
3.  In the pop-up window under **Hardware Accelerator**, select **GPU** (usually listed as `T4 GPU`).
4.  Click **Save**.

Your Colab session will now restart with GPU access.

Note that access to GPUs is limited and at times, you may not be able to run this lab on a GPU. All activities will still work but they will run slower and you will have to wait longer for some of the cells to finish running.


## Set up a Kaggle account



To run this notebook, you will have to sign up for [Kaggle](https://www.kaggle.com), a platform that hosts datasets and models for machine learning, and sign the agreement for using the Gemma 3 model. This is required so that you can download the weights of the Gemma model for fine-tuning.

### Step 1: Create your Kaggle account

* Go to the Kaggle website: https://www.kaggle.com

* Click the "Register" button in the top-right corner.

* You can sign up using your Google account (recommended for easy Colab integration) or by entering an email and password.

* Follow the on-screen prompts to complete your registration and verify your email.

### Step 2: Sign the Gemma 3 model agreement

* Make sure you are logged into your new Kaggle account.

* Go directly to the Gemma 3 model card page: https://www.kaggle.com/models/keras/gemma3/keras/

* You should see a "Request Access" button.

* Click the button, read through the license agreement, and click "Accept" to gain access to the model. You must do this before the API will let you download the model.

### Step 3: Generate your Kaggle API key

* From any Kaggle page, click on your profile picture or icon in the top-right corner.

* Select "Account" from the drop-down menu.

* Scroll down to the "API" section.

* Click the "Create New API Token" button.

* This will immediately download a file named `kaggle.json` to your computer. This file contains your username and your secret API key. Keep it safe.

### Step 4: Set your API Key in  Colab

* Click the "key" icon ðŸ”‘ in the left-hand sidebar.

* You will see the "Secrets" panel.

* Now, open the kaggle.json file you downloaded on your computer. It's a simple text file and will look like this:

   ```json
   {"username":"YOUR_KAGGLE_USERNAME","key":"YOUR_KAGGLE_API_KEY"}
   ```
* In the Colab Secrets panel, create two new secrets:

   1. Name: `KAGGLE_USERNAME`

      Value: Copy and paste `YOUR_KAGGLE_USERNAME` from your `kaggle.json` file.

   2. Name: `KAGGLE_KEY`

      Value: Copy and paste `YOUR_KAGGLE_API_KEY` from your `kaggle.json` file.

* For both secrets, make sure the "Notebook access" toggle is switched on.


## Imports

In this lab, you will use the Keras package for loading a Gemma model and the Pandas package to load the dataset.

Run the following cell to import the required packages.

In [None]:
%%capture
# Install the custom package for this course.
!pip install "git+https://github.com/google-deepmind/ai-foundations.git@main"

import os

from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")

os.environ["KERAS_BACKEND"] = "jax"  # Set the Keras backend to JAX.

import json # For loading training logs.
from urllib import request # For downloading training logs.
import keras # For training the model.
import keras_nlp # For loading Gemma-1B.
import pandas as pd # For loading the dataset.
import jax.numpy as jnp # For working with matrices and vectors.
from textwrap import fill # For formatting long paragraphs.
# For loading the formatting function that you implemented in previous labs.
from ai_foundations import formatting

# Avoids memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"
keras.utils.set_random_seed(812)  # For making the training reproducible.

## Fine-tune Gemma 3 with LoRA

### Prepare the data and load the model

The following cells implement the data preparation steps and load the Gemma-1B model. These are the same steps as you performed when you attempted to perform full-parameter fine-tuning with Gemma.

#### Load and format the fine-tuning data

For adding the special delimiters, you will again use the `format_qa` function that you have implemented earlier. Additionally, the following cell also defines a function `format_question` for formatting only a question. This function can be used for formatting prompts when you evaluate the model.

The following cell defines the `format_question`, loads the Africa Galore QA dataset, and creates the data dictionary that can be used to fine-tune the Keras implementation of Gemma.

In [None]:
def format_question(
    question: str,
    sot: str = "<start_of_turn>",
    eot: str = "<end_of_turn>"
) -> str:
    """
    Formats a question for prompting the model and adds special delimiters at
    the start and end of the question.

    Args:
      text: The question to be formatted.
      sot: The token to mark the start of a turn.
      eot: The token to mark the end of a turn.

    Returns:
      Formatted string of the question.
    """

    formatted_q = f"{sot}user\n{question}{eot}\n"

    return formatted_q

# Load the question-answer dataset.
africa_galore_qa = pd.read_json(
    "https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore_qa_v2.json"
)

questions = []  # List of formatted questions.
answers = []  # List of formatted answers.

for idx, row in africa_galore_qa.iterrows():
    # Run the format_qa function from the previous lab to format the question
    # and the answer.
    question, answer = formatting.format_qa(row)
    questions.append(question)
    answers.append(answer)

# Show the first set of input and output.
print(questions[0])
print(fill(answers[0], replace_whitespace=False))

# Prepare the data dictionary for fine-tuning Gemma.
data = {
    "prompts": questions,
    "responses": answers
}

#### Load the model

Run the following cell to load the Gemma model and print a summary of its number of parameters.


In [None]:
# Load the Gemma-1B Keras model.
model = keras_nlp.models.Gemma3CausalLM.from_preset("gemma3_1b")
model.summary()

### Prompt the pre-trained Gemma model

Before fine-tuning the model, observe how the pre-trained model responds to prompts in the format that you have been using for the flashcard generator.

Run the following cell to generate responses to three questions that you will use to evaluate your model throughout this lab:

1. "What is Kente cloth?": A question that is present in the Africa Galore QA dataset.
2. "What is the tallest mountain in Africa?": A question that does not appear in the Africa Galore QA dataset but that is on a topic that is present in the fine-tuning dataset.
3. "What is Mount Aconcagua?": A question that does not appear in the fine-tuning dataset and that is on a topic that is not covered in the fine-tuning dataset.

Inspect the answers to each of these questions.

In [None]:
evaluation_prompts = [
    "What is Kente cloth?",
    "What is the tallest mountain in Africa?",
    "What is Mount Aconcagua?"
]

# Predict answers for three formatted questions through the model.
# Generate answers with a length of (up to) 200 tokens.
for prompt in evaluation_prompts:
    formatted_prompt = format_question(prompt)
    model_response = model.generate(formatted_prompt, max_length=200)
    print(fill(model_response, replace_whitespace=False))
    print('\n------\n')

#### What did you observe?

You likely observed that when prompting a pre-trained model that has been only trained to predict the next token from a large corpus of texts, it may be able to produce some answers to questions. However, in many cases it will produce quite repetitive answers and sometimes it will generate similar questions rather than generating an appropriate answer.

Furthermore, you likely noticed that the model did not generate the category or the turn-taking tokens for any of the questions. As such, it did not adhere to the format that you want for your flashcard generator.

### Activate LoRA

LoRA is already implemented in the Keras implementation of Gemma and can be added with a call to the `enable_lora` method. For example, to enable LoRA with a rank of 4, you can call:

```python
model.backbone.enable_lora(rank=4)
```

Run the following cell to enable LoRA. Then consider the output of the summary method. Observe how the parameter count has slightly increased because of the additional LoRA parameters. At the same time, observe how the number of trainable parameters is much smaller than the number of total parameters.

In [None]:
model.backbone.enable_lora(rank=4)
model.summary()

## Activity 1: Setting hyperparameters

The last step before you can start training is to set the hyperparameters of the training process. Recall that the two parameters that are particularly important for fine-tuning are the **learning rate** and the **number of epochs**, that is the number of times you iterate through the fine-tuning data during fine-tuning.

For both of these parameters, you generally want to perform hyperparameter tuning and choose a set of parameters that leads to useful results on your evaluation examples. Note that you are using a model here that is different from your SLM and also you are training the model using LoRA instead of performing full-parameter fine-tuning. Therefore, the optimal hyperparameters may be different for fine-tuning the Gemma model than the ones you used to fine-tune your SLM.

For pre-trained models like Gemma, a reasonable range for the number of epochs is usually 1 to 20, depending both on the size of your fine-tuning set and the learning rate. If you use a larger fine-tuning set and/or a higher learning rate, you generally need to fine-tune for fewer epochs. Conversely, if you have very few fine-tuning examples or use a very low learning rate, you will need to fine-tune for more epochs.

For the learning rate, values between 0.0001 (1e-4) and 0.000001 (1e-6) tend to work well in practice for fine-tuning models like Gemma.

<br>

------
> **ðŸ’» Your task:**
>
> Find the right setting for the learning rate.
>
> In practice, you would fine-tune a model multiple times with different learning rates and evaluate the model performance during fine-tuning. You would then choose a learning rate that leads to the best evaluation results. However, since fine-tuning consumes a lot of GPU resources and your Colab GPU resources may be limited, the following cell provides you with the output of training runs with different learning rates.
>
> Select a learning rate and execute the cell to print the log of the training run with that specific learning rate. Look at the model generations after each epoch and choose a learning rate that leads to "good" results for all prompts.
>
------

In [None]:
# @title Training logs for different learning rates

if "training_logs" not in globals():
    TRAINING_LOG_URL = "https://storage.googleapis.com/dm-educational/assets/ai_foundations/finetune-gemma-training-logs.json"
    with request.urlopen(TRAINING_LOG_URL) as json_file:
        training_logs = json.loads(json_file.read().decode())

learning_rate = 1e-3 # @param ["1e-3","5e-4","2e-4","1e-4","5e-5","2e-5","1e-5","5e-6","2e-6","1e-6"] {"type":"raw"}

print(training_logs[str(learning_rate)])

Once you have determined the optimal learning rate for this fine-tuning problem, add the learning rate to the following cell.

In [None]:
# Set the learning rate
model.optimizer.learning_rate = # Add your code here.

# Set the number of epochs.
num_epochs = 10

# Set the maximum length.
model.preprocessor.sequence_length = 400

### Perform fine-tuning

Now that you have prepared the data and set up the model and LoRA, fine-tuning can be done just like any other training in Keras, that is, with the `model.fit()` method.

As previously, it is important to monitor the training progress using a callback function. Run the following cell to define a callback function that prints generations for the three evaluation prompts that you defined above.



In [None]:
class EvaluationCallback(keras.callbacks.Callback):
    """
    A Keras callback function to print generations for evaluation prompts.
    """
    def on_epoch_end(self, epoch: int, logs=None):
        """Prints generations for the three evaluation prompts.

        Args:
          epoch: The current epoch.
          logs: The logs dictionary.
        """

        evaluation_prompts = [
            "What is Kente cloth?",
            "What is the tallest mountain in Africa?",
            "What is Mount Aconcagua?"
        ]

        # Run three formatted questions through the model.
        # Generate answers with a length of (up to) 200 tokens.
        for prompt in evaluation_prompts:
            formatted_prompt = format_question(prompt)
            model_response = model.generate(formatted_prompt, max_length=200)
            print(fill(model_response, replace_whitespace=False))
            print('\n------\n')

evaluation_callback = EvaluationCallback()

------
> **ðŸ’» Your task:**
>
>Run the following cell to fine-tune your model.
>
>While the model is fine-tuning, monitor the fine-tuning progress using the three evaluation prompts.
>
>Keep note of:
>
>* When does the model begin to output "\<start_of_turn>model"?
>* When does the model begin to produce outputs that start with "Category:" and the correct category?
>* At which epochs does the model produce very strange outputs?
>* From which epoch on does the model consistently produce a single paragraph and stop repeating texts?
>* At which epoch does the model start to produce "\<end_of_turn>" and stop producing any other output thereafter?
>
>After the first epoch, the training will take about one to two minutes per epoch on a T4 GPU.  Expect that it takes 15 to 20 minutes in total for fine-tuning the model for 10 epochs.
>
------

In [None]:
training_history = model.fit(
    data,
    epochs=num_epochs,
    batch_size=1,
    verbose=1,
    callbacks=[evaluation_callback]
)


At the end of the fine-tuning process, did the model outputs satisfy all of the criteria mentioned above? If not, you may want to continue fine-tuning for a few more epochs.

Overall, you likely noticed "good" performance, even for questions that are not about African culture and geography. For example, the question about the South American Mountain, "What is Mount Aconcagua?". You likely observed that the pre-trained model was able to provide an answer about this mountain and managed to do so in the desired format for the flashcard generator. In doing so, it managed to combine the factual information from its pre-training performed on a dataset of 2 trillion tokens with the specifics of the format that it learned from the fine-tuning examples.

## Additional evaluations

As always, it is important to test your model on a wide range of questions. Define some more questions below and check the model outputs.

In [None]:
additional_prompts = [
    "What is Mount Aconcagua?",
    "What is American Football?",
    "What is Python?"
]

for prompt in additional_prompts:
    formatted_prompt = format_question(prompt)
    model_response = model.generate(formatted_prompt, max_length=300)
    print(fill(model_response, replace_whitespace=False))
    print('\n------\n')

#### What did you observe?

You likely observed that the model was able to infer appropriate categories for many questions and managed to output many reasonable responses.

If the responses to all your questions were useful, try to find very niche topics and ask the model about them. Are you able to find the limits of this model?


## Summary

In this lab you fine-tuned Gemma-1B with LoRA. You observed that the model could not only produce answers in the desired format for the topics that are present in the fine-tuning dataset (Africa Galore QA) but on a variety of topics.

This process demonstrated the power of combining "good" pre-trained models with fine-tuning. It allows you to build highly capable models for many tasks with a small fine-tuning dataset.

## Solutions

The following cells provide reference solutions to the coding activities in this notebook. If you really get stuck after trying to solve the activities yourself, you may want to consult these solutions.

It is recommended that you *only* look at the solutions after you have tried to solve the activities *multiple times*. The best way to learn challenging concepts in computer science and artificial intelligence is to debug your code piece-by-piece until it works, rather than copying existing solutions.

If you feel stuck, you may want to first try to debug your code. For example, by adding additional print statements to see what your code is doing at every step. This will provide you with a much deeper understanding of the code and the materials. It will also provide you with practice on how to solve challenging coding problems beyond this course.

To view the solutions for an activity, click the arrow to the left of the activity name. If you consult the solutions, do not copy and paste them into the cells above. Instead, look at them, and type them manually into the cell. This will help you understand where you went wrong.


### Activity 1

A good choice for the learning rate is 2e-5. With this setting, the answers stabilize after about 7 to 8 epochs and are factually correct.

A lower learning rate is also possible. However, the lower the learning rate you choose, the more epochs you need for fine-tuning.

Higher learning rates look good at first glance but the answers are of lower quality.

The answers produced by a learning rate of 5e-5 do not mention that Tokyo is the capital city of Japan and give a more generic description that applies to very large cities in general. That being said, they may still be acceptable.

A learning rate of 1e-4 leads to answers that contain factual errors. For example, Kilimanjaro is not the second highest peak in the world after Mount Everest. Kilimanjaro is known for many amazing facts, but not that it played a particular role in a 1952 ascent of Mount Everest. The first successful climb of Mount Everest happened one year later. Even higher learning rates have similar problems.

As you will observe, it can be quite difficult to judge the correctness of answers. The "destruction" of pre-trained knowledge with higher learning rates can be difficult to spot. Catastrophic forgetting does not necessarily occur on all aspects but can be more subtle as for the Kilimanjaro answers for a learning rate of 1e-4. Therefore, it is important to choose validation prompts that you yourself can answer very well at expert level.

## References

[1] Gemma Team, Google DeepMind. 2025. Gemma 3 technical report. https://arxiv.org/pdf/2503.19786

[2] Edward Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. 2022. LoRA: Low-rank adaptation of large language models. In *International Conference on Learning Representations (ICLR 2022)*. https://openreview.net/pdf?id=nZeVKeeFYf9

[3] Dan Biderman, Jacob Portes, Jose Javier Gonzalez Ortiz, Mansheej Paul, Philip Greengard, Connor Jennings, Daniel King, Sam Havens, Vitaliy Chiley, Jonathan Frankle, Cody Blakeney, and John P. Cunningham. 2024. LoRA learns less and forgets less. *Transactions on Machine Learning Research*. https://openreview.net/forum?id=aloEru2qCG
