## MedMCQA — LoRA fine-tune of **Qwen2.5‑7B‑Instruct** (Unsloth + QLoRA)

**Goal.** Train a tutor to pick the correct A/B/C/D option on MedMCQA and (optionally) give a short explanation.

**Dataset.** `openlifescienceai/medmcqa` (train/validation splits from the HF dataset).

**Base model.** `Qwen/Qwen2.5-7B-Instruct`  
*(A separate notebook covers Meta-Llama-3-8B-Instruct.)*

**Method.** Unsloth + QLoRA (bnb NF4) + LoRA adapters

**LoRA hyperparams actually used here.** `r=32`, `alpha=64`, `dropout=0.0`  
**Max sequence length.** `768` (reduced from 1024 to fit T4 16 GB during fine-tuning)  
**Hardware.** Kaggle T4 (16 GB)

> ⚠️ Educational use only. Not medical advice.

In [2]:
%%capture

!pip install unsloth # install unsloth
!pip install --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git # Also get the latest version Unsloth!

## Import all relevant packages throughout this walkthrough

In [3]:
import os, gc
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"

In [4]:
# Modules for fine-tuning
from unsloth import FastLanguageModel
import torch # Import PyTorch
from trl import SFTTrainer # Trainer for supervised fine-tuning (SFT)
from unsloth import is_bfloat16_supported # Checks if the hardware supports bfloat16 precision
# Hugging Face modules
from huggingface_hub import login # Lets you login to API
from transformers import TrainingArguments, EarlyStoppingCallback # Defines training hyperparameters
from datasets import load_dataset # Lets you load fine-tuning datasets
# Import weights and biases
import wandb
# Import kaggle secrets
from kaggle_secrets import UserSecretsClient
from functools import partial

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


2025-08-21 03:14:18.687717: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755746059.016882      63 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755746059.115539      63 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


🦥 Unsloth Zoo will now patch everything to make training faster!


## Create API keys and login to Hugging Face and Weights and Biases

In [5]:
import random

# reproducibility
torch.manual_seed(3407)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(3407)
random.seed(3407)

In [6]:
# Initialize Hugging Face & WnB tokens
user_secrets = UserSecretsClient() # from kaggle_secrets import UserSecretsClient
hugging_face_token = user_secrets.get_secret("Hugging_Face_Token")
wnb_token = user_secrets.get_secret("wnb")

# Login to Hugging Face
login(hugging_face_token) # from huggingface_hub import login

