# MedMCQA Robustness Study - Google Colab Demo

This notebook demonstrates running the Medical MCQ Robustness experiments on Google Colab.

**Requirements:**
- Google Colab with GPU runtime (T4 for 4B model, A100 for 27B model)
- Hugging Face account with access to MedGemma models

**Note:** MedGemma requires accepting the license at https://huggingface.co/google/medgemma-4b-it

## 1. Setup Environment

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Install dependencies
!pip install -q torch transformers datasets bitsandbytes accelerate huggingface_hub

In [None]:
# Login to Hugging Face (required for MedGemma access)
from huggingface_hub import login
login()  # This will prompt for your HF token

## 2. Load MedGemma Model

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Choose model based on available GPU memory
# - T4 (16GB): Use 4B model
# - A100 (40GB): Can use 27B with 4-bit quantization (but has issues, see below)
# - A100 (80GB): Can use 27B full precision

MODEL_ID = "google/medgemma-4b-it"  # Use 4B for T4 GPU
# MODEL_ID = "google/medgemma-27b-text-it"  # Use 27B for A100 80GB

print(f"Loading {MODEL_ID}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# For 4B model on T4
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# NOTE: For 27B model, do NOT use 4-bit quantization - it produces NaN logits!
# Use full precision bfloat16 on A100 80GB instead:
# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_ID,
#     torch_dtype=torch.bfloat16,
#     device_map="auto"
# )

model.eval()
print(f"Model loaded on {model.device}")

## 3. Test Model with a Simple Query

In [None]:
def generate_response(prompt, max_new_tokens=50):
    """Generate a response from the model."""
    messages = [{"role": "user", "content": prompt}]
    chat_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(chat_text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id
        )
    
    new_tokens = outputs[0, inputs['input_ids'].shape[1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True)

# Test query
response = generate_response("What is the mechanism of action of aspirin?")
print(f"Response: {response}")

## 4. Load MedMCQA Dataset

In [None]:
from datasets import load_dataset

# Load MedMCQA validation set
dataset = load_dataset("openlifescienceai/medmcqa", split="validation")
print(f"Loaded {len(dataset)} questions")

# Show a sample
sample = dataset[0]
print(f"\nSample question:")
print(f"Question: {sample['question']}")
print(f"A: {sample['opa']}")
print(f"B: {sample['opb']}")
print(f"C: {sample['opc']}")
print(f"D: {sample['opd']}")
print(f"Correct: {['A', 'B', 'C', 'D'][sample['cop']]}")

## 5. Run MCQ Evaluation

In [None]:
import re
from tqdm import tqdm

def format_mcq_prompt(item, style="zero_shot"):
    """Format a MedMCQA item into a prompt."""
    question = item['question']
    options = [
        f"A. {item['opa']}",
        f"B. {item['opb']}",
        f"C. {item['opc']}",
        f"D. {item['opd']}"
    ]
    
    if style == "zero_shot":
        prompt = f"""Answer the following medical question by selecting the correct option.

Question: {question}

{chr(10).join(options)}

Answer with just the letter (A, B, C, or D):"""
    elif style == "zero_shot_cot":
        prompt = f"""Answer the following medical question. Think step by step.

Question: {question}

{chr(10).join(options)}

Let's think step by step, then provide the answer as a single letter (A, B, C, or D):"""
    
    return prompt

def parse_answer(response):
    """Extract the answer letter from model response."""
    response = response.strip().upper()
    
    # Direct letter match
    if response and response[0] in 'ABCD':
        return response[0]
    
    # Look for patterns like "The answer is A" or "A."
    patterns = [
        r'answer\s*(?:is)?\s*([A-D])',
        r'([A-D])\s*[.)]',
        r'\b([A-D])\b'
    ]
    
    for pattern in patterns:
        match = re.search(pattern, response, re.IGNORECASE)
        if match:
            return match.group(1).upper()
    
    return "unknown"

def evaluate_mcq(dataset, num_samples=20, style="zero_shot"):
    """Evaluate model on MCQ dataset."""
    correct = 0
    total = 0
    results = []
    
    for i, item in enumerate(tqdm(dataset.select(range(num_samples)))):
        prompt = format_mcq_prompt(item, style)
        response = generate_response(prompt, max_new_tokens=100 if style == "zero_shot_cot" else 20)
        predicted = parse_answer(response)
        expected = ['A', 'B', 'C', 'D'][item['cop']]
        
        is_correct = predicted == expected
        if is_correct:
            correct += 1
        total += 1
        
        results.append({
            'question': item['question'][:50] + '...',
            'predicted': predicted,
            'expected': expected,
            'correct': is_correct
        })
    
    accuracy = correct / total if total > 0 else 0
    return accuracy, results

# Run evaluation on a small sample
print("Evaluating on 20 samples...")
accuracy, results = evaluate_mcq(dataset, num_samples=20, style="zero_shot")
print(f"\nAccuracy: {accuracy:.1%}")

In [None]:
# Show detailed results
import pandas as pd

df = pd.DataFrame(results)
display(df)

## 6. Experiment: Option Order Sensitivity

In [None]:
import random

def shuffle_options(item, seed=42):
    """Shuffle options and return new item with mapping."""
    random.seed(seed)
    
    options = [
        ('A', item['opa']),
        ('B', item['opb']),
        ('C', item['opc']),
        ('D', item['opd'])
    ]
    
    original_answer = ['A', 'B', 'C', 'D'][item['cop']]
    
    # Shuffle
    shuffled = options.copy()
    random.shuffle(shuffled)
    
    # Create mapping and find new answer position
    new_answer = None
    for new_pos, (orig_letter, text) in enumerate(shuffled):
        if orig_letter == original_answer:
            new_answer = ['A', 'B', 'C', 'D'][new_pos]
            break
    
    # Create new item
    new_item = item.copy()
    new_item['opa'] = shuffled[0][1]
    new_item['opb'] = shuffled[1][1]
    new_item['opc'] = shuffled[2][1]
    new_item['opd'] = shuffled[3][1]
    new_item['cop'] = ['A', 'B', 'C', 'D'].index(new_answer)
    
    return new_item

# Test option order sensitivity
print("Testing option order sensitivity on 10 samples...")

consistency_results = []
for i, item in enumerate(tqdm(dataset.select(range(10)))):
    # Original order
    prompt_orig = format_mcq_prompt(item)
    response_orig = generate_response(prompt_orig, max_new_tokens=20)
    pred_orig = parse_answer(response_orig)
    
    # Shuffled order
    shuffled_item = shuffle_options(item, seed=i)
    prompt_shuffled = format_mcq_prompt(shuffled_item)
    response_shuffled = generate_response(prompt_shuffled, max_new_tokens=20)
    pred_shuffled = parse_answer(response_shuffled)
    
    # Check if predictions are consistent (same content, different letter is OK)
    expected_orig = ['A', 'B', 'C', 'D'][item['cop']]
    expected_shuffled = ['A', 'B', 'C', 'D'][shuffled_item['cop']]
    
    orig_correct = pred_orig == expected_orig
    shuffled_correct = pred_shuffled == expected_shuffled
    
    consistency_results.append({
        'original_pred': pred_orig,
        'shuffled_pred': pred_shuffled,
        'original_correct': orig_correct,
        'shuffled_correct': shuffled_correct,
        'consistent': orig_correct == shuffled_correct
    })

df_consistency = pd.DataFrame(consistency_results)
print(f"\nOriginal accuracy: {df_consistency['original_correct'].mean():.1%}")
print(f"Shuffled accuracy: {df_consistency['shuffled_correct'].mean():.1%}")
print(f"Consistency rate: {df_consistency['consistent'].mean():.1%}")
display(df_consistency)

## 7. Load PubMedQA Dataset

In [None]:
# Load PubMedQA
pubmedqa = load_dataset("qiaojin/PubMedQA", "pqa_labeled", split="train")
print(f"Loaded {len(pubmedqa)} PubMedQA questions")

# Show sample
sample = pubmedqa[0]
print(f"\nQuestion: {sample['question']}")
print(f"Answer: {sample['final_decision']}")
print(f"Context: {sample['context']['contexts'][0][:200]}...")

In [None]:
def evaluate_pubmedqa(dataset, num_samples=10, use_context=True):
    """Evaluate on PubMedQA."""
    correct = 0
    results = []
    
    for item in tqdm(dataset.select(range(num_samples))):
        if use_context:
            context = " ".join(item['context']['contexts'])
            prompt = f"""Based on the following context, answer the research question.

Context: {context[:1500]}

Question: {item['question']}

Answer with one word: yes, no, or maybe."""
        else:
            prompt = f"""Based on your medical knowledge, answer the following research question.

Question: {item['question']}

Answer with one word: yes, no, or maybe."""
        
        response = generate_response(prompt, max_new_tokens=20)
        response_lower = response.strip().lower()
        
        # Parse answer
        if 'yes' in response_lower:
            predicted = 'yes'
        elif 'no' in response_lower:
            predicted = 'no'
        elif 'maybe' in response_lower:
            predicted = 'maybe'
        else:
            predicted = 'unknown'
        
        expected = item['final_decision']
        is_correct = predicted == expected
        if is_correct:
            correct += 1
        
        results.append({
            'question': item['question'][:50] + '...',
            'predicted': predicted,
            'expected': expected,
            'correct': is_correct
        })
    
    accuracy = correct / num_samples
    return accuracy, results

# Test with and without context
print("Evaluating PubMedQA WITH context...")
acc_with_ctx, res_with = evaluate_pubmedqa(pubmedqa, num_samples=10, use_context=True)
print(f"Accuracy with context: {acc_with_ctx:.1%}")

print("\nEvaluating PubMedQA WITHOUT context...")
acc_without_ctx, res_without = evaluate_pubmedqa(pubmedqa, num_samples=10, use_context=False)
print(f"Accuracy without context: {acc_without_ctx:.1%}")

## 8. Summary

This notebook demonstrated:
1. Loading MedGemma models on Colab
2. Evaluating on MedMCQA multiple choice questions
3. Testing option order sensitivity
4. Evaluating on PubMedQA with/without context

**Key Findings:**
- MedGemma-4B works well on T4 GPU
- MedGemma-27B requires A100 80GB with full precision (4-bit quantization has NaN issues)
- Option shuffling can affect model predictions
- Context generally improves PubMedQA performance