# MedGemma Fine-Tuning for CXR Triage

This notebook demonstrates how to fine-tune MedGemma 4B for chest X-ray urgency classification.

**Requirements:**
- GPU with 24GB+ VRAM (for full fine-tuning) or 16GB+ (with LoRA)
- MIMIC-CXR dataset access (see `data_scripts/download_datasets.sh`)
- `transformers>=4.50.0`, `peft>=0.7.0`

**Time to complete:** ~2-4 hours (depending on dataset size)

## 1. Setup & Configuration

In [None]:
# Install dependencies
# !pip install -U transformers accelerate peft torch bitsandbytes

import os
import json
import torch
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, Dict, List, Any
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
@dataclass
class TrainingConfig:
    """Training configuration for MedGemma fine-tuning."""
    # Model
    model_id: str = "google/medgemma-4b-it"
    
    # Data
    data_dir: str = "../data/processed"
    max_samples: Optional[int] = None  # None = use all
    
    # LoRA Configuration
    use_lora: bool = True
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    lora_target_modules: List[str] = None  # Auto-detect if None
    
    # Training Hyperparameters
    learning_rate: float = 2e-5
    num_epochs: int = 3
    batch_size: int = 2
    gradient_accumulation_steps: int = 8
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    
    # Generation
    max_new_tokens: int = 256
    
    # Output
    output_dir: str = "../models/checkpoints/urgency_classifier"
    logging_steps: int = 10
    save_steps: int = 100
    eval_steps: int = 50
    
    def __post_init__(self):
        if self.lora_target_modules is None:
            # Target attention and MLP layers for LoRA
            self.lora_target_modules = [
                "q_proj", "k_proj", "v_proj", "o_proj",
                "gate_proj", "up_proj", "down_proj"
            ]

config = TrainingConfig()
print("Training Configuration:")
for k, v in config.__dict__.items():
    print(f"  {k}: {v}")

## 2. Load Model with LoRA

We use PEFT (Parameter-Efficient Fine-Tuning) with LoRA to reduce memory requirements.

In [None]:
from transformers import (
    AutoModelForImageTextToText,
    AutoProcessor,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Quantization config for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

print(f"Loading model: {config.model_id}")

# Load base model
model = AutoModelForImageTextToText.from_pretrained(
    config.model_id,
    quantization_config=bnb_config if config.use_lora else None,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Load processor
processor = AutoProcessor.from_pretrained(config.model_id)

print(f"Model loaded. Parameters: {model.num_parameters():,}")

In [None]:
if config.use_lora:
    print("Applying LoRA configuration...")
    
    # Prepare model for k-bit training
    model = prepare_model_for_kbit_training(model)
    
    # LoRA configuration
    lora_config = LoraConfig(
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_target_modules,
        bias="none",
        task_type="CAUSAL_LM",
    )
    
    # Apply LoRA
    model = get_peft_model(model, lora_config)
    
    # Print trainable parameters
    model.print_trainable_parameters()

## 3. Prepare Dataset

Load and preprocess the MIMIC-CXR subset for training.

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random

class CXRTriageDataset(Dataset):
    """Dataset for CXR triage training."""
    
    SYSTEM_PROMPT = """You are an expert radiologist assistant. Analyze chest X-rays and provide:
1. Urgency classification (Urgent or Non-Urgent)
2. Brief explanation
3. Key findings

This is for clinical decision support only."""
    
    def __init__(
        self,
        data_file: str,
        processor,
        max_samples: Optional[int] = None,
        use_sample_data: bool = False,
    ):
        self.processor = processor
        self.samples = []
        
        if use_sample_data:
            # Create synthetic samples for testing
            self.samples = self._create_sample_data()
        else:
            # Load from JSONL file
            data_path = Path(data_file)
            if data_path.exists():
                with open(data_path) as f:
                    for line in f:
                        self.samples.append(json.loads(line))
            else:
                print(f"Warning: Data file not found: {data_file}")
                print("Using sample data for demonstration.")
                self.samples = self._create_sample_data()
        
        if max_samples:
            self.samples = self.samples[:max_samples]
        
        print(f"Dataset loaded: {len(self.samples)} samples")
    
    def _create_sample_data(self) -> List[Dict]:
        """Create sample data for testing the training pipeline."""
        return [
            {
                "id": "sample_1",
                "image_url": "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
                "urgency": "non-urgent",
                "primary_finding": "No Finding",
                "explanation": "Normal chest X-ray with clear lung fields.",
            },
        ] * 10  # Repeat for minimal training
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def _load_image(self, sample: Dict) -> Image.Image:
        """Load image from path or URL."""
        import requests
        from io import BytesIO
        
        if "image_url" in sample:
            response = requests.get(sample["image_url"], headers={"User-Agent": "MedGemma"})
            return Image.open(BytesIO(response.content)).convert("RGB")
        elif "image_path" in sample:
            return Image.open(sample["image_path"]).convert("RGB")
        else:
            raise ValueError("No image source found in sample")
    
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        sample = self.samples[idx]
        
        # Load image
        image = self._load_image(sample)
        
        # Create target response
        urgency = sample["urgency"].capitalize()
        if urgency == "Non-urgent":
            urgency = "Non-Urgent"
        
        explanation = sample.get("explanation", f"{sample['primary_finding']} detected.")
        
        target_response = f"""1. URGENCY: [{urgency}]
2. EXPLANATION: [{explanation}]
3. KEY FINDINGS: [{sample['primary_finding']}]"""
        
        # Create messages
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": self.SYSTEM_PROMPT}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Analyze this chest X-ray for triage."},
                    {"type": "image", "image": image}
                ]
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": target_response}]
            }
        ]
        
        return {
            "messages": messages,
            "image": image,
            "label": sample["urgency"],
        }