# Login to WnB
wandb.login(key=wnb_token) # import wandb

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mavpk[0m ([33mavpk-university-of-waterloo[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Load Llama-3-8B-Instruct and Tokenizer

We load the base model with Unsloth's `FastLanguageModel.from_pretrained()` and enable 4-bit (NF4) quantization to save VRAM.

**Key knobs:**
- `max_seq_length = 768`
- `load_in_4bit = True` (bitsandbytes NF4 via Unsloth)
- `gpu_memory_utilization = 0.45`

**Intuition behind 4-bit quantization**

Imagine compressing a **high-resolution image** to a smaller size—**it takes up less space but still looks good enough**. Similarly, **4-bit quantization reduces the precision of model weights**, making the model **smaller and faster while keeping most of its accuracy**. Instead of storing precise **32-bit or 16-bit numbers**, we compress them into **4-bit values**. This allows **large language models to run efficiently on consumer GPUs** without needing massive amounts of memory. 

In [7]:
# ==== Base model for this notebook (Qwen2.5-7B-Instruct) ====
base_name = "Qwen/Qwen2.5-7B-Instruct"

max_seq_length = 768
dtype = None
load_in_4bit = True  # QLoRA / NF4 via bitsandbytes under Unsloth

def load_base(model_name):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_name,
        max_seq_length=max_seq_length,         
        dtype=dtype,
        load_in_4bit=load_in_4bit,
        token=hugging_face_token,
        device_map={"": torch.cuda.current_device()} if torch.cuda.is_available() else {"": "cpu"},
        gpu_memory_utilization=0.45,             
        low_cpu_mem_usage=True,                 
    )
    return model, tokenizer

## Quick baseline check before fine-tuning

We sanity-check prompt formatting and then compute a small decode-free baseline with eval_mcq_logits. This is just to verify prompt formatting and tokenization before training.


### Running a quick baseline

We compute a small validation baseline with a **decode-free, answer-only scorer** (`eval_mcq_logits`), which scores the first generated token’s log-prob among {A,B,C,D}. This matches our training format and is much faster than long decoding.

Later, we also sample a **small batch of short explanations** (a few dozen tokens) purely for qualitative inspection.


### Step 1 — Load and lightly format MedMCQA

We use `openlifescienceai/medmcqa` directly (HF Datasets).  
We format each example into a short chat conversation and train with **answer-only supervision** (letter A/B/C/D).


In [8]:
# # MedMCQA splits (train/val/test). 

raw = load_dataset("openlifescienceai/medmcqa")
train_ds = raw["train"]
val_ds_orig   = raw["validation"]

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/85.9M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/936k [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/1.48M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/182822 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6150 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4183 [00:00<?, ? examples/s]

In [9]:
def to_text_examples(batch, tokenizer, eos_token=None):
    texts = []
    for q, a, b, c, d, gold in zip(batch["question"], batch["opa"], batch["opb"], batch["opc"], batch["opd"], batch["cop"]):
        gold_letter = gold_to_letter(gold) #or "A"
        user = (
            "You are a medical expert. Answer this MCQ with a single letter.\n\n"
            f"Question:\n{q}\n\nOptions:\nA. {a}\nB. {b}\nC. {c}\nD. {d}\n\n"
            "Respond in the format:\nAnswer: <A/B/C/D>"
        )
        assistant = f"Answer: {gold_letter}"
        messages = [
            {"role": "system", "content": "You are a medical expert."},
            {"role": "user", "content": user},
            {"role": "assistant", "content": assistant},
        ]
        txt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
        texts.append(txt)
    return {"text": texts}

### Step 2 — Setting up the model using LoRA

**An intuitive explanation of LoRA** 

Large language models (LLMs) have **millions or even billions of weights** that determine how they process and generate text. When fine-tuning a model, we usually update all these weights, which **requires massive computational resources and memory**.

LoRA (**Low-Rank Adaptation**) allows to fine-tune efficiently by:

- Instead of modifying all weights, **LoRA adds small, trainable adapters** to specific layers.  
- These adapters **capture task-specific knowledge** while leaving the original model unchanged.  
- This reduces the number of trainable parameters **by more than 90%**, making fine-tuning **faster and more memory-efficient**.  

Think of an LLM as a **complex factory**. Instead of rebuilding the entire factory to produce a new product, LoRA **adds small, specialized tools** to existing machines. This allows the factory to adapt quickly **without disrupting its core structure**.

Below, we will use the `get_peft_model()` function which stands for Parameter-Efficient Fine-Tuning — this function wraps the base model (`model`) with LoRA modifications, ensuring that only specific parameters are trained.

**This notebook uses:** `r=32`, `alpha=64`, `dropout=0.0` and targets attention/MLP projections (`q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj`).


In [10]:
def add_lora(model):
    return FastLanguageModel.get_peft_model(
        model,
        r=32,
        target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
        lora_alpha=64,                
        lora_dropout=0.0,            
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=3407,
        use_rslora=False,
        loftq_config=None,
    )


In [11]:
# where make_trainer is defined
def make_trainer(model_lora, tokenizer, dataset, outdir, eval_dataset=None):
    return SFTTrainer(
        model=model_lora,
        tokenizer=tokenizer,
        train_dataset=dataset,
        eval_dataset=eval_dataset,           
        dataset_text_field="text",
        max_seq_length=768,
        dataset_num_proc=2,
        packing=False,
        args=TrainingArguments(
            per_device_train_batch_size=1,
            per_device_eval_batch_size=2,
            gradient_accumulation_steps=8,   
            num_train_epochs=2,               # let early stopping pick best
            learning_rate=1e-4,               # LoRA likes a bit higher LR
            warmup_ratio=0.05,
            dataloader_num_workers=2,
            lr_scheduler_type="cosine",
            weight_decay=0.0,
            gradient_checkpointing=True,      
            max_grad_norm=1.0,                # clip for stability
            fp16=not is_bfloat16_supported(),
            bf16=is_bfloat16_supported(),
            logging_steps=50,
            optim="adamw_8bit",               # good with 4-bit base
            group_by_length=True,
            output_dir=outdir,
            report_to="wandb",
            seed=3407,
            data_seed=3407,
        ),
    )

In [12]:
import re
import numpy as np
from torch.utils.data import Subset

np.random.seed(3407)

# Answer-only prompt (no explanation at eval time)
ANSWER_ONLY_PROMPT = """You are a helpful medical AI.
Question: {question}

Options:
A. {opa}
B. {opb}
C. {opc}
D. {opd}

Respond with exactly one letter (A, B, C, or D).
Answer: """

def chat_wrap(tokenizer, user_text):
    # Use proper assistant turn for chatty instruct models (e.g., Llama 3, Qwen)
    if hasattr(tokenizer, "apply_chat_template"):
        messages = [
            {"role": "system", "content": "You are a medical expert."},
            {"role": "user",   "content": user_text},
        ]
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return user_text

In [13]:
# ===== Quick explanation spot-check (small sample, batched) =====
EXPLAIN_PROMPT = """You are a medical expert. Answer the MCQ and briefly justify in 3–6 sentences.

Question:
{question}

Options:
A. {opa}
B. {opb}
C. {opc}
D. {opd}

Respond in the format:
Answer: <A/B/C/D>
Explanation: <3–6 sentences>

Answer: """


ANS_RE = re.compile(r"Answer\s*:\s*([ABCD])", re.I)
EXPL_RE = re.compile(r"Explanation\s*:\s*(.*)", re.I | re.S)

def format_explain(row, tok):
    user = EXPLAIN_PROMPT.format(
        question=row["question"],
        opa=row["opa"], opb=row["opb"], opc=row["opc"], opd=row["opd"],
    )
    return chat_wrap(tok, user)

@torch.no_grad()
def sample_explanations(model, tokenizer, dataset, k=8, batch_size=2, max_len=768, new_tokens=160):
    model.eval()
    FastLanguageModel.for_inference(model)

    idxs = list(range(min(k, len(dataset))))
    ds = Subset(dataset, idxs)

    prompts, golds = [], []
    for i in range(len(ds)):# k items
        ex = ds[i]
        prompts.append(format_explain(ex, tokenizer))
        golds.append(gold_to_letter(ex["cop"]))

    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    eos_id = tokenizer.eos_token_id

    rows = []
    for s in range(0, len(prompts), batch_size):
        batch_prompts = prompts[s:s+batch_size]
        batch = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding="longest",          # explicit; respects tokenizer.padding_side="left"
            truncation=True,
            max_length=max_len,
        ).to(model.device)

        out = model.generate(
            **batch,
            max_new_tokens=new_tokens,
            do_sample=False, temperature=0.0, num_beams=1,
            pad_token_id=pad_id, eos_token_id=eos_id, use_cache=True
        )
        
        gen_only = out[:, batch["input_ids"].shape[1]:]
        decoded = tokenizer.batch_decode(gen_only, skip_special_tokens=True)
        for j, txt in enumerate(decoded):
            m_ans = ANS_RE.search(txt); pred = m_ans.group(1).upper() if m_ans else None
            m_exp = EXPL_RE.search(txt); expl = m_exp.group(1).strip() if m_exp else "(no explanation parsed)"
            idx = idxs[s + j]; gold = golds[s + j]
            rows.append((idx, gold, pred, pred == gold, expl, txt))

    # log to W&B if active
    try:
        table = wandb.Table(columns=["idx","gold","pred","correct","explanation","raw"])
        for r in rows: table.add_data(*r)
        wandb.log({"sample_explanations": table})
    except Exception:
        pass

    # Console preview (first 3)
    for idx, gold, pred, correct, expl, _ in rows[:3]:
        print(f"[{idx}] gold={gold} pred={pred} correct={correct}\n  explanation: {expl[:200]}...\n")

In [14]:
from collections import Counter
raw_train_cop = Counter(map(int, raw["train"]["cop"]))
raw_val_cop   = Counter(map(int, raw["validation"]["cop"]))
print("[raw cop] train:", raw_train_cop)
print("[raw cop] val_orig:", raw_val_cop)

[raw cop] train: Counter({0: 53591, 1: 47826, 2: 42442, 3: 38963})
[raw cop] val_orig: Counter({0: 1348, 1: 1085, 2: 925, 3: 825})


In [15]:
def gold_to_letter(cop):
    s = str(cop).strip()
    if s and s[0].upper() in "ABCD":
        return s[0].upper()
    try:
        n = int(s)
    except ValueError:
        return None
    return "ABCD"[n] if 0 <= n <= 3 else None

In [16]:
from transformers.trainer_callback import TrainerCallback
import os

class MCQAccuracyCallback(TrainerCallback):
    def __init__(self, tokenizer, val_dataset, every=100, patience=8,
                 max_items=600, log_key="val_accuracy", save_dir=None):
        self.tokenizer = tokenizer
        self.val_dataset = val_dataset
        self.every = every
        self.patience = patience
        self.max_items = max_items
        self.log_key = log_key
        self.wait = 0
        self.best_acc = -1.0
        self.save_dir = save_dir  

    def on_step_end(self, args, state, control, model=None, **kwargs):
        if state.global_step == 0 or (state.global_step % self.every) != 0:
            return control

        model.eval()
        acc = eval_mcq_logits(model, self.tokenizer, self.val_dataset, max_items=self.max_items, batch_size=2, max_len=768)
        print(f"[step {state.global_step}] accuracy {acc:.4f}")

        # optional W&B log
        try:
            import wandb
            wandb.log({self.log_key: acc}, step=int(state.global_step))
        except Exception:
            pass

        if self.best_acc < 0 or acc > self.best_acc + 1e-6:
            self.best_acc = acc
            self.wait = 0
            base = self.save_dir or args.output_dir
            best_dir = os.path.join(base, "best")
            os.makedirs(best_dir, exist_ok=True)
            print(f"New best acc={acc:.4f}, saving to {best_dir}")
            model.save_pretrained(best_dir)
            self.tokenizer.save_pretrained(best_dir)
        else:
            self.wait += 1
            if self.wait >= self.patience:
                print(f"Early stopping at step {state.global_step}, best={self.best_acc:.4f}")
                control.should_training_stop = True

        model.train()
        FastLanguageModel.for_training(model)
        
        return control

## Evaluation protocol (fast, decode-free)

During training we run a **decode-free evaluator** that scores the **log-probability of the first generated token** for `A/B/C/D` (with both `A` and `" A"` tokenizations). This closely matches the answer-only training format and is **much faster** than full generation. We also sample a few explanation outputs after training for qualitative checks.


In [17]:
@torch.no_grad()
def eval_mcq_logits(model, tokenizer, ds, max_items=600, batch_size=2, max_len=768):
    """
    Fast, regex-free MCQ evaluator that matches your training chat format.
    - Handles HF Dataset slicing (dict-of-lists) and list-of-rows.
    - Scores next-token log-prob for A/B/C/D (both 'A' and ' A' tokenizations).
    - Expects ANSWER_ONLY_PROMPT to end with 'Answer: ' (note trailing space).
    """
    import torch

    # ---- 0) Limit set size
    N = min(len(ds), max_items)
    if N == 0:
        return 0.0

    # ---- 1) Cache candidate token ids per tokenizer
    cache_key = getattr(tokenizer, "name_or_path", None) or id(tokenizer)
    _cache = getattr(eval_mcq_logits, "_cache", {})
    if cache_key not in _cache:
        letters = ["A", "B", "C", "D"]
        cand_ids_no = []
        cand_ids_sp = []
        for L in letters:
            ids0 = tokenizer(L, add_special_tokens=False).input_ids
            ids1 = tokenizer(" " + L, add_special_tokens=False).input_ids
            cand_ids_no.append(ids0[0] if ids0 else -1)
            cand_ids_sp.append(ids1[0] if ids1 else -1)
        _cache[cache_key] = {"no": cand_ids_no, "sp": cand_ids_sp}
        eval_mcq_logits._cache = _cache

    cand_ids_no = torch.tensor(_cache[cache_key]["no"], device=model.device)
    cand_ids_sp = torch.tensor(_cache[cache_key]["sp"], device=model.device)

    # ---- 2) Helper: build a single chat-formatted prompt
    def build_prompt(q, a, b, c, d):
        user = ANSWER_ONLY_PROMPT.format(question=q, opa=a, opb=b, opc=c, opd=d)
        messages = [
            {"role": "system", "content": "You are a medical expert."},
            {"role": "user",   "content": user},
        ]
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    correct = 0
    total = 0
    model.eval()

    # ---- 3) Iterate in minibatches
    for i in range(0, N, batch_size):
        j = min(i + batch_size, N)

        # Slice; HF Dataset returns dict-of-lists here
        subset = ds[i:j]

        if isinstance(subset, dict):
            rows = list(zip(
                subset["question"], subset["opa"], subset["opb"],
                subset["opc"], subset["opd"], subset["cop"],
            ))
        else:
            # Fallback: list of row dicts (or another sequence-like)
            rows = [
                (subset[k]["question"], subset[k]["opa"], subset[k]["opb"],
                 subset[k]["opc"], subset[k]["opd"], subset[k]["cop"])
                for k in range(len(subset))
            ]

        prompts = [build_prompt(q, a, b, c, d) for (q, a, b, c, d, _) in rows]

        inputs = tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_len,
        ).to(model.device)

        out = model(**inputs)
        logits = out.logits  # [B, T, V]

        # Position of first generated token (right after prompt)
        last_idx = (inputs["attention_mask"].sum(dim=1) - 1)  # [B]
        first_logits = logits[torch.arange(logits.size(0), device=logits.device), last_idx]  # [B, V]
        logp = torch.log_softmax(first_logits, dim=-1)  # [B, V]

        # Gather nospace/space variant scores in one go; take max across variants
        minus_inf = torch.finfo(logp.dtype).min
        def gather_or_inf(idx_vec):
            ok = idx_vec >= 0
            out = torch.full((logp.size(0), 4), minus_inf, device=logp.device, dtype=logp.dtype)
            if ok.any():
                out[:, ok] = logp.index_select(1, idx_vec[ok])
            return out

        scores = torch.maximum(gather_or_inf(cand_ids_no), gather_or_inf(cand_ids_sp))  # [B,4]
        preds = scores.argmax(dim=1).tolist()

        golds = ["ABCD".index(gold_to_letter(cop)) for (_, _, _, _, _, cop) in rows]
        correct += sum(int(p == g) for p, g in zip(preds, golds))
        total   += len(rows)

    return correct / max(total, 1)

In [18]:
# ---------- Filter: keep only rows with a valid, parseable gold label ----------
def _keep_valid(batch):
    return [gold_to_letter(c) is not None for c in batch["cop"]]

def filter_if_labeled(ds, name=""):
    if name in ("train", "val"):
        old = len(ds)
        ds = ds.filter(_keep_valid, batched=True, num_proc=2, desc=f"Filter {name}")
        print(f"[label filter] {name}: {old} -> {len(ds)}")
    else:
        # test labels are hidden; keep for inference
        print(f"[label filter] {name}: skipped (kept for predictions; labels hidden)")
    return ds

# (Optional) report sizes before filtering
n_train0, n_val0_orig = len(train_ds), len(val_ds_orig)

train_ds = filter_if_labeled(train_ds, "train")
val_ds_orig   = filter_if_labeled(val_ds_orig,   "val")

# (Optional) report sizes after filtering
n_train1, n_val1_orig = len(train_ds), len(val_ds_orig)
print(f"[label filter] train: {n_train0} -> {n_train1} | val: {n_val0_orig} -> {n_val1_orig}")

Filter train (num_proc=2):   0%|          | 0/182822 [00:00<?, ? examples/s]

[label filter] train: 182822 -> 182822


Filter val (num_proc=2):   0%|          | 0/4183 [00:00<?, ? examples/s]

[label filter] val: 4183 -> 4183
[label filter] train: 182822 -> 182822 | val: 4183 -> 4183


### Step 3 — Shortening the Dataset

#### Train/val subject selection

To keep runs feasible on T4, we **automatically pick the “middle-2” subjects by frequency** in the training split, and then stratify a 70/30 train/val split on `subject_name`. This keeps train/val distributions similar.

> If you want to pin specific subjects (e.g., “Physiology”, “Biochemistry”), replace the middle-2 selection with a fixed list.

**Purpose:** Done to reduce the size of the dataset, so that it can be trained using Kaggle free compute resources

**Note:** The validation dataset is obtained from splitting the train set. Original validation split is treated as test here (original test has no labels).

**Reason**

The original test dataset doesn't contain label to measure the accuracy. Hence, the existing validation dataset is used as an alternative for comparison

In [19]:
from collections import Counter

# 0) sanity: make sure column exists
assert "subject_name" in train_ds.column_names, "subject_name column missing"

# 1) pick middle-2
subj_counts = Counter(train_ds["subject_name"])
ordered = [s for s, _ in sorted(subj_counts.items(), key=lambda kv: kv[1], reverse=True)]
if len(ordered) < 2:
    raise ValueError(f"Need ≥2 subjects, found {len(ordered)}")

mid = len(ordered) // 2
middle2_subjects = ordered[max(0, mid - 1): mid + 1]  # 2 around the median
print("Middle-2 subjects (by train frequency):", middle2_subjects)

# 2) filter splits (IID: keep only these subjects)
keep = lambda ex: ex["subject_name"] in middle2_subjects
train_ds     = train_ds.filter(keep, num_proc=2)
val_ds_orig  = val_ds_orig.filter(keep, num_proc=2)

print(f"[middle2 filter] sizes -> train: {len(train_ds)} | val_orig: {len(val_ds_orig)}")
print("subjects (train):", sorted(set(train_ds["subject_name"])))

Middle-2 subjects (by train frequency): ['Physiology', 'Biochemistry']


Filter (num_proc=2):   0%|          | 0/182822 [00:00<?, ? examples/s]

Filter (num_proc=2):   0%|          | 0/4183 [00:00<?, ? examples/s]

[middle2 filter] sizes -> train: 17112 | val_orig: 342
subjects (train): ['Biochemistry', 'Physiology']


In [20]:
# Encode subject_name to ClassLabel so we can stratify
train_ds = train_ds.class_encode_column("subject_name")

split = train_ds.train_test_split(
    test_size=0.3,                  # 30% of train → val
    seed=3407,
    stratify_by_column="subject_name"
)

train_ds, val_ds = split["train"], split["test"]

print(f"[new split] train: {len(train_ds)} | val: {len(val_ds)}")
print("Subjects (train):", sorted(set(train_ds["subject_name"])))
print("Subjects (val):", sorted(set(val_ds["subject_name"])))

Flattening the indices:   0%|          | 0/17112 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/17112 [00:00<?, ? examples/s]

[new split] train: 11978 | val: 5134
Subjects (train): [0, 1]
Subjects (val): [0, 1]


In [21]:
print("[label sanity] train:", Counter(gold_to_letter(x) for x in train_ds["cop"]))
print("[label sanity] val_orig:", Counter(gold_to_letter(x) for x in val_ds_orig["cop"]))

[label sanity] train: Counter({'A': 3379, 'B': 3193, 'C': 2826, 'D': 2580})
[label sanity] val_orig: Counter({'A': 103, 'C': 86, 'B': 77, 'D': 76})


### Checkpoints & logging

- Checkpoints are saved under `/kaggle/working/outputs/<model-name>/`.  
  If you are not on Kaggle, change `save_dir` to a local folder like `outputs/<model-name>/`.
- Weights & Biases logging is enabled via `WANDB_API_KEY` (secret name: **`wnb`**).  
  If not set, disable/report-to by setting `report_to=None` in `TrainingArguments`.


### Step 4 — Fine-tuning the model

This block fine-tunes the model and picks the best iteration based on validation accuracy.

In [22]:
print(f"\n=== Fine-tuning {base_name} ===")

# free leftovers from any previous attempts / previous base
for var in ("trainer","model_lora","model","tokenizer"):
    if var in globals():
        try:
            del globals()[var]
        except:
            pass
gc.collect()
torch.cuda.empty_cache()

model, tokenizer = load_base(base_name)

# padding config FIRST
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
model.config.pad_token_id = tokenizer.pad_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id

print("Baseline (no LoRA) small-val:",
      eval_mcq_logits(model, tokenizer, val_ds_orig, max_items=200, batch_size=2, max_len=768))

run = wandb.init(
    project="medmcqa-finetune-reduced-updated",
    job_type="training",
    name = base_name.split("/")[-1].replace("-", "_"),
    reinit=True
)

eos_token = tokenizer.eos_token or "</s>"
train_finetune = train_ds.map(
    lambda b: to_text_examples(b, tokenizer, eos_token=tokenizer.eos_token),
    batched=True
)

model_lora = add_lora(model)
safe_name = base_name.split("/")[-1].lower().replace(" ", "-")

print("TRAIN EXAMPLE:\n", train_finetune[0]["text"][:400])

# sanity: trainable params > 0
trainable = sum(p.numel() for p in model_lora.parameters() if p.requires_grad)
total     = sum(p.numel() for p in model_lora.parameters())
print(f"Trainable params: {trainable:,} / {total:,}")

# pass eval_dataset=None; we validate via callback + eval_mcq
model.config.use_cache = False
save_dir = f"/kaggle/working/outputs/{safe_name}"
trainer = make_trainer(model_lora, tokenizer, train_finetune, outdir=save_dir, eval_dataset=None)
trainer.add_callback(MCQAccuracyCallback(
    tokenizer, val_ds, every=200, patience=8,
    max_items=600, save_dir=save_dir
))
_ = trainer.train()

# === reload best checkpoint if it exists ===
best_dir = os.path.join(save_dir, "best")
if os.path.isdir(best_dir):
    from peft import PeftModel
    model_best = PeftModel.from_pretrained(model, best_dir)
    print(f"Loaded best adapter from {best_dir}")
else:
    print("⚠️ No 'best' dir found, using last-step weights.")
    model_best = model_lora

# Evaluate on validation (answer-only, batched)
val_acc = eval_mcq_logits(model_best, tokenizer, val_ds, max_items=300, batch_size=2, max_len=768)

# ✅ log with the current step so it lands on the same x-axis
wandb.log({"val_accuracy": val_acc}, step=int(trainer.state.global_step))
print(f"Validation accuracy (A/B/C/D) for {base_name}: {val_acc:.3f}")

# small explanation probe
sample_explanations(model_best, tokenizer, val_ds, k=8, batch_size=2, new_tokens=160)


=== Fine-tuning Qwen/Qwen2.5-7B-Instruct ===
==((====))==  Unsloth 2025.8.9: Fast Qwen2 patching. Transformers: 4.55.2.
   \\   /|    Tesla T4. Num GPUs = 2. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 7.5. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.16G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/271 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

Baseline (no LoRA) small-val: 0.485




Map:   0%|          | 0/11978 [00:00<?, ? examples/s]

Unsloth 2025.8.9 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


TRAIN EXAMPLE:
 <|im_start|>system
You are a medical expert.<|im_end|>
<|im_start|>user
You are a medical expert. Answer this MCQ with a single letter.

Question:
Which one of these is absorbed in ileum?

Options:
A. Vitamin D
B. B12
C. Iron
D. Fat

Respond in the format:
Answer: <A/B/C/D><|im_end|>
<|im_start|>assistant
Answer: B<|im_end|>

Trainable params: 80,740,352 / 4,972,287,488


Unsloth: Tokenizing ["text"]:   0%|          | 0/11978 [00:00<?, ? examples/s]

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 11,978 | Num Epochs = 2 | Total steps = 1,498
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 8 x 1) = 16
 "-____-"     Trainable parameters = 80,740,352 of 7,696,356,864 (1.05% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
50,1.5224
100,0.6548
150,0.6477
200,0.6529
250,0.642
300,0.6263
350,0.6298
400,0.6377
450,0.629
500,0.6203


[step 200] accuracy 0.7033
New best acc=0.7033, saving to /kaggle/working/outputs/qwen2.5-7b-instruct/best
[step 400] accuracy 0.7300
New best acc=0.7300, saving to /kaggle/working/outputs/qwen2.5-7b-instruct/best
[step 600] accuracy 0.7417
New best acc=0.7417, saving to /kaggle/working/outputs/qwen2.5-7b-instruct/best
[step 800] accuracy 0.7433
New best acc=0.7433, saving to /kaggle/working/outputs/qwen2.5-7b-instruct/best
[step 1000] accuracy 0.7550
New best acc=0.7550, saving to /kaggle/working/outputs/qwen2.5-7b-instruct/best
[step 1200] accuracy 0.7650
New best acc=0.7650, saving to /kaggle/working/outputs/qwen2.5-7b-instruct/best
[step 1400] accuracy 0.7650




Loaded best adapter from /kaggle/working/outputs/qwen2.5-7b-instruct/best
Validation accuracy (A/B/C/D) for Qwen/Qwen2.5-7B-Instruct: 0.803


A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


[0] gold=C pred=C correct=True
  explanation: Hyperventilation is caused by increased pH in CSF, decreased plasma HCO3, increased adrenergic levels. CO poisoning causes hypoxia which leads to hypoventilation....

[1] gold=B pred=A correct=False
  explanation: The prime driving force for counter current multiplier system is medullary hyperosmolarity. This is due to the reabsorption of NaCl in the thick ascending loop of Henle, which creates a hyperosmotic g...

[2] gold=D pred=None correct=False
  explanation: <3–6 sentences>...



### Step 5 — Testing the model on the test-dataset

This block involves testing the best model on the test dataset and storing the parameters

**Note:** As mentioned earlier, in this workflow, original validation dataset is considered as the test dataset

In [23]:
val_acc_orig = eval_mcq_logits(model_best, tokenizer, val_ds_orig, max_items=len(val_ds_orig), batch_size=2, max_len=768)
print(f"[FINAL] Original validation accuracy (best): {val_acc_orig:.4f}")
wandb.log({"final_val_orig_accuracy": val_acc_orig}, step=int(trainer.state.global_step))

# Save final best adapter to a clean dir
final_dir = f"{safe_name}-medmcqa-lora-best"
model_best.save_pretrained(final_dir)
tokenizer.save_pretrained(final_dir)
print(f"Saved BEST LoRA to: {final_dir}")

[FINAL] Original validation accuracy (best): 0.6784
Saved BEST LoRA to: qwen2.5-7b-instruct-medmcqa-lora-best




In [24]:
run.finish()

0,1
train/epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇████
train/grad_norm,▃▁▂▂▂▂▂▂▂▁▁▂▂▂▁▆▆▅▆▅▆▆▆█▆▆▇▇█
train/learning_rate,▆█████▇▇▇▇▆▆▆▅▅▄▄▄▃▃▃▂▂▂▁▁▁▁▁
train/loss,█▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▄▄▅▅▅█

0,1
total_flos,1.0409321493210931e+17
train/epoch,2.0
train/global_step,1498.0
train/grad_norm,0.62493
train/learning_rate,0.0
train/loss,0.4839
train_loss,0.59327
train_runtime,13079.413
train_samples_per_second,1.832
train_steps_per_second,0.115
