### Fine Tuning Llama 3.2 on Medical Dataset

In [24]:
import warnings
warnings.filterwarnings("ignore")

#### Setup & Loading Model

In [25]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)

from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType

from pathlib import Path
import random
import string
import json
import time
import re


In [None]:
# Check if a CUDA-capable GPU is available
if torch.cuda.is_available():
    # Use GPU for computations if available
    device = torch.device("cuda")
else:
    # Fall back to CPU if no GPU is available
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: cuda


Load the `meta-llama/Llama-3.2-1B-Instruct` checkpoint from the Hugging Face Hub and initialize both `AutoTokenizer` and `AutoModelForCausalLM` with this model ID.

In [27]:
model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct"

In [None]:
# Load the tokenizer associated with the pretrained model
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set the padding token to be the same as the end-of-sequence token
# (common practice for causal language models)
tokenizer.pad_token = tokenizer.eos_token

# Specify that padding should be added on the right side of sequences
tokenizer.padding_side = "right"

# Load the pretrained causal language model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype=torch.float16,        # Use half precision for reduced memory usage and faster computation
    device_map={"": device},    # Map the entire model to the selected device (CPU or GPU)
)

print(f"Model loaded: {model_name}")

Model loaded: HuggingFaceTB/SmolLM2-1.7B-Instruct


#### LoRA Configuration

In [None]:
print("\nConfiguring LoRA...")

# LoRA hyperparameters (easy to tweak or pass as script arguments)
LORA_RANK = 16
LORA_ALPHA = 32        # typically = 2 * r or = r
LORA_DROPOUT = 0.05    # set to 0.0 for small datasets to avoid underfitting

# Target modules for a LLaMA/SmolLM-like architecture
# Covers both attention layers and MLP layers
TARGET_MODULES = [
    "q_proj", "k_proj", "v_proj", "o_proj",   # attention projections
    "gate_proj", "up_proj", "down_proj",      # MLP projections
]

# Define the LoRA configuration
lora_config = LoraConfig(
    r=LORA_RANK,                    # Rank of the low-rank adapters
    lora_alpha=LORA_ALPHA,          # Scaling factor for LoRA updates
    lora_dropout=LORA_DROPOUT,      # Dropout applied to LoRA layers
    bias="none",                    # Do not train bias parameters
    task_type=TaskType.CAUSAL_LM,   # Task type: causal language modeling
    target_modules=TARGET_MODULES,  # Modules where LoRA adapters are injected
)

# Wrap the base model with LoRA adapters
model = get_peft_model(model, lora_config)

model.print_trainable_parameters()


Configuring LoRA...
trainable params: 18,087,936 || all params: 1,729,464,320 || trainable%: 1.0459


#### Loading Dataset

Load the dataset, apply the formatting function to it, and restrict the resulting data to 500 formatted examples for training.

In [None]:
def format_prompt_medmcqa(example):
    """
    Format a MedMCQA example into a chat-style prompt
    compatible with SmolLM2-360M-Instruct.
    """

    # Extract MedMCQA fields safely (fallback to empty strings)
    question = example.get("question", "")
    opa = example.get("opa", "")
    opb = example.get("opb", "")
    opc = example.get("opc", "")
    opd = example.get("opd", "")
    cop = example.get("cop", "")  # Correct option: 'a','b','c','d' or 0–3 (ClassLabel)

    # Map answer label/index to the actual answer text
    option_map = {
        "a": opa,
        "b": opb,
        "c": opc,
        "d": opd,
        0: opa,
        1: opb,
        2: opc,
        3: opd,
    }
    answer = option_map.get(cop, "")

    # Basic filtering to remove invalid or low-quality examples
    if not question or len(question) < 10:
        return None
    if not answer or len(answer) < 2:
        return None

    # Format the multiple-choice question with all options
    mcq_text = (
        f"{question}\n"
        f"A. {opa}\n"
        f"B. {opb}\n"
        f"C. {opc}\n"
        f"D. {opd}"
    )

    # SmolLM2 chat-style prompt template
    text = (
        "<|im_start|>system\n"
        "You are a helpful AI assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{mcq_text}<|im_end|>\n"
        "<|im_start|>assistant\n"
        f"The answer is: {answer}<|im_end|>"
    )

    # Return formatted text in a dict (compatible with HF datasets mapping)
    return {"text": text}

