# Fine-tuning TinyLlama on MedMCQA with LoRA
This notebook fine-tunes `TinyLlama/TinyLlama-1.1B-Chat-v1.0` on a **subset (3k)** of MedMCQA using **LoRA**.
Link to SFT+LoRA code: https://colab.research.google.com/drive/1UfRcH8FcByb3mAV8sNTiq7r0Fxfa3ur5?usp=sharing

In [None]:
!pip install -q -U transformers datasets accelerate peft bitsandbytes trl sentencepiece huggingface_hub
!pip install -U trl



In [None]:
import torch
import re
import os
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTConfig,SFTTrainer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [None]:


# Set device to MPS if available, else fallback to CPU
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("✅ MPS enabled on Apple Silicon")
elif torch.cuda.is_available():  # (in case you ever run on a CUDA machine)
    device = torch.device("cuda")
    print("✅ CUDA GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("⚠️  Running on CPU")

print("Using device:", device)

✅ MPS enabled on Apple Silicon
Using device: mps


In [None]:
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DATASET_NAME = "openlifescienceai/medmcqa"
#OUTPUT_DIR = "/content/drive/MyDrive/tinyllama_medmcqa"
OUTPUT_DIR = "tinyllama_medmcqa/output"
LORA_TARGET_MODULES = [
    "q_proj", "k_proj", "v_proj", "o_proj",  # attention projection
    "gate_proj", "up_proj", "down_proj"     # MLP feed-forward projection
]

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
torch.mps.empty_cache()
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
)
model.to(device)  # Move model to Apple GPU

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:

def format_instruction_eval(sample):
    question = sample['question']
    options = "\n".join([
        f"0: {sample['opa']}",
        f"1: {sample['opb']}",
        f"2: {sample['opc']}",
        f"3: {sample['opd']}"
    ])
    prompt = (
        f"You are a helpful medical assistant.\n\n"
        f"Question: {question}\n"
        f"Options:\n{options}\n"
        f"Answer with only the number (0, 1, 2, or 3).\n"
        f"Answer:"
    )
    return {"text": prompt, "cop": sample["cop"]}

def format_instruction_train_chat(sample):
    question = sample['question']
    options = "\n".join([
        f"0: {sample['opa']}",
        f"1: {sample['opb']}",
        f"2: {sample['opc']}",
        f"3: {sample['opd']}"
    ])

    prompt = (
        f"You are a helpful medical assistant.\n\n"
        f"Question: {question}\n"
        f"Options:\n{options}\n"
        f"Answer with only the number (0, 1, 2, or 3).\n"
        f"Answer: {sample['cop']}"
    )

    return {"text": prompt, "cop": str(sample["cop"])}



from datasets import load_dataset

DATASET_NAME = "openlifescienceai/medmcqa"
TRAIN_SAMPLE_SIZE = 3000
EVAL_SAMPLE_SIZE = 1000

dataset = load_dataset(DATASET_NAME)

# Format and sample training set
train_dataset = dataset["train"] \
    .shuffle(seed=42) \
    .select(range(TRAIN_SAMPLE_SIZE)) \
    .map(format_instruction_train_chat)

# with answer

# Format and sample validation set
val_dataset = dataset["validation"].shuffle(seed=42).select(range(EVAL_SAMPLE_SIZE)).map(format_instruction_eval) # without answer

# For evaluation use
eval_dataset = val_dataset
print("Sample prompt:")
print(eval_dataset[0]["text"])

Sample prompt:
You are a helpful medical assistant.

Question: Amount of heat that is required to change boiling water into vapor is referred to as
Options:
0: Latent Heat of vaporization
1: Latent Heat of sublimation
2: Latent Heat of condensation
3: Latent heat of fusion
Answer with only the number (0, 1, 2, or 3).
Answer:


In [None]:
def evaluate_model(model, tokenizer, dataset):
    logging = True
    correct = 0

    for i, sample in enumerate(tqdm(dataset)):
        prompt = sample["text"]
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH).to(model.device)

        outputs = model.generate(
            **inputs,
            max_new_tokens=2,
            pad_token_id=tokenizer.eos_token_id,
            temperature=0.0,
            do_sample=False
        )

        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Just search for the answer anywhere in the decoded output
        #match = re.search(r"\b([0-3])\b", decoded)
        #predicted_answer = int(match.group(1)) if match else -1

        match = None
        matches = list(re.finditer(r"\b([0-3])\b", decoded))
        if matches:
            match = matches[-1]  # pick last match
            predicted_answer = int(match.group(1)) if match else -1

        if predicted_answer == sample["cop"]:
            correct += 1

        if logging:
            print("🔎 Prompt:", prompt)
            print("📤 Full Decoded Output:", decoded)
            print("🔢 Predicted:", predicted_answer, "| Actual:", sample["cop"])
            logging = False  # Show only once

    return correct / len(dataset)

