## Model used

We are using **Gemma** model with 2B parameters, Keras, English version, v2.

## Dataset

We will fine-tune **Gemma** using a [Medical Q & A](https://www.kaggle.com/datasets/gpreda/medquad/) dataset. This is a subset of the full public dataset [Healthcare NLP: LLMs, Transformers, Datasets](https://www.kaggle.com/datasets/jpmiller/layoutlm).



# Prepare packages


We will install updated version of Keras, KerasNLP, which we need for fine-tuning, and other dependencies.

In [1]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U tf-keras
!pip install -q -U keras-nlp==0.10.0
!pip install -q -U keras>=3

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.
tensorflow-decision-forests 1.8.1 requires tensorflow~=2.15.0, but you have tensorflow 2.18.0 which is incompatible.
tensorflow-text 2.15.0 requires tensorflow<2.16,>=2.15.0; platform_machine != "arm64" or platform_system != "Darwin", but you have tensorflow 2.18.0 which is incompatible.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency 

In [2]:
import os
os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.
os.environ["JAX_PLATFORMS"] = ""

In [3]:
import keras_nlp
import keras
import csv

print("KerasNLP version: ", keras_nlp.__version__)
print("Keras version: ", keras.__version__)

2024-11-22 10:58:29.470786: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-22 10:58:29.470836: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-22 10:58:29.472503: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


KerasNLP version:  0.10.0
Keras version:  3.6.0


# Load the model

In [4]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'task.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' f

In [5]:
gemma_lm.summary()

# Prepare the training data

We prepare the **Medical Q & A** data for training. We will load the data using the template where, for each data that will be included in the training set, we provide pairs of questions and answers.

In [6]:
from datasets import load_dataset

# Load the PubMedQA dataset
ds = load_dataset("qiaojin/PubMedQA", "pqa_artificial")

# Prepare the data in a similar format
data = []

# Iterate through the dataset and format the questions and answers
for example in ds['train']:  # You can also use 'validation' or 'test' splits
    question = example['question']
    answer = example['final_decision']  # Use 'final_decision' for binary yes/no answers or 'context' for detailed context
    template = f"Question:\n{question}\n\nAnswer:\n{answer}"
    data.append(template)

# Print some samples
for sample in data[:5]:
    print(sample)


Downloading readme:   0%|          | 0.00/5.19k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/233M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/211269 [00:00<?, ? examples/s]

Question:
Are group 2 innate lymphoid cells ( ILC2s ) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?

Answer:
yes
Question:
Does vagus nerve contribute to the development of steatohepatitis and obesity in phosphatidylethanolamine N-methyltransferase deficient mice?

Answer:
yes
Question:
Does psammaplin A induce Sirtuin 1-dependent autophagic cell death in doxorubicin-resistant MCF-7/adr human breast cancer cells and xenografts?

Answer:
yes
Question:
Is methylation of the FGFR2 gene associated with high birth weight centile in humans?

Answer:
yes
Question:
Do tumor-infiltrating immune cell profiles and their change after neoadjuvant chemotherapy predict response and prognosis of breast cancer?

Answer:
yes


In [7]:
len(data)

211269

In [8]:
data = data[:300]

# Check model inference before fine tuning

We wil first check the model before proceeding to fine-tuning. We will test it with some questions about medical matters.  

First, we will define an utility function to display the query and answer from LLM.

In [9]:
from IPython.display import display, Markdown
def colorize_text(text):
    for word, color in zip(["Category", "Question", "Answer"], ["blue", "red", "green"]):
        text = text.replace(f"{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

Let's check how we can display the content of one data input using the `colorize_text` function.

In [10]:
print(data[3])

Question:
Is methylation of the FGFR2 gene associated with high birth weight centile in humans?

Answer:
yes


In [11]:
display(Markdown(colorize_text(data[3])))



**<font color='red'>Question:</font>**
Is methylation of the FGFR2 gene associated with high birth weight centile in humans?



**<font color='green'>Answer:</font>**
yes

Now we will ask the model to answer to a question for which we know the expected answer.

In [12]:
prompt = template.format(
    question="What are the complications of Paget's Disease of Bone ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=128)
display(Markdown(colorize_text(response)))



**<font color='red'>Question:</font>**
Is low intramucosal pH associated with failure to acidify the gastric lumen in response to pentagastrin?



**<font color='green'>Answer:</font>**
yes



**<font color='red'>Question:</font>**
What is the effect of pentagastrin on the pH of the gastric lumen?



**<font color='green'>Answer:</font>**
Pentagastrin increases the pH of the gastric lumen.



**<font color='red'>Question:</font>**
What is the effect of pentagastrin on the pH of the gastric lumen?



**<font color='green'>Answer:</font>**
Pentagastrin increases the pH of the gastric lumen.



**<font color='red'>Question:</font>**
What is the effect of pentagastrin on the pH of the gastric lumen?



In [13]:
prompt = template.format(
    question="What are the treatments for Diabetes ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=128)
display(Markdown(colorize_text(response)))



**<font color='red'>Question:</font>**
Is low intramucosal pH associated with failure to acidify the gastric lumen in response to pentagastrin?



**<font color='green'>Answer:</font>**
yes



**<font color='red'>Question:</font>**
What is the effect of pentagastrin on the pH of the gastric lumen?



**<font color='green'>Answer:</font>**
Pentagastrin increases the pH of the gastric lumen.



**<font color='red'>Question:</font>**
What is the effect of pentagastrin on the pH of the gastric lumen?



**<font color='green'>Answer:</font>**
Pentagastrin increases the pH of the gastric lumen.



**<font color='red'>Question:</font>**
What is the effect of pentagastrin on the pH of the gastric lumen?



# Fine-tunning with LoRA   


We are using now **LoRA** for fine-tunning. **LoRA** stands for **Low Rank Adaptation** and is a method for modifying a pretrained model (for example, an LLM or vision transformer) to better suit a specific, often smaller, dataset by **adjusting only a small, low-rank subset of the model's parameters**.


The rank used here for LoRA controls the number of parameters that will be recalculated during fine-tuning.

In [14]:
# Enable LoRA for the model and set the LoRA rank to 3.
gemma_lm.backbone.enable_lora(rank=3)
gemma_lm.summary()

In [15]:
# Fine-tune on the Medical QA dataset.

# Limit the input sequence length to 128 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=10, batch_size=1)

Epoch 1/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m140s[0m 413ms/step - loss: 0.7900 - sparse_categorical_accuracy: 0.4900
Epoch 2/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m132s[0m 415ms/step - loss: 0.5686 - sparse_categorical_accuracy: 0.5743
Epoch 3/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 416ms/step - loss: 0.5548 - sparse_categorical_accuracy: 0.5803
Epoch 4/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 416ms/step - loss: 0.5439 - sparse_categorical_accuracy: 0.5854
Epoch 5/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 415ms/step - loss: 0.5326 - sparse_categorical_accuracy: 0.5902
Epoch 6/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 415ms/step - loss: 0.5205 - sparse_categorical_accuracy: 0.5950
Epoch 7/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 415ms/step - loss: 0.5071 - sparse_categorical_accuracy: 0.6029

<keras.src.callbacks.history.History at 0x781846a11540>

# Inference after fine tuning
We will run now the queries through the fine-tuned model.

In [16]:
prompt = template.format(
    question="What are the complications of Paget's Disease of Bone ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=128)
display(Markdown(colorize_text(response)))



**<font color='red'>Question:</font>**
Is low intramucosal pH associated with failure to acidify the gastric lumen in response to pentagastrin?



**<font color='green'>Answer:</font>**
yes

In [18]:
prompt = template.format(
    question="What are the treatments for Diabetes ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=128)
display(Markdown(colorize_text(response)))



**<font color='red'>Question:</font>**
Is low intramucosal pH associated with failure to acidify the gastric lumen in response to pentagastrin?



**<font color='green'>Answer:</font>**
yes

# Save the model

In [19]:
preset = "./medical_gemma_pubmed"
# Save the model to the preset directory.
gemma_lm.save_to_preset(preset)