In [None]:
# 1. Load the MedMCQA dataset
# The dataset contains train / validation / test splits
dataset = load_dataset("openlifescienceai/medmcqa")
raw_train = dataset["train"]

# 2. Apply prompt formatting to each training example
# This converts raw MedMCQA samples into chat-style prompts
formatted = raw_train.map(
    format_prompt_medmcqa,
    remove_columns=raw_train.column_names,  # Keep only the formatted text
)

# 3. Filter out invalid examples (where formatting returned None)
formatted = formatted.filter(lambda x: x["text"] is not None)

# 4. Keep only a subset of examples (e.g., for quick experiments or debugging)
train_dataset = formatted.select(range(min(500, len(formatted))))

print(train_dataset)

Dataset({
    features: ['text'],
    num_rows: 500
})


#### Training & Fine Tuning

In [None]:
def tokenize_function(examples):
    """
    Tokenize formatted text prompts for causal language model training.
    The labels are set to be the same as the input IDs (standard LM objective).
    """

    # Tokenize the input text
    tokenized = tokenizer(
        examples["text"],          # List of formatted prompts
        padding="max_length",      # Pad all sequences to the same length
        truncation=True,           # Truncate sequences longer than max_length
        max_length=512,            # Maximum sequence length
        return_tensors="pt"        # Return PyTorch tensors
    )

    # For causal language modeling, labels are identical to input_ids
    tokenized["labels"] = tokenized["input_ids"].clone()

    return tokenized

In [None]:
# Apply tokenization to the training dataset
# batched=True means tokenize_function receives a batch (list) of texts
tokenized_train = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"],  # Remove raw text, keep only model inputs
)

# Ensure outputs are returned as PyTorch tensors when accessed
tokenized_train.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "labels"],
)

# Print dataset structure and metadata
print(tokenized_train)

# Print the shape of each tensor for the first example (sanity check)
print({k: v.shape for k, v in tokenized_train[0].items()})

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 500
})
{'input_ids': torch.Size([512]), 'attention_mask': torch.Size([512]), 'labels': torch.Size([512])}


In [None]:
training_args = TrainingArguments(
    output_dir="./results",              # Directory to save model checkpoints and outputs
    num_train_epochs=3,                  # Number of full training epochs
    per_device_train_batch_size=1,       # Batch size per GPU/CPU
    gradient_accumulation_steps=4,       # Accumulate gradients to simulate a larger batch size
    learning_rate=2e-4,                  # Initial learning rate
    warmup_steps=10,                     # Number of warmup steps for the learning rate scheduler

    logging_steps=10,                    # Log training metrics every N steps
    save_steps=100,                      # Save a checkpoint every N steps
    save_total_limit=2,                  # Keep only the last 2 checkpoints to save disk space

    fp16=False,                          # Disable mixed precision (enable if training on GPU with FP16 support)
    logging_dir="./logs",                # Directory for TensorBoard logs
    report_to="none"                     # Disable external experiment trackers (e.g., WandB)
)

In [None]:
# Data collator for causal language modeling
# Handles dynamic padding and batch preparation during training
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False               # Disable masked language modeling (use causal LM objective)
)

# Initialize the Hugging Face Trainer
trainer = Trainer(
    model=model,            # LoRA-adapted language model
    args=training_args,     # Training configuration
    train_dataset=tokenized_train,  # Tokenized training dataset
    data_collator=data_collator,     # Data collator for batching
    tokenizer=tokenizer,    # Tokenizer (used for padding and decoding)
)

# Start the training process
trainer.train()

Step,Training Loss
10,2.2218
20,1.5711
30,1.4275
40,1.3838
50,1.3508
60,1.4126
70,1.3963
80,1.316
90,1.2965
100,1.226


TrainOutput(global_step=375, training_loss=1.1862285041809082, metrics={'train_runtime': 584.0308, 'train_samples_per_second': 2.568, 'train_steps_per_second': 0.642, 'total_flos': 7505515118592000.0, 'train_loss': 1.1862285041809082, 'epoch': 3.0})

