<img src="https://storage.googleapis.com/dm-educational/assets/ai_foundations/GDM-Labs-banner-image-C6-white-bg.png">

# Lab: Apply Gradient Accumulation

<a href='https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_7/gdm_lab_7_6_apply_gradient_accumulation.ipynb' target='_parent'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/></a>

Learn how to simulate a larger batch size on a memory-constrained GPU using gradient accumulation.

30 minutes

##Overview

Previously, you have explored how using a larger batch size can lead to more stable training and improve the efficiency of the GPU. However, you have also experienced that GPU memory is a critical bottleneck. So what happens when you want the benefits of a larger batch size, but your GPU doesn't have enough memory to handle it?

In this lab, you will encounter this exact problem. You will first attempt to train the Gemma-4B model with a larger batch size. Then, you will learn how to use **gradient accumulation** to achieve the same result successfully on your resource-constrained GPU.

### What you will learn:

By the end of this lab, you will be able to:

* Explain why a larger batch size can cause out-of-memory errors.
* Describe the concept of gradient accumulation.
* Implement gradient accumulation in Keras to train with a larger effective batch size.

### Tasks

**In this lab, you will**:
* Attempt to finetune Gemma-4B with a batch size of 2 and observe the memory error.
* Reconfigure the optimizer to use gradient accumulation.
* Fine-tune the model successfully by simulating a larger batch size of 8.

 **This lab needs to be run on a GPU. Choose a T4 GPU.** See the section "How to use Google Colaboratory (Colab)" below for instructions on how to do this.

 **You also need a Kaggle account** to download the weights of the Gemma 3 model. See the section "Set up a Kaggle account" for instructions on how to do this.


## How to use Google Colaboratory (Colab)

Google Colaboratory (also known as Google Colab) is a platform that allows you to run Python code in your browser. The code is written in **cells** that are excuted on a remote server.

To run a cell, hover over a cell and click on the `run` button to its left. The run button is the circle with the triangle (â–¶). Alternatively, you can also click on a cell and use the keyboard combination Ctrl+Return (or âŒ˜+Return if you are using a Mac).

To try this out, run the following cell. This should print today's day of the week below it.

In [None]:
from datetime import datetime
print(f"Today is {datetime.today():%A}.")

Note that the **order in which you run the cells matters**. When you are working through a lab, make sure to always run all cells in order, otherwise the code might not work. If you take a break while working on a lab, Colab may disconnect you and in that case, you have to execute all cells again before  continuing your work. To make this easier, you can select the cell you are currently working on and then choose _Runtime â†’ Run before_  from the menu above (or use the keyboard combination Ctrl/âŒ˜ + F8). This will re-execute all cells before the current one.

### Using Colab with a GPU

Follow these steps to run the activities in this lab on a GPU:

1.  In the top menu bar, click on **Runtime**.
2.  Select **Change runtime type** from the dropdown menu.
3.  In the pop-up window under **Hardware accelerator**, select **T4 GPU**.
4.  Click **Save**.

Your Colab session will now restart with GPU access.

Note that access to GPUs is limited and at times, you may not be able to run this lab on a GPU. All activities will still work but they will run slower and you will have to wait longer for some of the cells to finish running.

## Set up a Kaggle account



