# üß¨ Reproducing Med-PaLM M: Towards Generalist Biomedical AI

**Paper:** Tu et al., "Towards Generalist Biomedical AI" (arXiv:2307.14334)

**What this notebook does:**
1. Clones the project repo
2. Downloads the VQA-RAD dataset
3. Runs data sanity checks
4. Establishes zero-shot baseline (BLIP-2 without training)
5. Verifies training pipeline (overfit test)
6. Fine-tunes on VQA-RAD and evaluates against paper baselines
7. Runs generalization experiments
8. Generates final comparison table

**Requirements:** GPU runtime (T4 is sufficient)

---

‚ö†Ô∏è **FIRST: Enable GPU** ‚Üí Runtime ‚Üí Change runtime type ‚Üí T4 GPU

## Step 0: Verify GPU and Setup Environment

In [None]:
# Verify GPU is available
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    print("‚ùå NO GPU DETECTED ‚Äî Go to Runtime ‚Üí Change runtime type ‚Üí T4 GPU")

In [None]:
# Clone the repository
!git clone https://github.com/Mrabbi3/biomed-multimodal-reproduction.git
%cd biomed-multimodal-reproduction

In [None]:
# Install dependencies
!pip install -q transformers>=4.36.0 accelerate>=0.25.0 peft>=0.7.0 \
    bitsandbytes datasets Pillow tqdm pyyaml nltk rouge-score \
    matplotlib seaborn evaluate

---
## Phase 1: Download Data & Sanity Check

In [None]:
# Download VQA-RAD dataset (~50 MB)
!python data/download.py --dataset vqa_rad

In [None]:
# Run sanity checks ‚Äî verifies data loads, preprocessing works, metrics are correct
!python experiments/01_data_sanity_check.py

In [None]:
# Visualize the sanity check output
from IPython.display import Image, display
import os
if os.path.exists("results/figures/vqa_rad_sanity_check.png"):
    display(Image(filename="results/figures/vqa_rad_sanity_check.png"))
    print("‚úì Images and questions look correct!")
else:
    print("Sanity check image not generated ‚Äî check errors above")

---
## Phase 2: Zero-Shot Baseline (No Training)

We load BLIP-2 and test it on VQA-RAD **without any fine-tuning**.
This is comparable to the paper's PaLM-E 84B baseline (BLEU-1: 59.19%).

In [None]:
# Run forward pass test ‚Äî establishes zero-shot baseline
# This downloads BLIP-2 (~7GB) on first run
!python experiments/02_forward_pass_test.py --model blip2 --max_samples 50 --quantize

In [None]:
# View the baseline results
import json
if os.path.exists("results/tables/baseline_metrics.json"):
    with open("results/tables/baseline_metrics.json") as f:
        baseline = json.load(f)
    print(f"Zero-Shot Baseline Results:")
    print(f"  BLEU-1: {baseline['bleu_1']:.2f}%")
    print(f"  F1:     {baseline['f1']:.2f}%")
    print(f"\nPaper comparison:")
    print(f"  PaLM-E 84B (no finetune): BLEU-1=59.19%, F1=38.67%")

---
## Phase 3: Overfit Test (Verify Training Works)

Before real training, we memorize 5 examples to verify:
- LoRA adapters are applied correctly
- Gradients flow through the model
- Loss decreases toward zero

If this fails, there's a bug. If it passes, the pipeline is trustworthy.

In [None]:
# Overfit 5 examples for 50 epochs ‚Äî should memorize them
!python experiments/03_overfit_single_batch.py --num_samples 5 --epochs 50 --quantize

---
## Phase 4: Full Training & Evaluation üöÄ

This is the main experiment. We fine-tune BLIP-2 on the VQA-RAD training set
using LoRA, then evaluate on the test set and compare to Med-PaLM M baselines.

**Expected time:** ~20-40 minutes on a T4 GPU for 10 epochs.

In [None]:
# Full training run
!python experiments/04_train_vqa.py \
    --dataset vqa_rad \
    --epochs 10 \
    --batch_size 4 \
    --lr 5e-5 \
    --lora_rank 16 \
    --grad_accum 4 \
    --quantize \
    --use_exemplar

In [None]:
# View training log
if os.path.exists("results/logs/training_log.json"):
    with open("results/logs/training_log.json") as f:
        log = json.load(f)
    
    import matplotlib.pyplot as plt
    
    epochs = [e["epoch"] for e in log]
    train_loss = [e["train_loss"] for e in log]
    
    plt.figure(figsize=(8, 4))
    plt.plot(epochs, train_loss, 'b-o', label='Train Loss')
    
    if "val_loss" in log[0]:
        val_loss = [e["val_loss"] for e in log]
        plt.plot(epochs, val_loss, 'r-o', label='Val Loss')
    
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('results/figures/training_curve.png', dpi=150)
    plt.show()
    print(f"Final train loss: {train_loss[-1]:.4f}")

