# Disclaimer : This notebook was run on T4 GPU in Google Colab and you need a HuggingFace API Key

# Finetune Llama 3.2 on Medical Dataset with Hugging Face and PEFT

In this notebook, we train a Llama 3.2 model on a medical dataset using LoRA for parameter-efficient fine-tuning. We cover loading the model, configuring LoRA, training, and evaluation.

## Part 1: Setup and Model Loading

In [6]:
!pip install torch transformers datasets peft huggingface_hub ipywidgets -q

In [7]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
import json
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [8]:
!hf auth whoami

[1muser: [0m Othocs
[1morgs: [0m cvmistralparis


We choose cuda because I am running the notebook on a T4 GPU on colab

In [9]:
device = torch.device("cuda")

In [10]:
model_name = "meta-llama/Llama-3.2-1B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype=torch.float16,
    device_map={"": device},
)

print(f"Model loaded: {model_name}")

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

Model loaded: meta-llama/Llama-3.2-1B-Instruct


## Part 2: LoRA Configuration

LoRA (Low-Rank Adaptation) allows efficient fine-tuning by training only small adapter matrices instead of the full model weights.

Parameters:
- r=16: rank of the LoRA matrices
- lora_alpha=32: scaling factor
- target_modules: attention and MLP projection layers
- lora_dropout=0.05: dropout for regularization

In [11]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 11,272,192 || all params: 1,247,086,592 || trainable%: 0.9039


## Part 3: Dataset Preparation

Load the medical Q&A dataset and format it for training.

In [12]:
def format_prompt(example):
    question = example["Open-ended Verifiable Question"]
    answer = example["Ground-True Answer"]
    text = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{answer}<|eot_id|>"
    return {"text": text}

In [13]:
dataset = load_dataset("FreedomIntelligence/medical-o1-verifiable-problem")
print(f"Dataset loaded: {len(dataset['train'])} examples")

README.md: 0.00B [00:00, ?B/s]

medical_o1_verifiable_problem.json:   0%|          | 0.00/12.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/40644 [00:00<?, ? examples/s]

Dataset loaded: 40644 examples


In [14]:
print(dataset["train"].column_names)
print(dataset["train"][0])

['Open-ended Verifiable Question', 'Ground-True Answer']
{'Open-ended Verifiable Question': 'An 88-year-old woman with osteoarthritis is experiencing mild epigastric discomfort and has vomited material resembling coffee grounds multiple times. Considering her use of naproxen, what is the most likely cause of her gastrointestinal blood loss?', 'Ground-True Answer': 'Gastric ulcer'}


In [15]:
train_dataset = dataset["train"].select(range(500))
train_dataset = train_dataset.map(format_prompt)
print(f"Training on {len(train_dataset)} examples")

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Training on 500 examples


In [16]:
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=512, padding="max_length")

train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)
train_dataset.set_format("torch")
print(f"Dataset tokenized: {len(train_dataset)} examples")

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Dataset tokenized: 500 examples


## Part 4: Training Setup

Configure the data collator, training arguments, and trainer.

In [17]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

In [18]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    warmup_steps=10,
    logging_steps=10,
    save_steps=100,
    save_total_limit=2,
    fp16=False,
    logging_dir="./logs",
    report_to="none",
)

In [19]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
)

## Part 5: Training and Saving

In [20]:
print("Starting training...")
trainer.train()
print("Training complete!")

Starting training...


Step,Training Loss
10,2.9893
20,2.2771
30,2.306
40,2.0234
50,1.9142
60,1.9331
70,1.8456
80,1.9208
90,1.8087
100,1.8646


Training complete!


In [21]:
model.save_pretrained("./llama3_medical_lora")
tokenizer.save_pretrained("./llama3_medical_lora")

('./llama3_medical_lora/tokenizer_config.json',
 './llama3_medical_lora/special_tokens_map.json',
 './llama3_medical_lora/chat_template.jinja',
 './llama3_medical_lora/tokenizer.json')

## Part 6: Evaluation

Test the fine-tuned model on unseen examples from the test set.

In [22]:
import random
import time

random.seed(42)
test_set = dataset["train"].select(range(1000, len(dataset["train"])))
selected_indices = random.sample(range(len(test_set)), 20)

print(f"Total dataset size: {len(dataset['train'])}")
print(f"Training set: 0 to 1000")
print(f"Test set: 1000 to {len(dataset['train'])}")
print(f"Selected {len(selected_indices)} test examples")
print(f"Indices: {selected_indices[:5]}... (showing first 5)")

