# Fine‑tuning **Gemma‑3 Instruct (1 B)** on the *MTS‑Dialog* medical‑conversation dataset  

*An end‑to‑end, LoRA‑based workflow with evaluation & demo*  


**Provenance / Credits**

* Originally authored in Google Colab.  
* Refactored and cleaned up for GitHub readability – May 2025.  
* Dataset: **MTS‑Dialog** (© 2024, MIT‑licensed)  
* Model: **Gemma‑3 Instruct 1 B** via `keras‑hub`.  

> The notebook assumes Google Colab **or** a local machine with a modern GPU and official Gemma access.  


In [1]:
%%capture
# ─── Environment setup ────────────────────────────────────────────────────────
# Feel free to skip if these libraries are already installed.
!pip install -q -U keras-hub keras keras-nlp rouge_score scipy tqdm ipywidgets


## 🔑 Reproducibility – set a single global RNG seed

In [2]:
import os, random, gc, json, textwrap
import numpy as np
import tensorflow as tf

def set_global_seed(seed: int = 42):
    """Seed Python, NumPy, TF & Keras RNGs (see TF docs for details)."""
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    tf.keras.utils.set_random_seed(seed)
    print(f"✅ Seed set to {seed}")

SEED = 42
set_global_seed(SEED)


✅ Seed set to 42


In [3]:
# ─── Core libraries ──────────────────────────────────────────────────────────
import os, random, gc, pprint, itertools, math, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy import stats
from rouge_score import rouge_scorer
import torch

import tensorflow as tf
import keras
import keras_hub
import keras_nlp
from keras.callbacks import EarlyStopping

# ⬇️  Optional: Colab‑only helper to pull your Kaggle creds from the browser
try:
    from google.colab import userdata  # type: ignore
    os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
    os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")
except (ImportError, ModuleNotFoundError):
    print("Not on Colab – set $KAGGLE_USERNAME and $KAGGLE_KEY manually if needed.")


## 📝 Prompt template (single definition – DRY)

In [4]:
# Using Gemma's native chat format for best performance
TEMPLATE = """<start_of_turn>user

Instruction:
{instruction}

Doctor-Patient Dialogue:
{dialogue}
<end_of_turn>

<start_of_turn>model
"""

instruction = (
    "Summarize the following doctor–patient conversation into a concise past-tense clinical note beginning with ‘The patient…’. "
    "Include all medically relevant facts, especially symptoms, diagnoses, medications (with dosages), and procedures **if mentioned**. "
    "Do not mention categories that are not discussed. Omit small talk or irrelevant details. "
    "If there is nothing clinically relevant, respond with ‘None.’"
)

formatted_template = TEMPLATE.format(instruction=instruction, dialogue="{dialogue}")
print(formatted_template)

<start_of_turn>user

Instruction:
Summarize the following doctor–patient conversation into a concise past-tense clinical note beginning with ‘The patient…’. Include all medically relevant facts, especially symptoms, diagnoses, medications (with dosages), and procedures **if mentioned**. Do not mention categories that are not discussed. Omit small talk or irrelevant details. If there is nothing clinically relevant, respond with ‘None.’

Doctor-Patient Dialogue:
{dialogue}
<end_of_turn>

<start_of_turn>model



## 🔢 Token length utilities for data filtering

In [5]:
def estimate_tokens(text: str) -> int:
    """Rough token estimate (1 token ≈ 4 characters for most models)."""
    return len(text) // 4

def filter_data_by_length(
    prompts,
    responses,
    max_input_tokens: int = 1500,
    min_prompt_len: int = 10,
    min_response_len: int = 10,
    end_token: str = "<end_of_turn>",
):
    """
    Clean + length-filter prompt/response pairs.
    """
    assert len(prompts) == len(responses), "Prompts and responses must align"

    filtered_prompts, filtered_responses = [], []

    # Drop pairs with None or empty strings.
    for prompt, response in zip(prompts, responses):
        if not prompt or not response:              # None or empty string
            continue

        prompt, response = prompt.strip(), response.strip()
        # Keep samples whose character lengths meet `min_*_len` and `max_input_tokens`
        if (
            len(prompt) < min_prompt_len
            or len(response) < min_response_len
            or estimate_tokens(prompt) > max_input_tokens
        ):
            continue
        # Append `end_token`
        if not response.endswith(end_token):
            response += f"\n{end_token}"

        filtered_prompts.append(prompt)
        filtered_responses.append(response)

    kept = len(filtered_prompts)
    total = len(prompts)
    print(f"📊 Kept {kept:,} / {total:,} pairs ({kept/total*100:.1f}%)")

    return filtered_prompts, filtered_responses