In [36]:
model.save_pretrained("./llama3_medical_lora")
tokenizer.save_pretrained("./llama3_medical_lora")

('./llama3_medical_lora/tokenizer_config.json',
 './llama3_medical_lora/special_tokens_map.json',
 './llama3_medical_lora/chat_template.jinja',
 './llama3_medical_lora/vocab.json',
 './llama3_medical_lora/merges.txt',
 './llama3_medical_lora/added_tokens.json',
 './llama3_medical_lora/tokenizer.json')

#### Evaluation & Performance

In [None]:
train = dataset["train"]

# Parameters
train_cutoff = 1000
n_eval_examples = 20
seed = 42

# Explicitly define dataset splits
# Training set: indices [0, train_cutoff - 1]
# Test set: indices [train_cutoff, end]
train_set = train.select(range(0, min(train_cutoff, len(train))))
test_set = train.select(range(train_cutoff, len(train)))

# Reproducibly sample examples from the test set
rng = random.Random(seed)
selected_indices = rng.sample(range(len(test_set)), n_eval_examples)

# Create a small evaluation subset
eval_subset = test_set.select(selected_indices)

# Print dataset statistics and sanity checks
print(f"Total dataset size: {len(train)}")
print(f"Training set indices: 0 to {len(train_set) - 1}")
print(f"Test set indices: {train_cutoff} to {len(train) - 1}")
print(f"Selected {len(selected_indices)} test examples")
print(f"First 5 sampled indices in test split: {selected_indices[:5]}")

Total dataset size: 182822
Training set indices: 0 to 999
Test set indices: 1000 to 182821
Selected 20 test examples
First 5 sampled indices in test split: [167621, 29184, 6556, 72097, 64196]


In [None]:
def build_medmcqa_prompt(example):
    """
    Construct a simple multiple-choice question (MCQ) text
    from a MedMCQA dataset example.
    """

    # Extract question and answer options from the example
    question = example["question"]
    opa = example["opa"]
    opb = example["opb"]
    opc = example["opc"]
    opd = example["opd"]

    # Format the MCQ as a string with options labeled A-D
    mcq_text = (
        f"{question}\n"
        f"A. {opa}\n"
        f"B. {opb}\n"
        f"C. {opc}\n"
        f"D. {opd}"
    )

    return mcq_text

In [None]:
def get_prediction_from_example(example, max_tokens=50):
    """
    Generate a model prediction for a single MedMCQA example.

    Args:
        example (dict): A dataset example containing 'question' and options.
        max_tokens (int): Maximum number of tokens to generate.

    Returns:
        str: The model's generated answer (cleaned of special tokens).
    """

    # 1. Build the multiple-choice question (MCQ) text for the "user"
    user_content = build_medmcqa_prompt(example)

    # 2. Wrap the MCQ in the same chat-style template used for fine-tuning
    prompt = (
        "<|begin_of_text|>"
        "<|start_header_id|>system<|end_header_id|>\n\n"
        "You are a helpful AI medical assistant.\n"
        "<|eot_id|>"
        "<|start_header_id|>user<|end_header_id|>\n\n"
        f"{user_content}"
        "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    )

    # Tokenize the prompt and move tensors to the correct device (CPU/GPU)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    # Generate predictions without computing gradients (inference mode)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,  # Limit output length
            temperature=0.3,            # Low temperature for more deterministic output
            top_p=0.9,                  # Nucleus sampling for diversity
            do_sample=True,             # Enable sampling
            pad_token_id=tokenizer.eos_token_id,  # Ensure proper padding
        )

    # Decode the generated token IDs back to text (include special tokens)
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=False)

    # 3. Extract only the newly generated part (remove the prompt)
    generated = full_text[len(prompt):].strip()

    # Optional: clean up any leftover special tokens
    for tok in ["<|eot_id|>", "<|end_of_text|>"]:
        generated = generated.split(tok)[0].strip()

    return generated

