<link rel="stylesheet" href="/site-assets/css/gemma.css">
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200" />

##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tune Gemma models in Keras using LoRA

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/lora_tuning"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/google/generative-ai-docs/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://ai.google.dev/images/cloud-icon.svg" width="40" />Open in Vertex AI</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

## Overview

Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models.

Large Language Models (LLMs) like Gemma have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned with domain-specific data to perform downstream tasks (such as sentiment analysis).

LLMs are extremely large in size (parameters in the order of billions). Full fine-tuning (which updates all the parameters in the model) is not required for most applications because typical fine-tuning datasets are relatively much smaller than the pre-training datasets.

[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685) is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs.

This tutorial walks you through using KerasNLP to perform LoRA fine-tuning on a Gemma 2B model using the [Databricks Dolly 15k dataset](https://huggingface.co/datasets/databricks/databricks-dolly-15k). This dataset contains 15,000 high-quality human-generated prompt / response pairs specifically designed for fine-tuning LLMs.

## Setup

### Get access to Gemma

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

* Get access to Gemma on [kaggle.com](https://kaggle.com).
* Select a Colab runtime with sufficient resources to run
  the Gemma 2B model.
* Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

### Select the runtime

To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:

1. In the upper-right of the Colab window, select &#9662; (**Additional connection options**).
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **T4 GPU**.

### Configure your API key

To use Gemma, you must provide your Kaggle username and a Kaggle API key.

To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This will trigger the download of a `kaggle.json` file containing your API credentials.

In Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.

### Set environment variables

Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`.

In [2]:
!pip install kagglehub
import kagglehub



### Install dependencies

Install Keras, KerasNLP, and other dependencies.

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

### Import packages

Import Keras and KerasNLP.

In [5]:
!python --version

Python 3.10.15


In [6]:
!pip show keras
!pip show tensorflow
!pip show keras-nlp

Name: keras
Version: 3.6.0
Summary: Multi-backend Keras.
Home-page: https://github.com/keras-team/keras
Author: Keras team
Author-email: keras-users@googlegroups.com
License: Apache License 2.0
Location: /Users/babak/anaconda3/envs/llm/lib/python3.10/site-packages
Requires: absl-py, h5py, ml-dtypes, namex, numpy, optree, packaging, rich
Required-by: tensorflow
Name: tensorflow
Version: 2.18.0
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author: Google Inc.
Author-email: packages@tensorflow.org
License: Apache 2.0
Location: /Users/babak/anaconda3/envs/llm/lib/python3.10/site-packages
Requires: absl-py, astunparse, flatbuffers, gast, google-pasta, grpcio, h5py, keras, libclang, ml-dtypes, numpy, opt-einsum, packaging, protobuf, requests, setuptools, six, tensorboard, tensorflow-io-gcs-filesystem, termcolor, typing-extensions, wrapt
Required-by: tensorflow-text
Name: keras-nlp
Version: 0.17.0
Summary: Industry-streng

In [7]:
import tensorflow as tf
import keras
import keras_nlp
import numpy as np
import pandas as pd

In [8]:
from dotenv import load_dotenv
load_dotenv()

import os
os.environ["KAGGLE_USERNAME"] = os.getenv("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = os.getenv("KAGGLE_KEY")

In [9]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

In [11]:
!kaggle datasets list

ref                                                              title                                           size  lastUpdated          downloadCount  voteCount  usabilityRating  
---------------------------------------------------------------  ---------------------------------------------  -----  -------------------  -------------  ---------  ---------------  
valakhorasani/mobile-device-usage-and-user-behavior-dataset      Mobile Device Usage and User Behavior Dataset   11KB  2024-09-28 20:21:12          20073        417  1.0              
valakhorasani/gym-members-exercise-dataset                       Gym Members Exercise Dataset                    22KB  2024-10-06 11:27:38          11330        174  1.0              
muhammadroshaanriaz/students-performance-dataset-cleaned         Students Performance | Clean Dataset            10KB  2024-10-29 19:32:26           1456         33  1.0              
mahdiehhajian/life-expectancy-around-the-world                   Life expectancy

## Load Dataset

In [1]:
from datasets import Dataset
Dataset.cleanup_cache_files

  from .autonotebook import tqdm as notebook_tqdm


<function datasets.arrow_dataset.Dataset.cleanup_cache_files(self) -> int>

In [12]:
from datasets import load_dataset

ds = load_dataset("Pistachio-LLM/Medical-llama3-finetune-train")

In [13]:
ds

DatasetDict({
    train: Dataset({
        features: ['output', 'input', 'instruction'],
        num_rows: 37179
    })
})

In [14]:
train_ds = ds['train']

# Function to extract inputs and outputs from the dataset
def extract_features(example):
    return {
        'input': example['input'],
        'instruction': example['instruction'],
        'output': example['output']
    }

# Map the dataset to extract features
train_ds = train_ds.map(extract_features)
train_ds = pd.DataFrame(train_ds)


In [9]:
# type(train_ds)
train_ds[:1]

Unnamed: 0,output,input,instruction
0,The Mesenchyme gives rise to most connective t...,,What does the Mesenchyme give rise to?


In [15]:
train_ds_list = [
    f"Instruction:\n{row['instruction']}\n\nResponse:\n{row['output']}"
    for index, row in train_ds.iterrows()  # Use iterrows to iterate over DataFrame rows
]

In [16]:
train_ds_list[:1]

['Instruction:\nWhat does the Mesenchyme give rise to?\n\nResponse:\nThe Mesenchyme gives rise to most connective tissue.']

Preprocess the data. This tutorial uses a subset of 1000 training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning.

In [17]:
# import json
# data = []
# with open("databricks-dolly-15k.jsonl") as file:
#     for line in file:
#         features = json.loads(line)
#         # Filter out examples with context, to keep it simple.
#         if features["context"]:
#             continue
#         # Format the entire example as a single string.
#         template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
#         data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = train_ds_list[:1000]

In [18]:
data[:5]

['Instruction:\nWhat does the Mesenchyme give rise to?\n\nResponse:\nThe Mesenchyme gives rise to most connective tissue.',
 'Instruction:\nWhich class of antimicrobials is known to displace unconjugated bilirubin from serum albumin in the blood?\n\nResponse:\nSulfonamides are known to displace unconjugated bilirubin from serum albumin in the blood.',
 'Instruction:\nIn a female athlete who has amenorrhea and laboratory exam shows decreased FSH, LH, and estrogen levels, what is the likely diagnosis?\n\nResponse:\nThe likely diagnosis is hypogonadotropic hypogonadism, also known as hypothalamic amenorrhea. This is a condition where the hypothalamus in the brain does not release enough gonadotropin-releasing hormone (GnRH) to stimulate the pituitary gland to produce follicle-stimulating hormone (FSH) and luteinizing hormone (LH), which are necessary for ovulation and menstruation. As a result, estrogen levels are low, leading to amenorrhea. Female athletes are at increased risk of develo

## Load Model

KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [19]:
from keras_nlp.models import GemmaCausalLM
gemma_lm = GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/2/download/assets/tokenizer/vocabulary.spm...


100%|██████████████████████████████████████| 4.04M/4.04M [00:00<00:00, 5.05MB/s]
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


The `from_preset` method instantiates the model from a preset architecture and weights. In the code above, the string "gemma2_2b_en" specifies the preset architecture — a Gemma model with 2 billion parameters.

NOTE: A Gemma model with 7
billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform [distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) on Kaggle or Google Cloud.

## Inference before fine tuning

In this section, you will query the model with various prompts to see how it responds.

### Europe Trip Prompt

Query the model for suggestions on what to do on a trip to Europe.

In [29]:
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

In [30]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

I0000 00:00:1730990605.399889 1909735 service.cc:148] XLA service 0x600005619a00 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1730990605.401037 1909735 service.cc:156]   StreamExecutor device (0): Host, Default Version
I0000 00:00:1730990605.542299 1909735 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Instruction:
What should I do on a trip to Europe?

Response:
If you are traveling with a group, you might want to consider booking a tour. A tour is a great way to see a lot of the country in a short amount of time. You can usually choose from a variety of itineraries, so you can find one that fits your schedule. Tours are also a great way to meet new people and make friends. Many tours offer activities, such as sightseeing or hiking, so you can get a real taste of the country.


The model responds with generic tips on how to plan a trip.

### ELI5 Photosynthesis Prompt

Prompt the model to explain photosynthesis in terms simple enough for a 5 year old child to understand.

In [31]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is the process in which plants, algae, and some bacteria capture light energy and transform it into chemical energy in the form of sugars, which are later used to build biomass. This process occurs in a series of chemical reactions that take place in the chloroplasts, which are found in the cells of green plants. The chloroplasts capture light energy from the sun, and use it to split water molecules into oxygen and hydrogen ions. The hydrogen ions then combine with carbon dioxide molecules to form glucose and oxygen, which are both used in the process of respiration. Photosynthesis is important for sustaining life because it provides the energy that all living organisms need to grow, reproduce and maintain their functions. Without photosynthesis, plants, algae, and certain bacteria would not be able to survive, which would have dire consequences for the planet's ecosyste

The model response contains words that might not be easy to understand for a child such as chlorophyll.

## LoRA Fine-tuning

To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [22]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.6 billion to 2.9 million).

In [25]:
def data_generator():
    for sample in data:
        yield sample

train_ds = tf.data.Dataset.from_generator(
    data_generator,
    output_signature=(tf.TensorSpec(shape=(), dtype=tf.string))
).batch(1)

In [28]:
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14872s[0m 15s/step - loss: 0.8906 - sparse_categorical_accuracy: 0.6080


<keras.src.callbacks.history.History at 0x31bb635e0>

In [65]:
len(train_ds_list)

37179

### Note on mixed precision fine-tuning on NVIDIA GPUs

Full precision is recommended for fine-tuning. When fine-tuning on NVIDIA GPUs, note that you can use mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) to speed up training with minimal effect on training quality. Mixed precision fine-tuning does consume more memory so is useful only on larger GPUs.


For inference, half-precision (`keras.config.set_floatx("bfloat16")`) will work and save memory while mixed precision is not applicable.

In [None]:
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

## Inference after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

### Europe Trip Prompt

In [None]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.


The model now recommends places to visit in Europe.

### ELI5 Photosynthesis Prompt

In [32]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is the process that plants use to convert carbon dioxide and water into glucose and oxygen using sunlight. This process is essential for life on Earth because it provides plants with the energy they need to grow and survive.


The model now explains photosynthesis in simpler terms.

Note that for demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:

1. Increasing the size of the fine-tuning dataset
2. Training for more steps (epochs)
3. Setting a higher LoRA rank
4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`.

## Summary and next steps

This tutorial covered LoRA fine-tuning on a Gemma model using KerasNLP. Check out the following docs next:

* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).
* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/distributed_tuning).
* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).
* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb).