# **LM Fine-tuning for Medical MCQ (AfriMed-QA)**

**Models:**
1. TinyLlama | [Link](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0)
2. SmolLM2-360M-Instruct | 0.4B | [link](https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct)

### **Load Data from Drive**

In [None]:
from google.colab import drive
import os
from datasets import load_from_disk # Import load_from_disk

# Mount Google Drive
drive.mount('/content/drive')
DATA_PATH = "/content/drive/MyDrive/NLP/prosit_1/afrimedqa_splits"

train = load_from_disk(f"{DATA_PATH}/train") # Use load_from_disk
val   = load_from_disk(f"{DATA_PATH}/val")     # Use load_from_disk
test  = load_from_disk(f"{DATA_PATH}/test")    # Use load_from_disk

print(f"Dataset will be loaded from: {DATA_PATH}")

In [2]:
len(train), len(val), len(test)

(3351, 189, 184)

# **Finetuning**

In [4]:
!pip -q install -U transformers datasets accelerate peft bitsandbytes trl evaluate

In [None]:
import torch
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

### **Imports and Configs**

In [None]:
"""Section C — Domain-Specific Adaptation (Colab)

This notebook fine-tunes a small size chat LLMs (TinyLlama + SmolLM2) on AfriMed-QA MCQ items using QLoRA + LoRA.
We keep *structured* fields (question/options/correct answer) for clean evaluation, while
training uses a single `text` field for SFT.

Key outputs for the PROSIT Section C write-up:
- Baseline vs fine-tuned MCQ accuracy (primary)
- Validation/test loss + approximate perplexity (secondary)
- Confusion matrix
"""

from __future__ import annotations

import json
import math
import os
import re
from dataclasses import dataclass
from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from datasets import DatasetDict, load_from_disk
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from trl import SFTConfig, SFTTrainer

# System prompt used for chat-template inference (baseline + post-finetune evaluation).
SYSTEM_PROMPT = (
    "You are a helpful medical assistant. Answer concisely and clearly. "
    "Choose the correct option (A, B, C, D, or E)."
)


In [6]:
@dataclass
class CFG:
    """Experiment configuration.

    Notes:
      - `model_id` controls which base model we fine-tune.
      - Default: TinyLlama (1.1B).
      - For quick experiments, you can swap to SmolLM2-360M-Instruct.
    """

    # Base model to fine-tune
    model_id: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    # Uncomment for experiment purposes (smaller model, faster iteration):
    # model_id: str = "HuggingFaceTB/SmolLM2-360M-Instruct"

    # Preprocessed DatasetDict saved via `save_to_disk(...)`
    dataset_dir: str = "/content/drive/MyDrive/NLP/prosit_1/afrimedqa_splits"

    # Reproducibility
    seed: int = 42

    # Sequence / training
    max_seq_len: int = 512
    epochs: int = 5
    per_device_batch: int = 2
    grad_accum: int = 8
    lr: float = 2e-4
    warmup_ratio: float = 0.03

    # Output
    output_dir: str = "/content/drive/MyDrive/tinyllama_afrimedqa_mcq_qlora"


cfg = CFG()


### **Dataset Loader**

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
"""Load preprocessed AfriMed-QA splits from Google Drive.

Expected schema per split includes (at minimum):
- question_clean
- answer_options (JSON string)
- correct_answer (e.g., 'option3')
- text (supervised training string)
Optionally:
- answer_rationale
- metadata (country, specialty, etc.)
"""

ds = load_from_disk(cfg.dataset_dir)
train_ds = ds["train"]
val_ds = ds["validation"]
test_ds = ds["test"]

print({k: len(ds[k]) for k in ds.keys()})
print("Train columns:", train_ds.column_names)


In [None]:
train_ds[1]

In [None]:
LETTERS = "ABCDE"

def _parse_answer_options(answer_options: Any) -> Dict[str, str]:
    """Parse `answer_options` which is typically a JSON string.

    Returns a dict like: {"option1": "...", ..., "option5": "..."}.
    """
    if answer_options is None:
        return {}
    if isinstance(answer_options, dict):
        return answer_options
    if isinstance(answer_options, str):
        return json.loads(answer_options)
    raise TypeError(f"Unsupported answer_options type: {type(answer_options)}")


def _option_key_to_letter(opt_key: str) -> str:
    """Map option key (e.g., 'option3') -> letter ('C')."""
    idx = int(opt_key.replace("option", ""))
    return LETTERS[idx - 1]


