<link rel="stylesheet" href="/site-assets/css/gemma.css">
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200" />

# Fine-tune Gemma models in Keras using LoRA

## Overview

In this project, I demonstrate how to fine-tune the Gemma 2B model for conversational AI in the medical domain. 

### About Gemma
Gemma is a family of large language models designed for robust and scalable applications. With pretrained architectures optimized for versatility, Gemma models are particularly suitable for tasks involving natural language understanding and generation.

### About Low Rank Adaptation (LoRA)
[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685) is a technique that enables efficient fine-tuning of large language models by introducing trainable low-rank matrices. This approach reduces computational requirements while maintaining model performance, making it ideal for fine-tuning Gemma.

### Dataset
- **Source**: [Hugging Face Medical-Llama3 Fine-tune Dataset](https://huggingface.co/datasets/Pistachio-LLM/Medical-llama3-finetune-train)
- **Description**: A curated collection of over 37,000 medical conversational entries, optimized for healthcare-specific language tasks.
- **Justification**: Its diversity and focus on the medical domain make it a valuable dataset for enhancing conversational adaptability and accuracy.


## Setup

### Access to Gemma
I followed the setup instructions for [Gemma](https://ai.google.dev/gemma/docs/setup) to ensure smooth integration, including:
- Accessing Gemma on [Kaggle](https://kaggle.com).
- Configuring the runtime for the Gemma 2B model.
- Generating a Kaggle API key.


### Configure Environment
To integrate Gemma, I configured the required environment variables:
- Set `KAGGLE_USERNAME` and `KAGGLE_KEY` using the downloaded Kaggle API credentials.
- Ensured the environment is ready by validating access to `kagglehub`.

In [1]:
# !pip install kagglehub
import kagglehub

In [2]:
from dotenv import load_dotenv
import os

# Path and load the .env file 
dotenv_path = "../.env" 
load_dotenv(dotenv_path)

# Access the environment variables
kaggle_username = os.getenv("KAGGLE_USERNAME")
kaggle_key = os.getenv("KAGGLE_KEY")

In [3]:
!kaggle datasets list

ref                                                           title                                              size  lastUpdated          downloadCount  voteCount  usabilityRating  
------------------------------------------------------------  ------------------------------------------------  -----  -------------------  -------------  ---------  ---------------  
muhammadroshaanriaz/students-performance-dataset-cleaned      Students Performance | Clean Dataset               10KB  2024-10-29 19:32:26           7803        148  1.0              
whisperingkahuna/footballers-with-50-international-goals-men  Footballers with 50+ International Goals [men]      3KB  2024-11-17 12:51:23           1121         25  1.0              
daniellopez01/credit-risk                                     credit_risk                                        13KB  2024-11-17 22:13:54            899         29  1.0              
steve1215rogg/student-lifestyle-dataset                       student lifestyle 

### Install Dependencies
I installed the required packages, including:
- **Keras**: For model training and customization.
- **KerasNLP**: For natural language processing utilities.
- **TensorFlow/JAX**: For backend support.
- Additional utilities like `pandas` and `numpy`.

In [4]:
!python --version

Python 3.10.15


In [5]:
!pip show keras
!pip show tensorflow
!pip show keras-nlp

Name: keras
Version: 3.6.0
Summary: Multi-backend Keras.
Home-page: https://github.com/keras-team/keras
Author: Keras team
Author-email: keras-users@googlegroups.com
License: Apache License 2.0
Location: /Users/babak/anaconda3/envs/llm/lib/python3.10/site-packages
Requires: absl-py, h5py, ml-dtypes, namex, numpy, optree, packaging, rich
Required-by: tensorflow
Name: tensorflow
Version: 2.18.0
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author: Google Inc.
Author-email: packages@tensorflow.org
License: Apache 2.0
Location: /Users/babak/anaconda3/envs/llm/lib/python3.10/site-packages
Requires: absl-py, astunparse, flatbuffers, gast, google-pasta, grpcio, h5py, keras, libclang, ml-dtypes, numpy, opt-einsum, packaging, protobuf, requests, setuptools, six, tensorboard, tensorflow-io-gcs-filesystem, termcolor, typing-extensions, wrapt
Required-by: tensorflow-text
Name: keras-nlp
Version: 0.17.0
Summary: Industry-streng

### Select a Backend
For this project, I utilized the JAX backend for efficiency:

In [6]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Tensorflow, Keras and KerasNLP.
Also pandas and numpy

In [7]:
import tensorflow as tf
import keras
import keras_nlp
import numpy as np
import pandas as pd

## Dataset Preparation
I loaded the Medical-Llama3 Fine-tune Dataset and preprocessed it to extract relevant features:

In [8]:
from datasets import Dataset
Dataset.cleanup_cache_files

<function datasets.arrow_dataset.Dataset.cleanup_cache_files(self) -> int>

In [9]:
from datasets import load_dataset

ds = load_dataset("Pistachio-LLM/Medical-llama3-finetune-train")

In [10]:
ds

DatasetDict({
    train: Dataset({
        features: ['output', 'input', 'instruction'],
        num_rows: 37179
    })
})

In [11]:
train_ds = ds['train']

# Function to extract inputs and outputs from the dataset
def extract_features(example):
    return {
        'input': example['input'],
        'instruction': example['instruction'],
        'output': example['output']
    }

# Map the dataset to extract features
train_ds = train_ds.map(extract_features)
train_ds = pd.DataFrame(train_ds)


In [12]:
# type(train_ds)
train_ds[:1]

Unnamed: 0,output,input,instruction
0,The Mesenchyme gives rise to most connective t...,,What does the Mesenchyme give rise to?


In [13]:
train_ds_list = [
    f"Instruction:\n{row['instruction']}\n\nResponse:\n{row['output']}"
    for index, row in train_ds.iterrows()  # Use iterrows to iterate over DataFrame rows
]

In [14]:
train_ds_list[:1]

['Instruction:\nWhat does the Mesenchyme give rise to?\n\nResponse:\nThe Mesenchyme gives rise to most connective tissue.']

This subset ensured faster execution during early experimentation.

In [15]:
# Only use 1000 training examples, to keep it fast.
data = train_ds_list[:1000]

In [16]:
#checking the datast
data[:5]

['Instruction:\nWhat does the Mesenchyme give rise to?\n\nResponse:\nThe Mesenchyme gives rise to most connective tissue.',
 'Instruction:\nWhich class of antimicrobials is known to displace unconjugated bilirubin from serum albumin in the blood?\n\nResponse:\nSulfonamides are known to displace unconjugated bilirubin from serum albumin in the blood.',
 'Instruction:\nIn a female athlete who has amenorrhea and laboratory exam shows decreased FSH, LH, and estrogen levels, what is the likely diagnosis?\n\nResponse:\nThe likely diagnosis is hypogonadotropic hypogonadism, also known as hypothalamic amenorrhea. This is a condition where the hypothalamus in the brain does not release enough gonadotropin-releasing hormone (GnRH) to stimulate the pituitary gland to produce follicle-stimulating hormone (FSH) and luteinizing hormone (LH), which are necessary for ovulation and menstruation. As a result, estrogen levels are low, leading to amenorrhea. Female athletes are at increased risk of develo

## Model Loading and Inference

### Load the Gemma Model


KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). In this project, I create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [17]:
from keras_nlp.models import GemmaCausalLM
gemma_lm = GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


The model architecture includes 2 billion parameters, optimized for causal language modeling.

## Inference before fine tuning

I tested the model's responses to initial prompts, such as diagnosing medical conditions or explaining complex terms in simple language:

### Coughing and Wheezing Prompt

Query the model for suggestions on the most probable diagnosis.

In [18]:
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

In [19]:
prompt = template.format(
    instruction="A young teenage boy experiences wheezing, coughing, and shortness of breath triggered by exposure to cold air, often worsening during or after physical activity outdoors in chilly weather.What is the likely diagnosis?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=512))

Instruction:
A young teenage boy experiences wheezing, coughing, and shortness of breath triggered by exposure to cold air, often worsening during or after physical activity outdoors in chilly weather.What is the likely diagnosis?

Response:
Exercise-induced asthma, also known as exercise-induced bronchoconstriction (EIB), is a type of asthma that occurs in individuals who are sensitive to the cold, and is caused by exercise-induced bronchoconstriction. It is characterized by the narrowing of the airways, making it difficult to breathe.

Symptoms may include wheezing, coughing, shortness of breath, and chest pain. It may be aggravated by physical activity, cold temperatures, or exertion and may improve with rest. The condition is often exacerbated by exposure to cold air or cold weather.

Treatment typically involves avoiding exposure to cold air or cold weather and taking medication such as inhaled steroids or bronchodilators. It can be managed through regular exercise and physical co

### ELI5 chemotherapy Prompt

Prompt the model to explain chemotherapy in terms simple enough for a 5 year old child to understand.

In [20]:
prompt = template.format(
    instruction="Explain the process of chemotherapy in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of chemotherapy in a way that a child could understand.

Response:
Chemotherapy is a treatment that uses drugs to kill or slow the growth of cancer cells. Chemotherapy works best when it is used early in cancer treatment. It helps kill cancer cells that are already in your body and prevent new cancer from growing. Chemotherapy can also shrink tumors or make them easier to remove.


These initial results provided a baseline for comparison after fine-tuning.

## LoRA Fine-tuning

- LoRA Rank: Set to 8, balancing computational efficiency and expressive power.
- Optimizer: AdamW, configured for transformer models.

In [21]:
# Enable LoRA for the model and set the LoRA rank to 8.
gemma_lm.backbone.enable_lora(rank=8)
gemma_lm.summary()

In [22]:
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

The fine-tuning process trained the model on a subset of data for one epoch:

In [None]:
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m   1/1000[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m10:01:51[0m 36s/step - loss: 0.3563 - sparse_categorical_accuracy: 0.4516

## Post-Tuning Evaluation

### Improved Inference
After fine-tuning, I observed improved contextual accuracy in responses:

### Coughing and Wheezing Prompt

In [None]:
prompt = template.format(
    instruction="A young teenage boy experiences wheezing, coughing, and shortness of breath triggered by exposure to cold air, often worsening during or after physical activity outdoors in chilly weather.What is the likely diagnosis?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=512))

The model responds with most probable diagnosis.

### ELI5 Photosynthesis Prompt

In [None]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

The model now explains photosynthesis in simpler terms.

## Summary
This project demonstrates how LoRA fine-tuning can enhance a Gemma 2B model's performance for medical conversational tasks. With focused datasets and efficient techniques, the chatbot is now better equipped to handle diverse prompts, providing accurate and user-friendly responses.