Total dataset size: 40644
Training set: 0 to 1000
Test set: 1000 to 40644
Selected 20 test examples
Indices: [7296, 1639, 18024, 16049, 14628]... (showing first 5)


In [23]:
def get_prediction(question, max_tokens=50):
    prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.3,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer = response.split("assistant")[-1].strip()
    return answer

In [24]:
def check_accuracy(prediction, ground_truth):
    pred_lower = prediction.lower()
    truth_lower = ground_truth.lower()

    if truth_lower in pred_lower:
        return True, "exact"

    stop_words = {"the", "a", "an", "is", "are", "was", "were", "of", "in", "to", "for", "with", "on", "at"}
    truth_words = [w for w in truth_lower.split() if w not in stop_words]

    if not truth_words:
        return False, "no_match"

    matches = sum(1 for w in truth_words if w in pred_lower)
    match_ratio = matches / len(truth_words)

    if match_ratio >= 0.7:
        return True, "partial"

    return False, "no_match"

In [25]:
results = []
correct_exact = 0
correct_partial = 0
start_time = time.time()

for i, idx in enumerate(selected_indices, 1):
    example = test_set[idx]
    question = example["Open-ended Verifiable Question"]
    ground_truth = example["Ground-True Answer"]

    print(f"\nTEST {i}/20")
    print(f"Question: {question[:100]}...")
    print(f"Ground Truth: {ground_truth}")

    prediction = get_prediction(question)
    print(f"Prediction: {prediction[:100]}...")

    correct, match_type = check_accuracy(prediction, ground_truth)

    if correct:
        if match_type == "exact":
            correct_exact += 1
        else:
            correct_partial += 1
        print("CORRECT")
    else:
        print("INCORRECT")

    results.append({
        "question": question,
        "ground_truth": ground_truth,
        "prediction": prediction,
        "correct": correct,
        "match_type": match_type
    })

    accuracy = (correct_exact + correct_partial) / i * 100
    print(f"Running accuracy: {accuracy:.1f}% ({correct_exact + correct_partial}/{i})")

total_time = time.time() - start_time


TEST 1/20
Question: After a 60-year-old man underwent a successful orthotopic liver transplantation, the transplanted li...
Ground Truth: Reactive oxygen species
Prediction: Free oxygen radicals produced by the ischemic hepatocytes....
INCORRECT
Running accuracy: 0.0% (0/1)

TEST 2/20
Question: In a 37-year-old female patient with a fractured clavicle where the junction of the inner and middle...
Ground Truth: Thrombosis of the subclavian vein, causing a pulmonary embolism
Prediction: Neurological injury of the brachial plexi due to fracture of the clavicle....
INCORRECT
Running accuracy: 0.0% (0/2)

TEST 3/20
Question: In which condition does the antagonism of histamine by H1 antihistaminics not afford any benefit?...
Ground Truth: Common cold
Prediction: Asthma in children....
INCORRECT
Running accuracy: 0.0% (0/3)

TEST 4/20
Question: A 74-year-old man has a 1.5-centimeter, faintly erythematous, raised lesion with irregular borders o...
Ground Truth: Irreversible nuclear changes in

In [26]:
accuracy = (correct_exact + correct_partial) / 20 * 100
print(f"Total: {correct_exact + correct_partial}/20 correct ({accuracy:.1f}%)")
print(f"Exact matches: {correct_exact}")
print(f"Partial matches: {correct_partial}")
print(f"Incorrect: {20 - correct_exact - correct_partial}")
print(f"Total time: {total_time:.1f}s ({total_time/20:.1f}s per example)")

Total: 2/20 correct (10.0%)
Exact matches: 1
Partial matches: 1
Incorrect: 18
Total time: 14.2s (0.7s per example)


In [27]:
results_summary = {
    "accuracy": accuracy,
    "exact_matches": correct_exact,
    "partial_matches": correct_partial,
    "incorrect": 20 - correct_exact - correct_partial,
    "total_time": total_time,
    "selected_indices": selected_indices,
    "results": results
}

with open("evaluation_results.json", "w") as f:
    json.dump(results_summary, f, indent=2)

print("Results saved to: evaluation_results.json")

Results saved to: evaluation_results.json


## Part A: Model Improvement Strategies

### Question 1: Improving Model Performance

