<a href="https://colab.research.google.com/github/Matan-Vinkler/phi-3.5-finetuned-medqa/blob/main/phi3_5instruct_medqa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tuning `Phi-3.5-mini-instruct` on Medical QA Dataset

This notebook demonstrates how to fine-tune the **Phi-3.5-Mini-Instruct** model using **LoRA (Low-Rank Adaptation)** for a domain-specific task: creating a lightweight **medical assistant**.  
We use the **Medical Meadow** dataset (instruction-response pairs in the medical domain) and Hugging Face's **TRL `SFTTrainer`** to align the model with concise, factual medical Q&A style outputs.  

The pipeline includes:
- Loading and formatting the dataset (`instruction → response` format).
- Efficient fine-tuning with **LoRA** and **4-bit quantization** (Colab-friendly).
- Saving and pushing the fine-tuned model to the Hugging Face Hub.
- Running **inference examples** to validate the assistant's responses.

⚠️ **Disclaimer**: This model is for **educational purposes only** and is **not a substitute for professional medical advice**.


### Table of Content

>[Setup](#scrollTo=2bwXpsrNDvbE)

>[Data Loading and Preprocessing](#scrollTo=4kpYwX6wDXb0)

>[Model Training](#scrollTo=sAWrisz4DofO)

>[Model Inference](#scrollTo=IqQdsLAPE8Y8)

### Setup

We'll install any dependency that isn't built-in on Colab.

In [None]:
!pip install "datasets==3.6.0"
!pip install -U bitsandbytes
!pip install trl

Importing all libraries used for this project, and also checking for `datasets` version to be `3.6.0`.

In [None]:
import datasets
from datasets import load_dataset, get_dataset_config_names
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig

from huggingface_hub import notebook_login

import random

datasets.__version__

'3.6.0'

We'll also login to the HugginFace Hub so we'll be able to upload our model.

In [None]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

### Data Loading and Preprocessing

First of all, we'll load the `medalpaca/medical_meadow_medqa` dataset from HugginFace Hub.

In [None]:
dataset = load_dataset("medalpaca/medical_meadow_medqa")
dataset

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


DatasetDict({
    train: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 10178
    })
})

Some exploration on the data:

In [None]:
IDX = random.randint(0, len(dataset["train"]))
dataset["train"][IDX]["input"]

"Q:A 77-year-old man is brought to the emergency department by his wife because of headache, nausea, and vomiting for 24 hours. His wife says that over the past 2 weeks, he has been more irritable and has had trouble remembering to do routine errands. Two weeks ago, he fell during a skiing accident but did not lose consciousness. He has coronary artery disease and hypertension. He has smoked one pack of cigarettes daily for 50 years. He has had 2 glasses of wine daily since his retirement 10 years ago. Current medications include atenolol, enalapril, furosemide, atorvastatin, and aspirin. He appears acutely ill. He is oriented to person but not to place or time. His temperature is 37°C (98.6°F), pulse is 99/min, respirations are 16/min, and blood pressure is 160/90 mm Hg. During the examination, he is uncooperative and unable to answer questions. Deep tendon reflexes are 4+ on the left and 2+ on the right. Babinski's sign is present on the left. There is mild weakness of the left iliop

In [None]:
sample = dataset["train"].select(range(10))
sample

Dataset({
    features: ['input', 'instruction', 'output'],
    num_rows: 10
})

Now we'll define the preprocess function:

In [None]:
def preprocess_data(sample):
    instruction = sample["input"].split("? \n")[0][2:]
    output = sample["output"][3:]

    sample["text"] = f"### Instruction: {instruction}\n### Response: {output}"

    return sample

sample_processed = sample.map(preprocess_data, remove_columns=["input", "instruction", "output"])
sample_processed

Dataset({
    features: ['text'],
    num_rows: 10
})

In [None]:
sample_processed[0]

{'text': '### Instruction: A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?\n### Response: Nitrofurantoin'}

And apply the preprocess function to the whole dataset:

In [None]:
processed_dataset = dataset["train"].map(preprocess_data, remove_columns=["input", "instruction", "output"])
processed_dataset

Dataset({
    features: ['text'],
    num_rows: 10178
})

### Model Training

Loading model and tokenizer from `microsoft/Phi-3.5-mini-instruct` checkpoint:

In [None]:
model_name = "microsoft/Phi-3.5-mini-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True, device_map="auto")

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Defining training configuration and LoRA configuration:

In [None]:
sft_config = SFTConfig(
    output_dir="./medical-assistant",
    per_device_train_batch_size=1,
    num_train_epochs=2,
    learning_rate=2e-4,
    push_to_hub=True,
    hub_model_id="Matanvinkler18/phi-3.5-finetuned-medqa",
    bf16=False,
    fp16=True,
    dataset_text_field="text"
)

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)

Finally we'll create a `Trainer` object and begin to train the model (estimated 3 to 4 hours on T4 GPU on Colab):

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=processed_dataset,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=sft_config
)

