##### Copyright 2024 Google LLC.

In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tuning Gemma with Torch XLA and Hugging Face TRL

Welcome to this step-by-step guide on fine-tuning the [Gemma](https://huggingface.co/google/gemma-2b) using [Torch XLA](https://github.com/pytorch/xla).


[**Gemma**](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.

[**Torch XLA**](https://pytorch.org/xla/) enables you to leverage the computational power of TPUs (Tensor Processing Units) for efficient training of deep learning models. By interfacing PyTorch with the [XLA (Accelerated Linear Algebra)](https://openxla.org/xla) compiler, Torch XLA translates PyTorch operations into XLA operations that can be executed on TPUs. This means you can write your models in PyTorch as usual, and Torch XLA handles the underlying computations to run them efficiently on TPUs.

[**Transformer Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) is a framework developed by Hugging Face to fine-tune and align both transformer language and diffusion models using methods such as Supervised Fine-Tuning (SFT), Reward Modeling (RM), Proximal Policy Optimization (PPO), Direct Preference Optimization (DPO), and others.

Integrating PyTorch with XLA allows developers to run PyTorch code on TPUs with minimal changes to their existing codebase. This seamless integration provides the performance benefits of TPUs while maintaining the flexibility and ease of use of the PyTorch framework.

By the end of this notebook, you will learn:

- About Torch XLA
- How to peform **Parameter-Efficient Fine-Tuning (PEFT)** with the **Low-Rank Adaptation (LoRA)** on [Gemma 2 2B](https://huggingface.co/google/gemma-2-2b) using Hugging Face's **TRL** framework, **Torch XLA** and TPUs.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/Finetune_with_Torch_XLA.ipynb)

[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/notebooks/welcome?src=https://github.com/google-gemini/gemma-cookbook/blob/main/Gemma/Finetune_with_Torch_XLA.ipynb)

## Setup

### Selecting the Runtime Environment

To start, you can choose either **Google Colab** or **Kaggle** as your platform. Select one, and proceed from there.

- #### **Google Colab** <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png" alt="Google Colab" width="30"/>

  1. Click **Open in Colab**.
  2. In the menu, go to **Runtime** > **Change runtime type**.
  3. Under **Hardware accelerator**, select **TPU**.
  4. Ensure that the **TPU type** is set to **TPU v2-8**.

- #### **Kaggle** <img src="https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png" alt="Kaggle" width="40"/>

  1. Click **Open in Kaggle**.
  2. Click on **Settings** in the right sidebar.
  3. Under **Accelerator**, select **TPUs**.
    - Note: Kaggle currently provides **TPU v3-8**.
  4. Save the settings, and the notebook will restart with TPU support.


### Gemma using Hugging Face

Before diving into the tutorial, let's set up Gemma:

1. **Create a Hugging Face Account**: If you don't have one, you can sign up for a free account [here](https://huggingface.com/join).
2. **Access the Gemma Model**: Visit the [Gemma model page](https://huggingface.com/collections/google/gemma-2-release-667d6600fd5220e7b967f315) and accept the usage conditions.
3. **Generate a Hugging Face Token**: Go to your Hugging Face [settings page](https://huggingface.com/settings/tokens) and generate a new access token (preferably with `write` permissions). You'll need this token later in the tutorial.

**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**

### Configure Your Credentials

To access private models and datasets, you need to log in to the Hugging Face (HF) ecosystem.

- #### **Google Colab** <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png" alt="Google Colab" width="30"/>
  If you're using Colab, you can securely store your Hugging Face token (`HF_TOKEN`) using the Colab Secrets manager:
  1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src="https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg" alt="The Secrets tab is found on the left panel." width=50%>
  2. **Add Hugging Face Token**:
    - Create a new secret with the **name** `HF_TOKEN`.
    - Copy/paste your token key into the **Value** input box of `HF_TOKEN`.
    - **Toggle** the button on the left to allow notebook access to the secret

- #### **Kaggle** <img src="https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png" alt="Kaggle" width="40"/>
  To securely use your Hugging Face token (`HF_TOKEN`) in this notebook, you'll need to add it as a secret in your Kaggle environment:  
  1. Open your Kaggle notebook and locate the **Addons** menu at the top in your notebook interface.
  2. Click on **Secrets** to manage your environment secrets.  
  <img src="https://i.imgur.com/vxrtJuM.png" alt="The Secrets option is found at the top." width=50%>
  3. **Add Hugging Face Token**:
      - Click on the **Add secret** button.
      - In the **Label** field, enter `HF_TOKEN`.  
      - In the **Value** field, paste your Hugging Face token.
      - Click **Save** to add the secret.

In [None]:
import os
import sys

if 'google.colab' in sys.modules:
    # Running on Colab
    from google.colab import userdata
    os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')
elif os.path.exists('/kaggle/working'):
    # Running on Kaggle
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    os.environ['HF_TOKEN'] = user_secrets.get_secret("HF_TOKEN")
else:
    # Not running on Colab or Kaggle
    raise EnvironmentError('This notebook is designed to run on Google Colab or Kaggle.')

This code retrieves your secrets and sets them as environment variables, which you will use later in the tutorial.

### Setting Up the Environment

Next, you'll set up the environment by installing all the necessary Python packages for fine-tuning the Gemma model on a TPU VM using Torch XLA.


In [None]:
# Uninstalling any existing TensorFlow installations and then install the CPU-only version to avoid conflicts while using the TPU.
!pip uninstall -y tensorflow tf-keras
!pip install tensorflow==2.18.0 tf-keras==2.18.0

!pip uninstall tensorflow -y
!pip install tensorflow-cpu==2.18.0 -q

# Install the appropriate Hugging Face libraries to ensure compatibility with the Gemma model and PEFT.
!pip install transformers==4.46.1 -U -q
!pip install datasets==3.1.0 -U -q
!pip install trl==0.12.0 peft==0.13.2 -U -q
!pip install accelerate==0.34.0 -U -q

# Install PyTorch and Torch XLA with versions compatible with the TPU runtime, ensuring efficient TPU utilization.
!pip install -qq torch~=2.5.0 --index-url https://download.pytorch.org/whl/cpu
!pip install -qq torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html

# Install the `tpu-info` package to display TPU-related information
!pip install tpu-info

Found existing installation: tensorflow 2.15.0
Uninstalling tensorflow-2.15.0:
  Successfully uninstalled tensorflow-2.15.0
Found existing installation: tf_keras 2.15.1
Uninstalling tf_keras-2.15.1:
  Successfully uninstalled tf_keras-2.15.1
Collecting tensorflow==2.18.0
  Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting tf-keras==2.18.0
  Downloading tf_keras-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting tensorboard<2.19,>=2.18 (from tensorflow==2.18.0)
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting keras>=3.5.0 (from tensorflow==2.18.0)
  Downloading keras-3.6.0-py3-none-any.whl.metadata (5.8 kB)
Collecting ml-dtypes<0.5.0,>=0.4.0 (from tensorflow==2.18.0)
  Downloading ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting namex (from keras>=3.5.0->tensorflow==2.18.0)
  Downloading namex-0.0.8-py3-none-any.whl.metadata (246 bytes

**Note**: Ensure that your PyTorch and Torch XLA versions are compatible with the TPU you're using.

### Verify TPU Setup


You run `!tpu-info` to verify the TPU has been properly initialized.

In [None]:
!tpu-info

[3mTPU Chips                                     [0m
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━┓
┃[1m [0m[1mChip       [0m[1m [0m┃[1m [0m[1mType       [0m[1m [0m┃[1m [0m[1mDevices[0m[1m [0m┃[1m [0m[1mPID [0m[1m [0m┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━┩
│ /dev/accel0 │ TPU v2 chip │ 2       │ None │
│ /dev/accel1 │ TPU v2 chip │ 2       │ None │
│ /dev/accel2 │ TPU v2 chip │ 2       │ None │
│ /dev/accel3 │ TPU v2 chip │ 2       │ None │
└─────────────┴─────────────┴─────────┴──────┘
Libtpu metrics unavailable. Is there a framework using the TPU? See https://github.com/google/cloud-accelerator-diagnostics/tree/main/tpu_info for more information


If everything is set up correctly, you should see the TPU details printed out.


## Import the libraries

Now, import all the necessary libraries required for fine-tuning.


In [None]:
import pandas as pd

import torch
print(f"PyTorch version: {torch.__version__}")

import torch_xla
print(f"Torch XLA version: {torch_xla.__version__}")

import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig, PeftModel

from datasets import load_dataset

# Enable Single Program Multiple Data (SPMD) mode,
# which allows for parallel execution across multiple TPU cores
xr.use_spmd()

PyTorch version: 2.5.0+cpu
Torch XLA version: 2.5.0+libtpu


This setup ensures that your environment is correctly configured to use TPUs with PyTorch.


## Fine-tune using PEFT and LoRA

Traditional fine-tuning of large language models (LLMs) like Gemma requires adjusting billions of parameters, making it resource-intensive. This process demands significant computational power and time, which can be impractical for many use cases. That's where Parameter-Efficient Fine-Tuning (PEFT) techniques come in.

### Parameter-Efficient Fine-Tuning (PEFT)

PEFT allows you to adapt large models to specific tasks by updating only a small portion of their parameters. Instead of retraining the entire model, PEFT adds lightweight layers or adapters. Most of the pre-trained weights remain frozen. This approach greatly reduces the computational requirements and the amount of data needed for fine-tuning, making it feasible to fine-tune large models even on modest hardware.

### Low-Rank Adaptation (LoRA)

Among these techniques, one effective option is Low-Rank Adaptation (LoRA). LoRA introduces small, trainable matrices into the model's architecture, specifically targeting the attention layers of Transformer models. Instead of updating the full weight matrices, LoRA adds rank-decomposed matrices, making adaptation more efficient.

#### Key Advantages of LoRA

- **Efficiency**: LoRA significantly reduces the number of trainable parameters by using low-rank adaptations, making the fine-tuning process much more efficient.
- **Memory Savings**: Since only the additional low-rank matrices are updated, GPU/TPU memory requirements are considerably lower.
- **Modularity**: LoRA adapters can be easily merged with the original model or kept separate, offering flexibility for deployment.

In the next section, you'll explore how to implement PEFT with LoRA to fine-tune Gemma using Torch XLA on TPUs and perform the following steps:

- Load a dataset
- Configure the training parameters
- Load the Gemma model and tokenizer
- Fine-tune the model on TPUs using **TRL**'s `SFTTrainer` class

### Load a dataset

For this guide, you'll use an existing dataset from Hugging Face. You can replace it with your own dataset if you prefer.

The dataset chosen for this guide is [**hieunguyenminh/roleplay**](https://huggingface.com/datasets/hieunguyenminh/roleplay), which embodies a wide range of original characters, each with a unique persona. It includes fictional characters, complete with their own backgrounds, core traits, relationships, goals, and distinct speaking styles.

**Credits:** **https://huggingface.com/hieunguyenminh**

You specify the dataset name and use the `load_dataset` function from the Hugging Face `datasets` library to load the training split of the dataset.

In [None]:
dataset_name = "hieunguyenminh/roleplay"
dataset = load_dataset(dataset_name, split="train")
dataset

Dataset({
    features: ['name', 'description', 'text'],
    num_rows: 5755
})

First, let's split the dataset into training and validation sets.

In [None]:
# The first 80% of `train` for training
train_dataset = load_dataset(dataset_name, split='train[:80%]')
# The last 20% of `train` for evaluation
valid_dataset = load_dataset(dataset_name, split='train[-20%:]')

In [None]:
train_dataset

Dataset({
    features: ['name', 'description', 'text'],
    num_rows: 4604
})

In [None]:
valid_dataset

Dataset({
    features: ['name', 'description', 'text'],
    num_rows: 1151
})

Let's look at a few samples to understand the data.


In [None]:
valid_dataset[10]['text']

'<|system|>Introducing Minami "Echo" Ishikawa, a mysterious VR assassin known for her uncanny ability to blend seamlessly into the shadows. Minami possesses a deep understanding of stealth techniques, allowing her to silently eliminate her targets with calculated precision. Her cold and calculating demeanor makes her a formidable force to be reckoned with, leaving enemies shivering at the thought of facing her wrath.</s>\n<|user|>Echo, what makes you so adept at disappearing into thin air?</s>\n<|assistant|>Shadows are my allies, and silence is my weapon. I move as fluidly as the wind, becoming one with my surroundings. My footsteps are whispers, and my presence is a phantom in the night.</s>\n<|user|>What\'s the key to executing a flawless assassination, Echo?</s>\n<|assistant|>Precision is paramount. Every movement, every breath, calculated to achieve the perfect strike. My targets never see me coming, their lives snuffed out before they even realize they\'re in danger.</s>\n<|user|>

In [None]:
if 'google.colab' in sys.modules:
    from google.colab import data_table

    # Enable interactive DataFrame display
    data_table.enable_dataframe_formatter()

# Convert the 'train' split to a Pandas DataFrame
df = pd.DataFrame(train_dataset)

# Select the 'text' column and exclude the rest
df_text = df[['text']]
df_text

### Training Configuration

Now you need to define all the hyperparameters and configurations needed for the fine-tuning process and this includes defining the following:

- The base model and new model names
- LoRA Configuration
- Training Arguments
- SFT Parameters
- Misc. Parameters


You start by specifying the base model (`google/gemma-2b`) and the directory where the fine-tuned model will be saved (`gemma-ft`).


In [None]:
# Define model names
model_name = "google/gemma-2-2b"
new_model = "gemma-ft"

LoRA (Low-Rank Adaptation) allows for efficient fine-tuning by adapting only a subset of model parameters.

Here, you set the following parameters:
- `lora_r` to 64, which controls the rank of the adaptation matrices,
- `lora_alpha` to 32 for scaling
- `lora_dropout` to 0.1 to prevent overfitting.

In [None]:
# LoRA attention dimension
lora_r = 64 # @param {"type":"slider","min":0,"max":64,"step":2}
# Alpha parameter for LoRA scaling
lora_alpha = 32 # @param {"type":"slider","min":0,"max":64,"step":2}
# Dropout probability for LoRA layers
lora_dropout = 0.1 # @param {"type":"slider","min":0,"max":1,"step":0.01}

Set up the training arguments that define how the model will be trained.

Here, you'll define the **output directory**, **number of training epochs**, and **batch sizes** for training and evaluation. You enable **gradient checkpointing** to save memory and set `max_grad_norm` for gradient clipping to stabilize training. The **learning rate**, **optimizer**, and **learning rate scheduler** are configured to optimize the training process. The `max_steps` is set to **-1** to let the number of epochs control training duration.


In [None]:
# Output directory where the model predictions and checkpoints will be stored
output_dir = "./results" # @param {"type":"string"}
# Number of training epochs
num_train_epochs = 5 # @param {"type":"slider","min":1,"max":20,"step":2}
# Batch size per TPU core for training
per_device_train_batch_size = 32 # @param {"type":"slider","min":1,"max":64,"step":1}
# Batch size per TPU core for evaluation
per_device_eval_batch_size = 32 # @param {"type":"slider","min":1,"max":64,"step":1}
# Number of update steps to accumulate the gradients for
gradient_accumulation_steps = 1 # @param {"type":"slider","min":0,"max":16,"step":2}
# Maximum gradient normal (gradient clipping)
max_grad_norm = 0.3 # @param {"type":"slider","min":0,"max":1,"step":0.01}
# Initial learning rate (adafactor optimizer)
learning_rate = 0.0002 # @param {"type":"slider","min":0.00001,"max":0.0005,"step":0.00001}
# Optimizer to use
optim = "adafactor" # adafactor, adamw_torch_fused
# Learning rate schedule (constant a bit better than cosine)
lr_scheduler_type = "constant"
# Number of training steps (overrides num_train_epochs)
max_steps = -1
# Ratio of steps for a linear warmup (from 0 to learning rate)
warmup_ratio = 0.03 # @param {"type":"slider","min":0,"max":0.1,"step":0.01}
# Enable bfloat16 precision
bf16 = True
# Log every X updates steps
logging_steps = 1

In the SFT parameters, `max_seq_length` is set to 512 to define the maximum token length for inputs, and `packing` is enabled to pack multiple shorter sequences into one input for efficiency.

In [None]:
# Maximum sequence length to use
max_seq_length = 512 # @param {"type":"slider","min":32,"max":1024,"step":2}
# Pack multiple short examples in the same input sequence to increase efficiency
packing = True # False

### Under the hood: PyTorch and XLA

PyTorch programs define computation graphs dynamically using its `autograd` system. The TPU does not directly execute Python code; instead, it runs the computation graph defined by your PyTorch program. Behind the scenes, a compiler called **XLA (Accelerated Linear Algebra compiler)** transforms the PyTorch computation graph into TPU machine code. This compiler also performs numerous advanced optimizations on your code and memory layout. The compilation occurs automatically as tasks are sent to the TPU, and you do not need to include XLA in your build chain explicitly.

<img src="https://storage.googleapis.com/gweb-cloudblog-publish/images/1_PyTorchXLA_stack_diagram.max-800x800.png" alt="PyTorch and XLA 2.3 from https://cloud.google.com/blog/products/ai-machine-learning/introducing-pytorch-xla-2-3" width=50%>

The combination of **PyTorch** and **XLA** offers several key advantages:

1. **Seamless Performance Enhancement:** Maintain PyTorch's intuitive and pythonic workflow while effortlessly achieving significant performance gains through the XLA compiler. This integration allows you to optimize your models without altering your familiar coding practices.

2. **Comprehensive Ecosystem Access:** Leverage PyTorch's extensive ecosystem, including a wide range of tools, pretrained models, and a vibrant community. This access enables you to accelerate development, utilize state-of-the-art resources, and benefit from collective expertise.

Harnessing these advantages, you can efficiently fine-tune your custom Gemma model using TPUs.

### Fine-tune Gemma using TPUs

The training leverages PyTorch, XLA and TPUs for efficient computation and uses LoRA for parameter-efficient fine-tuning, which reduces the number of trainable parameters by adapting only specific layers.

Here, you'll be setting up the following:

- The Gemma **base model** and **tokenizer**
- The **LoRA (Low-Rank Adaptation)** configuration for **PEFT (Parameter-Efficient Fine-Tuning)**
- The [**FSDP**](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html#how-fsdp-works) configuration for efficient TPU training
- The **Hugging Face `SFTTrainer` instance** using the training and SFT parameters

First, load the Gemma 2B pre-trained model weights using `AutoModelForCausalLM`, while setting `torch_dtype` to `torch.bfloat16` for optimal performance on TPUs

In [None]:
# Load the Gemma pretrained model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16
)

# You must disable the cache to prevent issues during training
model.config.use_cache = False

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Next, you load the Gemma tokenizer using `AutoTokenizer` from Hugging Face. You adjust the tokenizer's padding side (and token if applicable) here to ensure compatibility during training.


In [None]:
# Load the Gemma tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# You adjust the tokenizer's padding side to ensure compatibility during TPU
# training.
tokenizer.padding_side = "right" # Fix overflow issue with bf16/fp16 training

Now, you've loaded the base Gemma model and tokenizer, and set up the configurations for fine-tuning. Let's focus on initializing the **LoRA** config. Since you're using LoRA, the PEFT library provides a handy [LoraConfig](https://huggingface.com/docs/peft/main/en/package_reference/lora#peft.LoraConfig) which defines on which layers of the base model to apply the adapters. One typically applies LoRA on the linear projection matrices of the attention layers of a **Transformer**. You then provide this configuration to the `SFTTrainer` class in the tutorial later.

The `LoraConfig` is initialized with the previously defined LoRA parameters, specifying the target modules (`k_proj` and `v_proj`) in the model to apply LoRA adaptations.

In [None]:
# Load LoRA configuration
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["k_proj", "v_proj"]
)

The **Fully Sharded Data Parallel (FSDP)** configuration is set up in `fsdp_config`, enabling [**full model sharding**](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy) and [**gradient checkpointing**](https://huggingface.co/docs/transformers/v4.19.4/en/performance#gradient-checkpointing) for memory efficiency on TPUs, and specifying that gradient checkpointing should be enabled with `xla_fsdp_grad_ckpt`.

In [None]:
# Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True.
fsdp_config = {
    "fsdp_transformer_layer_cls_to_wrap": [
        "GemmaDecoderLayer"
    ],
    "xla": True,
    "xla_fsdp_v2": True,
    "xla_fsdp_grad_ckpt": True
}

The `SFTConfig` is then initialized with all the training parameters defined earlier, including optimizer settings, learning rate, and logging configurations, and specifying that logs should be reported to `TensorBoard`.

In [None]:
# Set training parameters
training_arguments = SFTConfig(
    output_dir=output_dir,
    overwrite_output_dir=True,
    # Training
    num_train_epochs=num_train_epochs,
    # This is the global train batch size for SPMD
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    # Required for SPMD
    dataloader_drop_last=True,
    fsdp="full_shard",
    fsdp_config=fsdp_config,
    learning_rate=learning_rate,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    lr_scheduler_type=lr_scheduler_type,
    max_seq_length=max_seq_length,
    dataset_text_field="text",
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": False,
    },
    packing=packing,
    # Evaluation
    evaluation_strategy="epoch",
    # This is the global eval batch size for SPMD
    per_device_eval_batch_size=per_device_eval_batch_size,
    # Logging
    logging_steps=logging_steps,
    report_to="none",
    seed=42
)



Finally, you define the [SFTTrainer](https://huggingface.com/docs/trl/sft_trainer) available in the TRL library. This class inherits from the `Trainer` class available in the Transformers library, but is specifically optimized for supervised fine-tuning (instruction tuning). It can be used to train out-of-the-box on one or more GPUs/TPUs, using [Accelerate](https://huggingface.com/docs/accelerate/index) as backend.

Most notably, it supports [packing](https://huggingface.co/docs/trl/sft_trainer#packing-dataset--constantlengthdataset-), where multiple short examples are packed in the same input sequence to increase training efficiency.

In [None]:
# Set supervised fine-tuning parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    peft_config=peft_config,
    args=training_arguments
)



Now, let's start the fine-tuning process by calling `trainer.train()`, which uses `SFTTrainer` to handle the training loop, including data loading, forward and backward passes, and optimizer steps, all configured according to the settings you've provided.

In [None]:
trainer.train()

  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):


Epoch,Training Loss,Validation Loss
1,1.1016,1.291131
2,0.8438,1.298483
3,0.8633,1.269991
4,0.7305,1.299862
5,0.498,1.407629


  xldata.append(torch.load(xbio))
  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
  xldata.append(torch.load(xbio))


TrainOutput(global_step=710, training_loss=0.900107284330986, metrics={'train_runtime': 2044.2422, 'train_samples_per_second': 11.114, 'train_steps_per_second': 0.347, 'total_flos': 1.8324321780891648e+17, 'train_loss': 0.900107284330986, 'epoch': 5.0})

After training is complete, you save the fine-tuned model by moving it to the CPU with `trainer.model.to('cpu')` to ensure compatibility and then calling `save_pretrained(new_model)` to save the model weights and configuration files to the directory specified by `new_model` (**gemma-ft**). This allows you to reload and use the fine-tuned model later for inference or further training.

In [None]:
# Remove the model weights directory if it exists
!rm -rf gemma-ft

# Save the LoRA adapter
trainer.model.to('cpu').save_pretrained(new_model)

## Prompt using the newly fine-tuned model


Now that you've finally fine-tuned your custom Gemma model, let's reload the LoRA adapter weights to finally prompt using it and also verify if it's really working as intended.

To do this, use the following steps to correctly reload the adapter weights:

- Use `AutoModelForCausalLM.from_pretrained` to first load the **base Gemma model**, while setting `low_cpu_mem_usage=True` to optimize memory consumption (since you're using a TPU) and `torch_dtype=torch.bfloat16` for consistency with the fine-tuned model.

- Load the **fine-tuned LoRA adapter** that you've previously saved into the base model using `PeftModel.from_pretrained`, where `new_model` is the directory containing your fine-tuned weights.

- The `model.merge_and_unload` function **merges** the **LoRA adapter weights** with the **base model weights** and unloads the adapter, resulting in a standalone model ready for inference.

In [None]:
# Reload the fine-tuned Gemma model
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.bfloat16
)
model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

You reload the tokenizer to ensure it matches the model configuration, adjusting the padding side as before.

In [None]:
# Reload tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

Now, test the fine-tuned model with a sample prompt by first using the tokenizer to generate the input ids, and then relying on the reloaded fine-tuned model to generate a response using `model.generate()`.

In [None]:
input_text = """\
  <|system|>Introducing Minami "Echo" Ishikawa, a mysterious VR assassin known for her uncanny ability to blend seamlessly into the shadows. \
  Minami possesses a deep understanding of stealth techniques, allowing her to silently eliminate her targets with calculated precision. \
  Her cold and calculating demeanor makes her a formidable force to be reckoned with, leaving enemies shivering at the thought of facing her wrath.</s>
  <|user|>Echo, what makes you so adept at disappearing into thin air?</s>
  <|assistant|>"""

In [None]:
input_ids = tokenizer(input_text, return_tensors="pt").to("cpu")
outputs = model.generate(**input_ids, max_length=256, repetition_penalty=1.1)

Finally, you decode the output tokens back into human-readable text with `tokenizer.decode` and print the result, allowing you to see how the fine-tuned model responds to the prompt.

In [None]:
print(tokenizer.decode(outputs[0]))

<bos>  <|system|>Introducing Minami "Echo" Ishikawa, a mysterious VR assassin known for her uncanny ability to blend seamlessly into the shadows.   Minami possesses a deep understanding of stealth techniques, allowing her to silently eliminate her targets with calculated precision.   Her cold and calculating demeanor makes her a formidable force to be reckoned with, leaving enemies shivering at the thought of facing her wrath.</s>
  <|user|>Echo, what makes you so adept at disappearing into thin air?</s>
  <|assistant|>I'm not sure if I can answer that question without revealing my secrets. But let me tell you this: when it comes to vanishing like smoke in the wind, I am truly an expert! <s>( ͡° ͜ʖ ͡°)</s><eos>


Let's now define reusable functions that'll better help you interact with your newly fine-tuned model and also visualize the responses!

In [None]:
# @markdown ### Text Generation Utilities [RUN ME!]
def build_prompt(system_message, conversation):
    """Constructs the prompt using control tokens for system, user, and assistant."""
    # Start with the system message and add a newline at the end
    prompt = f"<|system|>{system_message}\n"

    # Add each turn in the conversation, each followed by a newline
    for turn in conversation:
        role = turn['role']
        content = turn['content']
        prompt += f"<|{role}|>{content}\n"

    # Append the assistant token at the end (without a newline)
    prompt += "<|assistant|>"

    return prompt

def colorize_text(text: str) -> str:
    """Replaces the role tokens with colored role labels and adds newlines for better readability."""
    replacements = [
        ("<|system|>", "\n\033[94mSystem:\033[0m\n"),       # Blue
        ("<|user|>", "\n\033[91mUser:\033[0m\n"),           # Red
        ("<|assistant|>", "\n\033[92mAssistant:\033[0m\n"), # Green
        # Remove tokens irrelevant for visualization
        ("<bos>", ""), ("<eos>", "")
    ]

    for token, replacement in replacements:
        text = text.replace(token, replacement)

    return text.strip()

def generate_response(system_message, question, tokenizer, model, max_length=256, repetition_penalty=1.1):
    """Generates a response from the model based on the system message and user question.

    Args:
    - system_message (str): The system prompt or description.
    - question (str): The user's question.
    - tokenizer: The tokenizer used for encoding the input text.
    - model: The language model used to generate the response.
    - max_length (int, optional): The maximum length of the generated output. Default is 256.
    - repetition_penalty (float, optional): The repetition penalty parameter for generation. Default is 1.1.

    Returns:
    - colorized_text (str): The generated response with role labels colorized.
    """
    # The conversation
    conversation = [
        {
            'role': 'user',
            'content': question
        }
    ]

    # Build the prompt using the function
    input_text = build_prompt(system_message, conversation)

    # Proceed with tokenization and model generation
    input_ids = tokenizer(input_text, return_tensors="pt").to("cpu")
    outputs = model.generate(
        **input_ids,
        max_length=max_length,
        repetition_penalty=repetition_penalty
    )

    # Decode the output
    generated_text = tokenizer.decode(outputs[0])

    # Colorize the generated text
    colorized_text = colorize_text(generated_text)

    return colorized_text

In [None]:
# The system message
system_message = "Akane Saito is a dedicated and hardworking member of the photography club. With a keen eye for capturing beautiful and meaningful moments, Akane's artistic vision and technical skills make her photographs stand out. She's passionate about using her lens to tell stories and convey emotions, earning her recognition both within the club and beyond." # @param {"type":"string"}
question = "Akane, what inspires you to take such stunning photographs?" # @param {"type":"string"}

# Generate the response
colorized_output = generate_response(system_message, question, tokenizer, model)

# Print the colorized text
print(colorized_output)

[94mSystem:[0m
Akane Saito is a dedicated and hardworking member of the photography club. With a keen eye for capturing beautiful and meaningful moments, Akane's artistic vision and technical skills make her photographs stand out. She's passionate about using her lens to tell stories and convey emotions, earning her recognition both within the club and beyond.

[91mUser:[0m
Akane, what inspires you to take such stunning photographs?

[92mAssistant:[0m
I am inspired by nature and people around me. I love taking pictures that capture the beauty in everyday life and bring joy to others.


In [None]:
# The system message
system_message = "In the bustling streets of Victorian London, there exists a figure of unparalleled intellect and deductive prowess - Sherlock Holmes. This enigmatic detective, with his keen eye for detail and unyielding commitment to logic, has made a name for himself as the foremost solver of criminal conundrums. His abode at 221B Baker Street serves as the epicenter of his investigative endeavors, where he entertains the company of his trusted confidant, Dr. John Watson. Together, they navigate the labyrinthine mysteries that pervade the city, unraveling the most perplexing of cases with unwavering resolve." # @param {"type":"string"}
question = "How do you approach a new case, Sherlock?" # @param {"type":"string"}

# Generate the response
colorized_output = generate_response(system_message, question, tokenizer, model)

# Print the colorized text
print(colorized_output)

[94mSystem:[0m
In the bustling streets of Victorian London, there exists a figure of unparalleled intellect and deductive prowess - Sherlock Holmes. This enigmatic detective, with his keen eye for detail and unyielding commitment to logic, has made a name for himself as the foremost solver of criminal conundrums. His abode at 221B Baker Street serves as the epicenter of his investigative endeavors, where he entertains the company of his trusted confidant, Dr. John Watson. Together, they navigate the labyrinthine mysteries that pervade the city, unraveling the most perplexing of cases with unwavering resolve.

[91mUser:[0m
How do you approach a new case, Sherlock?

[92mAssistant:[0m
I begin by gathering all available information about the crime scene and any witnesses who may have seen anything suspicious. I then analyze every piece of evidence carefully, looking for patterns or clues that could lead me to the culprit. Once I've gathered enough data, I start piecing together the p

Congratulations! You've successfully fine-tuned Gemma using Torch XLA and PEFT with LoRA on TPUs. With that, you've covered the entire process, from setting up the environment to training and testing the model.

## What's next?
Your next steps could include the following:

- **Evaluate Model Performance**: Implement metrics like [ROUGE](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval#rouge) or [BLEU](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval#bleu) to quantitatively assess your model's improvements.

- **Experiment with Different Datasets**: Try fine-tuning on other datasets in [Hugging Face](https://huggingface.co/docs/datasets/en/index) or your own data to adapt the model to various tasks or domains.

- **Tune Hyperparameters**: Adjust training parameters (e.g., learning rate, batch size, epochs, LoRA settings) to optimize performance and
improve training efficiency.

- **Optimize Model for Inference**: Apply quantization to reduce model size and speed up inference for deployment.

By exploring these activities, you'll deepen your understanding and further enhance your fine-tuned Gemma model. Happy experimenting!