def _build_prompt(question: str, options_dict: Dict[str, str]) -> str:
    """Build an inference prompt (no gold answer appended)."""
    return (
        "### Instruction:\n"
        "Choose the correct option (A, B, C, D, or E).\n\n"
        f"### Question:\n{question}\n\n"
        "### Options:\n"
        f"A) {options_dict.get('option1','')}\n"
        f"B) {options_dict.get('option2','')}\n"
        f"C) {options_dict.get('option3','')}\n"
        f"D) {options_dict.get('option4','')}\n"
        f"E) {options_dict.get('option5','')}\n\n"
        "### Answer:\n"
    )


def enrich_fields(example: Dict[str, Any]) -> Dict[str, Any]:
    """Add helper fields for evaluation/analysis without altering existing columns.

    Adds (if not present):
      - options_dict: parsed dict from answer_options
      - gold_letter: A–E from correct_answer
      - prompt: prompt-only string (no answer)
    """
    if "options_dict" not in example:
        example["options_dict"] = _parse_answer_options(example.get("answer_options"))
    if "gold_letter" not in example:
        example["gold_letter"] = _option_key_to_letter(example["correct_answer"])
    if "prompt" not in example:
        q = example.get("question_clean") or example.get("question") or ""
        example["prompt"] = _build_prompt(q, example["options_dict"])
    return example


# Only enrich if any split is missing expected helper fields
need_cols = {"options_dict", "gold_letter", "prompt"}
if not need_cols.issubset(set(train_ds.column_names)):
    train_ds = train_ds.map(enrich_fields)
if not need_cols.issubset(set(val_ds.column_names)):
    val_ds = val_ds.map(enrich_fields)
if not need_cols.issubset(set(test_ds.column_names)):
    test_ds = test_ds.map(enrich_fields)

print("Enriched columns present:", need_cols.issubset(set(train_ds.column_names)))


In [None]:
"""Sanity check: verify the stored examples are consistent."""

ex0 = train_ds[0]
print("--- train[0] ---")
print("correct_answer:", ex0["correct_answer"], "| gold_letter:", ex0.get("gold_letter"))
print("question_clean (preview):", (ex0.get("question_clean") or "")[:120], "...")
print("text ends with:", ex0["text"].strip().splitlines()[-1])

assert "### Answer:\n" in ex0["text"], "Expected '### Answer:\n' delimiter in `text`."
assert ex0["text"].strip().splitlines()[-1].strip() in LETTERS, "Expected final line to be A–E."
if ex0.get("gold_letter") is not None:
    assert ex0["text"].strip().splitlines()[-1].strip() == ex0["gold_letter"], "Gold letter mismatch."
print("Sanity check passed.")


### **Load Model weights with QLoRA + LoRA**

In [None]:
def build_qlora_lora_model(model_id: str):
    """Load a causal LM in 4-bit (QLoRA) and attach LoRA adapters.

    Args:
      model_id: Hugging Face model id (e.g., TinyLlama or SmolLM2).

    Returns:
      (model, tokenizer)
    """
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    if tokenizer.pad_token is None:
        # Common practice for causal LMs: reuse EOS as PAD for batching.
        tokenizer.pad_token = tokenizer.eos_token

    bnb = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16,  # safest default on most Colab GPUs
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        quantization_config=bnb,
        torch_dtype=torch.float16,
    )
    model.config.use_cache = False
    model = prepare_model_for_kbit_training(model)

    # LoRA targets: keep broad coverage for transformer blocks.
    lora_cfg = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    )
    model = get_peft_model(model, lora_cfg)
    model.print_trainable_parameters()
    return model, tokenizer


# Base model: TinyLlama (default).
# If you want to run a smaller experiment, set cfg.model_id to SmolLM2 above.
model, tokenizer = build_qlora_lora_model(cfg.model_id)


### **Baseline evaluation (before fine-tune) — MCQ accuracy**

In [16]:
len(train_ds), len(val_ds), len(test_ds)

(3351, 189, 184)

In [None]:
def format_chat_example(example: Dict[str, Any], include_rationale: bool = False) -> Dict[str, Any]:
    """Create a chat-formatted training string for SFT (stored in `text`).

    We keep training simple: supervised *answer-only* (A–E).
    Optionally, you can include rationales (if present) for an alternate experiment.

    Expected keys in `example`:
      - question_clean
      - options_dict
      - gold_letter
      - answer_rationale (optional)

    Returns:
      dict with a single key 'text' for SFTTrainer.
    """
    q = example["question_clean"]
    opts = example["options_dict"]
    gold = example["gold_letter"]

    user_content = (
        "### Question:\n"
        f"{q}\n\n"
        "### Options:\n"
        f"A) {opts.get('option1','')}\n"
        f"B) {opts.get('option2','')}\n"
        f"C) {opts.get('option3','')}\n"
        f"D) {opts.get('option4','')}\n"
        f"E) {opts.get('option5','')}\n"
    )

    if include_rationale and example.get("answer_rationale"):
        assistant_content = f"{gold}\n\n### Explanation:\n{example['answer_rationale']}"
    else:
        assistant_content = f"{gold}"

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": assistant_content},
    ]

    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    return {"text": text}


