<link rel="stylesheet" href="/site-assets/css/gemma.css">
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200" />

# Fine-tune Gemma models in Keras using LoRA

## Overview

In this project, I demonstrate how to fine-tune the Gemma 2B model for conversational AI in the medical domain. 

### About Gemma
Gemma is a family of large language models designed for robust and scalable applications. With pretrained architectures optimized for versatility, Gemma models are particularly suitable for tasks involving natural language understanding and generation.

### About Low Rank Adaptation (LoRA)
[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685) is a technique that enables efficient fine-tuning of large language models by introducing trainable low-rank matrices. This approach reduces computational requirements while maintaining model performance, making it ideal for fine-tuning Gemma.

### Dataset
- **Source**: [Hugging Face Medical-Llama3 Fine-tune Dataset](https://huggingface.co/datasets/Pistachio-LLM/Medical-llama3-finetune-train)
- **Description**: A curated collection of over 37,000 medical conversational entries, optimized for healthcare-specific language tasks.
- **Justification**: Its diversity and focus on the medical domain make it a valuable dataset for enhancing conversational adaptability and accuracy.


## Setup

### Access to Gemma
I followed the setup instructions for [Gemma](https://ai.google.dev/gemma/docs/setup) to ensure smooth integration, including:
- Accessing Gemma on [Kaggle](https://kaggle.com).
- Configuring the runtime for the Gemma 2B model.
- Generating a Kaggle API key.


### Configure Environment
To integrate Gemma, I configured the required environment variables:
- Set `KAGGLE_USERNAME` and `KAGGLE_KEY` using the downloaded Kaggle API credentials.
- Ensured the environment is ready by validating access to `kagglehub`.

In [1]:
# !pip install kagglehub
import kagglehub

In [2]:
from dotenv import load_dotenv
import os

# Path and load the .env file 
dotenv_path = "../.env" 
load_dotenv(dotenv_path)

# Access the environment variables
kaggle_username = os.getenv("KAGGLE_USERNAME")
kaggle_key = os.getenv("KAGGLE_KEY")

In [3]:
!kaggle datasets list

ref                                                          title                                          size  lastUpdated          downloadCount  voteCount  usabilityRating  
-----------------------------------------------------------  --------------------------------------------  -----  -------------------  -------------  ---------  ---------------  
bhadramohit/customer-shopping-latest-trends-dataset          Customer Shopping (Latest Trends) Dataset      76KB  2024-11-23 15:26:12           4225         82  1.0              
ikynahidwin/depression-student-dataset                       Depression Student Dataset                      4KB  2024-11-20 06:42:01           3624         71  1.0              
steve1215rogg/student-lifestyle-dataset                      student lifestyle dataset                      22KB  2024-11-11 19:11:28           7006        114  1.0              
steve1215rogg/e-commerce-dataset                             E-Commerce Dataset                          

### Install Dependencies
I installed the required packages, including:
- **Keras**: For model training and customization.
- **KerasNLP**: For natural language processing utilities.
- **TensorFlow/JAX**: For backend support.
- Additional utilities like `pandas` and `numpy`.

In [4]:
!python --version

Python 3.10.15


In [5]:
!pip show keras
!pip show tensorflow
!pip show keras-nlp

Name: keras
Version: 3.6.0
Summary: Multi-backend Keras.
Home-page: https://github.com/keras-team/keras
Author: Keras team
Author-email: keras-users@googlegroups.com
License: Apache License 2.0
Location: /Users/babak/anaconda3/envs/llm/lib/python3.10/site-packages
Requires: absl-py, h5py, ml-dtypes, namex, numpy, optree, packaging, rich
Required-by: tensorflow
Name: tensorflow
Version: 2.18.0
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author: Google Inc.
Author-email: packages@tensorflow.org
License: Apache 2.0
Location: /Users/babak/anaconda3/envs/llm/lib/python3.10/site-packages
Requires: absl-py, astunparse, flatbuffers, gast, google-pasta, grpcio, h5py, keras, libclang, ml-dtypes, numpy, opt-einsum, packaging, protobuf, requests, setuptools, six, tensorboard, tensorflow-io-gcs-filesystem, termcolor, typing-extensions, wrapt
Required-by: tensorflow-text
Name: keras-nlp
Version: 0.17.0
Summary: Industry-streng

### Select a Backend
For this project, I utilized the JAX backend for efficiency:

In [6]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Tensorflow, Keras and KerasNLP.
Also pandas and numpy

In [7]:
import tensorflow as tf
import keras
import keras_nlp
import numpy as np
import pandas as pd

## Dataset Preparation
I loaded the Medical-Llama3 Fine-tune Dataset and preprocessed it to extract relevant features:

In [8]:
from datasets import Dataset
Dataset.cleanup_cache_files

<function datasets.arrow_dataset.Dataset.cleanup_cache_files(self) -> int>

In [9]:
from datasets import load_dataset

ds = load_dataset("Pistachio-LLM/Medical-llama3-finetune-train")

In [10]:
ds

DatasetDict({
    train: Dataset({
        features: ['output', 'input', 'instruction'],
        num_rows: 37179
    })
})

In [11]:
train_ds = ds['train']

# Function to extract inputs and outputs from the dataset
def extract_features(example):
    return {
        'input': example['input'],
        'instruction': example['instruction'],
        'output': example['output']
    }

# Map the dataset to extract features
train_ds = train_ds.map(extract_features)
train_ds = pd.DataFrame(train_ds)


In [41]:
# type(train_ds)
train_ds[['instruction', 'output', 'input']].head(4).style.format().set_caption("Medical Dataset")

Unnamed: 0,instruction,output,input
0,What does the Mesenchyme give rise to?,The Mesenchyme gives rise to most connective tissue.,
1,Which class of antimicrobials is known to displace unconjugated bilirubin from serum albumin in the blood?,Sulfonamides are known to displace unconjugated bilirubin from serum albumin in the blood.,
2,"In a female athlete who has amenorrhea and laboratory exam shows decreased FSH, LH, and estrogen levels, what is the likely diagnosis?","The likely diagnosis is hypogonadotropic hypogonadism, also known as hypothalamic amenorrhea. This is a condition where the hypothalamus in the brain does not release enough gonadotropin-releasing hormone (GnRH) to stimulate the pituitary gland to produce follicle-stimulating hormone (FSH) and luteinizing hormone (LH), which are necessary for ovulation and menstruation. As a result, estrogen levels are low, leading to amenorrhea. Female athletes are at increased risk of developing this condition due to the stress of exercise and low body fat. Treatment may involve lifestyle changes, such as reducing exercise and increasing caloric intake, as well as hormone therapy to stimulate ovulation and restore menstrual cycles.",
3,What does a physical examination for aortic dissection entail?,"Tachycardia may be present due to pain, anxiety, aortic rupture with massive bleeding, pericardial tamponade, aortic insufficiency with acute pulmonary edema and hypoxemia. Pulsus paradoxus (a drop of > 10 mmHg in arterial blood pressure on inspiration) may be present of pericardial tamponade develops. Pseudohypotension (falsely low blood pressure measurement) may occur due to involvement of the brachiocephalic artery (supplying the right arm) or the left subclavian artery (supplying the left arm). While many patients with an aortic dissection have a history of hypertension, the blood pressure is quite variable among patients with acute aortic dissection, and tends to be higher in individuals with a distal dissection. In individuals with a proximal aortic dissection, 36% present with hypertension, while 25% present with hypotension. In those that present with distal aortic dissections, 70% present with hypertension while 4% present with hypotension. A wide pulse pressure may be present if acute aortic insufficiency develops. Severe hypotension at presentation is a grave prognostic indicator. It is usually associated with pericardial tamponade, severe aortic insufficiency, or rupture of the aorta. Accurate measurement of the blood pressure is important. Swelling of the neck and face may be present due to compression of the superior vena cava or Superior vena cava syndrome Horner syndrome may be present due to compression of the superior cervical ganglia The patient may be hoarse due to compression of the left recurrent laryngeal nerve. Rales may be present due to cardiogenic pulmonary edema which may result from acute aortic regurgitation. Hemothorax and / or pleural effusion may cause dullness to percussion. Stridor and wheezing may be present due to compression of the airway Hemoptysis may be present due to compression of and erosion into the bronchus Aortic insufficiency occurs in 1/2 to 2/3 of ascending aortic dissections, and the murmur of aortic insufficiency is audible in about 32% of proximal dissections. The intensity (loudness) of the murmur is dependent on the blood pressure and may be inaudible in the event of hypotension. Aortic insufficiency is more commonly associated with type I or type II dissection. The murmur of aortic insufficiency (AI) due to aortic dissection is best heard at the right 2nd intercostal space (ICS), as compared with the lower left sternal border for AI due to primary aortic valvular disease. Beck's triad may be present: Hypotension (due to decreased stroke volume) Jugular venous distension (due to impaired venous return to the heart) Muffled heart sounds (due to fluid inside the pericardium) Distension of veins in the forehead and scalp Altered sensorium (decreasing Glasgow coma scale) Peripheral edema In addition to the Beck's triad and pulsus paradoxus the following can be found on cardiovascular examination: Pericardial rub Clicks - As ventricular volume shrinks disproportionately, there may be psuedoprolapse/true prolapse of mitral and/or tricuspid valvular structures that result in clicks. Kussmaul's sign - Decrease in jugular venous pressure with inspiration is uncommon. Diminution or absence of pulses is found in up to 40% of patients, and occurs due to occlusion of a major aortic branch. For this reason it is critical to assess the pulse and blood pressure in both arms. The iliac arteries may be affected as well. Neurologic deficits such as coma, altered mental status, Cerebrovascular accident (CVA) and vagal episodes are seen in up to 20%. There can also be focal neurologic signs due to occlusion of a spinal artery. This condition is known as Anterior spinal artery syndrome or ""Beck's syndrome"". Physical Examination Findings Evidence of insufficient blood supply: Absent pulse Systolic blood pressure difference Focal neurological deficit (along with pain) Aortic diastolic murmur (new and with pain) Hypotension or shock",


In [13]:
train_ds_list = [
    f"Instruction:\n{row['instruction']}\n\nResponse:\n{row['output']}"
    for index, row in train_ds.iterrows()  # Use iterrows to iterate over DataFrame rows
]

In [14]:
train_ds_list[:1]

['Instruction:\nWhat does the Mesenchyme give rise to?\n\nResponse:\nThe Mesenchyme gives rise to most connective tissue.']

This subset ensured faster execution during early experimentation.

In [15]:
# Only use 1000 training examples, to keep it fast.
data = train_ds_list[:1000]

In [16]:
#checking the datast
data[:5]

['Instruction:\nWhat does the Mesenchyme give rise to?\n\nResponse:\nThe Mesenchyme gives rise to most connective tissue.',
 'Instruction:\nWhich class of antimicrobials is known to displace unconjugated bilirubin from serum albumin in the blood?\n\nResponse:\nSulfonamides are known to displace unconjugated bilirubin from serum albumin in the blood.',
 'Instruction:\nIn a female athlete who has amenorrhea and laboratory exam shows decreased FSH, LH, and estrogen levels, what is the likely diagnosis?\n\nResponse:\nThe likely diagnosis is hypogonadotropic hypogonadism, also known as hypothalamic amenorrhea. This is a condition where the hypothalamus in the brain does not release enough gonadotropin-releasing hormone (GnRH) to stimulate the pituitary gland to produce follicle-stimulating hormone (FSH) and luteinizing hormone (LH), which are necessary for ovulation and menstruation. As a result, estrogen levels are low, leading to amenorrhea. Female athletes are at increased risk of develo

## Model Loading and Inference

### Load the Gemma Model


KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). In this project, I create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [20]:
from keras_nlp.models import GemmaCausalLM
gemma_lm = GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

The model architecture includes 2 billion parameters, optimized for causal language modeling.

## Inference before fine tuning

I tested the model's responses to initial prompts, such as diagnosing medical conditions or explaining complex terms in simple language:

### Coughing and Wheezing Prompt

Query the model for suggestions on the most probable diagnosis.

In [21]:
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

In [22]:
prompt = template.format(
    instruction="A young teenage boy experiences wheezing, coughing, and shortness of breath triggered by exposure to cold air, often worsening during or after physical activity outdoors in chilly weather.What is the likely diagnosis?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=512))

Instruction:
A young teenage boy experiences wheezing, coughing, and shortness of breath triggered by exposure to cold air, often worsening during or after physical activity outdoors in chilly weather.What is the likely diagnosis?

Response:
Exercise-induced asthma, also known as exercise-induced bronchoconstriction (EIB), is a type of asthma that occurs in individuals who are sensitive to the cold, and is caused by exercise-induced bronchoconstriction. It is characterized by the narrowing of the airways, making it difficult to breathe.

Symptoms may include wheezing, coughing, shortness of breath, and chest pain. It may be aggravated by physical activity, cold temperatures, or exertion and may improve with rest. The condition is often exacerbated by exposure to cold air or cold weather.

Treatment typically involves avoiding exposure to cold air or cold weather and taking medication such as inhaled steroids or bronchodilators. It can be managed through regular exercise and physical co

### Next step Prompt

Prompt the model to suggest what is a the next step based on a senario.

In [23]:
prompt = template.format(
    instruction="Hi, I’ve been feeling short of breath for the past two days, and today I also started having mild chest pain. What should I do?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=512))

Instruction:
Hi, I’ve been feeling short of breath for the past two days, and today I also started having mild chest pain. What should I do?

Response:
Hi, I’m glad you’re feeling a little better, but I’m sorry to see you’re experiencing some symptoms like shortness of breath (which is a common sign of anxiety) and chest pain.

It can be difficult to figure out what’s causing your symptoms, especially because there are so many possible causes. But it’s important to rule out any serious conditions or health issues first.

If you haven’t done so already, it’s a good idea to see a healthcare professional who can conduct a thorough assessment and determine the cause of your symptoms. This will help you get the right treatment and feel better.

I hope you’re feeling a little better and I’m wishing you a good night’s sleep.

Best regards,

Dr. N

Instruction:
I’ve been having chest pain for a few weeks now, and it’s getting worse. What should I do?

Response:
It can be difficult to figure ou

These initial results provided a baseline for comparison after fine-tuning.

## LoRA Fine-tuning

- LoRA Rank: Set to 8, balancing computational efficiency and expressive power.
- Optimizer: AdamW, configured for transformer models.

In [24]:
# Enable LoRA for the model and set the LoRA rank to 8.
gemma_lm.backbone.enable_lora(rank=8)
gemma_lm.summary()

In [25]:
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

The fine-tuning process trained the model on a subset of data for one epoch:

In [26]:
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23885s[0m 24s/step - loss: 0.8988 - sparse_categorical_accuracy: 0.6055


<keras.src.callbacks.history.History at 0x487de27a0>

In [27]:
# from keras_nlp.models import GemmaCausalLM

# # Define the preset and weights file path
# preset = "gemma_2b_en"  # Replace with the correct preset for your model
# weights_file = "/Users/babak/Documents/Model/model.weights.h5"

# # Initialize the model from the preset
# gemma_lm = GemmaCausalLM.from_preset(preset)

# # Load the custom weights
# gemma_lm.load_weights(weights_file)

# # Display the model summary
# gemma_lm.summary()

## Post-Tuning Evaluation

### Improved Inference
After fine-tuning, I observed improved contextual accuracy in responses:

### Coughing and Wheezing Prompt

In [28]:
prompt = template.format(
    instruction="A young teenage boy experiences wheezing, coughing, and shortness of breath triggered by exposure to cold air, often worsening during or after physical activity outdoors in chilly weather.What is the likely diagnosis?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=512))

Instruction:
A young teenage boy experiences wheezing, coughing, and shortness of breath triggered by exposure to cold air, often worsening during or after physical activity outdoors in chilly weather.What is the likely diagnosis?

Response:
This patient has symptoms suggestive of asthma, which is a common cause of chronic airway disease in children.


The model responds with most probable diagnosis.

### Next step Prompt

Prompt the model to suggest what is a the next step based on a senario.

In [42]:
prompt = template.format(
    instruction="Hi, I’ve been feeling short of breath for the past two days, and today I also started having mild chest pain. What should I do?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=512))

Instruction:
Hi, I’ve been feeling short of breath for the past two days, and today I also started having mild chest pain. What should I do?

Response:
Hi, I’m sorry that you’re feeling short of breath and having chest pain. This is a serious medical condition and it’s important that you get it checked out as soon as possible. Please call your healthcare provider or go to the emergency department of your nearest hospital if you are experiencing chest pain or shortness of breath. If you are unable to reach your healthcare provider or if you’re not able to get to a hospital, please call 911 or your local emergency number.


## Summary
This project demonstrates how LoRA fine-tuning can enhance a Gemma 2B model's performance for medical conversational tasks. With focused datasets and efficient techniques, the chatbot is now better equipped to handle diverse prompts, providing accurate and user-friendly responses.

## Save the model & Publish the model on Kaggle as a Kaggle Model

In [43]:
# Save the finetuned model as a KerasNLP preset.
preset_dir = "/Users/babak/Documents"
gemma_lm.save_to_preset(preset_dir)

# Upload the preset as a new model variant on Kaggle
kaggle_uri = f"kaggle://{kaggle_username}/gemmed/keras/finetuned_gpt2"
keras_nlp.upload_preset(kaggle_uri, preset_dir)

Uploading Model https://www.kaggle.com/models/babakdavani/gemmed/keras/finetuned_gpt2 ...
Starting upload for file /Users/babak/Documents/preprocessor.json


Uploading: 100%|███████████████████████████| 1.42k/1.42k [00:00<00:00, 1.87kB/s]

Upload successful: /Users/babak/Documents/preprocessor.json (1KB)
Starting upload for file /Users/babak/Documents/Untitled.drawio







Uploading: 100%|███████████████████████████████| 85.0/85.0 [00:00<00:00, 104B/s]

Upload successful: /Users/babak/Documents/Untitled.drawio (85B)
Starting upload for file /Users/babak/Documents/.DS_Store







Uploading: 100%|███████████████████████████| 6.15k/6.15k [00:00<00:00, 7.13kB/s]

Upload successful: /Users/babak/Documents/.DS_Store (6KB)
Starting upload for file /Users/babak/Documents/.localized







Uploading: 0.00B [00:00, ?B/s]

Upload successful: /Users/babak/Documents/.localized (0B)
Starting upload for file /Users/babak/Documents/config.json







Uploading: 100%|███████████████████████████████| 785/785 [00:00<00:00, 1.02kB/s]

Upload successful: /Users/babak/Documents/config.json (785B)
Starting upload for file /Users/babak/Documents/task.json







Uploading: 100%|███████████████████████████| 2.98k/2.98k [00:00<00:00, 4.01kB/s]

Upload successful: /Users/babak/Documents/task.json (3KB)
Starting upload for file /Users/babak/Documents/tokenizer.json







Uploading: 100%|█████████████████████████████████| 591/591 [00:00<00:00, 778B/s]

Upload successful: /Users/babak/Documents/tokenizer.json (591B)
Starting upload for file /Users/babak/Documents/metadata.json







Uploading: 100%|█████████████████████████████████| 143/143 [00:00<00:00, 168B/s]

Upload successful: /Users/babak/Documents/metadata.json (143B)
Starting upload for file /Users/babak/Documents/model.weights.h5







Uploading: 100%|██████████████████████████| 10.0G/10.0G [2:54:48<00:00, 956kB/s]

Upload successful: /Users/babak/Documents/model.weights.h5 (9GB)
Starting upload for file /Users/babak/Documents/GitHub/desktop-tutorial/README.md







Uploading: 100%|█████████████████████████████████| 206/206 [00:00<00:00, 256B/s]

Upload successful: /Users/babak/Documents/GitHub/desktop-tutorial/README.md (206B)
Starting upload for file /Users/babak/Documents/.ipynb_checkpoints/Clustering Codealong_Student_Facing_Checkpoint2-checkpoint.ipynb







Uploading: 100%|████████████████████████████| 1.31M/1.31M [00:02<00:00, 575kB/s]

Upload successful: /Users/babak/Documents/.ipynb_checkpoints/Clustering Codealong_Student_Facing_Checkpoint2-checkpoint.ipynb (1MB)
Starting upload for file /Users/babak/Documents/.ipynb_checkpoints/Clustering Codealong_Student_Facing_Checkpoint3-checkpoint.ipynb







Uploading: 100%|████████████████████████████| 1.92M/1.92M [00:02<00:00, 699kB/s]

Upload successful: /Users/babak/Documents/.ipynb_checkpoints/Clustering Codealong_Student_Facing_Checkpoint3-checkpoint.ipynb (2MB)
Starting upload for file /Users/babak/Documents/.ipynb_checkpoints/PCA-checkpoint.ipynb







Uploading: 100%|████████████████████████████| 2.03M/2.03M [00:03<00:00, 661kB/s]

Upload successful: /Users/babak/Documents/.ipynb_checkpoints/PCA-checkpoint.ipynb (2MB)
Starting upload for file /Users/babak/Documents/.ipynb_checkpoints/V4_Clustering Codealong-checkpoint.ipynb







Uploading: 100%|████████████████████████████| 2.88M/2.88M [00:04<00:00, 712kB/s]

Upload successful: /Users/babak/Documents/.ipynb_checkpoints/V4_Clustering Codealong-checkpoint.ipynb (3MB)
Starting upload for file /Users/babak/Documents/.ipynb_checkpoints/Clustering Codealong_Student_Facing_Template-checkpoint.ipynb







Uploading: 100%|███████████████████████████| 10.6k/10.6k [00:00<00:00, 14.5kB/s]

Upload successful: /Users/babak/Documents/.ipynb_checkpoints/Clustering Codealong_Student_Facing_Template-checkpoint.ipynb (10KB)
Starting upload for file /Users/babak/Documents/assets/tokenizer/vocabulary.spm







Uploading: 100%|████████████████████████████| 4.24M/4.24M [00:05<00:00, 812kB/s]

Upload successful: /Users/babak/Documents/assets/tokenizer/vocabulary.spm (4MB)





Your model instance has been created.
Files are being processed...
See at: https://www.kaggle.com/models/babakdavani/gemmed/keras/finetuned_gpt2