In [6]:
def compile_with_sampler(model, k: int = 5, seed: int = 42):
    """Attach a deterministic Top‑K sampler & compile the model."""
    sampler = keras_nlp.samplers.TopKSampler(k=k, seed=seed)
    model.compile(sampler=sampler)
    return model

### 🔍 Quick sanity‑check inference with the *base* Gemma model

In [None]:
# Free GPU memory before loading base model (prevents OOM errors)

if 'base_model' in globals():
    del base_model
    gc.collect()
    tf.keras.backend.clear_session()

base_model = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_1b")
compile_with_sampler(base_model)

sample_dialogue = """Doctor: How are you feeling today?
Patient: I'm still having chest pains.
Doctor: Are they sharp or dull?
Patient: More of a crushing pressure right here."""

prompt = TEMPLATE.format(
    instruction=instruction,
    dialogue='"""' + sample_dialogue + '"""'
)

print(base_model.generate(prompt, max_length=512))


## 📚 Load & prepare the *MTS‑Dialog* training set

*MTS‑Dialog* (`1.7 k` doctor–patient conversations with expert summaries).  
Below we:  
1. Fetch or load the CSV  
2. (Optionally) **sample 500 rows** for a quick LoRA demo  
3. Wrap dialogue text in triple quotes to preserve line breaks  
4. Build two parallel lists `prompts` and `responses` expected by `keras‑hub`  


In [7]:
DATA_PATH = "/content/MTS-Dialog-TrainingSet.csv"  # adjust if running locally

df = pd.read_csv(DATA_PATH)
print(f"💾 Loaded dataset with {len(df):,} rows")

# --- speed‑up during experimentation ---
# df = df.sample(n=500, random_state=SEED)  # comment‑out for full training

# --- minimal preprocessing ---
df = df.dropna(subset=["dialogue", "section_text"])

# Remove rows where 'section_text' is very short
df = df.query("section_text.str.len() > 10", engine="python")

# Trim whitespace and add end token to 'section_text'
df["section_text"] = df["section_text"].str.strip() + "\n<end_of_turn>"

df['dialogue'] = '"""' + df['dialogue'] + '"""'

prompts, responses = [], []
for row in df.itertuples(index=False):
    prompts.append(
        TEMPLATE.format(instruction=instruction, dialogue=row.dialogue)
    )
    responses.append(row.section_text)

# Stage 2: Token-length filtering (after prompt formatting)

prompts, responses = filter_data_by_length(prompts, responses, max_input_tokens=1024)

data = {'prompts': prompts, 'responses': responses}
print(f"✅ Prepared {len(prompts):,} filtered prompt/response pairs")

💾 Loaded dataset with 1,201 rows
📊 Kept 1,115 / 1,121 pairs (99.5%)
✅ Prepared 1,115 filtered prompt/response pairs


## 🔧 LoRA fine‑tuning