Based on your evaluation results, propose at least 2 or 3 specific strategies to improve your model's accuracy. For each strategy, explain what you would change, why it helps, and potential trade-offs.

To improve model performance, we could train on more data by increasing from 500 to 2000+ examples to expose the model to more patterns. We could also train for more epochs (5-10 instead of 3) to better learn the task. Another approach is adjusting the prompt format to include instructions like "Answer concisely in 1-3 words" to get shorter responses that match the expected output format.

### Question 2: Analyzing Failure Patterns

Review your incorrect predictions and identify patterns in failures. What can you tell about the model errors?

Looking at the failure patterns, the model tends to generate verbose explanations instead of concise answers. For example, it outputs "Islet cell tumor of the pancreas" when simply "Insulinoma" is expected. The model also sometimes gives related but incorrect concepts, such as "Free oxygen radicals" instead of "Reactive oxygen species". Overall, the model tends to over-explain when a simple term is expected, or gives a related concept that is not the exact expected answer.

### Question 3: Data Quality vs. Quantity

What do you think is better between training on 2000 examples (same quality) or 500 curated high-quality examples?

For medical Q&A, 500 high-quality curated examples are likely better because medical domain requires precision over breadth, and noisy data could teach wrong associations. However, 2000 diverse examples might help generalization if the quality is maintained throughout.

## Part B: Resource-Constrained Inference

### Question 4: Optimizing for Limited Resources

How can you design a strategy to reduce inference time/memory for deployment in constrained environments?

To reduce inference time and memory for constrained environments, we can apply quantization using 4-bit or 8-bit to reduce memory footprint. We can also reduce max_new_tokens from 50 to 20 since medical answers are typically short. Using a smaller LoRA rank (r=8 instead of 16) results in smaller adapter weights. Merging LoRA weights into the base model removes adapter overhead during inference. Additionally, torch.compile() can speed up inference.

### Question 5: Speed vs. Accuracy Trade-offs

Analyze how changing generation parameters affects speed, quality, and consistency.

Lower temperature (0.1) gives more deterministic outputs with faster convergence but less diversity. Higher temperature (0.7) produces more creative responses but may hallucinate. Reducing max_tokens speeds up inference but may truncate longer answers. Using do_sample=False (greedy decoding) is fastest but may get stuck in repetitive patterns. Our current settings (temperature=0.3, top_p=0.9) balance quality and speed reasonably well.

## Part C: Evaluation Methodology

### Question 6: Improving Evaluation Metrics

Analyze limitations of current exact/partial match evaluation and propose improvements. Do you think you have false negatives or false positives? What can we do about it?

The current evaluation has several limitations. We get false negatives when verbose answers like "Islet cell tumor of the pancreas" contain a related concept but not the exact term "Insulinoma". We can get false positives because the partial match threshold (70%) could accept wrong answers that share common medical terms. Semantically related answers like "Reactive oxygen species" vs "Free oxygen radicals" are marked wrong despite being related concepts.

To improve this, we could use semantic similarity with embeddings instead of string matching. We could also use LLM-as-judge to evaluate if the prediction is medically equivalent to the ground truth. Another approach is allowing synonym matching using medical ontologies like UMLS or SNOMED.

### Question 7: Test Set Size and Confidence

Test other test set sizes and observe the result. What can you say about the results? How can you improve it?

With 20 samples, there is high variance where a single correct or incorrect answer changes accuracy by 5%. Using 50 samples would give a more stable estimate with a narrower confidence interval. Using 100+ samples would be statistically reliable but slower to evaluate. Our current 10% accuracy with 20 samples has a wide confidence interval (roughly 0-20%). For reliable assessment, we recommend using 50-100 samples.

## Part D: Real-World Deployment Scenario

### Question 8: Production Considerations

What can you do to address safety, reliability, updates, and edge cases for deploying in a medical assistance application?

For safety, we should add confidence scores and refuse to answer when uncertain. We should include disclaimers stating "This is not medical advice, consult a healthcare professional". High-risk queries involving drug interactions or emergencies should be flagged for human review.

For reliability, we need input validation to reject malformed or adversarial inputs. The model should fallback to "I don't know" rather than hallucinating. We should also monitor for distribution shift in incoming queries.

For updates, we need version control for model weights and training data. A/B testing should be performed before deploying new models, and we should maintain rollback capability to previous versions.

For edge cases, we should handle multilingual queries and medical abbreviations. The model should be tested on rare diseases and unusual presentations. We need to clearly define the scope of what the model should and should not answer.