##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tune Gemma models in Keras using LoRA

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/lora_tuning"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335"><img src="https://ai.google.dev/images/cloud-icon.svg" width="40" />Open in Vertex AI</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

In [1]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.0"
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

### Import packages

Import Keras and KerasNLP.

In [4]:
import keras
import keras_nlp

## Load Dataset

Preprocess the data. This tutorial uses a subset of 1000 training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning.

In [5]:
import json
data = []
with open('/kaggle/input/nutritionx-data/first_data_revised.jsonl') as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        #if features["context"]:
#             continue
        # Format the entire example as a single string.

        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
# data = data[:1000]

## Load Model

KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/){:.external}. In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [6]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


## Inference before fine tuning

In this section, you will query the model with various prompts to see how it responds.


### Balanced Diet Prompt

Ask what a balanced diet is.

In [None]:
prompt = template.format(
    instruction="What is a balanced diet?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

The model answers are too verbose.

### Breakfast Prompt

Prompt the model to suggest breakfast ideas.


In [None]:
prompt = template.format(
    instruction='I am a 70 year old moderately active man, what is a good breakfast option?',
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

The model doesn't answer in easy to understand ways.

## LoRA Fine-tuning

To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [7]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.5 billion to 1.3 million).

In [8]:
import random
random.seed(42)
data_sample = random.sample(data, 15000)

In [9]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data_sample, epochs=1, batch_size=1)

