## Model used

We are using **Gemma** model with 2B parameters, Keras, English version, v2.

## Dataset

We will fine-tune **Gemma** using a [Medical Q & A](https://www.kaggle.com/datasets/gpreda/medquad/) dataset. This is a subset of the full public dataset [Healthcare NLP: LLMs, Transformers, Datasets](https://www.kaggle.com/datasets/jpmiller/layoutlm).



# Prepare packages


We will install updated version of Keras, KerasNLP, which we need for fine-tuning, and other dependencies.

In [1]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U tf-keras
!pip install -q -U keras-nlp==0.10.0
!pip install -q -U keras>=3

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.
tensorflow-decision-forests 1.8.1 requires tensorflow~=2.15.0, but you have tensorflow 2.18.0 which is incompatible.
tensorflow-text 2.15.0 requires tensorflow<2.16,>=2.15.0; platform_machine != "arm64" or platform_system != "Darwin", but you have tensorflow 2.18.0 which is incompatible.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency 

In [2]:
import os
os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.
os.environ["JAX_PLATFORMS"] = ""

In [3]:
import keras_nlp
import keras
import csv

print("KerasNLP version: ", keras_nlp.__version__)
print("Keras version: ", keras.__version__)

2024-11-26 16:45:11.269102: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-26 16:45:11.269156: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-26 16:45:11.270961: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


KerasNLP version:  0.10.0
Keras version:  3.6.0


# Load the model

In [4]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'task.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' f

In [5]:
gemma_lm.summary()

# Prepare the training data

We prepare the **Medical Q & A** data for training. We will load the data using the template where, for each data that will be included in the training set, we provide pairs of questions and answers.

In [7]:
from datasets import load_dataset

# Load the ChatDoctor-HealthCareMagic-100k dataset
ds = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k")

# Prepare the data
data = []

# Iterate through the dataset and format the questions and answers
for example in ds['train']:  # Adjust to use 'validation' or 'test' splits if needed
    instruction = example.get('instruction', '')  # Instruction field
    input_text = example.get('input', '')         # User input or question
    output_text = example.get('output', '')       # Model's response or answer

    # Template for formatted data
    template = f"Instruction:\n{instruction}\n\nQuestion:\n{input_text}\n\nAnswer:\n{output_text}"
    data.append(template)

# Print some samples to verify
for sample in data[:5]:
    print(sample)


Instruction:
If you are a doctor, please answer the medical questions based on the patient's description.

Question:
I woke up this morning feeling the whole room is spinning when i was sitting down. I went to the bathroom walking unsteadily, as i tried to focus i feel nauseous. I try to vomit but it wont come out.. After taking panadol and sleep for few hours, i still feel the same.. By the way, if i lay down or sit down, my head do not spin, only when i want to move around then i feel the whole world is spinning.. And it is normal stomach discomfort at the same time? Earlier after i relieved myself, the spinning lessen so i am not sure whether its connected or coincidences.. Thank you doc!

Answer:
Hi, Thank you for posting your query. The most likely cause for your symptoms is benign paroxysmal positional vertigo (BPPV), a type of peripheral vertigo. In this condition, the most common symptom is dizziness or giddiness, which is made worse with movements. Accompanying nausea and vomi

In [8]:
len(data)

112165

In [9]:
data = data[:300]

# Check model inference before fine tuning

We wil first check the model before proceeding to fine-tuning. We will test it with some questions about medical matters.  

First, we will define an utility function to display the query and answer from LLM.

In [10]:
from IPython.display import display, Markdown
def colorize_text(text):
    for word, color in zip(["Category", "Question", "Answer"], ["blue", "red", "green"]):
        text = text.replace(f"{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

Let's check how we can display the content of one data input using the `colorize_text` function.

In [11]:
print(data[3])

Instruction:
If you are a doctor, please answer the medical questions based on the patient's description.

Question:
lump under left nipple and stomach pain (male) Hi,I have recently noticed a few weeks ago a lump under my nipple, it hurts to touch and is about the size of a quarter. Also I have bern experiencing stomach pains that prevent me from eating. I immediatly feel full and have extreme pain. Please help

Answer:
HI. You have two different problems. The lump under the nipple should be removed, biopsied. This will help you to get rid of the disease, and you get a diagnosis. Second problem looks a bit serious one


In [12]:
display(Markdown(colorize_text(data[3])))

Instruction:
If you are a doctor, please answer the medical questions based on the patient's description.



**<font color='red'>Question:</font>**
lump under left nipple and stomach pain (male) Hi,I have recently noticed a few weeks ago a lump under my nipple, it hurts to touch and is about the size of a quarter. Also I have bern experiencing stomach pains that prevent me from eating. I immediatly feel full and have extreme pain. Please help



**<font color='green'>Answer:</font>**
HI. You have two different problems. The lump under the nipple should be removed, biopsied. This will help you to get rid of the disease, and you get a diagnosis. Second problem looks a bit serious one

Now we will ask the model to answer to a question for which we know the expected answer.

In [13]:
prompt = template.format(
    question="What are the complications of Paget's Disease of Bone ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=128)
display(Markdown(colorize_text(response)))

Instruction:
If you are a doctor, please answer the medical questions based on the patient's description.



**<font color='red'>Question:</font>**
Within the past few hours, my husband has developed a rash that covers his back, neck, and upper arms. He seems to have no other symptoms. He has been taking antibiotics (Levaquin, sorry about the spelling) and just finished the last one today. Was in the hospital last week for 2 days with bronchitis / pneumonia. Has also been taking Mucinex. He has gone to bed in the last few minutes and is sleeping peacefully. I was very nervous about the rash and wondered

In [14]:
prompt = template.format(
    question="What are the treatments for Diabetes ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=128)
display(Markdown(colorize_text(response)))

Instruction:
If you are a doctor, please answer the medical questions based on the patient's description.



**<font color='red'>Question:</font>**
Within the past few hours, my husband has developed a rash that covers his back, neck, and upper arms. He seems to have no other symptoms. He has been taking antibiotics (Levaquin, sorry about the spelling) and just finished the last one today. Was in the hospital last week for 2 days with bronchitis / pneumonia. Has also been taking Mucinex. He has gone to bed in the last few minutes and is sleeping peacefully. I was very nervous about the rash and wondered

# Fine-tunning with LoRA   


We are using now **LoRA** for fine-tunning. **LoRA** stands for **Low Rank Adaptation** and is a method for modifying a pretrained model (for example, an LLM or vision transformer) to better suit a specific, often smaller, dataset by **adjusting only a small, low-rank subset of the model's parameters**.


The rank used here for LoRA controls the number of parameters that will be recalculated during fine-tuning.

In [15]:
# Enable LoRA for the model and set the LoRA rank to 3.
gemma_lm.backbone.enable_lora(rank=3)
gemma_lm.summary()

In [16]:
# Fine-tune on the Medical QA dataset.

# Limit the input sequence length to 128 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=10, batch_size=1)

Epoch 1/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m164s[0m 492ms/step - loss: 2.9420 - sparse_categorical_accuracy: 0.4163
Epoch 2/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m159s[0m 504ms/step - loss: 2.4272 - sparse_categorical_accuracy: 0.5061
Epoch 3/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m151s[0m 504ms/step - loss: 2.3851 - sparse_categorical_accuracy: 0.5124
Epoch 4/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m151s[0m 504ms/step - loss: 2.3590 - sparse_categorical_accuracy: 0.5158
Epoch 5/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m151s[0m 502ms/step - loss: 2.3390 - sparse_categorical_accuracy: 0.5177
Epoch 6/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m152s[0m 505ms/step - loss: 2.3194 - sparse_categorical_accuracy: 0.5198
Epoch 7/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m152s[0m 505ms/step - loss: 2.2978 - sparse_categorical_accuracy: 0.5223

<keras.src.callbacks.history.History at 0x7e9c8e3c1180>

# Inference after fine tuning
We will run now the queries through the fine-tuned model.

In [17]:
prompt = template.format(
    question="What are the complications of Paget's Disease of Bone ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=128)
display(Markdown(colorize_text(response)))

Instruction:
If you are a doctor, please answer the medical questions based on the patient's description.



**<font color='red'>Question:</font>**
Within the past few hours, my husband has developed a rash that covers his back, neck, and upper arms. He seems to have no other symptoms. He has been taking antibiotics (Levaquin, sorry about the spelling) and just finished the last one today. Was in the hospital last week for 2 days with bronchitis / pneumonia. Has also been taking Mucinex. He has gone to bed in the last few minutes and is sleeping peacefully. I was very nervous about the rash and wondered

In [18]:
print(response)


Instruction:
If you are a doctor, please answer the medical questions based on the patient's description.

Question:
Within the past few hours, my husband has developed a rash that covers his back, neck, and upper arms. He seems to have no other symptoms. He has been taking antibiotics (Levaquin, sorry about the spelling) and just finished the last one today. Was in the hospital last week for 2 days with bronchitis / pneumonia. Has also been taking Mucinex. He has gone to bed in the last few minutes and is sleeping peacefully. I was very nervous about the rash and wondered


In [19]:
prompt = template.format(
    question="What are the treatments for Diabetes ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=128)
display(Markdown(colorize_text(response)))

Instruction:
If you are a doctor, please answer the medical questions based on the patient's description.



**<font color='red'>Question:</font>**
Within the past few hours, my husband has developed a rash that covers his back, neck, and upper arms. He seems to have no other symptoms. He has been taking antibiotics (Levaquin, sorry about the spelling) and just finished the last one today. Was in the hospital last week for 2 days with bronchitis / pneumonia. Has also been taking Mucinex. He has gone to bed in the last few minutes and is sleeping peacefully. I was very nervous about the rash and wondered

# Save the model

In [None]:
preset = "./medical_gemma_Healthcare"
# Save the model to the preset directory.
gemma_lm.save_to_preset(preset)