## 1. Environment Setup

In [None]:
# Check GPU availability
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
import psutil

ram_gb = psutil.virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

In [None]:
# Clone repository (or upload files)
!git clone https://github.com/YourRepo/Few-Shot-Domain-Adaptation-for-Medical-Image-Classification.git
%cd Few-Shot-Domain-Adaptation-for-Medical-Image-Classification

In [None]:
# Install dependencies
!pip install -q torch torchvision timm transformers scikit-learn pandas Pillow matplotlib pytorch-lightning wandb opencv-python scipy

In [None]:
# Import libraries
import os
import sys
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Subset

# Add project to path (if needed)
# sys.path.append('/content/Few-Shot-Domain-Adaptation-for-Medical-Image-Classification')

from configs.config import Config
from data.datasets import SimpleMedicalDataset, get_transforms, sample_few_shot_indices
from models.vit_backbone import ViTWrapper
from models.cnn_backbones import build_cnn
from lora.lora import apply_lora_to_model
from adapters.adapter import attach_adapter_to_vit
from prompts.prompt_tuning import attach_visual_prompt_to_vit
from train.trainer import LitModel
from eval.evaluator import compute_metrics, bootstrap_confidence_interval
from utils.utils import set_seed, count_parameters

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"GPU name: {torch.cuda.get_device_name(0)}")

## 2. Dataset Preparation

**Note:** You need to prepare your CheXpert and NIH ChestX-ray14 datasets.

Expected structure:
```
data/
â”œâ”€â”€ chexpert_train.csv
â”œâ”€â”€ chexpert_val.csv
â”œâ”€â”€ nih_train.csv
â”œâ”€â”€ nih_val.csv
â”œâ”€â”€ images/
â”‚   â”œâ”€â”€ patient001/
â”‚   â”‚   â”œâ”€â”€ study1/
â”‚   â”‚   â”‚   â””â”€â”€ view1.jpg
```

CSV format:
- Column `Path`: relative image path
- Columns for each of 14 pathologies (0/1 labels, -1 for uncertainty)

In [None]:
!pip install kaggle

In [None]:
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download ashery/chexpert

In [None]:
!unzip air-heart-disease.zip -d ./data

In [None]:
# Configure paths (modify as needed)
DATA_ROOT = './data'
CHEXPERT_TRAIN = f'{DATA_ROOT}/chexpert_train.csv'
CHEXPERT_VAL = f'{DATA_ROOT}/chexpert_val.csv'
NIH_TRAIN = f'{DATA_ROOT}/nih_train.csv'
NIH_VAL = f'{DATA_ROOT}/nih_val.csv'
NIH_TEST = f'{DATA_ROOT}/nih_test.csv'

# Create checkpoint directory
!mkdir -p checkpoints logs

In [None]:
# Verify datasets
print("Checking dataset files...")
for csv_path in [CHEXPERT_TRAIN, CHEXPERT_VAL, NIH_TRAIN, NIH_VAL]:
    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
        print(f"âœ“ {csv_path}: {len(df)} samples")
    else:
        print(f"âœ— {csv_path}: NOT FOUND")

## 3. Configuration and Utilities

In [None]:
# Base configuration
config = Config(
    data_root=DATA_ROOT,
    img_size=224,
    num_classes=14,
    batch_size=32,
    num_workers=4,
    backbone='vit_base_patch16_224',
    pretrained=True,
    epochs=30,
    lr=1e-4,
    weight_decay=1e-4,
    optimizer='adamw',
    mixed_precision=True,
    gradient_checkpointing=True,
    few_shot_k=50,
    checkpoint_dir='./checkpoints',
    use_wandb=False,
    seed=42
)

set_seed(config.seed)
print("Configuration:")
print(f"  Backbone: {config.backbone}")
print(f"  Batch size: {config.batch_size}")
print(f"  Epochs: {config.epochs}")
print(f"  Learning rate: {config.lr}")
print(f"  Mixed precision: {config.mixed_precision}")
print(f"  Gradient checkpointing: {config.gradient_checkpointing}")

## 4. Experiment 1: Baseline ViT on Source Domain (CheXpert)

Train a standard Vision Transformer on the source domain without any adaptation.

In [None]:
# Load CheXpert datasets
train_ds = SimpleMedicalDataset(CHEXPERT_TRAIN, DATA_ROOT, transform=get_transforms(config.img_size, True))
val_ds = SimpleMedicalDataset(CHEXPERT_VAL, DATA_ROOT, transform=get_transforms(config.img_size, False))

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, 
                                           num_workers=config.num_workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=config.batch_size, shuffle=False, 
                                         num_workers=config.num_workers, pin_memory=True)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