In [None]:
# Set of common stop words to ignore when computing partial similarity
STOP_WORDS = {
    "the", "a", "an", "is", "are", "was", "were", "of", "in", "to", "for",
    "with", "on", "at", "and", "or", "by", "from"
}

def _normalize(text: str) -> str:
    """
    Normalize text for comparison:
    - Convert to lowercase
    - Remove punctuation
    - Collapse multiple spaces and strip leading/trailing spaces
    """
    # Convert to lowercase
    text = text.lower()
    # Remove punctuation
    text = text.translate(str.maketrans("", "", string.punctuation))
    # Normalize spaces
    text = re.sub(r"\s+", " ", text).strip()
    return text

def check_accuracy(prediction: str, ground_truth: str, partial_threshold: float = 0.7):
    """
    Compare a model prediction against the ground truth.

    Returns:
        (bool, str): Tuple indicating if it's a match and the match type:
                     - "exact" for exact match
                     - "partial" for token-level similarity above threshold
                     - "no_match" if no significant overlap
    """

    # Normalize both prediction and ground truth
    pred_norm = _normalize(prediction)
    truth_norm = _normalize(ground_truth)

    # 1. Exact match (after normalization)
    if truth_norm and truth_norm in pred_norm:
        return True, "exact"

    # 2. Partial similarity on "informative" tokens (exclude stop words)
    truth_tokens = [w for w in truth_norm.split() if w not in STOP_WORDS]
    pred_tokens = [w for w in pred_norm.split() if w not in STOP_WORDS]

    # If either has no informative tokens, consider it a non-match
    if not truth_tokens or not pred_tokens:
        return False, "no_match"

    # Convert tokens to sets for Jaccard similarity
    truth_set = set(truth_tokens)
    pred_set = set(pred_tokens)

    # Compute Jaccard similarity: intersection / union
    intersection = len(truth_set & pred_set)
    union = len(truth_set | pred_set)

    if union == 0:
        return False, "no_match"

    jaccard = intersection / union

    # Consider it a partial match if similarity exceeds threshold
    if jaccard >= partial_threshold:
        return True, "partial"

    # Otherwise, no match
    return False, "no_match"

In [None]:
results = []
correct_exact = 0
correct_partial = 0
start_time = time.time()

n_total = len(selected_indices)  # Total number of evaluation examples

# Iterate over the sampled test examples
for i, idx in enumerate(selected_indices, 1):
    example = test_set[idx]

    # Extract MedMCQA fields
    question = example["question"]
    opa = example["opa"]
    opb = example["opb"]
    opc = example["opc"]
    opd = example["opd"]
    cop = example["cop"]  # Correct option: 'a'/'b'/'c'/'d'

    # Map the correct option to its text
    option_map = {"a": opa, "b": opb, "c": opc, "d": opd, 0: opa, 1: opb, 2: opc, 3: opd}
    ground_truth = option_map.get(cop, "")

    # Construct the MCQ text seen by the user (matching fine-tuning format)
    user_question = (
        f"{question}\n"
        f"A. {opa}\n"
        f"B. {opb}\n"
        f"C. {opc}\n"
        f"D. {opd}"
    )

    # Print example info for debugging / tracking
    print(f"\nTEST {i}/{n_total}")
    print(f"Question: {question[:100]}...")
    print(f"Options: A) {opa[:30]} | B) {opb[:30]} | C) {opc[:30]} | D) {opd[:30]}")
    print(f"Ground Truth ({cop}): {ground_truth}")

    # Call the model to get a prediction (chat-style prompt)
    prediction = get_prediction_from_example(example)
    print(f"Prediction: {prediction[:100]}...")

    # Evaluate the prediction against the ground truth
    correct, match_type = check_accuracy(prediction, ground_truth)

    # Update counters for exact or partial matches
    if correct:
        if match_type == "exact":
            correct_exact += 1
        else:
            correct_partial += 1
        print(f"CORRECT ({match_type})")
    else:
        print("INCORRECT")

    # Store the results for later analysis
    results.append({
        "id": example.get("id"),
        "question": question,
        "opa": opa,
        "opb": opb,
        "opc": opc,
        "opd": opd,
        "cop": cop,
        "ground_truth": ground_truth,
        "prediction": prediction,
        "correct": correct,
        "match_type": match_type,
    })

    # Print running accuracy after each example
    accuracy = (correct_exact + correct_partial) / i * 100
    print(f"Running accuracy: {accuracy:.1f}% "
          f"({correct_exact + correct_partial}/{i} ; exact={correct_exact}, partial={correct_partial})")