In [None]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mmatanvinkler[0m ([33mmatanvinkler-my-company[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
10,1.6413
20,1.5175
30,1.2194
40,1.1324
50,1.19
60,1.2179
70,1.0794
80,1.1654
90,1.2521
100,1.1903


TrainOutput(global_step=20356, training_loss=1.048548677014939, metrics={'train_runtime': 12112.335, 'train_samples_per_second': 1.681, 'train_steps_per_second': 1.681, 'total_flos': 9.987171594603725e+16, 'train_loss': 1.048548677014939})

### Model Inference

We'll load the model in a `PeftModel` object:

In [None]:
lora_model = PeftModel.from_pretrained(model, "./medical-assistant")
lora_model.eval()



PeftModel(
  (base_model): LoraModel(
    (model): Phi3ForCausalLM(
      (model): Phi3Model(
        (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
        (layers): ModuleList(
          (0-31): 32 x Phi3DecoderLayer(
            (self_attn): Phi3Attention(
              (o_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (qkv_proj)

And generate a text-generation pipeline with this model and tokenizer:

In [None]:
generator = pipeline(
    "text-generation",
    model=lora_model,
    tokenizer=tokenizer,
    device_map="auto"
)

Device set to use cuda:0
The model 'PeftModel' is not supported for text-generation. Supported models are ['PeftModelForCausalLM', 'ArceeForCausalLM', 'AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BitNetForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DiffLlamaForCausalLM', 'DogeForCausalLM', 'Dots1ForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'Ernie4_5ForCausalLM', 'Ernie4_5_MoeForCausalLM', 'Exaone4ForCausalLM', 'FalconForCausalLM', 'FalconH1ForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'Gem

And test some example medical prompts to see our trained model's performance:

In [None]:
prompts = [
    """### Instruction:
Explain the difference between Type 1 and Type 2 Diabetes.

### Response:""",
    """### Instruction:
List the common symptoms of iron deficiency anemia.

### Response:""",
    """### Instruction:
Explain the difference between bacterial and viral infections.

### Response:""",
    """### Instruction:
A patient complains of chest pain when climbing stairs. Suggest possible causes.

### Response:""",
    """### Instruction:
A patient complains of chest pain when climbing stairs. Suggest possible causes.

### Response:"""
          ]

for prompt in prompts:
    output = generator(
        prompt,
        max_new_tokens=256,
    )

    print(output[0]["generated_text"])
    print("-" * 50)

### Instruction:
Explain the difference between Type 1 and Type 2 Diabetes.

### Response: Type 1 Diabetes is caused by the immune system destroying pancreatic beta cells, while Type 2 Diabetes is caused by the body not responding to insulin.
--------------------------------------------------
### Instruction:
List the common symptoms of iron deficiency anemia.

### Response: Fatigue, pale skin, shortness of breath, weakness, dizziness, chest pain, headache, irregular or rapid heartbeat, short attention span, cold hands and feet, brittle nails, and headaches.
--------------------------------------------------
### Instruction:
Explain the difference between bacterial and viral infections.

### Response: Bacteria are single-celled microorganisms that are much larger than viruses. They are capable of independent existence outside a host cell, and they can reproduce on non-living surfaces.

Viruses are much smaller than bacteria and cannot reproduce or survive outside of a host cell. They m