In [None]:
@torch.no_grad()
def predict_letter(model, tokenizer, example: Dict[str, Any]) -> Optional[str]:
    """Greedy-predict a single MCQ option letter (A–E) for one example.

    Uses the chat template at inference time (system + user message).
    """
    q = example["question_clean"]
    opts = example["options_dict"]

    user_content = (
        "### Instruction:\n"
        "Choose the correct option (A, B, C, D, or E).\n\n"
        f"### Question:\n{q}\n\n"
        "### Options:\n"
        f"A) {opts.get('option1','')}\n"
        f"B) {opts.get('option2','')}\n"
        f"C) {opts.get('option3','')}\n"
        f"D) {opts.get('option4','')}\n"
        f"E) {opts.get('option5','')}\n"
    )

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_content},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    out = model.generate(
        **inputs,
        max_new_tokens=2,
        do_sample=False,
        temperature=0.0,
    )
    gen = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
    for ch in gen:
        if ch in LETTERS:
            return ch
    return None


@torch.no_grad()
def mcq_accuracy(model, tokenizer, dataset, max_items: int = 50) -> Dict[str, Any]:
    """Compute MCQ accuracy (+ confusion counts) over a dataset split."""
    n = min(max_items, len(dataset))
    correct = 0
    invalid = 0
    confusion = {g: {p: 0 for p in LETTERS} for g in LETTERS}

    for ex in dataset.select(range(n)):
        gold = ex["gold_letter"]
        pred = predict_letter(model, tokenizer, ex)
        if pred is None:
            invalid += 1
            continue
        confusion[gold][pred] += 1
        correct += int(pred == gold)

    scored = n - invalid
    acc = correct / scored if scored > 0 else 0.0
    return {
        "accuracy": acc,
        "correct": correct,
        "n_total": n,
        "n_scored": scored,
        "n_invalid": invalid,
        "confusion": confusion,
    }


baseline_test = mcq_accuracy(model, tokenizer, test_ds, max_items=50)
print(
    f"Baseline TEST acc: {baseline_test['accuracy']:.3f} | "
    f"Correct: {baseline_test['correct']} | Total: {baseline_test['n_total']}"
)


### **Train with SFTTrainer**

In [None]:
# To avoid TRL switching into prompt/completion mode, we pass a *text-only* view to the trainer.
# Keep the full datasets (train_ds/val_ds/test_ds) for analysis and MCQ accuracy evaluation.
train_text = train_ds.select_columns(["text"])
val_text = val_ds.select_columns(["text"])

In [None]:
args = SFTConfig(
    output_dir=cfg.output_dir,
    num_train_epochs=cfg.epochs,
    per_device_train_batch_size=cfg.per_device_batch,
    per_device_eval_batch_size=cfg.per_device_batch,
    gradient_accumulation_steps=cfg.grad_accum,
    learning_rate=cfg.lr,
    warmup_ratio=cfg.warmup_ratio,
    logging_steps=25,
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    # Precision: default to fp16 for broad Colab compatibility.
    fp16=True,
    bf16=False,

    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    report_to="none",
    seed=cfg.seed,

    remove_unused_columns=False,
    dataset_text_field="text",
    max_length=cfg.max_seq_len,
    packing=False,
)

trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=train_text,
    eval_dataset=val_text,
    args=args,
)


In [None]:
trainer.train()

### **Loss/perplexity evaluation**

In [None]:
def tokenize_for_clm(example: Dict[str, Any]) -> Dict[str, Any]:
    """Tokenize a `text` example for loss/perplexity evaluation."""
    tok = tokenizer(
        example["text"],
        truncation=True,
        max_length=cfg.max_seq_len,
        padding=False,
    )
    tok["labels"] = tok["input_ids"].copy()
    return tok


val_tok = val_ds.map(tokenize_for_clm, remove_columns=val_ds.column_names)
test_tok = test_ds.map(tokenize_for_clm, remove_columns=test_ds.column_names)

val_metrics = trainer.evaluate(eval_dataset=val_tok)
test_metrics = trainer.evaluate(eval_dataset=test_tok)

def approx_ppl(metrics: Dict[str, Any]) -> Optional[float]:
    """Compute exp(loss) when loss is in a safe numeric range."""
    loss = metrics.get("eval_loss")
    if loss is None:
        return None
    return math.exp(loss) if loss < 20 else float("inf")

print("VAL  loss:", val_metrics.get("eval_loss"), "| ppl:", approx_ppl(val_metrics))
print("TEST loss:", test_metrics.get("eval_loss"), "| ppl:", approx_ppl(test_metrics))


