##### Copyright 2024 Google LLC.

In [19]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/yoruba-bank-data/bank-data-yoruba-full.csv
/kaggle/input/databricks-dolly-15k/README.md
/kaggle/input/databricks-dolly-15k/databricks-dolly-15k.jsonl
/kaggle/input/gemma/keras/gemma_2b_en/1/config.json
/kaggle/input/gemma/keras/gemma_2b_en/1/tokenizer.json
/kaggle/input/gemma/keras/gemma_2b_en/1/metadata.json
/kaggle/input/gemma/keras/gemma_2b_en/1/model.weights.h5
/kaggle/input/gemma/keras/gemma_2b_en/1/assets/tokenizer/vocabulary.spm
/kaggle/input/gemma/keras/gemma_2b_en/2/config.json
/kaggle/input/gemma/keras/gemma_2b_en/2/tokenizer.json
/kaggle/input/gemma/keras/gemma_2b_en/2/metadata.json
/kaggle/input/gemma/keras/gemma_2b_en/2/model.weights.h5
/kaggle/input/gemma/keras/gemma_2b_en/2/assets/tokenizer/vocabulary.spm


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tune Gemma models in Keras using LoRA

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/lora_tuning"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335"><img src="https://ai.google.dev/images/cloud-icon.svg" width="40" />Open in Vertex AI</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

## Overview

Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models.

Large Language Models (LLMs) like Gemma have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned with domain-specific data to perform downstream tasks (such as sentiment analysis).

LLMs are extremely large in size (parameters in the order of millions). Full fine-tuning (which updates all the parameters in the model) is not required for most applications because typical fine-tuning datasets are relatively much smaller than the pre-training datasets.

[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685){:.external} is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs.