# Total evaluation time
total_time = time.time() - start_time
print(f"\nFinished {n_total} examples in {total_time:.1f} seconds.")
print(f"Final accuracy: {((correct_exact + correct_partial) / n_total) * 100:.1f}% "
      f"(exact={correct_exact}, partial={correct_partial})")


TEST 1/20
Question: Which of the following is found in the respiratory zone of the lung?...
Options: A) Goblet cells | B) Main bronchi | C) Mucous cells | D) Type I epithelial cells
Ground Truth (3): Type I epithelial cells
Prediction: 1. The answer is: Mucous cells...
INCORRECT
Running accuracy: 0.0% (0/1 ; exact=0, partial=0)

TEST 2/20
Question: Which of the following does not occur in starvation?...
Options: A) Hypoglycemia | B) Hypercholesterolemia | C) Lipolyiss | D) Ketoacidosis
Ground Truth (1): Hypercholesterolemia
Prediction: 1. The answer is: Ketoacidosis...
INCORRECT
Running accuracy: 0.0% (0/2 ; exact=0, partial=0)

TEST 3/20
Question: A 20 month old female child is brought for routine check-up. Complete blood count (CBC) shows modera...
Options: A) Corticosteroid administration | B) Multivitamin administration | C) Watch and wait strategy | D) Antibiotics to prevent infecti
Ground Truth (2): Watch and wait strategy
Prediction: 1. The answer is: Watch and wait strategy...

In [None]:
n_total = len(selected_indices)
n_correct = correct_exact + correct_partial
n_incorrect = n_total - n_correct

# Compute overall accuracy as a percentage
accuracy = (n_correct / n_total) * 100 if n_total > 0 else 0.0

# Print evaluation summary
print("\n=== Evaluation summary ===")
print(f"Total: {n_correct}/{n_total} correct ({accuracy:.1f}%)")
print(f"  Exact matches   : {correct_exact}")
print(f"  Partial matches : {correct_partial}")
print(f"  Incorrect       : {n_incorrect}")

# Print timing statistics if available
if total_time > 0:
    print(f"\nTotal time   : {total_time:.1f}s")
    print(f"Per example  : {total_time / n_total:.2f}s/example")
else:
    print("\nTotal time   : 0.0s")


=== Evaluation summary ===
Total: 6/20 correct (30.0%)
  Exact matches   : 6
  Partial matches : 0
  Incorrect       : 14

Total time   : 66.8s
Per example  : 3.34s/example


In [None]:
n_total = len(selected_indices)
n_correct = correct_exact + correct_partial
n_incorrect = n_total - n_correct

# Prepare a dictionary summarizing the evaluation results
results_summary = {
    "n_total": n_total,                                                 # Total examples evaluated
    "accuracy": accuracy,                                               # Overall accuracy (%)
    "exact_matches": correct_exact,                                     # Number of exact matches
    "partial_matches": correct_partial,                                 # Number of partial matches
    "incorrect": n_incorrect,                                           # Number of incorrect predictions
    "total_time": total_time,                                           # Total evaluation time in seconds
    "time_per_example": total_time / n_total if n_total > 0 else None,  # Average time per example
    "selected_indices": list(map(int, selected_indices)),     # List of evaluated indices
    "results": results                                                  # List of individual example results
}

# Define the path for saving the evaluation results
output_path = Path("evaluation_results.json")

# Save the results summary as a JSON file
with output_path.open("w", encoding="utf-8") as f:
    json.dump(results_summary, f, indent=2, ensure_ascii=False)

# Confirm that results were saved
print(f"Results saved to: {output_path.resolve()}")

Results saved to: /content/evaluation_results.json


#### Model Improvement Strategies

