# üß¨ Reproducing Med-PaLM M: Towards Generalist Biomedical AI

**Paper:** Tu et al., "Towards Generalist Biomedical AI" (arXiv:2307.14334)

‚ö†Ô∏è **IMPORTANT:** Run cells ONE AT A TIME from top to bottom. Wait for each cell to finish (spinner stops) before running the next one.

‚ö†Ô∏è **FIRST:** Enable GPU ‚Üí Runtime ‚Üí Change runtime type ‚Üí T4 GPU

---

## Step 0: Verify GPU and Setup Environment

In [None]:
# Verify GPU is available
import torch
import os
import json

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    vram = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"VRAM: {vram:.1f} GB")
    print("\n‚úÖ GPU is ready!")
else:
    print("\n‚ùå NO GPU ‚Äî Go to Runtime ‚Üí Change runtime type ‚Üí T4 GPU")

In [None]:
# Clone the repository
!rm -rf biomed-multimodal-reproduction
!git clone https://github.com/Mrabbi3/biomed-multimodal-reproduction.git
%cd biomed-multimodal-reproduction
print("\n‚úÖ Repository cloned!")

In [None]:
# Install dependencies (takes ~2 minutes)
!pip install -q transformers>=4.36.0 accelerate>=0.25.0 peft>=0.7.0 \
    bitsandbytes datasets Pillow tqdm pyyaml nltk rouge-score \
    matplotlib seaborn evaluate
print("\n‚úÖ Dependencies installed!")

---
## Phase 1: Download Data & Sanity Check

Downloads VQA-RAD (~50 MB) and verifies everything loads correctly.

In [None]:
# Download VQA-RAD dataset (~50 MB)
!python data/download.py --dataset vqa_rad

In [None]:
# Run sanity checks
!python experiments/01_data_sanity_check.py

In [None]:
# Visualize the sanity check output
import os
from IPython.display import Image, display

fig_path = "results/figures/vqa_rad_sanity_check.png"
if os.path.exists(fig_path):
    display(Image(filename=fig_path))
    print("‚úÖ Images and questions look correct!")
else:
    print("‚ö†Ô∏è Sanity check image not generated ‚Äî check errors above")

---
## Phase 2: Zero-Shot Baseline (No Training)

Load BLIP-2 and test on VQA-RAD **without any fine-tuning**.
Comparable to the paper's PaLM-E 84B baseline (BLEU-1: 59.19%).

‚è±Ô∏è **~10 minutes** (downloads ~7GB model on first run)

In [None]:
# Run forward pass test ‚Äî establishes zero-shot baseline
!python experiments/02_forward_pass_test.py --model blip2 --max_samples 50 --quantize

In [None]:
# View the baseline results
import os
import json

baseline_path = "results/tables/baseline_metrics.json"
if os.path.exists(baseline_path):
    with open(baseline_path) as f:
        baseline = json.load(f)
    print("ZERO-SHOT BASELINE RESULTS")
    print("=" * 50)
    print(f"  BLEU-1: {baseline['bleu_1']:.2f}%")
    print(f"  F1:     {baseline['f1']:.2f}%")
    print(f"\nPaper comparison:")
    print(f"  PaLM-E 84B (no finetune): BLEU-1=59.19%, F1=38.67%")
else:
    print("‚ö†Ô∏è No baseline results yet ‚Äî run the cell above first")

---
## Phase 3: Overfit Test (Verify Training Works)

Before real training, we memorize 5 examples to verify the pipeline:
- LoRA adapters are applied correctly
- Gradients flow through the model
- Loss decreases toward zero

If this fails ‚Üí bug in code. If it passes ‚Üí safe to do full training.

‚è±Ô∏è **~5 minutes**

In [None]:
# Overfit 5 examples for 50 epochs
!python experiments/03_overfit_single_batch.py --num_samples 5 --epochs 50 --quantize

---
## Phase 4: Full Training & Evaluation üöÄ

Fine-tune BLIP-2 on VQA-RAD training set using LoRA, then evaluate
on the test set and compare to Med-PaLM M baselines.

‚è±Ô∏è **~20-40 minutes on T4 GPU**

In [None]:
# Full training run
!python experiments/04_train_vqa.py \
    --dataset vqa_rad \
    --epochs 10 \
    --batch_size 4 \
    --lr 5e-5 \
    --lora_rank 16 \
    --grad_accum 4 \
    --quantize \
    --use_exemplar

In [None]:
# View training curve
import os
import json
import matplotlib.pyplot as plt

log_path = "results/logs/training_log.json"
if os.path.exists(log_path):
    with open(log_path) as f:
        log = json.load(f)

    epochs = [e["epoch"] for e in log]
    train_loss = [e["train_loss"] for e in log]

    plt.figure(figsize=(8, 4))
    plt.plot(epochs, train_loss, 'b-o', label='Train Loss')

    if "val_loss" in log[0]:
        val_loss = [e["val_loss"] for e in log]
        plt.plot(epochs, val_loss, 'r-o', label='Val Loss')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('results/figures/training_curve.png', dpi=150)
    plt.show()
    print(f"Final train loss: {train_loss[-1]:.4f}")
