##### Copyright 2024 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tune Gemma models in Keras using LoRA

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/lora_tuning"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335"><img src="https://ai.google.dev/images/cloud-icon.svg" width="40" />Open in Vertex AI</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

## Overview

Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models.

Large Language Models (LLMs) like Gemma have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned with domain-specific data to perform downstream tasks (such as sentiment analysis).

LLMs are extremely large in size (parameters in the order of millions). Full fine-tuning (which updates all the parameters in the model) is not required for most applications because typical fine-tuning datasets are relatively much smaller than the pre-training datasets.

[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685){:.external} is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs.

This tutorial walks you through using KerasNLP to perform LoRA fine-tuning on a Gemma 2B model using the [Databricks Dolly 15k dataset](https://www.kaggle.com/datasets/databricks/databricks-dolly-15k){:.external}. This dataset contains 15,000 high-quality human-generated prompt / response pairs specifically designed for fine-tuning LLMs.

## Setup

### Get access to Gemma

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

Gemma models are hosted by Kaggle. To use Gemma, request access on Kaggle:

- Sign in or register at [kaggle.com](https://www.kaggle.com)
- Open the [Gemma model card](https://www.kaggle.com/models/google/gemma) and select _"Request Access"_
- Complete the consent form and accept the terms and conditions


### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [2]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [3]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Keras and KerasNLP.

In [4]:
import keras
import keras_nlp

## Load Dataset

Preprocess the data. This tutorial uses a subset of 1000 training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning.

In [5]:
import pandas as pd

quizard_df = pd.read_excel("/kaggle/input/quizard-excel/Quizard_dataset.xlsx")

In [6]:
quizard_df.head()

Unnamed: 0,Context,Question,Answer
0,The process of photosynthesis occurs in the ch...,What happens during photosynthesis?,"During photosynthesis, light energy is convert..."
1,"In World War II, the Axis Powers consisted of ...",Who were the Axis Powers in World War II?,"The Axis Powers in World War II were Germany, ..."
2,The human circulatory system includes the hear...,What does the circulatory system do?,The circulatory system pumps blood through the...
3,The Earth orbits the Sun once every year. The ...,How long does it take for the Earth to orbit t...,The Earth takes one year to orbit the Sun.
4,The process of photosynthesis is crucial for p...,What is photosynthesis?,Photosynthesis is the process through which pl...


In [7]:
quizard_df.isna().sum()

Context     19
Question    19
Answer      19
dtype: int64

In [8]:
quizard_df=quizard_df.dropna()

In [9]:
quizard_df.isna().sum()

Context     0
Question    0
Answer      0
dtype: int64

In [10]:
quizard_df.describe()

Unnamed: 0,Context,Question,Answer
count,2994,2994,2994
unique,2932,2677,2895
top,A black hole is a region of space where the gr...,"What is a black hole, and how is it formed?",Opportunity cost refers to the value of the ne...
freq,3,11,3


In [11]:
import json

data = []

template = """Instruction: Generate an answer to the question using the provided context. If the context does not contain sufficient information to answer the question, respond with "The provided context does not contain sufficient information to answer this question."

Context:
{Context}

Question:
{Question}

Response:
{Answer}
"""

for index, row in quizard_df.iterrows():
    data.append(json.dumps(template.format(**row.to_dict())))

In [12]:
data[0]

'"Instruction: Generate an answer to the question using the provided context. If the context does not contain sufficient information to answer the question, respond with \\"The provided context does not contain sufficient information to answer this question.\\"\\n\\nContext:\\nThe process of photosynthesis occurs in the chloroplasts of plant cells. During photosynthesis, light energy is converted into chemical energy, which is stored in the form of glucose. This process uses carbon dioxide from the air and water from the soil, releasing oxygen as a byproduct.\\n\\nQuestion:\\nWhat happens during photosynthesis?\\n\\nResponse:\\nDuring photosynthesis, light energy is converted into chemical energy, stored as glucose, and oxygen is released as a byproduct.\\n"'

## Load Model

KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/){:.external}. In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [13]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en")
gemma_lm.summary()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


The `from_preset` method instantiates the model from a preset architecture and weights. In the code above, the string "gemma_2b_en" specifies the preset architecture — a Gemma model with 2 billion parameters.

NOTE: A Gemma model with 7
billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform [distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) on Kaggle or Google Cloud.

## Inference before fine tuning

In this section, you will query the model with various prompts to see how it responds.


### Europe Trip Prompt

Query the model for suggestions on what to do on a trip to Europe.

In [14]:
prompt = template.format(
    Context="The Great Barrier Reef is the world's largest coral reef system, located in the Coral Sea, off the coast of Queensland, Australia. It is composed of over 2,900 individual reefs and 900 islands stretching over 2,300 kilometers. The reef is known for its biodiversity, hosting countless marine species, and is a popular destination for snorkeling and diving enthusiasts. However, it faces threats from climate change, overfishing, and pollution.",
    Question="What are some of the major threats faced by the Great Barrier Reef?",
    Answer=""
)

sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction: Generate an answer to the question using the provided context. If the context does not contain sufficient information to answer the question, respond with "The provided context does not contain sufficient information to answer this question."

Context:
The Great Barrier Reef is the world's largest coral reef system, located in the Coral Sea, off the coast of Queensland, Australia. It is composed of over 2,900 individual reefs and 900 islands stretching over 2,300 kilometers. The reef is known for its biodiversity, hosting countless marine species, and is a popular destination for snorkeling and diving enthusiasts. However, it faces threats from climate change, overfishing, and pollution.

Question:
What are some of the major threats faced by the Great Barrier Reef?

Response:

The provided context states that the Great Barrier Reef faces threats from: 
* **Climate change**
* **Overfishing**
* **Pollution**. 
<end_of_turn>


The model just responds with a recommendation to take a trip to Europe.

The responses contains words that might not be easy to understand for a child such as chlorophyll, glucose, etc.

## LoRA Fine-tuning

To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [15]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.5 billion to 1.3 million).

In [16]:
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.001,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

In [17]:
gemma_lm.fit(data, epochs=8, batch_size=1)

Epoch 1/8
[1m2994/2994[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2978s[0m 984ms/step - loss: 0.3302 - sparse_categorical_accuracy: 0.8367
Epoch 2/8
[1m2994/2994[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2947s[0m 979ms/step - loss: 0.2375 - sparse_categorical_accuracy: 0.8701
Epoch 3/8
[1m2994/2994[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2932s[0m 979ms/step - loss: 0.2219 - sparse_categorical_accuracy: 0.8765
Epoch 4/8
[1m2994/2994[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2937s[0m 980ms/step - loss: 0.2061 - sparse_categorical_accuracy: 0.8841
Epoch 5/8
[1m2994/2994[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2936s[0m 980ms/step - loss: 0.1901 - sparse_categorical_accuracy: 0.8925
Epoch 6/8
[1m2994/2994[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2937s[0m 981ms/step - loss: 0.1749 - sparse_categorical_accuracy: 0.9008
Epoch 7/8
[1m2994/2994[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2936s[0m 980ms/step - loss: 0.1603 - sparse_categorical_ac

<keras.src.callbacks.history.History at 0x7e6ca4326ef0>

In [25]:
!ls -lh

  pid, fd = os.forkpty()


total 20G
-rw-r--r-- 1 root root    0 Nov 24 11:13 '=3'
-rw-r--r-- 1 root root 3.1K Nov 24 17:50  QA-Gemma-Quizard.keras
-rw-r--r-- 1 root root  78M Nov 24 17:53  QA-Gemma-Quizard.weights.h5
-rw-r--r-- 1 root root 9.8G Nov 24 17:50  finetuned_study_assistant.keras
-rw------- 1 root root 9.7G Nov 24 17:50  tmpq55momln


In [26]:
# Delete the large Keras model file if it's not needed
!rm -rf finetuned_study_assistant.keras

In [27]:
!rm -rf QA-Gemma-Quizard.keras

In [28]:
!rm -rf QA-Gemma-Quizard.weights.h5

In [33]:
# Step 1: Install huggingface_hub if not already installed
!pip install huggingface_hub

# Step 2: Import the notebook_login function
from huggingface_hub import notebook_login

# Step 3: Log in to Hugging Face
notebook_login()  # This will prompt you for your token



VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [34]:
# Step 1: Install huggingface_hub if not already installed
!pip install huggingface_hub

# Step 2: Import necessary functions
from huggingface_hub import create_repo, HfApi
import os

# Step 3: Load Hugging Face token from Kaggle Secrets
hf_token = os.getenv('HF_TOKEN')  # Retrieve token securely from Kaggle secrets

# Step 4: Load Hugging Face username and repository name from environment variables
hf_username = os.getenv('HF_USERNAME')  # You can set this in Kaggle Secrets
repo_name = os.getenv('HF_REPO_NAME')  # Set this in Kaggle Secrets, e.g., "QA-Gemma-Quizard"

# Ensure the Hugging Face token is available
if not hf_token:
    raise ValueError("Hugging Face token not found! Please add it to Kaggle Secrets.")

# Step 5: Create an API instance with the token
api = HfApi(token=hf_token)

# Step 6: Define the repository ID using the username and repo name
repo_id = f"{hf_username}/{repo_name}"

# Create the repository (you can set private=True if you want it to be private)
create_repo(repo_id, token=hf_token)  # This creates a new repository

# Step 7: Upload the weights file to Hugging Face
api.upload_file(
    path_or_fileobj='QA-Gemma-Quizard.weights.h5',  # Path to your weights file
    path_in_repo='QA-Gemma-Quizard.weights.h5',     # Path in the repo where it will be stored
    repo_id=repo_id,                                # Your repository ID
    repo_type='model'                               # Type of repository (model)
)



QA-Gemma-Quizard.weights.h5:   0%|          | 0.00/10.5G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/aayeshanakarmi/QA-Gemma-Quizard/commit/f8a2aee96790df1cb1a8b43e8a9991b70c81bf71', commit_message='Upload QA-Gemma-Quizard.weights.h5 with huggingface_hub', commit_description='', oid='f8a2aee96790df1cb1a8b43e8a9991b70c81bf71', pr_url=None, repo_url=RepoUrl('https://huggingface.co/aayeshanakarmi/QA-Gemma-Quizard', endpoint='https://huggingface.co', repo_type='model', repo_id='aayeshanakarmi/QA-Gemma-Quizard'), pr_revision=None, pr_num=None)

In [None]:
from huggingface_hub import HfApi

# Define your Hugging Face username and model name
username = "aayeshanakarmi"
model_name = "QA-Gemma-Quizard"

# Create an API instance
api = HfApi()

# Save the weights temporarily
weights_path = 'QA-Gemma-Quizard.weights.h5'
gemma_lm.save_weights(weights_path)

# Upload the weights file to Hugging Face
api.upload_file(
    path_or_fileobj=weights_path,
    path_in_repo=weights_path,
    repo_id=f"{username}/{model_name}",
    repo_type="model"
)

In [None]:
# Step 1: Install huggingface_hub if not already installed
# !pip install huggingface_hub

# Step 2: Log in to Hugging Face
from huggingface_hub import notebook_login

notebook_login()

# Step 3: Upload the Keras model
from huggingface_hub import push_to_hub_keras

# Save your model (already done)
gemma_lm.save('QA-Gemma-Quizard.keras')

# Push the model to Hugging Face Hub
push_to_hub_keras('QA-Gemma-Quizard.keras', "your_hf_username/QA-Gemma-Quizard")

In [23]:
# Save only the weights with the correct filename
gemma_lm.save_weights('QA-Gemma-Quizard.weights.h5')

OSError: [Errno 28] Can't synchronously write data (file write failed: time = Sun Nov 24 17:53:27 2024
, filename = 'QA-Gemma-Quizard.weights.h5', file descriptor = 113, errno = 28, error message = 'No space left on device', buf = 0x7e65c73fd2c8, total write size = 2278235464, bytes this sub-write = 2278235464, bytes actually written = 18446744073709551615, offset = 0)

In [21]:
gemma_lm.save('QA-Gemma-Quizard.keras')

RuntimeError: Can't decrement id ref count (file write failed: time = Sun Nov 24 17:50:46 2024
, filename = '/kaggle/working/tmpq55momln', file descriptor = 116, errno = 28, error message = 'No space left on device', buf = 0x5b7abc486f00, total write size = 2048, bytes this sub-write = 2048, bytes actually written = 18446744073709551615, offset = 0)

In [19]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# Assuming gemma_lm is your trained model
# gemma_lm.save_pretrained('./path_to_save_model')

# Assuming 'model' is your fine-tuned Gemma model
gemma_lm.save_pretrained('/kaggle/tmp/QA-Gemma-Quizard')

AttributeError: 'GemmaCausalLM' object has no attribute 'save_pretrained'

In [None]:
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained('/kaggle/tmp/QA-Gemma-Quizard')
tokenizer = AutoTokenizer.from_pretrained('/kaggle/tmp/QA-Gemma-Quizard')

In [None]:
print(model)
print(tokenizer)

In [None]:
# Define your model repository name
model_name = "aayeshanakarmi/QA-Gemma-Quizard"

# Push the model and tokenizer to Hugging Face Hub
model.push_to_hub(model_name)
tokenizer.push_to_hub(model_name)

In [None]:
import keras_nlp

# Save as a KerasNLP preset
gemma.save_to_preset('./gemma-pirate-instruct-7b')

# Define your Hugging Face URI
hf_uri = "hf://your_hf_username/gemma-pirate-instruct-7b"

# Upload the preset to Hugging Face Hub
keras_nlp.upload_preset(hf_uri, './gemma-pirate-instruct-7b')

In [None]:
from transformers import AutoModelForCausalLM

# Load your model
model = AutoModelForCausalLM.from_pretrained('./path_to_save_model')

# Define your model repository name (username/model_name)
model_name = "your_hf_username/gemma_finetuned_model"

# Push the model to Hugging Face Hub
model.push_to_hub(model_name)

## Inference after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

### Europe Trip Prompt


In [18]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

KeyError: 'Context'

The model now recommends places to visit in Europe.

### ELI5 Photosynthesis Prompt


In [None]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

The model now explains photosynthesis in simpler terms.

Note that for demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:

1. Increasing the size of the fine-tuning dataset
2. Training for more steps (epochs)
3. Setting a higher LoRA rank
4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`.


## Summary and next steps

This tutorial covered LoRA fine-tuning on a Gemma model using KerasNLP. Check out the following docs next:

* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).
* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/distributed_tuning).
* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).
* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb).