1. Based on your evaluation results, propose at least 2 or 3 specific strategies to improve your model's accuracy. For each strategy, explain what you would change, why it helps, and potential trade-offs.

Accuracy is low (30% on 20 questions), so the main levers are better supervision, better decoding, and better use of the base model’s capacity. Below are three concrete, defensible strategies.

**Improve prompt formatting & decoding**

- The evaluation shows many predictions that include extra text, question restatements, or numbered lists instead of a clean option or answer string, which then fail the matcher despite being semantically right in some cases.

- Change:
    - Constrain generation with a system/user instruction like “Answer with exactly one option (A, B, C, or D) and nothing else”.

    - At evaluation time, post-process the model output to extract only the option letter or the answer span before comparing to the ground truth.

- Why it helps: Reduces “formatting” errors such as answering “1. The answer is: Varicella” when the expected output is just “Varicella” or the option index.

- Trade-offs:
    - Slightly more engineering in the evaluation script and prompting.
    - If the post‑processing is brittle, it may mis‑parse some valid free‑form answers.

**Increase & balance the fine‑tuning data**

- The model was fine‑tuned only on 500 training examples, which is tiny compared with the difficulty and variety of MedMCQA questions seen in evaluation.

- Change:
    - Use many more MedMCQA training samples (e.g., the full train split or at least several thousand examples).

    - Optionally upsample under‑represented specialties or question types where errors are concentrated (e.g., oncology, ophthalmology, ENT), which appear repeatedly among wrong predictions.

- Why it helps: More coverage of patterns and medical subdomains reduces random guessing and improves generalization across rare entities and treatments.

- Trade‑offs:
    - Higher compute and training time, possibly requiring more aggressive LoRA or lower max_length to fit in memory.
    
    - Risk of overfitting if training too long without validation monitoring.

**Strengthen the supervision signal**

- Currently, the model is trained with a standard causal LM objective over the whole prompt plus answer, which means it also learns to copy the question and template tokens rather than focusing purely on predicting the correct option.

- Change:
    - Mask the loss so that it is applied only on the answer segment (e.g., the “The answer is: …” part).

    - Alternatively, reformat data so the target is just the option letter or short answer, and treat this as a short‑sequence generation problem.

- Why it helps: Directly optimizes the part of the output that matters for accuracy, instead of wasting capacity on reconstructing the prompt.

- Trade‑offs:
    - Requires modifying the data collator or label construction logic.

    - The model becomes specialized for this answer style and may be less fluent for longer generative explanations without additional fine‑tuning.


2. Review your incorrect predictions and identify patterns in failures. What can you tell about the model errors ?

The errors are mostly systematic rather than random guessing, with several recurring patterns in how the model fails.

**Misclassification across plausible options**

- In many cases the model chooses a medically plausible but wrong option (e.g., picking “Varicella” instead of “HIV” for minimal teratogenic risk, or “Endometrial carcinoma” instead of “Renal cell carcinoma” for polycythemia).

- This suggests limited factual grounding or confusion between related risk factors and tumor paraneoplastic associations rather than purely random outputs.

**Confusion in “except / not” style questions**

- For questions framed as “all are true except” or “which is not used,” the model often selects a true statement instead of the requested exception (e.g., xenon anesthesia question, agents not used in diabetic macular edema).

- This indicates the model struggles with negation and exception reasoning, a common weakness in LMs, especially under few‑shot or weak supervision.

**Domain gaps & rare facts**

- Several errors are on niche topics (specific lymphoma epidemiology, named surgical procedures, unusual drug uses) where the correct answer is a relatively rare fact, and the model picks a more common entity instead.

- This pattern suggests insufficient exposure during fine‑tuning (only 500 examples) and reliance on prior general‑domain knowledge that is not specialized enough for fine‑grained medical trivia.

3. What do you think it's better between training on 2000 examples (same quality) or 500 curated high-quality examples ?

With the same label quality and format, 2000 reasonably good examples are typically better for this task than 500 curated ones, given your current low data regime and broad domain coverage needs.

**Why 2000 > 500 here**