This tutorial walks you through using KerasNLP to perform LoRA fine-tuning on a Gemma 2B model using the [Databricks Dolly 15k dataset](https://www.kaggle.com/datasets/databricks/databricks-dolly-15k){:.external}. This dataset contains 15,000 high-quality human-generated prompt / response pairs specifically designed for fine-tuning LLMs.

## Setup

### Get access to Gemma

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

Gemma models are hosted by Kaggle. To use Gemma, request access on Kaggle:

- Sign in or register at [kaggle.com](https://www.kaggle.com)
- Open the [Gemma model card](https://www.kaggle.com/models/google/gemma) and select _"Request Access"_
- Complete the consent form and accept the terms and conditions


### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [2]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [3]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Keras and KerasNLP.

In [4]:
import keras
import keras_nlp

## Load Dataset

Preprocess the data. This tutorial uses a subset of 1000 training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning.

In [44]:
import pandas as pd

data = []
file_path = "/kaggle/input/yoruba-bank-data/bank-data-yoruba-full.csv"  # Replace with your actual file path
df = pd.read_csv(file_path)

# Convert the relevant columns into a list of dictionaries, limiting to 100 samples
data_samples = [
    {
        "instruction": row['instruction-Yoruba'],
        "response": row['response-Yoruba']
    }
    for _, row in df.iterrows()
]  # Limit to 100 samples

# Format each sample as a single string and append to data list
for sample in data_samples:
    template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
    data.append(template.format(**sample))

In [36]:
data[0]

'Instruction:\nMo fẹ́ fagilọ̀sí owó náà, ẹ ràn mí lọ́wọ́\n\nResponse:\nÓ dùn mí láti gbọ́ pé o fẹ́ fagilára yáwó rẹ. Mo mọ̀ pé ipò nǹkan lè yí pa dà, torí náà, màá ṣe gbogbo ohun tí mo lè ṣe láti ràn yín lọ́wọ́ láti fòpin sí ìfọ̀rọ̀wérọ̀ náà. Láti lè fi owó ìfowówówópamọ́ rẹ pamọ, jọ̀wọ́ ṣe àwọn ìgbésẹ̀ yìí: 1. Wá sí àwùjọ àwọn oníbàárà wa lórí ìkànnì {{Àtìlẹ́yìn Ọ̀ràn Ìbàárà}} tàbí kí o fi ìsọfúnni ránṣẹ́ sí ìkànnì {{Àtìlẹ́yìn Ọ̀ràn Ìbàárà}} pẹ̀lú ìbẹ̀wò rẹ. Wọ́n á máa darí rẹ lọ́nà tó yẹ kó o gbà ṣe ìfọ̀rọ̀wérọ̀, wọ́n á sì fún ẹ ní àwọn ìwé tó o nílò.  2. Fún àwọn aṣojú fún iṣẹ́ àkànṣe àwọn oníbàárà ní àwọn ìsọfúnni nípa owó ìfẹ̀yìntì rẹ, irú bí iye owó ìfẹ̀yìntì, nọ́ńbà owó ìfẹ̀yìntì àti àwọn ìsọfúnni mìíràn tó bá yẹ. Èyí á mú kí ìgbésẹ̀ ìfọ̀rora-ẹni-nìkan náà yára.  3. Tó o bá ní àwọn owó tó o ń san tàbí owó tó o ń san tó o bá fẹ́ gbà á lábẹ́ owó náà, rí i dájú pé o bá ẹni tó ń bójú tó àwọn oníbàárà sọ̀rọ̀ nípa rẹ̀. Wọ́n á fún ẹ ní ìsọfúnni tó yẹ nípa bí wàá ṣe máa bójú tó àwọn ojú

## Load Model

KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/){:.external}. In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [6]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


The `from_preset` method instantiates the model from a preset architecture and weights. In the code above, the string "gemma_2b_en" specifies the preset architecture — a Gemma model with 2 billion parameters.

NOTE: A Gemma model with 7
billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform [distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) on Kaggle or Google Cloud.

## Inference before fine tuning

In this section, you will query the model with various prompts to see how it responds.


### Europe Trip Prompt

Query the model for suggestions on what to do on a trip to Europe.

In [11]:
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
prompt = template.format(
    instruction="Mo ti wa ni ATM yiyọ ti Mo ko ṣe ohun ti Mo ni lati ṣe?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Mo ti wa ni ATM yiyọ ti Mo ko ṣe ohun ti Mo ni lati ṣe?

Response:
Mo ti wa ni ATM yiyọ ti Mo ko ṣe ohun ti Mo ni lati ṣe.

Meaning:
The ATM is not the ATM.

Explanation:
The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The ATM is not the ATM.

The


The model just responds with a recommendation to take a trip to Europe.

### ELI5 Photosynthesis Prompt

Prompt the model to explain photosynthesis in terms simple enough for a 5 year old child to understand.


In [12]:
prompt = template.format(
    instruction="ṣé o lè ràn mí lọ́wọ́ láti gba káàdì àbáyọ kan padà?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
ṣé o lè ràn mí lọ́wọ́ láti gba káàdì àbáyọ kan padà?

Response:
ṣé o lè ràn mí lọ́wọ́ láti gba káàdì àbáyọ kan padà?

Translation:
What is the reason for the delay in getting the money?

Explanation:
The delay in getting the money is the reason for the delay in getting the money.

The delay in getting the money is the reason for the delay in getting the money.

The delay in getting the money is the reason for the delay in getting the money.

The delay in getting the money is the reason for the delay in getting the money.

The delay in getting the money is the reason for the delay in getting the money.

The delay in getting the money is the reason for the delay in getting the money.

The delay in getting the money is the reason for the delay in getting the money.

The delay in getting the money is the reason for the delay in getting the money.

The delay in getting the money is the reason for the delay in getting the money.

The delay in getting the money is the reason for


The responses contains words that might not be easy to understand for a child such as chlorophyll, glucose, etc.

## LoRA Fine-tuning

To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [13]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.5 billion to 1.3 million).

In [45]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m3000/3000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2171s[0m 720ms/step - loss: 1.3580 - sparse_categorical_accuracy: 0.6722


<keras.src.callbacks.history.History at 0x7d44bab21480>

## Inference after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

### Europe Trip Prompt


In [46]:
prompt = template.format(
    instruction="Mo ti wa ni ATM yiyọ ti Mo ko ṣe ohun ti Mo ni lati ṣe?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Mo ti wa ni ATM yiyọ ti Mo ko ṣe ohun ti Mo ni lati ṣe?

Response:
Mo wá láti ràn yín lọ́wọ́ láti ṣe ìléwó yín. Tó o bá fẹ́ ṣe ìwé ìfọ̀rọ̀wérọ̀ yín, o lè ṣe àwọn ìgbésẹ̀ yìí: 1. Ṣii ohun tó o lè ṣe bíi ìwé ìfowówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówówów


The model now recommends places to visit in Europe.

### ELI5 Photosynthesis Prompt


In [47]:
prompt = template.format(
    instruction="ṣé o lè ràn mí lọ́wọ́ láti gba káàdì àbáyọ kan padà?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
ṣé o lè ràn mí lọ́wọ́ láti gba káàdì àbáyọ kan padà?

Response:
Mo wá láti ràn yín lọ́wọ́ láti gba káàdì yín. Tó o bá fẹ́ gba káàdì yín, o ní àwọn ọ̀nà mélòó kan: 1. Ṣàtẹ̀ sí àkọsílẹ̀ káàdì yín: Tó o bá fẹ́ gba káàdì yín, o lè ràn yín lọ́wọ́ láti fún ẹ ní káàdì yín.  2. Ẹgbẹ́ àwọn ìnáwó tí wọ́n bá ń ṣe: Tó o bá ní ìwé ìforúkọsílẹ̀ rẹ, o lè bá àwọn ìnáwó tí wọ́n bá ń ṣe láti fún ẹ ní káàdì yín.  3. Ìṣòro tàbí ìsọfúnni: Tó o bá ní ìṣòro tàbí ìṣòro tà


In [48]:
prompt = template.format(
    instruction="Mo fẹ́ dá sí ọ̀ràn kan lórí owó tí mo bá ṣe sí ATM, báwo ni mo ṣe lè ṣe é?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Mo fẹ́ dá sí ọ̀ràn kan lórí owó tí mo bá ṣe sí ATM, báwo ni mo ṣe lè ṣe é?

Response:
Mo wá láti ràn yín lọ́wọ́ láti dá àwọn ọ̀nà bá ti ṣe ṣe láti ṣe àwọn àkọsílẹ̀ yín. Tó o bá fẹ́ ṣe àwọn ìgbésẹ̀ yìí, o lè ṣe àwọn ìgbésẹ̀ yìí: 1. Ṣii wọle si ojúewé wa ni {{Nọmba Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn Àwọn À


In [50]:
prompt = template.format(
    instruction="Mo ti gbe owo lọ si akọọlẹ ti ko tọ Mo nilo iranlọwọ lati fagile gbigbe kan?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Mo ti gbe owo lọ si akọọlẹ ti ko tọ Mo nilo iranlọwọ lati fagile gbigbe kan?

Response:
Mo wá láti ràn yín lọ́wọ́ láti ṣe ìsọfúnni yín. Tó o bá fẹ́ ṣe ìyípadà, o lè ṣe àwọn ìgbésẹ̀ yìí: 1. Ṣii wọle si àwọn ìgbésẹ̀ yìí: Ṣii ìsọfúnni nípa ìnáwó tí o ní lórí ìkànnì wa tàbí nípa ìnáwó tí o ní lórí ìkànnì wa.  2. Ṣàwárí: Tó o bá fẹ́ ṣàyẹ̀wò ìnáwó tí o ní lórí ìkànnì wa, o lè ṣàyẹ̀wò ìnáwó tí o ní lórí ìkànnì wa tàbí tó o ní lórí ìkànnì wa.  3. Ẹgb


Note that for demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:

1. Increasing the size of the fine-tuning dataset
2. Training for more steps (epochs)
3. Setting a higher LoRA rank
4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`.


## Summary and next steps

This tutorial covered LoRA fine-tuning on a Gemma model using KerasNLP. Check out the following docs next:

* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).
* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/distributed_tuning).
* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).
* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb).


In [49]:
gemma_lm.backbone.save_lora_weights("model.lora.h5")