In [None]:
# Create datasets
train_dataset = CXRTriageDataset(
    data_file=f"{config.data_dir}/train.jsonl",
    processor=processor,
    max_samples=config.max_samples,
    use_sample_data=True,  # Set to False when using real data
)

val_dataset = CXRTriageDataset(
    data_file=f"{config.data_dir}/val.jsonl",
    processor=processor,
    max_samples=config.max_samples // 10 if config.max_samples else None,
    use_sample_data=True,
)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

## 4. Training Setup

In [None]:
from transformers import TrainingArguments, Trainer
from tqdm import tqdm

# Create output directory
output_dir = Path(config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

# Training arguments
training_args = TrainingArguments(
    output_dir=str(output_dir),
    num_train_epochs=config.num_epochs,
    per_device_train_batch_size=config.batch_size,
    per_device_eval_batch_size=config.batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    learning_rate=config.learning_rate,
    warmup_ratio=config.warmup_ratio,
    weight_decay=config.weight_decay,
    logging_steps=config.logging_steps,
    save_steps=config.save_steps,
    eval_steps=config.eval_steps,
    evaluation_strategy="steps",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    bf16=True,
    dataloader_pin_memory=False,
    report_to="none",  # Disable wandb/tensorboard for simplicity
)

print("Training arguments configured.")

In [None]:
def collate_fn(batch):
    """Custom collate function for multimodal data."""
    # Process each sample
    processed_batch = []
    
    for sample in batch:
        # Apply chat template and tokenize
        inputs = processor.apply_chat_template(
            sample["messages"],
            add_generation_prompt=False,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        )
        
        # Create labels (same as input_ids for causal LM)
        inputs["labels"] = inputs["input_ids"].clone()
        
        processed_batch.append(inputs)
    
    # Stack tensors
    return {
        "input_ids": torch.cat([b["input_ids"] for b in processed_batch]),
        "attention_mask": torch.cat([b["attention_mask"] for b in processed_batch]),
        "labels": torch.cat([b["labels"] for b in processed_batch]),
    }

## 5. Run Training

⚠️ **Note**: This is a simplified training loop for demonstration. For production training, use the full `Trainer` with proper data collation.

In [None]:
# Simple training loop for demonstration
print("=" * 60)
print("Starting Fine-Tuning (Demonstration Mode)")
print("=" * 60)
print()
print("Note: This notebook demonstrates the training setup.")
print("For full training, ensure you have:")
print("  1. MIMIC-CXR dataset downloaded and processed")
print("  2. Sufficient GPU memory (24GB+ recommended)")
print("  3. Several hours of training time")
print()

# Save configuration
config_file = output_dir / "training_config.json"
with open(config_file, "w") as f:
    json.dump(config.__dict__, f, indent=2, default=str)
print(f"Configuration saved to: {config_file}")

In [None]:
# Optional: Run a single training step to verify setup
print("\nVerifying training setup with single batch...")

try:
    # Get a single sample
    sample = train_dataset[0]
    
    # Process the sample
    inputs = processor.apply_chat_template(
        sample["messages"],
        add_generation_prompt=False,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )
    
    # Move to device
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    inputs["labels"] = inputs["input_ids"].clone()
    
    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs)
        loss = outputs.loss
    
    print(f"✓ Forward pass successful! Loss: {loss.item():.4f}")
    print(f"✓ Input shape: {inputs['input_ids'].shape}")
    print("\nSetup verified. Ready for full training.")
    
except Exception as e:
    print(f"✗ Error during verification: {e}")
    raise

## 6. Save LoRA Adapter

Save the trained LoRA adapter for inference.

In [None]:
# Save the LoRA adapter (after training)
if config.use_lora:
    print("Saving LoRA adapter...")
    model.save_pretrained(output_dir)
    processor.save_pretrained(output_dir)
    print(f"Adapter saved to: {output_dir}")
else:
    print("Full model fine-tuning - saving complete model...")
    model.save_pretrained(output_dir)
    processor.save_pretrained(output_dir)

# Save training summary
summary = {
    "model_id": config.model_id,
    "training_samples": len(train_dataset),
    "val_samples": len(val_dataset),
    "epochs": config.num_epochs,
    "learning_rate": config.learning_rate,
    "use_lora": config.use_lora,
    "lora_r": config.lora_r if config.use_lora else None,
}

with open(output_dir / "training_summary.json", "w") as f:
    json.dump(summary, f, indent=2)

print("\nTraining Summary:")
for k, v in summary.items():
    print(f"  {k}: {v}")

## 7. Next Steps

After training, proceed to:
1. **Evaluation**: Run `03_evaluation_and_metrics.ipynb` to compute AUC, sensitivity, PPV
2. **Demo**: Export to ONNX and deploy via `demo_app/`

### Hyperparameter Tuning Tips

| Parameter | Recommendation |
|-----------|---------------|
| `lora_r` | Start with 16, increase to 32 if underfitting |
| `learning_rate` | 1e-5 to 5e-5 works well for LoRA |
| `epochs` | 3-5 epochs usually sufficient |
| `batch_size` | As large as GPU allows (effective batch = batch × grad_accum) |