else:
    print("‚ö†Ô∏è No training log yet ‚Äî run the training cell above first")

In [None]:
# View evaluation results vs paper baselines
import os
import json

metrics_path = "results/tables/vqa_rad_metrics.json"
if os.path.exists(metrics_path):
    with open(metrics_path) as f:
        results = json.load(f)

    print("=" * 60)
    print("FINE-TUNED RESULTS vs PAPER BASELINES (VQA-RAD)")
    print("=" * 60)

    comparison = [
        ("Prior SOTA (specialist)",    71.03, None),
        ("PaLM-E 84B (no finetune)",   59.19, 38.67),
        ("Med-PaLM M 12B",             64.02, 50.66),
        ("Med-PaLM M 84B",             69.38, 59.90),
        ("Med-PaLM M 562B",            71.27, 62.06),
        ("Ours (BLIP-2 + LoRA)",       results['bleu_1'], results['f1']),
    ]

    print(f"{'Model':<30} {'BLEU-1':>10} {'F1':>10}")
    print("-" * 50)
    for name, bleu, f1 in comparison:
        b = f"{bleu:.2f}%" if bleu else "N/A"
        f = f"{f1:.2f}%" if f1 else "N/A"
        marker = " ‚Üê US" if "Ours" in name else ""
        print(f"{name:<30} {b:>10} {f:>10}{marker}")
    print("=" * 60)
else:
    print("‚ö†Ô∏è No evaluation results yet ‚Äî run training cell above first")

---
## Phase 5: Generalization Experiments

Tests the paper's key claims:
1. **Cross-dataset transfer** ‚Äî Does fine-tuning on VQA-RAD help on Slake-VQA?
2. **Exemplar ablation** ‚Äî Does the one-shot prompting trick actually help?

‚è±Ô∏è **~15 minutes**

In [None]:
# Download Slake-VQA for cross-dataset testing
!python data/download.py --dataset slake_vqa

In [None]:
# Run all generalization experiments
!python experiments/05_zero_shot_eval.py --experiment all --max_samples 100 --quantize

In [None]:
# View exemplar ablation results
import os
import json

ablation_path = "results/tables/exemplar_ablation.json"
if os.path.exists(ablation_path):
    with open(ablation_path) as f:
        ablation = json.load(f)

    print("ONE-SHOT EXEMPLAR ABLATION")
    print("=" * 50)
    print(f"{'Mode':<25} {'BLEU-1':>10} {'F1':>10}")
    print("-" * 45)
    for mode, m in ablation.items():
        print(f"{mode:<25} {m['bleu_1']:>9.2f}% {m['f1']:>9.2f}%")

    diff = ablation['with_exemplar']['bleu_1'] - ablation['without_exemplar']['bleu_1']
    print(f"\nExemplar effect: {diff:+.2f}% BLEU-1")
    if diff > 0:
        print("‚Üí Exemplar HELPS (confirms paper's approach)")
    else:
        print("‚Üí Exemplar did not help (interesting finding for our model)")
else:
    print("‚ö†Ô∏è No ablation results yet ‚Äî run the cell above first")

---
## Final: Generate Complete Comparison Report

In [None]:
# Generate the unified comparison table and bar charts
!python evaluation/compare_to_paper.py

In [None]:
# Display comparison charts
import os
import glob
from IPython.display import Image, display

charts = glob.glob("results/figures/*_comparison.png")
if charts:
    for fig_path in charts:
        print(f"\n{fig_path}:")
        display(Image(filename=fig_path))
else:
    print("‚ö†Ô∏è No comparison charts generated yet")

In [None]:
# Show the final markdown comparison table
import os

table_path = "results/tables/full_comparison.md"
if os.path.exists(table_path):
    with open(table_path) as f:
        print(f.read())
else:
    print("‚ö†Ô∏è No comparison table yet ‚Äî run the cell above first")

---
## üì• Download Results

Downloads all results (metrics, charts, predictions) as a zip file.

In [None]:
# Package results for download
import os
!tar -czf /content/reproduction_results.tar.gz results/

from google.colab import files
files.download('/content/reproduction_results.tar.gz')
print("‚úÖ Results downloaded! Add these to your GitHub repo.")

---
## Summary

### What We Reproduced
- Med-PaLM M's medical VQA methodology using open-source models
- Instruction task prompting with one-shot exemplars
- Domain-specific fine-tuning and its impact on performance
- Cross-dataset generalization evaluation

### Key Differences from Original Paper
| Aspect | Med-PaLM M | Our Reproduction |
|--------|-----------|------------------|
| Model | PaLM-E (562B) | BLIP-2 (~3B) |
| Training | Full fine-tuning on TPU pods | LoRA on single GPU |
| Data | 1M+ samples across 14 tasks | ~3.5K-14K VQA samples |
| Compute | Weeks on TPU v4 | ~1 hour on T4 GPU |