<a href="https://colab.research.google.com/github/Matan-Vinkler/phi-3.5-finetuned-medqa/blob/main/phi3_5instruct_medqa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tuning `Phi-3.5-mini-instruct` on Medical QA Dataset

This notebook demonstrates how to fine-tune the **Phi-3.5-Mini-Instruct** model using **LoRA (Low-Rank Adaptation)** for a domain-specific task: creating a lightweight **medical assistant**.  
We use the **Medical Meadow** dataset (instruction-response pairs in the medical domain) and Hugging Face's **TRL `SFTTrainer`** to align the model with concise, factual medical Q&A style outputs.  

The pipeline includes:
- Loading and formatting the dataset (`instruction → response` format).
- Efficient fine-tuning with **LoRA** and **4-bit quantization** (Colab-friendly).
- Saving and pushing the fine-tuned model to the Hugging Face Hub.
- Running **inference examples** to validate the assistant's responses.

⚠️ **Disclaimer**: This model is for **educational purposes only** and is **not a substitute for professional medical advice**.


### Table of Content

>[Setup](#scrollTo=2bwXpsrNDvbE)

>[Data Loading and Preprocessing](#scrollTo=4kpYwX6wDXb0)

>[Model Training](#scrollTo=sAWrisz4DofO)

>[Model Inference](#scrollTo=IqQdsLAPE8Y8)

### Setup

We'll install any dependency that isn't built-in on Colab.

In [None]:
!pip install "datasets==3.6.0"
!pip install -U bitsandbytes
!pip install trl

Importing all libraries used for this project, and also checking for `datasets` version to be `3.6.0`.

In [None]:
import datasets
from datasets import load_dataset, get_dataset_config_names
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig

from huggingface_hub import notebook_login

import random

datasets.__version__

We'll also login to the HugginFace Hub so we'll be able to upload our model.

In [None]:
notebook_login()

### Data Loading and Preprocessing

First of all, we'll load the `medalpaca/medical_meadow_medqa` dataset from HugginFace Hub.

In [None]:
dataset = load_dataset("medalpaca/medical_meadow_medqa")
dataset

Some exploration on the data:

In [None]:
IDX = random.randint(0, len(dataset["train"]))
dataset["train"][IDX]["input"]

In [None]:
sample = dataset["train"].select(range(10))
sample

Now we'll define the preprocess function:

In [None]:
def preprocess_data(sample):
    instruction = sample["input"].split("? \n")[0][2:]
    output = sample["output"][3:]

    sample["text"] = f"### Instruction: {instruction}\n### Response: {output}"

    return sample

sample_processed = sample.map(preprocess_data, remove_columns=["input", "instruction", "output"])
sample_processed

In [None]:
sample_processed[0]

And apply the preprocess function to the whole dataset:

In [None]:
processed_dataset = dataset["train"].map(preprocess_data, remove_columns=["input", "instruction", "output"])
processed_dataset

### Model Training

Loading model and tokenizer from `microsoft/Phi-3.5-mini-instruct` checkpoint:

In [None]:
model_name = "microsoft/Phi-3.5-mini-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True, device_map="auto")

Defining training configuration and LoRA configuration:

In [None]:
sft_config = SFTConfig(
    output_dir="./medical-assistant",
    per_device_train_batch_size=1,
    num_train_epochs=2,
    learning_rate=2e-4,
    push_to_hub=True,
    hub_model_id="Matanvinkler18/phi-3.5-finetuned-medqa",
    bf16=False,
    fp16=True,
    dataset_text_field="text"
)

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)

Finally we'll create a `Trainer` object and begin to train the model (estimated 3 to 4 hours on T4 GPU on Colab):

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=processed_dataset,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=sft_config
)

In [None]:
trainer.train()

### Model Inference

We'll load the model in a `PeftModel` object:

In [None]:
lora_model = PeftModel.from_pretrained(model, "./medical-assistant")
lora_model.eval()

And generate a text-generation pipeline with this model and tokenizer:

In [None]:
generator = pipeline(
    "text-generation",
    model=lora_model,
    tokenizer=tokenizer,
    device_map="auto"
)

And test some example medical prompts to see our trained model's performance:

In [None]:
prompts = [
    """### Instruction:
Explain the difference between Type 1 and Type 2 Diabetes.

### Response:""",
    """### Instruction:
List the common symptoms of iron deficiency anemia.

### Response:""",
    """### Instruction:
Explain the difference between bacterial and viral infections.

### Response:""",
    """### Instruction:
A patient complains of chest pain when climbing stairs. Suggest possible causes.

### Response:""",
    """### Instruction:
A patient complains of chest pain when climbing stairs. Suggest possible causes.

### Response:"""
          ]

for prompt in prompts:
    output = generator(
        prompt,
        max_new_tokens=256,
    )

    print(output[0]["generated_text"])
    print("-" * 50)