> **Tip:** a single T4 GPU needs ~8 GB for rank 8.  
> Mixed‑precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`)  
> can further reduce memory but is optional.  


In [8]:
# Clean start – free GPU RAM if notebook was re‑run
if 'gemma_lm' in globals():
    del gemma_lm
    gc.collect()
    tf.keras.backend.clear_session()

# Mixed precision
keras.mixed_precision.set_global_policy('mixed_bfloat16')

# Load model WITH pre-trained weights, then enable LoRA
gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_1b")  # load_weights=True by default
gemma_lm.backbone.enable_lora(rank=8)  # rank=8: balance between adaptation capacity and efficiency

compile_with_sampler(gemma_lm, k=5, seed=SEED)

gemma_lm.summary

early_stop = EarlyStopping(
    monitor="val_loss",
    patience=2,
    restore_best_weights=False,
)

VAL_FRACTION = 0.10
rng = np.random.default_rng(SEED)

prompts_arr   = np.array(data["prompts"],    dtype=object)
responses_arr = np.array(data["responses"],  dtype=object)

idx           = rng.permutation(len(prompts_arr))
val_size      = int(len(idx) * VAL_FRACTION)

val_idx, train_idx = idx[:val_size], idx[val_size:]

train_data = {
    "prompts":   prompts_arr[train_idx],
    "responses": responses_arr[train_idx],
}
val_data = {
    "prompts":   prompts_arr[val_idx],
    "responses": responses_arr[val_idx],
}

print(f"🔹 Train samples: {len(train_idx)}")
print(f"🔹 Val   samples: {len(val_idx)}")

history = gemma_lm.fit(
    train_data,
    validation_data=val_data,   # 10 % of the batches become validation data
    epochs=2,
    batch_size=1,
    callbacks=[early_stop],
    shuffle=True,            # keeps training / validation split random each run
)

# Save the adapters so the model can be re-used elsewhere
gemma_lm.backbone.save_lora_weights("lora_rank8.weights.lora.h5")

# Note : We will reuse the adapter later in the Playground section

🔹 Train samples: 1004
🔹 Val   samples: 111
Epoch 1/5
[1m502/502[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1487s[0m 3s/step - loss: 0.2083 - sparse_categorical_accuracy: 0.5058 - val_loss: 0.1358 - val_sparse_categorical_accuracy: 0.5562
Epoch 2/5
[1m502/502[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1325s[0m 3s/step - loss: 0.1319 - sparse_categorical_accuracy: 0.5541 - val_loss: 0.1266 - val_sparse_categorical_accuracy: 0.5776
Epoch 3/5
[1m502/502[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1273s[0m 3s/step - loss: 0.1245 - sparse_categorical_accuracy: 0.5716 - val_loss: 0.1229 - val_sparse_categorical_accuracy: 0.5892
Epoch 4/5
[1m502/502[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1273s[0m 3s/step - loss: 0.1208 - sparse_categorical_accuracy: 0.5800 - val_loss: 0.1208 - val_sparse_categorical_accuracy: 0.5918
Epoch 5/5
[1m502/502[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1296s[0m 3s/step - loss: 0.1184 - sparse_categorical_accuracy: 0.5852 - val_loss: 0

### 📈 Training curves

In [None]:
loss = history.history.get('loss', [])
acc  = history.history.get('accuracy', [])

epochs = range(1, len(loss)+1)

plt.figure(figsize=(6,4))
plt.plot(epochs, loss, marker='o', label='Loss')
if acc:
    plt.plot(epochs, acc, marker='o', label='Accuracy')
plt.xlabel("Epoch")
plt.xticks(epochs)
plt.grid(True, alpha=0.3)
plt.title("Gemma LoRA fine‑tuning progress")
plt.legend()
plt.tight_layout()
plt.show()


## 🛠️ Utility functions

In [None]:
def safe_generate(model, prompt, max_new_tokens: int = 4_000, strip_prompt=True, **kw):
    """
    Length-safe wrapper around Gemma3CausalLM.generate().
    Keeps the public API identical to Hugging-Face helpers
    (which accept `max_new_tokens`) while calling Keras-Hub
    with the correct arguments.
    """
    # Tokenised prompt length
    prompt_len = len(
        model.preprocessor.generate_preprocess([prompt])["token_ids"][0]
    )

    # Model’s absolute limit
    ceiling = getattr(
        getattr(model, "config", None), "max_position_embeddings", 8_192
    )

    if prompt_len >= ceiling:
        raise ValueError(
            f"Prompt is {prompt_len} tokens but Gemma can accept at most {ceiling}."
        )

    # Clip the *new* tokens so total ≤ ceiling
    allowed_new  = min(max_new_tokens, ceiling - prompt_len)
    total_length = prompt_len + allowed_new

    # Make sure we DON’T pass the unsupported kwarg downstream
    kw.pop("max_new_tokens", None)
    kw["max_length"] = total_length

    # Optional: forward strip_prompt if the caller set it
    return model.generate(prompt, strip_prompt=strip_prompt, **kw)

def evaluate_model(model, prompts, references, model_name: str = "model"):
    # Evaluate using ROUGE metrics (standard for summarization tasks)
    # ROUGE-1/2: n-gram overlap, ROUGE-L: longest common subsequence
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    outputs = []

    for prompt in tqdm(prompts, desc=f"Generating with {model_name}"):
        outputs.append(safe_generate(model, prompt, max_new_tokens=3000))

    for hyp, ref in zip(outputs, references):
        rouge = scorer.score(ref, hyp)
        for k in scores:
            scores[k].append(rouge[k].fmeasure)

    metrics = {k: np.mean(v) for k, v in scores.items()}
    return outputs, scores, metrics

## 🔬 Validation set evaluation

We compare the *base* and *fine‑tuned* checkpoints on the MTS‑Dialog validation split.  
*(~25 min on a T4, ~8 min on an A100 – feel free to skip in a hurry.)*  


In [None]:
VAL_PATH = "/content/MTS-Dialog-ValidationSet.csv"
val_df = pd.read_csv(VAL_PATH)

# Optional speed-up during experimentation
val_df = val_df.sample(n=50, random_state=SEED)

# Match the training-time preprocessing
val_df["dialogue"] = '"""' + val_df["dialogue"] + '"""'

prompts_val, references_val = [], []
for row in val_df.itertuples(index=False):
    prompts_val.append(TEMPLATE.format(instruction=instruction, dialogue=row.dialogue))
    references_val.append(row.section_text)

# Apply same filtering as training data
prompts_val, references_val = filter_data_by_length(
    prompts_val, references_val, max_input_tokens=1500
)

# ---------- 1. Evaluate the fine-tuned model first ----------
fine_outputs, fine_scores, fine_metrics = evaluate_model(
    gemma_lm, prompts_val, references_val, "Fine-tuned"
)

# Free GPU memory held by the tuned checkpoint
del gemma_lm
gc.collect()
tf.keras.backend.clear_session()

# ---------- 2. Evaluate a fresh base model ----------
base_model = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_1b")
compile_with_sampler(base_model, k=5, seed=SEED)

base_outputs, base_scores, base_metrics = evaluate_model(
    base_model, prompts_val, references_val, "Base"
)

# ---------- 3. Summarise the improvements ----------
rows = []
for metric in ["rouge1", "rouge2", "rougeL"]:
    base_vals = base_scores[metric]
    fine_vals = fine_scores[metric]
    delta = np.mean(fine_vals) - np.mean(base_vals)
    t_stat, p_val = stats.ttest_rel(fine_vals, base_vals)
    rows.append((metric.upper(), np.mean(base_vals), np.mean(fine_vals), delta, p_val))

results_df = pd.DataFrame(
    rows, columns=["Metric", "Base", "Fine-tuned", "Δ", "p-value"]
).set_index("Metric")

results_df.style.format("{:.4f}")

### 📝 Qualitative sample (first 2 rows)

In [None]:
for i in range(2):
    print(f"\n——— Example {i+1} ———")
    print("REFERENCE:\n", references_val[i][:800], "\n")
    print("BASE MODEL:\n", base_outputs[i][:800], "\n")
    print("FINE‑TUNED:\n", fine_outputs[i][:800])


## 🎛️ Playground

Type / paste a medical conversation below and click **Summarise** to see the fine‑tuned model in action.  


In [None]:
import ipywidgets as widgets
from IPython.display import display

# Load the fine-tuned model with the LoRA-adapted weights
fresh_lm = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_1b")
fresh_lm.backbone.enable_lora(rank=8, trainable=False)   # must match the saved rank
fresh_lm.backbone.load_lora_weights("lora_rank8.weights.lora.h5")

input_box = widgets.Textarea(
    value="Doctor: How do you feel?\nPatient: A bit dizzy.",
    placeholder="Paste dialogue…",
    description="Dialogue:",
    layout=widgets.Layout(width='100%', height='140px')
)

button = widgets.Button(description="Summarise", button_style='primary')
output = widgets.Output()

def on_click(_):
    output.clear_output()
    prompt = TEMPLATE.format(
        instruction="Summarise the dialogue into a concise clinical note.",
        dialogue=input_box.value
    )
    with output:
        result = safe_generate(fresh_lm, prompt, max_length=512)  # Consistent with evaluation
        print(result)

button.on_click(on_click)
display(input_box, button, output)

<details>
<summary><strong>Further work & deployment notes</strong></summary>

* **Data scaling 🗂️** – training on the full 1.7 k rows adds ~0.4 ROUGE‑L.  
* **Hyper‑parameters ⚙️** – try `rank=16` LoRA or 5 epochs for marginal gains.  
* **Inference 🚀** – export merged weights (`gemma_lm.backbone.merge_lora_weights()`) for faster CPU serving.  
* **Safety 🔍** – run generated notes through clinical‑NLP validation & redaction before use.  

</details>