In [None]:

print("Evaluating base TinyLlama...")
base_acc = evaluate_model(model, tokenizer, eval_dataset)
print(f"Base Accuracy: {base_acc:.2%}")

Evaluating base TinyLlama...


  0%|          | 3/1000 [00:00<03:05,  5.39it/s]

🔎 Prompt: You are a helpful medical assistant.

Question: Amount of heat that is required to change boiling water into vapor is referred to as
Options:
0: Latent Heat of vaporization
1: Latent Heat of sublimation
2: Latent Heat of condensation
3: Latent heat of fusion
Answer with only the number (0, 1, 2, or 3).
Answer:
📤 Full Decoded Output: You are a helpful medical assistant.

Question: Amount of heat that is required to change boiling water into vapor is referred to as
Options:
0: Latent Heat of vaporization
1: Latent Heat of sublimation
2: Latent Heat of condensation
3: Latent heat of fusion
Answer with only the number (0, 1, 2, or 3).
Answer: 0
🔢 Predicted: 0 | Actual: 0


100%|██████████| 1000/1000 [01:26<00:00, 11.59it/s]

Base Accuracy: 24.30%





In [None]:
from transformers import TrainingArguments
from trl import SFTConfig
LORA_R = 16
LORA_ALPHA = 32
LEARNING_RATE = 1e-4
EPOCHS = 3
GRAD_ACC = 4
TRAIN_BATCH_SIZE = 4
LORA_DROPOUT = 0.05
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=LORA_TARGET_MODULES,
    bias="none",
    use_rslora=True,
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
training_args =  SFTConfig(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=GRAD_ACC,
    learning_rate=LEARNING_RATE,
    num_train_epochs=EPOCHS,
    logging_steps=10,
    eval_steps=200,  # ✅ evaluation will happen every 200 steps
    save_steps=1000,
    lr_scheduler_type='cosine_with_restarts',
    warmup_ratio=0.05,
    report_to="none",
    save_total_limit=2,
    optim="adamw_torch",
    max_grad_norm=0.3,
    remove_unused_columns=False,
    bf16=torch.backends.mps.is_available(),
    save_safetensors=True
)
torch.mps.empty_cache()
model.to("mps")


  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'
trainable params: 12,615,680 || all params: 1,112,664,064 || trainable%: 1.1338


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 2048)
        (layers): ModuleList(
          (0-21): 22 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=2048, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(

In [None]:
train_dataset[0]

{'id': 'c69ce981-e254-4707-b87a-cb493e8a948e',
 'question': 'OPV can be used if vaccine l monitor is showing?',
 'opa': 'Colour of outer circle is same as inner square',
 'opb': 'Colour of outer circle is darker than inner square',
 'opc': 'Colour of outer circle is lighter than inner square',
 'opd': 'None of the above',
 'cop': 1,
 'choice_type': 'multi',
 'exp': "Ans. is 'b' i.e., Colour of outer circle is darker than inner square",
 'subject_name': 'Social & Preventive Medicine',
 'topic_name': None,
 'text': 'You are a helpful medical assistant.\n\nQuestion: OPV can be used if vaccine l monitor is showing?\nOptions:\n0: Colour of outer circle is same as inner square\n1: Colour of outer circle is darker than inner square\n2: Colour of outer circle is lighter than inner square\n3: None of the above\nAnswer with only the number (0, 1, 2, or 3).\nAnswer: 1'}

In [None]:
MAX_SEQ_LENGTH = 256
def tokenize_function_masked(example):
    text = example["text"]
    answer = str(example["cop"])  # e.g. "1"

    tokenized = tokenizer(
        text,
        truncation=True,
        padding="max_length",
        max_length=MAX_SEQ_LENGTH,
    )
    input_ids = tokenized["input_ids"]
    labels = [-100] * len(input_ids)

    # Tokenize the answer string directly
    answer_token_ids = tokenizer(answer, add_special_tokens=False)["input_ids"]

    # Search for the answer token in input_ids
    for i in range(len(input_ids) - len(answer_token_ids) + 1):
        if input_ids[i:i+len(answer_token_ids)] == answer_token_ids:
            for j in range(len(answer_token_ids)):
                labels[i + j] = input_ids[i + j]
            break  # stop after first match

    tokenized["labels"] = labels
    return tokenized
tokenized_train_dataset = train_dataset.map(tokenize_function_masked, batched=False)

Map: 100%|██████████| 3000/3000 [00:00<00:00, 5750.21 examples/s]


In [None]:
d = tokenize_function_masked(train_dataset[0])
print("Decoded:", tokenizer.decode(d["input_ids"]))
print("Labels:", [tokenizer.decode([x]) if x != -100 else '_' for x in d["labels"]])

Decoded: <s> You are a helpful medical assistant.

Question: OPV can be used if vaccine l monitor is showing?
Options:
0: Colour of outer circle is same as inner square
1: Colour of outer circle is darker than inner square
2: Colour of outer circle is lighter than inner square
3: None of the above
Answer with only the number (0, 1, 2, or 3).
Answer: 1</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>
Labels: ['_', '_', '_', '_

In [None]:

from transformers import EarlyStoppingCallback
# Initialize the trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    peft_config=peft_config,
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=10)]  # 👈 patience=10
)

