# Fine-Tuning a Language Model for Healthcare Question Answering

**Module 02 | Notebook 4 of 4**

In this notebook, we will fine-tune a specialized language model (`TinyLlama`) to answer medical questions using the `MedQuad` dataset.

## Learning Objectives
1.  Understand the difference between **Context (RAG)** vs **Fine-Tuning**.
2.  Learn about **PEFT (Parameter-Efficient Fine-Tuning)** and **LoRA**.
3.  Fine-tune a model on a custom dataset.
4.  Export the model for local use.

---

## 1. When to use What? (RAG vs. Fine-Tuning)

Before we start, it's important to know *when* to fine-tune.

| Feature | **RAG (Retrieval-Augmented Gen)** | **Fine-Tuning** |
| :--- | :--- | :--- |
| **Analogy** | Giving the model an open textbook during the exam. | Sending the model to medical school for 4 years. |
| **Goal** | Add new *knowledge* (facts, data). | Change *behavior*, style, or learn specialized jargon. |
| **Pros** | Cheaper, easier to update facts. | Better performance on specific tasks, faster inference (no retrieval). |
| **Cons** | Limited context window. | Expensive to train, hard to update facts (requires re-training). |

**In this notebook**, we are doing **Fine-Tuning** to teach the model how to *act* like a medical assistant and understand medical terminology, not necessarily to memorize every drug interaction (RAG would be better for that).

In [None]:
%%capture
!pip install transformers datasets accelerate peft trl bitsandbytes

In [None]:
import torch
import os
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig,
    TrainingArguments
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available():
    device = "mps"  # For Mac users

print(f"Using device: {device}")

---

## 2. Load the Dataset

We will use `MedQuad` from generic sources. It contains pairs of `Question` and `Answer`.

In [None]:
# Load dataset
dataset_name = "keivalya/MedQuad-MedicalQnADataset"
dataset = load_dataset(dataset_name, split="train")

# Use a small subset for demonstration (Top 500 examples)
dataset = dataset.select(range(500))

print(f"Training on {len(dataset)} examples")
print("Sample:", dataset[0])

### Formatting
To train a chat model, we format the data clearly so the model knows what is the input and what is the output.

```
### Question:
{User's Question}

### Answer:
{Model's Answer}
```

In [None]:
def formatting_func(example):
    text = f"""### Question:
{example['Question']}

### Answer:
{example['Answer']}"""
    return text

print(formatting_func(dataset[0]))

---

## 3. Model Setup (with Conditional Quantization)

We use **TinyLlama-1.1B**. It's small enough to run on most laptops.

### Hardware Note
*   **NVIDIA GPU**: We can use **4-bit quantization** to save massive memory.
*   **Mac (M1/M2/M3)**: 4-bit quantization (`bitsandbytes`) is not natively supported. We will load the model in `float16` instead. TinyLlama is small (2GB), so this works fine!

In [None]:
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# Determine optimized settings based on hardware
if device == "cuda":
    # Quantization Config (NVIDIA only)
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )
    model_kwargs = {"quantization_config": bnb_config}
else:
    # Mac/CPU: Load in half-precision (float16) to save RAM
    bnb_config = None
    model_kwargs = {"torch_dtype": torch.float16}

print(f"Loading model with config: {model_kwargs}")

# Load Model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto" if device == "cuda" else None, # MPS/CPU mapping handled manually or by defaults
    trust_remote_code=True,
    **model_kwargs
)

# For Mac MPS specifically, we explicit move if needed, but 'auto' usually avoids MPS for some models unless explicit
if device == "mps":
    model.to("mps")

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

---

## 4. Setting up LoRA (Low-Rank Adaptation)

### The Valid Analogy
Imagine you want to customize your car (Pre-trained Model).
*   **Full Fine-Tuning**: Rebuilding the entire engine. Powerful, but expensive and slow.
*   **LoRA**: Adding a "Turbocharger" plugin. You don't touch the engine; you just add a small, focused part that modifies the performance.

### Parameters
*   **`r` (Rank)**: The size of the "plugin". Bigger = smarter but slower. (Common: 8, 16)
*   **`target_modules`**: Where to attach the plugin. In a Transformer, we usually attach to the Attention layers (`q_proj`, `v_proj`).

In [None]:
if bnb_config: # If using quantization
    model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=16, 
    lora_alpha=32, 
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

---

## 5. Training

In [None]:
training_args = SFTConfig(
    output_dir="./tinyllama-medical",
    num_train_epochs=1,
    per_device_train_batch_size=2, # Keep low for standard laptops
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    logging_steps=10,
    fp16=True if device != "cpu" else False, # Use fp16 for MPS/CUDA
    dataset_text_field="text",
    max_length=512,
    packing=False,
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=lora_config,
    formatting_func=formatting_func,
    args=training_args,
    processing_class=tokenizer
)

print("Starting Training...")
trainer.train()

In [None]:
# Save the adapter (the plugin)
trainer.model.save_pretrained("./tinyllama-medical-adapter")
tokenizer.save_pretrained("./tinyllama-medical-adapter")
print("Adapter saved!")

---

## 6. Testing

In [None]:
def ask(question):
    prompt = f"### Question:\n{question}\n\n### Answer:\n"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    outputs = model.generate(
        **inputs, 
        max_new_tokens=100, 
        do_sample=True, 
        temperature=0.7
    )
    
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))

ask("What are the symptoms of a cold?")

---

## ðŸŽ¯ Student Challenge

### Challenge: Create an Empathetic Mental Health Bot

Medical data is factual. But what if we want a bot that is comforting?

1.  **Modify the formatting function**: Add a system instruction like "You are a caring friend."
2.  **Dataset**: I've provided a tiny list of mental health Q&A below.
3.  **Train**: Retrain the model on this new small dataset.

In [None]:
# Mental Health Mini-Dataset
mental_health_data = [
    {"q": "I feel sad.", "a": "I'm sorry to hear that. It's okay to feel down sometimes. Do you want to talk about it?"},
    {"q": "I am anxious.", "a": "Take a deep breath. Anxiety is tough, but you are not alone. Let's focus on the present moment."},
    {"q": "Nobody likes me.", "a": "That must be a painful thought. I care about you, and I'm sure others do too, even if it's hard to see right now."}
]

# TODO: Create a new formatting function for this data
# def empathetic_format(example):
#     ...

# TODO: Fine-tune the model on this new list
# Hint: transform the list into a Hugging Face dataset first:
# from datasets import Dataset
# mh_dataset = Dataset.from_list(mental_health_data)

---

## Key Takeaways
1.  **LoRA** allows us to fine-tune significantly faster by freezing the main model.
2.  **Quantization** is great for NVIDIA GPUs, but smaller models (1B) run fine on Mac/CPU in `float16`.
3.  **Data Formatting** is critical in teaching the model *how* to speak (e.g., Q&A format).