# Lab 6 - Fine-tuning Llama 3.2 on Medical Dataset

This notebook demonstrates fine-tuning a Llama model on a medical Q&A dataset using LoRA (Low-Rank Adaptation) for efficient training.

## 1. Setup and Imports


In [None]:
%pip install -r requirements.txt -q

In [None]:
import warnings
import os
warnings.filterwarnings("ignore")
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"


## Hugging Face Authentication

Run the next cell once to log in with your HF token (needed for gated models).


In [None]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model
from datasets import load_dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [None]:
from huggingface_hub import login

# Run once to store token locally
# If you already logged in on this machine, skip rerunning this cell.
login()

## 2. Load Model and Tokenizer


In [None]:
import os

model_name = os.environ.get("LLAMA_MODEL", "meta-llama/Llama-3.2-1B-Instruct")
hf_token = os.environ.get("HF_TOKEN")  # set via login()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    token=hf_token,
    dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
)
print(f"Model loaded: {model_name}")


tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

Model loaded: meta-llama/Llama-3.2-1B-Instruct


## 3. Configure LoRA

LoRA enables efficient fine-tuning by training low-rank decomposition matrices instead of full model weights:


In [None]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


trainable params: 3,407,872 || all params: 1,239,222,272 || trainable%: 0.2750


## 4. Load and Prepare Dataset


In [None]:
dataset = load_dataset("FreedomIntelligence/medical-o1-verifiable-problem")
train_dataset = dataset["train"].select(range(1000))
print(f"Training on {len(train_dataset)} examples")
print(f"Sample: {train_dataset[0]}")


README.md: 0.00B [00:00, ?B/s]

medical_o1_verifiable_problem.json:   0%|          | 0.00/12.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/40644 [00:00<?, ? examples/s]

Training on 1000 examples
Sample: {'Open-ended Verifiable Question': 'An 88-year-old woman with osteoarthritis is experiencing mild epigastric discomfort and has vomited material resembling coffee grounds multiple times. Considering her use of naproxen, what is the most likely cause of her gastrointestinal blood loss?', 'Ground-True Answer': 'Gastric ulcer'}


In [None]:
def format_prompt(example):
    question = example.get("question", example.get("Open-ended Verifiable Question", ""))
    answer = example.get("answer", example.get("Ground-True Answer", ""))
    prompt = f"<|user|>\n{question}\n<|assistant|>\nThe answer is: {answer}"
    return {"text": prompt}

def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,
        padding="max_length"
    )

train_dataset = train_dataset.map(format_prompt)
tokenized_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

## 5. Training


In [None]:
training_args = TrainingArguments(
    output_dir="./llama3_medical_lora",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=10,
    save_steps=100,
    warmup_steps=50,
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

print("Starting training...")
trainer.train()
print("Training complete!")


The model is already on multiple devices. Skipping the move to device specified in `args`.


Starting training...


[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33malexander-p-verhaeghe[0m ([33malexander-p-verhaeghe-essec-business-school[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
10,2.7181
20,2.2867
30,2.0166
40,1.8655
50,1.8469
60,1.7877
70,1.7538
80,1.6989
90,1.7213
100,1.6753


Training complete!


In [None]:
model.save_pretrained("./llama3_medical_lora")
tokenizer.save_pretrained("./llama3_medical_lora")
print("Model saved to: ./llama3_medical_lora")


Model saved to: ./llama3_medical_lora


## 6. Evaluation on Test Set


In [None]:
import random

test_dataset = dataset["train"].select(range(1000, 1020))

def get_prediction(question):
    prompt = f"<|user|>\n{question}\n<|assistant|>\nThe answer is:"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=50, temperature=0.3, top_p=0.9)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response.split("The answer is:")[-1].strip()

correct = 0
for i, example in enumerate(test_dataset):
    question = example.get("Open-ended Verifiable Question", example.get("question", ""))
    ground_truth = example.get("Ground-True Answer", example.get("answer", ""))
    prediction = get_prediction(question)
    is_correct = ground_truth.lower() in prediction.lower()
    if is_correct:
        correct += 1
    print(f"{i+1}. {'✓' if is_correct else '✗'} Pred: {prediction[:50]}... | Truth: {ground_truth[:30]}")

print(f"\nAccuracy: {correct}/{len(test_dataset)} ({100*correct/len(test_dataset):.1f}%)")


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


1. ✗ Pred: Acute rheumatic fever.... | Truth: Endocarditis.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


2. ✗ Pred: Squamous cells with keratinization.... | Truth: Hyperplastic cells


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


3. ✗ Pred: Placenta previa with uterine rupture.... | Truth: Accidental hemorrhage


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


4. ✓ Pred: Chronic lymphocytic leukemia (CLL) with splenomega... | Truth: Chronic lymphocytic leukemia


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


5. ✗ Pred: Gallstones, Gallbladder inflammation, Gallbladder ... | Truth: Cholesterosis, Adenomyomatosis


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


6. ✓ Pred: Trastuzumab (Herceptin) <|... | Truth: Trastuzumab


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


7. ✗ Pred: CD3 and CD4 markers are positive, but CD4 is negat... | Truth: BCL-6


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


8. ✗ Pred: Inspiratory component of the respiratory control p... | Truth: Pre-BOTC


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


9. ✗ Pred: Esophageal stricture, Pharyngoesophageal fistula, ... | Truth: Pharyngeal diverticulum and di


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


10. ✗ Pred: Genetic predisposition to PBC.... | Truth: Improved quality of care for P


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


11. ✗ Pred: 2 years old.... | Truth: 15 months


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


12. ✗ Pred: Hypothyroidism, Hypogonadism, Hypogonadism, Hypoth... | Truth: Cushing's syndrome, insulinoma


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


13. ✓ Pred: Rheumatoid factor (RF) and anti-citrullinated prot... | Truth: Anti-CCP


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


14. ✗ Pred: 50-100 mN/m2.... | Truth: Capillary blood pressure


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


15. ✗ Pred: Alcohol, cocaine, and nicotine.... | Truth: Heroin, Ketamine, LSD


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


16. ✗ Pred: Shigella dysenteriae 1 and 2, and Salmonella Typhi... | Truth: Clostridium perfringens


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


17. ✗ Pred: Chronic alcohol use leading to iron deficiency ane... | Truth: toxic marrow suppression


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


18. ✗ Pred: Chronic interstitial cystitis.... | Truth: Pelvic lipomatosis


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


19. ✗ Pred: Bed rest and monitoring of fetal well-being.... | Truth: Resuscitation and observation 
20. ✗ Pred: Deposits of IgA in the glomeruli and mesangium.... | Truth: Mesangial deposits of IgA

Accuracy: 3/20 (15.0%)