[1m15000/15000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10926s[0m 727ms/step - loss: 0.1006 - sparse_categorical_accuracy: 0.7534


<keras.src.callbacks.history.History at 0x7a86d012c580>

### Save Model - Weights

In [11]:
gemma_lm.save_to_preset("./gemma-nutritionx-2b")

# Upload the preset as a new model variant on Kaggle
kaggle_url = "kaggle://favouryahdii/gemma-nutritionx/keras/gemma-nutritionx-2b"
keras_nlp.upload_preset(kaggle_url, "./gemma-nutritionx-2b")

Uploading Model https://www.kaggle.com/models/favouryahdii/gemma-nutritionx/keras/gemma-nutritionx-2b ...
Model 'gemma-nutritionx' does not exist or access is forbidden for user 'favouryahdii'. Creating or handling Model...
Model 'gemma-nutritionx' Created.
Starting upload for file ./gemma-nutritionx-2b/tokenizer.json


Uploading: 100%|██████████| 367/367 [00:00<00:00, 582B/s]

Upload successful: ./gemma-nutritionx-2b/tokenizer.json (367B)
Starting upload for file ./gemma-nutritionx-2b/preprocessor.json



Uploading: 100%|██████████| 899/899 [00:00<00:00, 1.55kB/s]

Upload successful: ./gemma-nutritionx-2b/preprocessor.json (899B)
Starting upload for file ./gemma-nutritionx-2b/config.json



Uploading: 100%|██████████| 785/785 [00:00<00:00, 1.30kB/s]

Upload successful: ./gemma-nutritionx-2b/config.json (785B)
Starting upload for file ./gemma-nutritionx-2b/task.json



Uploading: 100%|██████████| 2.34k/2.34k [00:00<00:00, 3.64kB/s]

Upload successful: ./gemma-nutritionx-2b/task.json (2KB)
Starting upload for file ./gemma-nutritionx-2b/model.weights.h5



Uploading: 100%|██████████| 10.0G/10.0G [03:00<00:00, 55.5MB/s]

Upload successful: ./gemma-nutritionx-2b/model.weights.h5 (9GB)
Starting upload for file ./gemma-nutritionx-2b/metadata.json



Uploading: 100%|██████████| 143/143 [00:00<00:00, 242B/s]

Upload successful: ./gemma-nutritionx-2b/metadata.json (143B)
Starting upload for file ./gemma-nutritionx-2b/assets/tokenizer/vocabulary.spm



Uploading: 100%|██████████| 4.24M/4.24M [00:00<00:00, 5.89MB/s]

Upload successful: ./gemma-nutritionx-2b/assets/tokenizer/vocabulary.spm (4MB)





Your model instance has been created.
Files are being processed...
See at: https://www.kaggle.com/models/favouryahdii/gemma-nutritionx/keras/gemma-nutritionx-2b


In [None]:
# save model weights
# gemma_lm.save_weights('nutrition_x.weights.h5')

## Inference after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

In [None]:
prompt = template.format(
    instruction="What is a balanced diet?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

In [None]:
prompt = template.format(
    instruction="I am a 25 year old active man, what is a good breakfast option?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

### Adding Function Calling


In [2]:
import requests
import spacy
import keras
import keras_nlp
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("google 2")

# Load spaCy's pre-trained English model
nlp = spacy.load('en_core_web_sm')

# Define template for instruction and response
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

base = keras_nlp.models.CausalLM.from_preset("kaggle://favouryahdii/gemma-nutritionx/keras/gemma-nutritionx-2b")

# Step 2: Function to process user input and make a prediction
def process_user_input(user_input, api_data="", max_length=256):
    # Preprocess the input
    processed_input = template.format(instruction=user_input, response=api_data)
    
    # Assuming the model (base) is already defined and loaded
    prediction = base.generate(processed_input, max_length)  # Adjust based on your model's predict method
    return prediction

# Function to extract the place using spaCy's NER
def extract_place_spacy(query):
    doc = nlp(query)
    for ent in doc.ents:
        if ent.label_ in ['ORG', 'GPE', 'PERSON']:  # GPE: Geopolitical Entity (like cities, countries)
            return ent.text
    return None

# Step 3: Function to convert a place to latitude and longitude using Google Geocoding API
def get_location(place):
    api_key = secret_value_0
    base_url = "https://maps.googleapis.com/maps/api/geocode/json?address="

    # Extract the place name using the NER model
    place_name = extract_place_spacy(place)

    if not place_name:
        return {"error": "No place found in the query"}

    # Properly format the URL
    url = f"{base_url}{place_name}&key={api_key}"
    response = requests.get(url)
    data = response.json()

    # Check if the response contains results
    if len(data['results']) > 0:
        location = data['results'][0]['geometry']['location']
        return location  # This will return latitude and longitude
    else:
        return {"error": "Could not find the location"}

# Step 4: Function to call an external API (Google Places API example)
def get_nearby_places(location, radius=1000, place_type='convenience_store'):
    api_key = secret_value_0
    base_url = "https://maps.googleapis.com/maps/api/place/nearbysearch/json"

    # Ensure the location (lat, lng) is passed correctly from get_location
    if "error" in location:
        return location  # Return the error if location is not found

    params = {
        'location': f"{location['lat']},{location['lng']}",
        'radius': radius,
        'type': place_type,
        'key': api_key
    }

    response = requests.get(base_url, params=params)

    if response.status_code == 200:
        places = response.json()['results']
        place_names = [place['name'] for place in places]
        return ", ".join(place_names)
    else:
        return {"error": "Failed to retrieve places information"}

# Function to extract the place type from the user query
def extract_place_type(query):
    # Define common place types (you can extend this list as needed)
    place_types = ['convenience_store', 'restaurant', 'gym', 'drugstore', 'hospital', 'department_store', 'pharmacy', 'physiotherapist']
    
    # Lowercase the user query for easier matching
    query_lower = query.lower()

    # Check if any place type is mentioned in the query
    for place_type in place_types:
        if place_type in query_lower:
            return place_type

    return "convenience_store"  # Default place type if none is found


# Step 5: Handle user input, model response, and function calling
def handle_function_call(user_input):
    # Always generate a model response first
    model_response = process_user_input(user_input)

    # Check if the query is related to location or nearby search
    if 'nearby' in user_input.lower() or 'location' in user_input.lower():
        # Extract the place name
        location = get_location(user_input)  # Extract the place and get its coordinates (latitude, longitude)

        # Extract the place type from the user query dynamically
        place_type = extract_place_type(user_input)

        # Set default values for radius
        radius = 1000

        # Call the API to get nearby places using the latitude and longitude from get_location
        api_response = get_nearby_places(location, radius, place_type)

        # If there is valid API data, append it to the model's response
        if "error" not in api_response:
            return model_response + "\n\nHere are some nearby places I found: " + api_response
        else:
            return model_response + "\n\nHowever, I couldn't find any nearby places due to: " + api_response['error']

    # Return the regular model response if not location-related
    return model_response


# Step 6: Example user query
user_query = "Can you find nearby grocery stores to University of Leeds?"
response = handle_function_call(user_query)
print(response)


normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
I0000 00:00:1729449374.454995      30 service.cc:145] XLA service 0x56e0330341b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1729449374.455054      30 service.cc:153]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
I0000 00:00:1729449381.978361      30 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Instruction:
Can you find nearby grocery stores to University of Leeds?

Response:
There are 12 grocery stores in walking distance to University of Leeds, and the nearest one is 0.10 miles away.

Here are some nearby places I found: One Stop, Best-one, Yonisa Store, Tesco Express, Bestnes, Nisa Local, Mace, Tesco Express, Tesco Express, Around The Clock, Co-op Food @ LUU, Yonisa Local 7 days, Blenheim convenience store,, Everyday Essentials Store, Morrisons Daily, Tesco Express, Go Local - Hyde Park Stores, Londis Woodhouse Street Convenience Store / Post Office, Woodhouse Mini Market, Co-op Food - Leeds - Burley Street