- The model currently sees only 500 MedMCQA questions while being evaluated on diverse, fine‑grained medical facts (tumors, infections, anesthesia, etc.), and many errors look like domain‑coverage gaps rather than noise from bad labels.

- Adding more examples (up to 2000) improves coverage of specialties and question templates, which is crucial for a general medical MCQ model, and LoRA can handle that scale easily on our setup.

**When 500 curated could win**

- If “curated” means carefully chosen hard/representative items that target known weaknesses (negation, “except” questions, rare entities) and the alternative 2000 set contains substantial label noise or badly formatted prompts, the smaller curated set can yield better generalization per example.​

- However, our notebook already filters out obviously low‑quality items (short questions, missing answers), and there is no strong evidence of heavy label noise, so in this specific setup the larger 2000‑example training set is the safer choice.

#### Resource-Constrained Inference

4. How can you design a strategy to reduce inference time/memory for deployment in constrained environments ?

A good strategy is to combine parameter‑efficient fine‑tuning (LoRA), quantization, and hardware‑aware deployment so that only a small, compressed adapter‑augmented model is served.

**Use LoRA & freeze the base**

- Keep the current setup where only ~1% of parameters are trainable adapters and the base SmolLM2 model is frozen.

- At deployment, we load the pretrained base weights once and merge or attach only the small LoRA layers, which reduces the size of the fine‑tuned delta we need to ship and store.

**Apply post‑training quantization**

- Quantize the model to 8‑bit or 4‑bit weights (e.g., with bitsandbytes or similar libraries) so that parameters and activations use fewer bytes, lowering both memory footprint and bandwidth.

- This typically yields large memory and latency savings with modest accuracy loss, especially on relatively small models like 1.7B parameters.

**Optimize decoding & batching**

- Constrain generation for MCQs (e.g., short max_length, greedy or low‑beam search) since you only need one option, which cuts per‑request compute.

- If the use case allows, batch multiple questions per forward pass and reuse the same loaded model instance, improving throughput without increasing model size.

5. Analyze how changing generation parameters affects speed, quality, and consistency.

Generation parameters trade off speed, output quality, and consistency in predictable ways for your MCQ setting.

**Max length & early stopping**

- Lower `max_new_tokens` (or `max_length`) speeds up inference almost linearly, because fewer tokens are generated, but can truncate explanations or long answers if set too low.

- In your MCQ task, a small limit (e.g., 10–20 tokens) is usually enough and improves latency without hurting accuracy, since only a short option/phrase is needed.

**Decoding strategy (greedy vs sampling vs beam)**

- Greedy decoding (no sampling, `top_p`/`temperature` off) is fastest and most consistent because each step picks the argmax token, but may be slightly less robust if the model’s next‑token distribution is poorly calibrated.

- Sampling with higher `temperature` or `top_p` increases diversity and can improve answer quality in open‑ended tasks, but here it mainly adds randomness and reduces answer consistency across runs.

- Beam search can improve exact‑match quality by exploring multiple candidate sequences but increases latency roughly proportional to beam size and is often unnecessary when the desired output is a single short label.

**Temperature, top‑k & top‑p**

- Lower `temperature` (e.g., 0.1–0.3) and small `top_p`/`top_k` make outputs more deterministic and regular, improving consistency across repeated calls but possibly locking in systematic biases or mistakes.​

- Higher `temperature` or larger `top_p`/`top_k` allow exploration, which can occasionally find the correct option when the greedy path is wrong, but at the cost of unstable answers and slower average decoding if more tokens are sampled before reaching an end token.



#### Evaluation Methodology

6. Analyze limitations of current exact/partial match evaluation and propose improvements. Do you think you have false negatives or false positives ? What can we do about it ?

The current evaluation is prone mainly to false negatives (correct or acceptable answers marked wrong) because it relies on brittle string matching of fairly free‑form generations.

**Limitations & likely false negatives**

- Many outputs contain the right answer embedded in extra text or formatting, such as “1. The answer is: Varicella” when the ground truth is “Varicella”; these are counted as incorrect because the evaluation expects an exact string or tightly constrained form.

- Some predictions restate or slightly rephrase the option text (e.g., “lymphocyte‑predominant HD” vs “lymphocyte‑predominant Hodgkin disease (HD)”), which are semantically equivalent but scored as no‑match due to surface differences.