trainer.train()

Truncating train dataset: 100%|██████████| 3000/3000 [00:00<00:00, 87829.63 examples/s]
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
10,3.2607
20,0.6628
30,0.4632
40,0.4485
50,0.4234
60,0.4142
70,0.4262
80,0.4133
90,0.389
100,0.4491


TrainOutput(global_step=561, training_loss=0.4029473512567938, metrics={'train_runtime': 2448.0006, 'train_samples_per_second': 3.676, 'train_steps_per_second': 0.229, 'total_flos': 1.441116280848384e+16, 'train_loss': 0.4029473512567938})

In [None]:
trainer.save_model(f"{OUTPUT_DIR}/LoRA1")
print("Model saved to", f"{OUTPUT_DIR}/LoRA1")
!ls -lh {OUTPUT_DIR}/LoRA1

Model saved to tinyllama_medmcqa/output/LoRA1
total 106752
-rw-r--r--  1 jagadeeshbandlamudi  staff   5.0K Apr 26 19:49 README.md
-rw-r--r--  1 jagadeeshbandlamudi  staff   866B Apr 26 19:49 adapter_config.json
-rw-r--r--  1 jagadeeshbandlamudi  staff    48M Apr 26 19:49 adapter_model.safetensors
-rw-r--r--  1 jagadeeshbandlamudi  staff   551B Apr 26 19:49 special_tokens_map.json
-rw-r--r--  1 jagadeeshbandlamudi  staff   3.5M Apr 26 19:49 tokenizer.json
-rw-r--r--  1 jagadeeshbandlamudi  staff   488K Apr 26 19:49 tokenizer.model
-rw-r--r--  1 jagadeeshbandlamudi  staff   1.4K Apr 26 19:49 tokenizer_config.json
-rw-r--r--  1 jagadeeshbandlamudi  staff   5.6K Apr 26 19:49 training_args.bin


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [None]:
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os

# Setup
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Paths
base_model_id = MODEL_NAME
adapter_path = f"{OUTPUT_DIR}/LoRA1"
merged_model_path = f"{OUTPUT_DIR}/Merged"

print(f"Base model ID: {base_model_id}")
print(f"Adapter path: {adapter_path}")
print(f"Merged model path: {merged_model_path}")

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.backends.mps.is_available() else torch.float32,
    trust_remote_code=True
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_id)

# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, adapter_path)

# Merge LoRA into base model
model = model.merge_and_unload()

# Save merged model
model.save_pretrained(merged_model_path)
tokenizer.save_pretrained(merged_model_path)
print(f"Model merged and saved to {merged_model_path}")


Base model ID: TinyLlama/TinyLlama-1.1B-Chat-v1.0
Adapter path: tinyllama_medmcqa/output/LoRA1
Merged model path: tinyllama_medmcqa/output/Merged
Model merged and saved to tinyllama_medmcqa/output/Merged


In [None]:
# === Reload merged model for evaluation ===
print("merged_model_path", merged_model_path)
mergedModel = AutoModelForCausalLM.from_pretrained(merged_model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(merged_model_path)

merged_model_path tinyllama_medmcqa/output/Merged


In [None]:
# Evaluate
ft_acc = evaluate_model(mergedModel, tokenizer, eval_dataset)
results = pd.DataFrame({
    "Model": ["Base", "Fine-tuned"],
    "Accuracy": [base_acc, ft_acc],
    "Improvement": ["-", f"{ft_acc - base_acc:.5%}"]
})
print(results)

  0%|          | 1/1000 [00:05<1:32:15,  5.54s/it]

🔎 Prompt: You are a helpful medical assistant.

Question: Amount of heat that is required to change boiling water into vapor is referred to as
Options:
0: Latent Heat of vaporization
1: Latent Heat of sublimation
2: Latent Heat of condensation
3: Latent heat of fusion
Answer with only the number (0, 1, 2, or 3).
Answer:
📤 Full Decoded Output: You are a helpful medical assistant.

Question: Amount of heat that is required to change boiling water into vapor is referred to as
Options:
0: Latent Heat of vaporization
1: Latent Heat of sublimation
2: Latent Heat of condensation
3: Latent heat of fusion
Answer with only the number (0, 1, 2, or 3).
Answer: 1
🔢 Predicted: 1 | Actual: 0


100%|██████████| 1000/1000 [07:07<00:00,  2.34it/s] 

        Model  Accuracy Improvement
0        Base     0.243           -
1  Fine-tuned     0.322    7.90000%