In [None]:
# View evaluation results vs paper baselines
if os.path.exists("results/tables/vqa_rad_metrics.json"):
    with open("results/tables/vqa_rad_metrics.json") as f:
        results = json.load(f)
    
    print("=" * 60)
    print("FINE-TUNED RESULTS vs PAPER BASELINES (VQA-RAD)")
    print("=" * 60)
    
    comparison = [
        ("Prior SOTA (specialist)",    71.03, None),
        ("PaLM-E 84B (no finetune)",   59.19, 38.67),
        ("Med-PaLM M 12B",             64.02, 50.66),
        ("Med-PaLM M 84B",             69.38, 59.90),
        ("Med-PaLM M 562B",            71.27, 62.06),
        ("Ours (BLIP-2 + LoRA)",       results['bleu_1'], results['f1']),
    ]
    
    print(f"{'Model':<30} {'BLEU-1':>10} {'F1':>10}")
    print("-" * 50)
    for name, bleu, f1 in comparison:
        b = f"{bleu:.2f}%" if bleu else "N/A"
        f = f"{f1:.2f}%" if f1 else "N/A"
        marker = " ‚Üê US" if "Ours" in name else ""
        print(f"{name:<30} {b:>10} {f:>10}{marker}")
    print("=" * 60)

---
## Phase 5: Generalization Experiments

These experiments test the paper's key claims:
1. **Cross-dataset transfer** ‚Äî Does fine-tuning on VQA-RAD help on Slake-VQA?
2. **Exemplar ablation** ‚Äî Does the one-shot prompting trick actually help?

In [None]:
# Download Slake-VQA for cross-dataset testing
!python data/download.py --dataset slake_vqa

In [None]:
# Run all generalization experiments
!python experiments/05_zero_shot_eval.py --experiment all --max_samples 100 --quantize

In [None]:
# View exemplar ablation results
if os.path.exists("results/tables/exemplar_ablation.json"):
    with open("results/tables/exemplar_ablation.json") as f:
        ablation = json.load(f)
    
    print("ONE-SHOT EXEMPLAR ABLATION")
    print("=" * 50)
    print(f"{'Mode':<25} {'BLEU-1':>10} {'F1':>10}")
    print("-" * 45)
    for mode, m in ablation.items():
        print(f"{mode:<25} {m['bleu_1']:>9.2f}% {m['f1']:>9.2f}%")
    
    diff = ablation['with_exemplar']['bleu_1'] - ablation['without_exemplar']['bleu_1']
    print(f"\nExemplar effect: {diff:+.2f}% BLEU-1")
    if diff > 0:
        print("‚Üí Exemplar HELPS (confirms paper's approach)")
    else:
        print("‚Üí Exemplar did not help (interesting finding for our model)")

---
## Final: Generate Complete Comparison Report

In [None]:
# Generate the unified comparison table and bar charts
!python evaluation/compare_to_paper.py

In [None]:
# Display comparison chart
from IPython.display import Image, display
import glob

for fig_path in glob.glob("results/figures/*_comparison.png"):
    print(f"\n{fig_path}:")
    display(Image(filename=fig_path))

In [None]:
# Show the final markdown comparison table
if os.path.exists("results/tables/full_comparison.md"):
    with open("results/tables/full_comparison.md") as f:
        print(f.read())

---
## üì• Download Results

Run this cell to zip all results for download.

In [None]:
# Package results for download
!tar -czf /content/reproduction_results.tar.gz results/

from google.colab import files
files.download('/content/reproduction_results.tar.gz')
print("‚úì Results downloaded! Add these to your GitHub repo.")

---
## Summary

### What We Reproduced
- Med-PaLM M's medical VQA methodology using open-source models
- Instruction task prompting with one-shot exemplars
- Domain-specific fine-tuning and its impact on performance
- Cross-dataset generalization evaluation

### Key Differences from Original Paper
| Aspect | Med-PaLM M | Our Reproduction |
|--------|-----------|------------------|
| Model | PaLM-E (562B) | BLIP-2 (~3B) |
| Training | Full fine-tuning on TPU pods | LoRA on single GPU |
| Data | 1M+ samples across 14 tasks | ~3.5K-14K VQA samples |
| Compute | Weeks on TPU v4 | ~30 min on T4 GPU |