# Model Architecture Comparison

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Osman-Geomatics93/pansharpening-toolkit-/blob/main/notebooks/02_model_comparison.ipynb)

This notebook provides a detailed comparison of all available pansharpening models.

## Models Covered
1. **PNN** - Basic 3-layer CNN
2. **PanNet** - ResNet-style with high-pass filtering
3. **DRPNN** - Deep Residual PanNet
4. **PanNetCBAM** - PanNet with CBAM attention
5. **MultiScalePanNet** - Feature pyramid architecture
6. **PanFormer** - Transformer-based
7. **PanFormerLite** - Lightweight transformer

In [None]:
# Install from GitHub (for Google Colab)
import subprocess
import sys

# Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    print("Running in Google Colab - Installing pansharpening toolkit...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", 
                          "git+https://github.com/Osman-Geomatics93/pansharpening-toolkit-.git"])
else:
    sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from time import time

from models import (
    PNN, PanNet, DRPNN, PanNetCBAM, 
    MultiScalePanNet, PanFormer, PanFormerLite,
    create_model, AVAILABLE_MODELS
)

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Model Statistics

In [None]:
# Collect model statistics
stats = []

for model_name in AVAILABLE_MODELS:
    model = create_model(model_name, ms_bands=4)
    n_params = sum(p.numel() for p in model.parameters())
    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # Memory estimate (rough)
    memory_mb = n_params * 4 / (1024 * 1024)  # 4 bytes per float32
    
    stats.append({
        'Model': model_name,
        'Parameters': n_params,
        'Trainable': n_trainable,
        'Memory (MB)': memory_mb
    })

# Display as table
print(f"{'Model':<18} {'Parameters':>12} {'Memory (MB)':>12}")
print("-" * 45)
for s in stats:
    print(f"{s['Model']:<18} {s['Parameters']:>12,} {s['Memory (MB)']:>12.2f}")

In [None]:
# Visualize parameter counts
fig, ax = plt.subplots(figsize=(10, 5))

models = [s['Model'] for s in stats]
params = [s['Parameters'] / 1000 for s in stats]  # in thousands

bars = ax.bar(models, params, color='steelblue', edgecolor='navy')
ax.set_ylabel('Parameters (K)')
ax.set_title('Model Parameter Comparison')
ax.set_xticklabels(models, rotation=45, ha='right')

# Add value labels
for bar, param in zip(bars, params):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 10,
            f'{param:.0f}K', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

## Inference Speed Comparison

In [None]:
# Benchmark inference speed
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Create test input
ms = torch.randn(1, 4, 256, 256).to(device)
pan = torch.randn(1, 1, 256, 256).to(device)

speed_stats = []
n_runs = 10

for model_name in AVAILABLE_MODELS:
    model = create_model(model_name, ms_bands=4).to(device)
    model.eval()
    
    # Warmup
    with torch.no_grad():
        for _ in range(3):
            _ = model(ms, pan)
    
    # Benchmark
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    start = time()
    with torch.no_grad():
        for _ in range(n_runs):
            _ = model(ms, pan)
            if device.type == 'cuda':
                torch.cuda.synchronize()
    
    elapsed = (time() - start) / n_runs * 1000  # ms
    speed_stats.append({'Model': model_name, 'Time (ms)': elapsed})
    print(f"{model_name:<18}: {elapsed:>8.2f} ms")

In [None]:
# Visualize speed
fig, ax = plt.subplots(figsize=(10, 5))

models = [s['Model'] for s in speed_stats]
times = [s['Time (ms)'] for s in speed_stats]

bars = ax.barh(models, times, color='coral', edgecolor='darkred')
ax.set_xlabel('Inference Time (ms)')
ax.set_title(f'Inference Speed Comparison (256x256, {device})')

# Add value labels
for bar, t in zip(bars, times):
    ax.text(bar.get_width() + 0.5, bar.get_y() + bar.get_height()/2,
            f'{t:.1f}ms', va='center', fontsize=9)

plt.tight_layout()
plt.show()

## Architecture Details

In [None]:
# Show model architecture
model_name = 'pannet_cbam'
model = create_model(model_name, ms_bands=4)
print(f"=== {model_name.upper()} ===")
print(model)

In [None]:
# Show PanFormerLite architecture
model = create_model('panformer_lite', ms_bands=4)
print("=== PANFORMER_LITE ===")
print(model)

## Summary Table

In [None]:
# Create summary
summary = {
    'pnn': {'Type': 'CNN', 'Attention': 'No', 'Multi-scale': 'No', 'Best For': 'Baseline'},
    'pannet': {'Type': 'ResNet', 'Attention': 'No', 'Multi-scale': 'No', 'Best For': 'General use'},
    'drpnn': {'Type': 'Deep ResNet', 'Attention': 'No', 'Multi-scale': 'No', 'Best For': 'Complex scenes'},
    'pannet_cbam': {'Type': 'ResNet+Attn', 'Attention': 'CBAM', 'Multi-scale': 'No', 'Best For': 'Balanced'},
    'mspannet': {'Type': 'FPN', 'Attention': 'CBAM', 'Multi-scale': 'Yes', 'Best For': 'Multi-scale features'},
    'panformer': {'Type': 'Transformer', 'Attention': 'Self+Cross', 'Multi-scale': 'No', 'Best For': 'Best quality'},
    'panformer_lite': {'Type': 'Window Trans.', 'Attention': 'Window', 'Multi-scale': 'No', 'Best For': 'Efficient transformer'}
}

print(f"{'Model':<16} {'Type':<14} {'Attention':<12} {'Multi-scale':<12} {'Best For':<20}")
print("-" * 75)
for model, info in summary.items():
    print(f"{model:<16} {info['Type']:<14} {info['Attention']:<12} {info['Multi-scale']:<12} {info['Best For']:<20}")

## Recommendations

| Use Case | Recommended Model |
|----------|------------------|
| Quick prototyping | `pnn` |
| General purpose | `pannet` |
| Best quality | `panformer_lite` (with 100+ epochs) |
| Limited compute | `pannet_cbam` |
| Research/SOTA | `panformer` |