**Possible false positives**

- There are fewer indications of false positives, but they can occur if the evaluation only checks for the presence of the correct string anywhere in the output; in that case, a long explanation that mentions multiple options could be mis‑scored as correct even if the final choice is wrong.

- For example, a model might list all options or discuss them and incidentally include the correct phrase without clearly selecting it, which a naive substring check might accept.

**Proposed improvements**

- Constrain generation & labels: Force the model to answer with a single letter (A/B/C/D) or a fixed pattern like “Answer: A”, and store ground truth as the option index; evaluation then compares normalized letters instead of full strings.

- Robust post‑processing: Strip prefixes like “1.”, “The answer is:”, and surrounding punctuation; extract the final option letter or answer span using regex/heuristics before matching.

- Semantic or rule‑based matching: Normalize case, remove parentheses/abbreviations, and allow small edit distances so that “lymphocyte‑predominant HD” matches the longer canonical form.

- Human spot‑check: For a small sample, manually inspect “incorrect” and “partial” cases to quantify how often the model is actually right but mis‑scored, then refine the matching rules accordingly.


7. Test other test size and observe the result. What can you say about the results ? How can you improve it ?

With the current setup (20‑question sample, 30% accuracy), the test size is very small and the metric is noisy; increasing or varying test size will mainly stabilize the estimate and reveal variance across question subsets.

**What changing test size shows**

- On only 20 items, each individual question shifts accuracy by 5 percentage points, so small changes in sampled indices can give quite different reported performance, especially with many borderline cases and formatting‑sensitive scoring.

- Using larger test sizes (e.g., 50, 100, 200 questions) will smooth out this sampling noise and give a more reliable picture of true accuracy, and may show that performance varies by topic (e.g., pediatrics vs oncology vs ENT).

**How to improve the evaluation ?**

- Increase `n_total` in our evaluation script to test on more randomly sampled questions (or the full validation/test split) so that metrics are statistically more stable.

- Stratify the test set by specialty or question type so you can see systematic weaknesses instead of a single aggregate accuracy, then focus fine‑tuning and prompt design on those weak categories.

#### Real-World Deployment Scenario

8. What can you do to address safety, reliability, updates, and edge cases for deploying in a medical assistance application ?

For deployment as a medical assistance tool, you would need a multi‑layer approach: strict scope limits, guardrails around the model, continuous validation, and clear user‑facing safety mechanisms.

**Safety & scope**

- Explicitly restrict the model’s scope to educational/decision‑support use, with prompts and system instructions that forbid giving definitive diagnoses, prescriptions, or emergency instructions; always include a disclaimer and encourage consulting a clinician.

- Add rule‑based and learned filters that block or heavily qualify outputs in sensitive areas (e.g., dosing, oncologic treatment plans, pregnancy‑related advice, pediatrics) and instead respond with safe guidance like “consult a specialist.”

**Reliability & evaluation**

- Before deployment, run large‑scale evaluation on curated medical benchmarks and internal test sets, including manual review by domain experts, especially for high‑risk topics where our 30% accuracy shows current unreliability.​

- In production, implement monitoring: log queries and sampled outputs (with proper privacy protections), track error reports and adverse feedback, and regularly re‑evaluate performance to detect regressions over time.

**Updates & versioning**

- Maintain a clear model versioning scheme (base model version, fine‑tune date, dataset snapshot) so that any change to guidelines or medical knowledge results in a new, traceable release.

- Periodically update training data with new guidelines, drug warnings, and protocols, and run regression tests to ensure that improvements in one area do not degrade performance elsewhere.

**Handling edge cases**

- Detect out‑of‑scope or ambiguous queries (e.g., very rare diseases, incomplete clinical context, medico‑legal questions) via confidence estimates or auxiliary classifiers, and respond by asking for clarification or deferring to human clinicians rather than improvising.

- For critical domains (e.g., emergency medicine, intensive care), design workflows that require human approval: the model can propose differential diagnoses or questions to ask, but a clinician must review and confirm any actionable recommendation.