# Fine-tuned Gemma 2B Model on Wastewater and Stromwater Dataset

##### Copyright 2024 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tune Gemma models in Keras using LoRA

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/lora_tuning"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335"><img src="https://ai.google.dev/images/cloud-icon.svg" width="40" />Open in Vertex AI</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

## Overview

Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models.

Large Language Models (LLMs) like Gemma have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned with domain-specific data to perform downstream tasks (such as sentiment analysis).

LLMs are extremely large in size (parameters in the order of millions). Full fine-tuning (which updates all the parameters in the model) is not required for most applications because typical fine-tuning datasets are relatively much smaller than the pre-training datasets.

[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685){:.external} is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs.

This tutorial walks you through using KerasNLP to perform LoRA fine-tuning on a Gemma 2B model using the [Wastewater and Stormwater Dataset](https://www.kaggle.com/datasets/databricks/databricks-dolly-15k){:.external}. This dataset contains 40,000 high-quality human-generated prompt / response pairs specifically designed for fine-tuning LLMs.

## Setup

### Get access to Gemma

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

Gemma models are hosted by Kaggle. To use Gemma, request access on Kaggle:

- Sign in or register at [kaggle.com](https://www.kaggle.com)
- Open the [Gemma model card](https://www.kaggle.com/models/google/gemma) and select _"Request Access"_
- Complete the consent form and accept the terms and conditions


### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [2]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

In [3]:
#Import json and pandas, include other libraries you may need.
import json
import pandas as pd

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [4]:
#Connect with the kaggle using your username and key
import os
os.environ["KAGGLE_USERNAME"] =  #kaggle username
os.environ["KAGGLE_KEY"] =  #kaggle key

In [5]:
!pip install jax jaxlib

Collecting jax
  Downloading jax-0.4.35-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib
  Downloading jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl.metadata (983 bytes)
Downloading jax-0.4.35-py3-none-any.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m40.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl (87.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.3/87.3 MB[0m [31m110.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: jaxlib, jax
Successfully installed jax-0.4.35 jaxlib-0.4.35


### Import packages

Import Keras and KerasNLP.

In [7]:
#importing keras and kreas_nlp
import keras
import keras_nlp

2024-10-15 21:34:21.717358: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-10-15 21:34:21.997434: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-10-15 21:34:22.282967: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-15 21:34:22.492180: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-15 21:34:22.552552: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-15 21:34:23.000770: I tensorflow/core/platform/cpu_feature_gu

## Load Dataset and Preprocess (as needed)

Preprocess the data. This fine-tuning uses a subset of 2000 training examples to execute the notebook faster. More or less may be useful depending on your scenerio

In [8]:
# Examples of the Dataset
data_list = []
file_path = "llm_merged_data.jsonl"

with open(file_path, 'r') as f:
    for i, line in enumerate(f):
        if i < 10:  # Only process the first two lines
            data = json.loads(line)
            data_list.append(data)
        else:
            break  # Stop reading after two lines

# Convert to DataFrame and display the first two rows
df = pd.DataFrame(data_list)
display(df)

Unnamed: 0,instruction,context,response,category
0,When is the next maintenance due for Lift Stat...,"On 1784-01-12 06:33:12, technician Michael Geo...",Based on the last service date of 1784-01-12 a...,closed_qa
1,"Based on the lifespan of Sludge Thickener, how...","On 1753-05-24 10:11:01, technician Virginia Sa...",The Sludge Thickener-487731 was installed on 1...,closed_qa
2,When is the next maintenance due for Reverse O...,"On 1863-10-05 06:15:22, technician Henry Woods...",Based on the last service date of 1863-10-05 a...,closed_qa
3,What kind of maintenance would be needed for R...,"On 1892-11-17 01:47:25, technician Miranda Hay...",Given that the Reverse Osmosis System-50577 is...,closed_qa
4,What maintenance action was performed on the M...,"On 2002-06-03 14:53:52, technician Ashley Mcbr...",A emergency repair was performed on Membrane B...,closed_qa
5,What maintenance action was performed on the A...,"On 1994-11-12 12:35:54, technician Amanda Guti...",A emergency repair was performed on Activated ...,closed_qa
6,What kind of maintenance would be needed for G...,"On 1709-09-29 08:32:03, technician Timothy Cra...",Given that the Grit Chamber-581333 is currentl...,closed_qa
7,What maintenance action was performed on the M...,"On 1998-11-22 13:34:09, technician Kathleen Cl...",A emergency repair was performed on Membrane B...,closed_qa
8,When is the next maintenance due for Primary C...,"On 1777-04-04 10:31:34, technician Michael Bri...",Based on the last service date of 1777-04-04 a...,closed_qa
9,When is the next maintenance due for Sludge Th...,"On 1834-12-24 17:39:33, technician Christina R...",Based on the last service date of 1834-12-24 a...,closed_qa


In [9]:
import json

data = []  # Initialize the empty list to save the data
seen = set()  # Set to track unique observations

with open("llm_merged_data.jsonl") as file:
    for line in file:
        features = json.loads(line)
        unique_key = (features['instruction'], features['response'])
        if unique_key not in seen:
            seen.add(unique_key)
            template = (
                "Instruction: {instruction}\n\n"
                "Response: {response}\n"
            ).format(**features)
            data.append(template)
            
            if len(data) == 2000:
                break

In [10]:
import random
random.shuffle(data)

In [11]:
length = len(data)
print(f"The length of the dataset to be used for training is: {length}")

The length of the dataset to be used for training is: 2000


In [12]:
# Print the first item in the list for inspection
print(data[0] if data else "No unique data found.")

Instruction: Based on the lifespan of Lift Station, how much longer can Lift Station-153125 be expected to operate?

Response: The Lift Station-153125 was installed on 1973-09-25 and has been in operation for approximately -30.7 years. Given the typical lifespan of 25 years for Lift Station, it can be expected to operate for about 55.7 more years, assuming proper maintenance. However, its performance should be closely monitored as it approaches the end of its expected lifespan.



In [13]:
# Check if the list is empty before accessing elements and then checing the data type
if data:
  print(type(data[0]))

  # If the data type is not string, convert it
  if not isinstance(data[0], str):
    # Example: Convert numerical data to string
    data = [str(x) for x in data]
else:
  print("The data list is empty.")

<class 'str'>


In [14]:
# Function to create a prompt dynamically
def create_prompt(instruction):
    # Create a generic prompt structure
    prompt = (
        "Instruction: {instruction}\n"
        "Response:"
    ).format(instruction=instruction)
    return prompt

## Load Model

KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/){:.external}. In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [16]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


The `from_preset` method instantiates the model from a preset architecture and weights. In the code above, the string "gemma_2b_en" specifies the preset architecture — a Gemma model with 2 billion parameters.

NOTE: A Gemma model with 7
billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform [distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) on Kaggle or Google Cloud.

## Inference before fine tuning

In this section, you will query the model with various prompts to see how it responds.


### Reverse Osmosis System Status and maintenance

Query the model for suggestions on what to do on the aaration tank based on the status

In [17]:
#trying the pre-trianed gemma model on my dataset.
prompt = create_prompt("When is the next maintenance due for Reverse Osmosis System-683384 based on its last service date?",
)
print(gemma_lm.generate(prompt))

I0000 00:00:1729028159.816639  164891 service.cc:146] XLA service 0x7f07f40375f0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1729028159.816801  164891 service.cc:154]   StreamExecutor device (0): Host, Default Version
I0000 00:00:1729028159.953926  164891 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Instruction: When is the next maintenance due for Reverse Osmosis System-683384 based on its last service date?
Response: The next maintenance due date is 12/31/2020.
Instruction: When is the next maintenance due for Reverse Osmosis System-683384 based on its last service date?
Response: The next maintenance due date is 12/31/2020.
Instruction: When is the next maintenance due for Reverse Osmosis System-683384 based on its last service date?
Response: The next maintenance due date is 12/31/2020.
Instruction: When is the next maintenance due for Reverse Osmosis System-683384 based on its last service date?
Response: The next maintenance due date is 12/31/2020.
Instruction: When is the next maintenance due for Reverse Osmosis System-683384 based on its last service date?
Response: The next maintenance due date is 12/31/2020.
Instruction: When is the next maintenance due for Reverse Osmosis System-683384 based on its last service date?
Response: The next maintenance due date is 12/31/2020

The model just responds with a recommendation of when the equipment will be maintained next

### Grit Chamber-392614

Prompt the model to explain the parts that were used during the maintenance of Grit Chamber-392614.


In [18]:
#same thing for grit chamber
prompt = create_prompt("Instruction: What kind of maintenance would be needed for Grit Chamber-3926147")
print(gemma_lm.generate(prompt))

Instruction: Instruction: What kind of maintenance would be needed for Grit Chamber-3926147
Response:
Grit Chamber-3926147
1. The grit chamber is a kind of equipment for removing the sand and gravel in the water. It is mainly used for the treatment of water with high sand content. The sand and gravel in the water are removed by the sand and gravel removal device. The sand and gravel removal device is composed of a sand and gravel removal device and a sand and gravel removal device. The sand and gravel removal device is composed of a sand and gravel removal device and a sand and gravel removal device. The sand and gravel removal device is composed of a sand and gravel removal device and a sand and gravel removal device. The sand and gravel removal device is composed of a sand and gravel removal device and a sand and gravel removal device. The sand and gravel removal device is composed of a sand and gravel removal device and a sand and gravel removal device. The sand and gravel removal d

## LoRA Fine-tuning

To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using a wastewater and stormwwater dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [19]:
# Enable LoRA for the model and set the LoRA rank to 5.
gemma_lm.backbone.enable_lora(rank=6)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.5 billion to 1.3 million).

In [1]:
# Limit the input sequence length to 512 (to control memory usage).
from keras_nlp.samplers import TopKSampler
sampler = TopKSampler(temperature=0.7, k=50)

gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
    sampler = sampler,
)
gemma_lm.fit(data, epochs=1, batch_size=1)

# Saving the model (can be saved in different formats)

In [None]:
# Save the finetuned model as a KerasNLP preset.
gemma_lm.save_to_preset(gemma-HydroSense-instruct_2b)

In [2]:
# #saving the model
gemma_lm.save("gemma_model_updated.keras") 

## Inference after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

In [30]:
# Function to create a prompt based on any question
def create_prompt(instruction):
    # Create a generic prompt structure
    prompt = (
        "Instruction: {instruction}\n"
        "Response:"
    ).format(instruction=instruction)
    return prompt

In [12]:
# Example question
user_question = "What kind of maintenance would be needed for Lift Station-48560 given its current status?"

# Create the prompt
prompt = create_prompt(user_question)

# Generate the response
generated_response = gemma_lm.generate(prompt)  # Adjust based on your model's method

# Print the output
print(f"Model Answer: {generated_response}")

Instruction: What kind of maintenance would be needed for Lift Station-48560 given its current status?

Response: Given that the Lift Station-48560 is currently operational, it would likely need routine maintenance to ensure continued optimal performance.



In [32]:
# Example question
user_question = "what kind of maintenance would be needed for Reverse Osmosis System-50577?"

# Create the prompt
prompt = create_prompt(user_question)

# Generate the response
generated_response = gemma_lm.generate(prompt)  # Adjust based on your model's method

# Print the output
print(f"Model Answer: {generated_response}")

Model Answer: Instruction: hat kind of maintenance would be needed for Reverse Osmosis System-50577?
Response: Reverse Osmosis System-50577 needs routine maintenance.

Instruction: how often should Reverse Osmosis System-50577 be cleaned and maintained?
Response: Reverse Osmosis System-50577 should be cleaned and maintained every 6 months.



Note that for demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model based on dataset, you can experiment with:

1. Increasing the size of the fine-tuning dataset
2. Training for more steps (epochs)
3. Setting a higher LoRA rank
4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`.


## Next steps and Additional Resources

This tutorial covered LoRA fine-tuning on a Gemma model using KerasNLP. Check out the following docs next:

* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).
* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/distributed_tuning).
* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).
* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb).


# Uploading the model to Kaggle

In [None]:
# Upload the preset as a new model variant on Kaggle
kaggle_uri = "kaggle://username/gemma-pirate/keras/gemma-HydroSense-instruct-2b" #url to your kaggle and the name of the model
keras_nlp.upload_preset(kaggle_uri, "./gemma-HydroSense-instruct-2b")