# Pansharpening Toolkit - Quick Start Guide

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Osman-Geomatics93/pansharpening-toolkit-/blob/main/notebooks/01_quick_start.ipynb)

This notebook demonstrates the basic usage of the pansharpening toolkit.

## Contents
1. Setup & Installation (for Colab)
2. Loading Images
3. Classic Pansharpening Methods
4. Deep Learning Models
5. Quality Metrics

## 1. Setup & Installation

Run this cell to install the toolkit (required for Google Colab).

In [None]:
# Install from GitHub (for Google Colab)
import subprocess
import sys

# Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    print("Running in Google Colab - Installing pansharpening toolkit...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", 
                          "git+https://github.com/Osman-Geomatics93/pansharpening-toolkit-.git"])
else:
    # Local installation
    sys.path.insert(0, '..')

# Import libraries
import torch
import numpy as np
import matplotlib.pyplot as plt

from models import create_model, create_loss, AVAILABLE_MODELS
from utils import compute_all_metrics

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"\nAvailable models: {AVAILABLE_MODELS}")

## 2. Create Sample Data

For demonstration, we'll create synthetic satellite imagery.

In [None]:
# Create synthetic data
def create_synthetic_data(height=256, width=256, ms_bands=4):
    """Create synthetic PAN and MS images for testing."""
    # Create a base pattern
    x = np.linspace(0, 4*np.pi, width)
    y = np.linspace(0, 4*np.pi, height)
    X, Y = np.meshgrid(x, y)
    
    # PAN: High resolution with fine details
    pan = 0.5 + 0.3 * np.sin(X) * np.cos(Y) + 0.2 * np.sin(3*X) * np.sin(3*Y)
    pan = np.clip(pan, 0, 1)
    
    # MS: Lower resolution, multiple bands
    ms = np.zeros((ms_bands, height//4, width//4))
    for i in range(ms_bands):
        phase = i * np.pi / ms_bands
        x_lr = np.linspace(0, 4*np.pi, width//4)
        y_lr = np.linspace(0, 4*np.pi, height//4)
        X_lr, Y_lr = np.meshgrid(x_lr, y_lr)
        ms[i] = 0.5 + 0.3 * np.sin(X_lr + phase) * np.cos(Y_lr)
    ms = np.clip(ms, 0, 1)
    
    return pan[np.newaxis], ms

pan, ms = create_synthetic_data()
print(f"PAN shape: {pan.shape}")
print(f"MS shape: {ms.shape}")

In [None]:
# Visualize the data
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow(pan[0], cmap='gray')
axes[0].set_title('PAN (High Resolution)')
axes[0].axis('off')

axes[1].imshow(ms[:3].transpose(1, 2, 0))  # RGB
axes[1].set_title('MS - RGB (Low Resolution)')
axes[1].axis('off')

axes[2].imshow(ms[3], cmap='Reds')
axes[2].set_title('MS - NIR Band')
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 3. Model Inference

Let's test different models on our synthetic data.

In [None]:
from scipy.ndimage import zoom

# Upsample MS to PAN resolution
scale = pan.shape[1] // ms.shape[1]
ms_up = np.array([zoom(band, scale, order=3) for band in ms])
print(f"Upsampled MS shape: {ms_up.shape}")

In [None]:
# Test a model
model_name = 'pannet_cbam'
model = create_model(model_name, ms_bands=4)
model.eval()

# Convert to tensors
ms_tensor = torch.from_numpy(ms_up).unsqueeze(0).float()
pan_tensor = torch.from_numpy(pan).unsqueeze(0).float()

print(f"Model: {model_name}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Run inference
with torch.no_grad():
    fused = model(ms_tensor, pan_tensor)
    fused = fused.squeeze(0).numpy()

print(f"Output shape: {fused.shape}")

In [None]:
# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow(ms_up[:3].transpose(1, 2, 0))
axes[0].set_title('MS Upsampled (Input)')
axes[0].axis('off')

axes[1].imshow(pan[0], cmap='gray')
axes[1].set_title('PAN (Input)')
axes[1].axis('off')

axes[2].imshow(np.clip(fused[:3].transpose(1, 2, 0), 0, 1))
axes[2].set_title(f'{model_name.upper()} (Output)')
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 4. Compare All Models

In [None]:
# Compare different models
results = {}

for model_name in ['pnn', 'pannet', 'pannet_cbam', 'panformer_lite']:
    model = create_model(model_name, ms_bands=4)
    model.eval()
    
    with torch.no_grad():
        fused = model(ms_tensor, pan_tensor)
        results[model_name] = fused.squeeze(0).numpy()
    
    n_params = sum(p.numel() for p in model.parameters())
    print(f"{model_name:15s}: {n_params:>10,} parameters")

In [None]:
# Visualize all results
fig, axes = plt.subplots(2, 3, figsize=(12, 8))

axes[0, 0].imshow(ms_up[:3].transpose(1, 2, 0))
axes[0, 0].set_title('MS Upsampled')
axes[0, 0].axis('off')

axes[0, 1].imshow(pan[0], cmap='gray')
axes[0, 1].set_title('PAN')
axes[0, 1].axis('off')

for idx, (name, result) in enumerate(results.items()):
    row, col = (idx + 2) // 3, (idx + 2) % 3
    axes[row, col].imshow(np.clip(result[:3].transpose(1, 2, 0), 0, 1))
    axes[row, col].set_title(name.upper())
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()

## 5. Loss Functions

In [None]:
# Demonstrate loss functions
from models import CombinedLoss, AdvancedCombinedLoss

# Create sample tensors
pred = torch.randn(1, 4, 64, 64)
target = torch.randn(1, 4, 64, 64)

# Basic combined loss
loss_fn = CombinedLoss()
loss, loss_dict = loss_fn(pred, target)
print("CombinedLoss:")
print(f"  Total: {loss.item():.4f}")
for k, v in loss_dict.items():
    print(f"  {k}: {v:.4f}")

In [None]:
# Advanced loss with SSIM and SAM
advanced_loss = AdvancedCombinedLoss(
    l1_weight=1.0,
    mse_weight=0.5,
    gradient_weight=0.1,
    ssim_weight=0.2,
    sam_weight=0.1
)

loss, loss_dict = advanced_loss(pred, target)
print("\nAdvancedCombinedLoss:")
print(f"  Total: {loss.item():.4f}")
for k, v in loss_dict.items():
    print(f"  {k}: {v:.4f}")

## Summary

This notebook demonstrated:
- Creating and using pansharpening models
- Comparing different architectures
- Using various loss functions

For full training, see `run_deep_learning.py` or the training notebook.