To run this notebook, you will have to sign up for [Kaggle](https://www.kaggle.com), a platform that hosts datasets and models for machine learning. You will also need to sign the agreement for using the Gemma 3 model. This is required so that you can download the weights of the Gemma model for fine-tuning.

### Step 1: Create your Kaggle account

* Go to the Kaggle website: https://www.kaggle.com

* Click the "Register" button in the top-right corner.

* You can sign up using your Google account (recommended for easy Colab integration) or by entering an email and password.

* Follow the on-screen prompts to complete your registration and verify your email.

### Step 2: Sign the Gemma 3 model agreement

* Make sure you are logged into your new Kaggle account.

* Go directly to the Gemma 3 model card page: https://www.kaggle.com/models/keras/gemma3/keras/

* You should see a "Request Access" button.

* Click the button, read through the license agreement, and if you are happy to, click "Accept" to gain access to the model. You must do this before the API will let you download the model.

### Step 3: Generate your Kaggle API key

* From any Kaggle page, click on your profile picture or icon in the top-right corner.

* Select "Account" from the drop-down menu.

* Scroll down to the "API" section.

* Click the "Create New API Token" button.

* This will immediately download a file named `kaggle.json` to your computer. This file contains your username and your secret API key. Keep it safe.

### Step 4: Set your API Key in  Colab

* Click the "key" icon ðŸ”‘ in the left-hand sidebar.

* You will see the "Secrets" panel.

* Now, open the kaggle.json file you downloaded on your computer. It's a simple text file and will look like this:

   ```json
   {"username":"YOUR_KAGGLE_USERNAME","key":"YOUR_KAGGLE_API_KEY"}
   ```
* In the Colab Secrets panel, create two new secrets:

   1. Name: `KAGGLE_USERNAME`

      Value: Copy and paste `YOUR_KAGGLE_USERNAME` from your `kaggle.json` file.

   2. Name: `KAGGLE_KEY`

      Value: Copy and paste `YOUR_KAGGLE_API_KEY` from your `kaggle.json` file.

* For both secrets, make sure the "Notebook access" toggle is switched on.


## Imports

In this lab, you will use the Keras package for loading and fine-tuning a Gemma model and the Pandas package, for loading the dataset.

Run the following cell to import the required packages.


In [None]:
%%capture
# Install the custom package for this course.
!pip install "git+https://github.com/google-deepmind/ai-foundations.git@main"

import os # For setting system variables.

from google.colab import userdata # For using Colab secrets.

os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")

os.environ["KERAS_BACKEND"] = "jax"  # Set the Keras backend to JAX.
# Disable the command buffer pre-allocation to free up memory.
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_command_buffer="
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.9"

import keras # For defining and training models.
import keras_nlp # For loading the Keras implementation of Gemma.
import pandas as pd # For loading the dataset.
from textwrap import fill # For formatting long paragraphs.
from ai_foundations import formatting # For formatting the training data.

keras.utils.set_random_seed(812)  # For Keras layers.

## Load and process data

As in previous labs, the following cell defines a function `format_question` that formats a prompt. It also loads the Africa Galore QA dataset and processes the individual questions so that the data can be used to fine-tune a model to generate answers for flashcards, as in previous courses.

In [None]:
def format_question(
    question: str,
    sot = "<start_of_turn>",
    eot = "<end_of_turn>"
) -> str:
    """
    Formats a question for prompting the model and adds special delimiters at
    the start and end of the question.

    Args:
      text: The question to be formatted.
      sot: The token to mark the start of a turn.
      eot: The token to mark the end of a turn.

    Returns:
      Formatted string of the question.
    """

    formatted_q = f"{sot}user\n{question}{eot}\n"

    return formatted_q

# Load the question-answer dataset.
africa_galore_qa = pd.read_json(
    "https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore_qa_v2.json"
)

questions = []  # List of formatted questions.
answers = []  # List of formatted answers.

for idx, row in africa_galore_qa.iterrows():
    # Run the format_qa function from the previous lab to format the question
    # and the answer.
    question, answer = formatting.format_qa(row)
    questions.append(question)
    answers.append(answer)

# Show the first set of input and output.
print(questions[0])
print(fill(answers[0], replace_whitespace=False))

# Prepare the data dictionary for fine-tuning Gemma.
data = {
    "prompts": questions,
    "responses": answers
}

## Coding Activity 1: Fine-tune with a larger batch size

As you have previously encountered, by loading the data in batches you can:

* Improve training stability. This is because processing only a single datapoint at a time can introduce more noise when estimating gradients.

* Improve training efficiency by leveraging the parallel processing ability of GPUs to increase the amount of data being processed at once.

What happens when you try to increase the batch size for training Gemma-4B?

------
> ðŸ’» **Your task**:
>
> Run the following four cells to load the model, activate LoRA, set the hyperparameters, define the callback function, and set up the training code.
>
> Then, in the cell that calls the `model.fit()` method, change the `batch_size` value to 4.
>
> Observe what happens when you attempt to run training.

Note: It may take a few minutes before you see any output.
------

In [None]:
# Load the Gemma3-4B Keras model with bfloat16 precision.
model = keras_nlp.models.Gemma3CausalLM.from_preset(
    preset="gemma3_4b_text", dtype="bfloat16"
)
model.summary()

In [None]:
model.backbone.enable_lora(rank=4)
model.summary()

In [None]:
# Set hyperparameters.
optimizer = keras.optimizers.AdamW(
    learning_rate=1e-4,
    weight_decay=0.01,
    beta_1=0.9,
    beta_2=0.95,
    epsilon=1e-6,
    clipnorm=0.5,
)

# Determine the number of epochs.
num_epochs = 4

# Set the maximum length.
model.preprocessor.sequence_length = 400

# Compile the optimizer.
model.compile(
    optimizer=optimizer,
)

In [None]:
# Define a list of prompts you want to check after each epoch.
test_prompts = [
    "What is Kente cloth?",
    "What is Kilimanjaro?",
    "What is Tokyo?"
]

class GenerationMonitor(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f"\n--- Generations after epoch {epoch + 1} ---")
        for prompt in test_prompts:
            # Format the prompt correctly for the model.
            formatted_prompt = format_question(prompt)

            # Generate and print the output.
            output = self.model.generate(formatted_prompt, max_length=150)
            print(output)
            print("-" * 20)

In [None]:
# Create an instance of the monitoring callback.
generation_callback = GenerationMonitor()

# Train the model.
history = model.fit(
    data,
    epochs=num_epochs,
    batch_size=..., # Add your code here.
    callbacks=[generation_callback],
)

### What did you observe?

You should observe a `RESOURCE_EXHAUSTED` error in the output above. This is very similar to the error you observed in the lab where you attempted to load Gemma-4B with 32-bit floating point numbers.

This is again an out-of-memory error. It means that the GPU ran out of its dedicated memory while trying to execute the training step.

The key line is at the end:

`Out of memory while trying to allocate 52428800 bytes.`

This tells you the GPU needed to allocate over 52 MB of additional memory for this one step, but it did not have enough free space.

In the context of this lab, this happened because increasing the batch size from 1 to 4 increased the amount of activations that needed to be stored in memory during the forward pass by a factor of 4. This change was enough to exceed the GPU's memory limit.

### Restart your Colab session

Sometimes, especially when working with large models, it is best to clear all variables and free up the GPU's memory for the next task. The most reliable way to do this is to restart your session.

Before continuing with Coding Activity 2, follow these steps to restart your Colab session:

1. In the top menu bar, click **Runtime**.

2. Select **Restart session** from the dropdown menu.

3. In the pop-up window, click **Yes** to confirm.

Your Colab session will now restart with a fresh environment. Note that this clears all of your variables and imported libraries, so you will need to re-run the imports and any set-up cells again before proceeding. These are copied below for convenience, so you can start running all cells from the **Coding Activity 2** section.

## Imports

In [None]:
%%capture
# Install the custom package for this course.
!pip install "git+https://github.com/google-deepmind/ai-foundations.git@main"

import os # For setting system variables.

from google.colab import userdata # For using Colab secrets.

os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")

os.environ["KERAS_BACKEND"] = "jax"  # Set the Keras backend to JAX.
# Disable the command buffer pre-allocation to free up memory.
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_command_buffer="
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.9"

import keras # For defining and training models.
import keras_nlp # For loading the Keras implementation of Gemma.
import pandas as pd # For loading the dataset.
from textwrap import fill # For formatting long paragraphs.
from ai_foundations import formatting # For formatting the training data.

keras.utils.set_random_seed(812)  # For Keras layers.

## Data preprocessing

In [None]:
def format_question(
    question: str,
    sot = "<start_of_turn>",
    eot = "<end_of_turn>"
) -> str:
    """
    Formats a question for prompting the model and adds special delimiters at
    the start and end of the question.

    Args:
      text: The question to be formatted.
      sot: The token to mark the start of a turn.
      eot: The token to mark the end of a turn.

    Returns:
      Formatted string of the question.
    """

    formatted_q = f"{sot}user\n{question}{eot}\n"

    return formatted_q

# Load the question-answer dataset.
africa_galore_qa = pd.read_json(
    "https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore_qa_v2.json"
)

questions = []  # List of formatted questions.
answers = []  # List of formatted answers.

for idx, row in africa_galore_qa.iterrows():
    # Run the format_qa function from the previous lab to format the question
    # and the answer.
    question, answer = formatting.format_qa(row)
    questions.append(question)
    answers.append(answer)

# Show the first set of input and output.
print(questions[0])
print(fill(answers[0], replace_whitespace=False))

# Prepare the data dictionary for fine-tuning Gemma.
data = {
    "prompts": questions,
    "responses": answers
}

## Coding Activity 2: Fine-tune with gradient accumulation

You have just confirmed that your GPU cannot handle a physical batch size of 4. One remedy to this issue is to simulate a larger batch size using **gradient accumulation**.

This technique works by performing the forward and backward passes on several smaller batches and accumulating (summing) their gradients. Only after processing a specified number of small batches does the optimizer update the model's weights using the combined gradient. This achieves the same stable learning signal as a large batch, without the high memory cost.

Run the following cells to load and process the data again and to load and prepare the model for fine-tuning.

In [None]:
# Load the Gemma3-4B Keras model with bfloat16 precision.
model = keras_nlp.models.Gemma3CausalLM.from_preset(
    preset="gemma3_4b_text", dtype="bfloat16"
)
model.summary()

In [None]:
model.backbone.enable_lora(rank=4)
model.summary()

------
> ðŸ’» **Your task**:
>
> In the following cell, configure the `AdamW` optimizer to use gradient accumulation.
>
> To do this, you need to add the `gradient_accumulation_steps` argument directly to the optimizer. Configure it to simulate a batch size of **8** (i.e., 8 times as large as the physical batch size).
>
> ```
> optimizer = keras.optimizers.AdamW(
>    learning_rate=1e-4,
>    ...,
>    gradient_accumulation_steps=<NUM_STEPS>
>)
> ```
>
------

In [None]:
# Set hyperparameters.
optimizer = keras.optimizers.AdamW(
    # Multiply the learning rate to match the new effective batch size.
    learning_rate=8e-4,
    weight_decay=0.01,
    beta_1=0.9,
    beta_2=0.95,
    epsilon=1e-6,
    clipnorm=0.5,
    # Add your code here.
)

# Determine the number of epochs.
num_epochs = 4

# Set the maximum length.
model.preprocessor.sequence_length = 400

# Compile the optimizer.
model.compile(
    optimizer=optimizer,
)

In [None]:
# Define a list of prompts you want to check after each epoch.
test_prompts = [
    "What is Kente cloth?",
    "What is Kilimanjaro?",
    "What is Tokyo?"
]

class GenerationMonitor(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f"\n--- Generations after epoch {epoch + 1} ---")
        for prompt in test_prompts:
            # Format the prompt correctly for the model.
            formatted_prompt = format_question(prompt)

            # Generate and print the output.
            output = self.model.generate(formatted_prompt, max_length=150)
            print(output)
            print("-" * 20)

In [None]:
# Create an instance of the monitoring callback.
generation_callback = GenerationMonitor()

# Train the model.
history = model.fit(
    data,
    epochs=num_epochs,
    batch_size=1,
    callbacks=[generation_callback],
)

### What did you observe?

Across epochs, you likely observed that the generations become steadily more coherent and on-topic. After one epoch, there is still some incoherence. However, notice that the model does not generate any single-token loops like the "KMnO4" output from the previous training run in the lab "Fine-Tune a Model with bfloat16." By the end of the second epoch, the outputs are fluent and factual.

The prior lab's training run used a physical batch size of 1. In comparison, this run with gradient accumulation, resulted in more coherent results more quickly and finished with a lower loss and higher accuracy. This demonstrates that gradient accumulation improves training stability and leads to smoother training dynamics.

## Summary

This lab introduced you to another practical solution for training with limited GPU resources. You explored how you can use a larger effective batch size than your GPU's memory can physically handle. This can be done through **gradient accumulation**. By accumulating gradients over several smaller batches before updating the weights, you can successfully simulate a larger batch size, gaining the benefit of more stable gradients without the high memory cost.

## Solutions

The following cells provide reference solutions to the coding activities in this notebook. If you really get stuck after trying to solve the activities yourself, you may want to consult these solutions.

It is recommended that you *only* look at the solutions after you have tried to solve the activities *multiple times*. The best way to learn challenging concepts in computer science and artificial intelligence is to debug your code piece-by-piece until it works, rather than copying existing solutions.

If you feel stuck, you may want to first try to debug your code. For example, by adding additional print statements to see what your code is doing at every step. This will provide you with a much deeper understanding of the code and the materials. It will also provide you with practice on how to solve challenging coding problems beyond this course.

To view the solutions for an activity, click on the arrow to the left of the activity name. If you consult the solutions, do not copy and paste them into the cells above. Instead, look at them, and type them manually into the cell. This will help you understand where you went wrong.

### Coding Activity 1

In [None]:
# Train the model.
history = model.fit(
    data,
    epochs=num_epochs,
    batch_size=4,
    callbacks=[generation_callback],
)

### Coding Activity 2

In [None]:
# Set hyperparameters.
optimizer = keras.optimizers.AdamW(
    learning_rate=8e-4, # Multiply the learning rate to match the new effective batch size.
    weight_decay=0.01,
    beta_1=0.9,
    beta_2=0.95,
    epsilon=1e-6,
    clipnorm=0.5,
    gradient_accumulation_steps=8,
)