In [None]:
# Build baseline ViT model
baseline_model = ViTWrapper(model_name=config.backbone, num_classes=config.num_classes, pretrained=config.pretrained)

# Enable gradient checkpointing
if hasattr(baseline_model.backbone, 'set_grad_checkpointing'):
    baseline_model.backbone.set_grad_checkpointing(True)

params = count_parameters(baseline_model)
print(f"Total parameters: {params['total']:,}")
print(f"Trainable parameters: {params['trainable']:,}")

In [None]:
# Setup Lightning trainer
lit_baseline = LitModel(baseline_model, config)

checkpoint_callback = ModelCheckpoint(
    dirpath='./checkpoints/baseline_vit',
    filename='best-{epoch:02d}-{val/auc:.4f}',
    monitor='val/auc',
    mode='max',
    save_top_k=3
)

early_stop = EarlyStopping(monitor='val/auc', patience=10, mode='max', verbose=True)

trainer_baseline = pl.Trainer(
    max_epochs=config.epochs,
    accelerator='gpu',
    devices=-1,  # Use all available GPUs
    precision=16 if config.mixed_precision else 32,
    callbacks=[checkpoint_callback, early_stop],
    gradient_clip_val=1.0,
    log_every_n_steps=10
)

print("ðŸš€ Training baseline ViT on CheXpert...")
trainer_baseline.fit(lit_baseline, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
# Save baseline checkpoint
torch.save({
    'model_state_dict': baseline_model.state_dict(),
    'config': config
}, './checkpoints/baseline_vit_source.pth')

print(f"âœ“ Baseline model saved")
print(f"Best val AUC: {checkpoint_callback.best_model_score:.4f}")

## 5. Experiment 2: Few-Shot LoRA Adaptation (CheXpert â†’ NIH)

Apply LoRA to adapt the baseline model to NIH dataset with limited samples.

In [None]:
# Load NIH target dataset
nih_train_full = SimpleMedicalDataset(NIH_TRAIN, DATA_ROOT, transform=get_transforms(config.img_size, True))
nih_val = SimpleMedicalDataset(NIH_VAL, DATA_ROOT, transform=get_transforms(config.img_size, False))

# Sample few-shot subset
few_shot_indices = sample_few_shot_indices(nih_train_full, k_per_class=config.few_shot_k, seed=config.seed)
print(f"Few-shot sampling: {len(few_shot_indices)} samples (k={config.few_shot_k} per class)")

nih_train_fewshot = Subset(nih_train_full, few_shot_indices)

train_loader_nih = torch.utils.data.DataLoader(nih_train_fewshot, batch_size=config.batch_size, 
                                                shuffle=True, num_workers=config.num_workers, pin_memory=True)
val_loader_nih = torch.utils.data.DataLoader(nih_val, batch_size=config.batch_size, 
                                              shuffle=False, num_workers=config.num_workers, pin_memory=True)

In [None]:
# Load baseline checkpoint and apply LoRA
lora_model = ViTWrapper(model_name=config.backbone, num_classes=config.num_classes, pretrained=False)
checkpoint = torch.load('./checkpoints/baseline_vit_source.pth')
lora_model.load_state_dict(checkpoint['model_state_dict'])
print("âœ“ Loaded baseline checkpoint")

# Apply LoRA
apply_lora_to_model(lora_model, r=8, alpha=32.0)

# Freeze original parameters
for name, param in lora_model.named_parameters():
    if 'lora_' not in name:
        param.requires_grad = False

lora_params = count_parameters(lora_model)
print(f"\nLoRA model parameters:")
print(f"  Total: {lora_params['total']:,}")
print(f"  Trainable: {lora_params['trainable']:,}")
print(f"  Efficiency: {100.0 * lora_params['trainable'] / lora_params['total']:.2f}% trainable")

In [None]:
# Train with LoRA
lit_lora = LitModel(lora_model, config)

checkpoint_callback_lora = ModelCheckpoint(
    dirpath='./checkpoints/lora_adaptation',
    filename='best-{epoch:02d}-{val/auc:.4f}',
    monitor='val/auc',
    mode='max',
    save_top_k=3
)

trainer_lora = pl.Trainer(
    max_epochs=config.epochs,
    accelerator='gpu',
    devices=-1,
    precision=16 if config.mixed_precision else 32,
    callbacks=[checkpoint_callback_lora, EarlyStopping(monitor='val/auc', patience=10, mode='max')],
    gradient_clip_val=1.0
)

print("ðŸš€ Training with LoRA adaptation...")
trainer_lora.fit(lit_lora, train_dataloaders=train_loader_nih, val_dataloaders=val_loader_nih)

print(f"âœ“ LoRA training complete")
print(f"Best val AUC: {checkpoint_callback_lora.best_model_score:.4f}")

## 6. Experiment 3: Few-Shot Adapter Adaptation

In [None]:
# Load baseline and apply adapters
adapter_model = ViTWrapper(model_name=config.backbone, num_classes=config.num_classes, pretrained=False)
checkpoint = torch.load('./checkpoints/baseline_vit_source.pth')
adapter_model.load_state_dict(checkpoint['model_state_dict'])

# Attach adapters
attach_adapter_to_vit(adapter_model.backbone, adapter_dim=64)

# Freeze backbone
for name, param in adapter_model.named_parameters():
    if 'adapter' not in name and 'classifier' not in name:
        param.requires_grad = False

adapter_params = count_parameters(adapter_model)
print(f"Adapter model - Trainable: {adapter_params['trainable']:,} ({100.0 * adapter_params['trainable'] / adapter_params['total']:.2f}%)")

In [None]:
# Train with adapters
lit_adapter = LitModel(adapter_model, config)

checkpoint_callback_adapter = ModelCheckpoint(
    dirpath='./checkpoints/adapter_adaptation',
    filename='best-{epoch:02d}-{val/auc:.4f}',
    monitor='val/auc',
    mode='max'
)

trainer_adapter = pl.Trainer(
    max_epochs=config.epochs,
    accelerator='gpu',
    devices=-1,
    precision=16,
    callbacks=[checkpoint_callback_adapter, EarlyStopping(monitor='val/auc', patience=10, mode='max')]
)

print("ðŸš€ Training with Adapter adaptation...")
trainer_adapter.fit(lit_adapter, train_dataloaders=train_loader_nih, val_dataloaders=val_loader_nih)
print(f"Best val AUC: {checkpoint_callback_adapter.best_model_score:.4f}")

## 7. Experiment 4: Few-Shot Prompt Tuning

In [None]:
# Load baseline and apply prompt tuning
prompt_model = ViTWrapper(model_name=config.backbone, num_classes=config.num_classes, pretrained=False)
checkpoint = torch.load('./checkpoints/baseline_vit_source.pth')
prompt_model.load_state_dict(checkpoint['model_state_dict'])

# Attach visual prompts
attach_visual_prompt_to_vit(prompt_model.backbone, prompt_tokens=10)

# Freeze everything except prompts and classifier
for name, param in prompt_model.named_parameters():
    if 'visual_prompt' not in name and 'classifier' not in name:
        param.requires_grad = False

prompt_params = count_parameters(prompt_model)
print(f"Prompt model - Trainable: {prompt_params['trainable']:,} ({100.0 * prompt_params['trainable'] / prompt_params['total']:.2f}%)")

In [None]:
# Train with prompt tuning
lit_prompt = LitModel(prompt_model, config)

checkpoint_callback_prompt = ModelCheckpoint(
    dirpath='./checkpoints/prompt_adaptation',
    filename='best-{epoch:02d}-{val/auc:.4f}',
    monitor='val/auc',
    mode='max'
)

trainer_prompt = pl.Trainer(
    max_epochs=config.epochs,
    accelerator='gpu',
    devices=-1,
    precision=16,
    callbacks=[checkpoint_callback_prompt, EarlyStopping(monitor='val/auc', patience=10, mode='max')]
)

print("ðŸš€ Training with Prompt Tuning...")
trainer_prompt.fit(lit_prompt, train_dataloaders=train_loader_nih, val_dataloaders=val_loader_nih)
print(f"Best val AUC: {checkpoint_callback_prompt.best_model_score:.4f}")

## 8. Experiment 5: CNN Baselines

In [None]:
# ResNet-50 baseline
resnet_model = build_cnn('resnet50', num_classes=config.num_classes, pretrained=True)
lit_resnet = LitModel(resnet_model, config)

trainer_resnet = pl.Trainer(
    max_epochs=config.epochs,
    accelerator='gpu',
    devices=-1,
    precision=16,
    callbacks=[ModelCheckpoint(dirpath='./checkpoints/resnet50', monitor='val/auc', mode='max')]
)

print("ðŸš€ Training ResNet-50 baseline...")
trainer_resnet.fit(lit_resnet, train_dataloaders=train_loader_nih, val_dataloaders=val_loader_nih)

In [None]:
# DenseNet-121 baseline
densenet_model = build_cnn('densenet121', num_classes=config.num_classes, pretrained=True)
lit_densenet = LitModel(densenet_model, config)

trainer_densenet = pl.Trainer(
    max_epochs=config.epochs,
    accelerator='gpu',
    devices=-1,
    precision=16,
    callbacks=[ModelCheckpoint(dirpath='./checkpoints/densenet121', monitor='val/auc', mode='max')]
)

print("ðŸš€ Training DenseNet-121 baseline...")
trainer_densenet.fit(lit_densenet, train_dataloaders=train_loader_nih, val_dataloaders=val_loader_nih)

## 9. Results Comparison and Visualization

In [None]:
# Collect results
results = {
    'Baseline ViT': checkpoint_callback.best_model_score.item(),
    'LoRA': checkpoint_callback_lora.best_model_score.item(),
    'Adapter': checkpoint_callback_adapter.best_model_score.item(),
    'Prompt Tuning': checkpoint_callback_prompt.best_model_score.item(),
}

# Parameter efficiency
param_counts = {
    'Baseline ViT': params['trainable'],
    'LoRA': lora_params['trainable'],
    'Adapter': adapter_params['trainable'],
    'Prompt Tuning': prompt_params['trainable'],
}

# Create comparison DataFrame
comparison_df = pd.DataFrame({
    'Method': list(results.keys()),
    'Val AUC': list(results.values()),
    'Trainable Params': list(param_counts.values())
})

comparison_df['Param Efficiency (%)'] = 100.0 * comparison_df['Trainable Params'] / params['total']
comparison_df = comparison_df.sort_values('Val AUC', ascending=False)

print("\n" + "="*70)
print("FINAL RESULTS COMPARISON")
print("="*70)
print(comparison_df.to_string(index=False))
print("="*70)

In [None]:
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# AUC comparison
axes[0].bar(comparison_df['Method'], comparison_df['Val AUC'], color=['blue', 'green', 'orange', 'red'])
axes[0].set_ylabel('Validation AUC', fontsize=12)
axes[0].set_title('Performance Comparison', fontsize=14, fontweight='bold')
axes[0].tick_params(axis='x', rotation=45)
axes[0].grid(axis='y', alpha=0.3)

# Parameter efficiency
axes[1].bar(comparison_df['Method'], comparison_df['Param Efficiency (%)'], color=['blue', 'green', 'orange', 'red'])
axes[1].set_ylabel('Trainable Parameters (%)', fontsize=12)
axes[1].set_title('Parameter Efficiency', fontsize=14, fontweight='bold')
axes[1].tick_params(axis='x', rotation=45)
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('results_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nâœ“ Results saved to results_comparison.png")

## 10. Test Set Evaluation with Confidence Intervals

In [None]:
# Load test dataset
if os.path.exists(NIH_TEST):
    nih_test = SimpleMedicalDataset(NIH_TEST, DATA_ROOT, transform=get_transforms(config.img_size, False))
    test_loader = torch.utils.data.DataLoader(nih_test, batch_size=config.batch_size, 
                                               shuffle=False, num_workers=config.num_workers)
    
    # Evaluate best LoRA model
    print("Evaluating LoRA model on test set...")
    test_results = trainer_lora.test(lit_lora, dataloaders=test_loader)
    
    # Get predictions for bootstrap CI
    lora_model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            imgs = batch['image'].cuda()
            labels = batch['labels'].numpy()
            logits = lora_model(imgs)
            preds = torch.sigmoid(logits).cpu().numpy()
            all_preds.append(preds)
            all_labels.append(labels)
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    # Compute metrics with CI
    from functools import partial
    auc_fn = partial(compute_metrics, thr=0.5)
    ci_results = bootstrap_confidence_interval(all_preds, all_labels, 
                                                lambda p, l: compute_metrics(p, l)['auc_roc'], 
                                                n_bootstrap=1000)
    
    print("\n" + "="*70)
    print("TEST SET RESULTS (LoRA Model)")
    print("="*70)
    print(f"AUC-ROC: {ci_results['mean']:.4f} (95% CI: [{ci_results['ci_lower']:.4f}, {ci_results['ci_upper']:.4f}])")
    print("="*70)
else:
    print("Test set not found, skipping test evaluation")

## 11. Save Final Report

In [None]:
# Save comparison results
comparison_df.to_csv('results_comparison.csv', index=False)
print("\nâœ“ All experiments complete!")
print("âœ“ Results saved to results_comparison.csv")
print("âœ“ Checkpoints saved in ./checkpoints/")
print("\nNext steps:")
print("1. Download checkpoints and results for further analysis")
print("2. Experiment with different k values (few_shot_k)")
print("3. Try different LoRA ranks and adapter dimensions")
print("4. Compare cross-domain generalization (CheXpertâ†’NIH vs NIHâ†’CheXpert)")