### **Post-finetune MCQ accuracy**

In [23]:
ft_test = mcq_accuracy(trainer.model, tokenizer, test_ds, max_items=500)
print("Fine-tuned TEST acc:", ft_test["accuracy"])

Fine-tuned TEST acc: 0.22282608695652173


### **Plots**

In [None]:
def plot_loss_curves(log_history):
    steps_train, loss_train = [], []
    steps_eval, loss_eval = [], []

    for row in log_history:
        if "loss" in row and "step" in row:
            steps_train.append(row["step"])
            loss_train.append(row["loss"])
        if "eval_loss" in row and "step" in row:
            steps_eval.append(row["step"])
            loss_eval.append(row["eval_loss"])

    plt.figure()
    if steps_train:
        plt.plot(steps_train, loss_train, label="train loss")
    if steps_eval:
        plt.plot(steps_eval, loss_eval, label="val loss")
    plt.xlabel("step")
    plt.ylabel("loss")
    plt.legend()
    plt.title("TinyLlama QLoRA Fine-tuning: Train vs Val Loss")
    plt.show()

def plot_confusion(confusion: Dict[str, Dict[str, int]], title="MCQ Confusion (counts)"):
    letters = LETTERS
    mat = np.array([[confusion[g][p] for p in letters] for g in letters])

    plt.figure()
    plt.imshow(mat)
    plt.xticks(range(5), letters)
    plt.yticks(range(5), letters)
    plt.xlabel("Predicted")
    plt.ylabel("Gold")
    plt.title(title)
    plt.colorbar()
    plt.show()

plot_loss_curves(trainer.state.log_history)
plot_confusion(ft_test["confusion"], title="Fine-tuned TEST Confusion (counts)")

### **Save Artifacts**

In [None]:
"""Save adapter weights + tokenizer and (optionally) persist to Google Drive."""

import shutil

# Save locally (cfg.output_dir can be on Drive already; saving is still fine).
os.makedirs(cfg.output_dir, exist_ok=True)
trainer.model.save_pretrained(cfg.output_dir)
tokenizer.save_pretrained(cfg.output_dir)

# Save dataset splits alongside the model for reproducibility.
splits = DatasetDict({"train": train_ds, "validation": val_ds, "test": test_ds})
splits.save_to_disk(os.path.join(cfg.output_dir, "dataset_splits"))
print("Saved model + tokenizer + dataset_splits to:", cfg.output_dir)


In [27]:
# Post-training MCQ accuracy (primary Section C metric)
ft_test = mcq_accuracy(trainer.model, tokenizer, test_ds, max_items=50)
print(
    f"Fine-tuned TEST acc: {ft_test['accuracy']:.3f} | "
    f"Correct: {ft_test['correct']} | Total: {ft_test['n_total']}"
)

# Quick confusion-matrix heatmap (counts)
letters = list(LETTERS)
mat = np.array([[ft_test["confusion"][g][p] for p in letters] for g in letters])

plt.figure()
plt.imshow(mat)
plt.xticks(range(5), letters)
plt.yticks(range(5), letters)
plt.xlabel("Predicted")
plt.ylabel("Gold")
plt.title("MCQ Confusion Matrix (counts) — Test")
plt.colorbar()
plt.show()


Correct: 15 | Total: 50
TEST MCQ Accuracy: 0.3


### **Save Adapter Weights**

In [None]:
# (Optional) Save adapters again to a specific directory
# Useful when you want a clean, named export separate from cfg.output_dir.
EXPORT_DIR = cfg.output_dir  # change if you want a different folder

trainer.model.save_pretrained(EXPORT_DIR)
tokenizer.save_pretrained(EXPORT_DIR)
print("Saved adapter + tokenizer to:", EXPORT_DIR)


# **TinyLlama Inference**

In [None]:
"""Inference: load base model + LoRA adapter and answer MCQs.

This section loads:
- BASE model weights (frozen)
- LoRA adapter (fine-tuned)

Then runs generation for a single example.
"""

from peft import PeftModel

# Base model used during fine-tuning
BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Uncomment to test the smaller model variant (if you fine-tuned it):
# BASE = "HuggingFaceTB/SmolLM2-360M-Instruct"

# Path to the fine-tuned adapter directory on Drive
ADAPTER_DIR = "/content/drive/MyDrive/NLP/prosit_1/fine_tuned_models/tinyllama_afrimedqa_mcq_qlora"

tokenizer = AutoTokenizer.from_pretrained(BASE, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE,
    device_map="auto",
    quantization_config=bnb,
    torch_dtype=torch.float16,
)

model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
model.eval()

# Demo: run on one test example
ex = test_ds[0]
print("GOLD:", ex["gold_letter"])
pred = predict_letter(model, tokenizer, ex)
print("PRED:", pred)
