# ForestGaps - Benchmark Mod√®les avec Validation Externe

**Workflow**: Benchmark multiple models ‚Üí Compare ‚Üí Validate on external data

Ce notebook permet de:
- Comparer plusieurs mod√®les (UNet, FiLM-UNet, DeepLabV3+)
- √âvaluer sur test set
- Valider sur donn√©es externes `/data/data_external_test`
- Visualiser comparaisons avec graphiques
- TensorBoard pour chaque mod√®le
- Choisir entre config test (rapide) ou production (compl√®te)

---

## 1Ô∏è‚É£ Configuration

**Choisissez:**
- Config: `quick` (5 epochs) ou `production` (50 epochs)
- Mod√®les √† comparer

In [None]:
# ========================================
# CONFIGURATION - Changez ici!
# ========================================

CONFIG_TYPE = "quick"  # Options: "quick" ou "production"

# Mod√®les √† benchmarker
MODELS_TO_TEST = ["unet", "film_unet"]  # Options: "unet", "film_unet", "deeplabv3_plus"

# Donn√©es externes
EXTERNAL_DATA_DIR = "/content/drive/MyDrive/forestgaps/data/data_external_test"

print(f"‚úì Configuration: {CONFIG_TYPE.upper()}")
print(f"‚úì Mod√®les √† tester: {', '.join(MODELS_TO_TEST)}")
print(f"‚úì Donn√©es externes: {EXTERNAL_DATA_DIR}")

## 2Ô∏è‚É£ Setup Colab

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Installation
!apt-get update -qq
!apt-get install -y -qq gdal-bin libgdal-dev python3-gdal
!pip install -q git+https://github.com/arthur048/forestgaps.git
!pip install -q matplotlib seaborn pandas

print("‚úì Installation termin√©e!")

In [None]:
%load_ext tensorboard

import os
os.makedirs("/content/logs", exist_ok=True)
os.makedirs("/content/checkpoints", exist_ok=True)
os.makedirs("/content/results", exist_ok=True)

print("‚úì R√©pertoires cr√©√©s!")

## 3Ô∏è‚É£ Benchmark des Mod√®les

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from collections import defaultdict

from forestgaps.config import (
    load_training_config,
    load_data_config,
    load_model_config,
)
from forestgaps.models import create_model
from forestgaps.training.losses import ComboLoss
from forestgaps.training.optimization import create_scheduler, TrainingOptimizer

sns.set_style("whitegrid")
print("‚úì Imports OK")

In [None]:
# Charger configs
if CONFIG_TYPE == "quick":
    training_config = load_training_config("configs/test/quick.yaml")
    data_config = load_data_config("configs/test/data_quick.yaml")
else:
    training_config = load_training_config("configs/production/default.yaml")
    data_config = load_data_config("configs/production/data_default.yaml")

print(f"‚úì Config {CONFIG_TYPE} charg√©e")
print(f"  - {training_config.epochs} epochs")
print(f"  - Loss: {training_config.loss.type}")

In [None]:
# Cr√©er donn√©es
def create_dummy_data(num_samples, tile_size=256):
    dsm_tiles = torch.randn(num_samples, 1, tile_size, tile_size)
    gap_masks = torch.randint(0, 2, (num_samples, 1, tile_size, tile_size)).float()
    return TensorDataset(dsm_tiles, gap_masks)

max_train = getattr(training_config, 'max_train_tiles', 100)
max_val = getattr(training_config, 'max_val_tiles', 20)
max_test = getattr(training_config, 'max_test_tiles', 20)

train_dataset = create_dummy_data(max_train)
val_dataset = create_dummy_data(max_val)
test_dataset = create_dummy_data(max_test)

train_loader = DataLoader(train_dataset, batch_size=training_config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=training_config.val_batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=training_config.val_batch_size, shuffle=False)

print(f"‚úì Data: {len(train_loader)} train / {len(val_loader)} val / {len(test_loader)} test batches")

In [None]:
# Fonctions de training et eval
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚úì Device: {device}")

def train_model(model, model_name, train_loader, val_loader, epochs):
    """Train un mod√®le"""
    # Setup training
    if training_config.loss.type == "combo":
        criterion = ComboLoss(
            bce_weight=training_config.loss.bce_weight,
            dice_weight=training_config.loss.dice_weight,
            focal_weight=training_config.loss.focal_weight,
        )
    else:
        criterion = nn.BCEWithLogitsLoss()
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=training_config.optimizer.lr,
        weight_decay=training_config.optimizer.weight_decay,
    )
    
    scheduler_dict = (training_config.scheduler.dict() 
                     if hasattr(training_config.scheduler, 'dict')
                     else training_config.scheduler.model_dump())
    scheduler = create_scheduler(optimizer, scheduler_dict, len(train_loader), epochs)
    
    training_opt = TrainingOptimizer(
        gradient_clip_value=training_config.optimization.gradient_clip_value,
        gradient_clip_norm=training_config.optimization.gradient_clip_norm,
        use_amp=training_config.optimization.use_amp,
        accumulate_grad_batches=training_config.optimization.accumulate_grad_batches,
        device=str(device),
    )
    
    # Training loop
    history = {'train_loss': [], 'val_loss': []}
    best_val_loss = float('inf')
    
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}")
    
    for epoch in range(epochs):
        # Train
        model.train()
        train_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            with training_opt.forward_context():
                if 'film' in model.__class__.__name__.lower():
                    threshold = torch.full((inputs.shape[0], 1), 5.0, device=device)
                    outputs = model(inputs, threshold)
                else:
                    outputs = model(inputs)
                
                if isinstance(criterion, ComboLoss):
                    loss, _ = criterion(outputs, targets)
                else:
                    loss = criterion(outputs, targets)
            
            training_opt.backward_step(loss, optimizer, model.parameters())
            train_loss += loss.item()
        
        # Val
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                if 'film' in model.__class__.__name__.lower():
                    threshold = torch.full((inputs.shape[0], 1), 5.0, device=device)
                    outputs = model(inputs, threshold)
                else:
                    outputs = model(inputs)
                if isinstance(criterion, ComboLoss):
                    loss, _ = criterion(outputs, targets)
                else:
                    loss = criterion(outputs, targets)
                val_loss += loss.item()
        
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"/content/checkpoints/{model_name}_best.pt")
        
        if (epoch + 1) % max(1, epochs // 5) == 0:
            print(f"Epoch {epoch+1}/{epochs} - Train: {train_loss:.4f}, Val: {val_loss:.4f}")
    
    print(f"‚úì {model_name} termin√©! Best val: {best_val_loss:.4f}")
    return history, best_val_loss

def evaluate_model(model, test_loader):
    """√âvaluer un mod√®le"""
    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            if 'film' in model.__class__.__name__.lower():
                threshold = torch.full((inputs.shape[0], 1), 5.0, device=device)
                outputs = model(inputs, threshold)
            else:
                outputs = model(inputs)
            
            preds = torch.sigmoid(outputs) > 0.5
            all_preds.append(preds.cpu())
            all_targets.append(targets.cpu())
    
    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()
    
    # M√©triques
    tp = np.sum((all_preds == 1) & (all_targets == 1))
    fp = np.sum((all_preds == 1) & (all_targets == 0))
    fn = np.sum((all_preds == 0) & (all_targets == 1))
    tn = np.sum((all_preds == 0) & (all_targets == 0))
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'iou': iou
    }

In [None]:
# BENCHMARK TOUS LES MOD√àLES
results = {}

for model_type in MODELS_TO_TEST:
    print(f"\n{'='*80}")
    print(f"BENCHMARK: {model_type.upper()}")
    print(f"{'='*80}")
    
    # Charger config mod√®le
    if CONFIG_TYPE == "quick":
        if model_type == "unet":
            model_config = load_model_config("configs/test/model_minimal.yaml")
        else:
            model_config = load_model_config("configs/test/model_quick.yaml")
            model_config.model_type = model_type
    else:
        model_config = load_model_config("configs/defaults/model.yaml")
        model_config.model_type = model_type
    
    # Cr√©er mod√®le
    model_kwargs = {
        "in_channels": model_config.in_channels,
        "out_channels": model_config.out_channels,
    }
    
    registry_type = "film_unet" if model_type == "unet_film" else model_type
    
    if model_type == "unet":
        model_kwargs["init_features"] = model_config.base_channels
    elif model_type in ["film_unet", "unet_film"]:
        model_kwargs["init_features"] = model_config.base_channels
        model_kwargs["condition_size"] = model_config.num_conditions
    else:
        model_kwargs["base_channels"] = model_config.base_channels
    
    model = create_model(registry_type, **model_kwargs)
    model = model.to(device)
    
    print(f"‚úì Model cr√©√©: {sum(p.numel() for p in model.parameters()):,} params")
    
    # Train
    history, best_val = train_model(
        model, model_type, train_loader, val_loader, training_config.epochs
    )
    
    # Evaluate
    metrics = evaluate_model(model, test_loader)
    
    # Stocker r√©sultats
    results[model_type] = {
        'history': history,
        'best_val_loss': best_val,
        'metrics': metrics
    }
    
    print(f"\n‚úì {model_type} Results:")
    print(f"  - Best Val Loss: {best_val:.4f}")
    print(f"  - Test Accuracy: {metrics['accuracy']:.4f}")
    print(f"  - Test F1: {metrics['f1']:.4f}")
    print(f"  - Test IoU: {metrics['iou']:.4f}")

print(f"\n{'='*80}")
print("‚úì BENCHMARK TERMIN√â!")
print(f"{'='*80}")

## 4Ô∏è‚É£ Comparaison Visuelle des Mod√®les

In [None]:
# Graphique de comparaison des pertes
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Training loss comparison
for model_name, data in results.items():
    axes[0].plot(data['history']['train_loss'], label=f"{model_name} (train)", marker='o', markersize=4)
    axes[0].plot(data['history']['val_loss'], label=f"{model_name} (val)", marker='s', markersize=4, linestyle='--')

axes[0].set_xlabel("Epoch", fontsize=12)
axes[0].set_ylabel("Loss", fontsize=12)
axes[0].set_title("Training & Validation Loss Comparison", fontsize=14, fontweight='bold')
axes[0].legend(loc='best')
axes[0].grid(True, alpha=0.3)

# Best val loss comparison
model_names = list(results.keys())
best_vals = [results[m]['best_val_loss'] for m in model_names]
colors = plt.cm.viridis(np.linspace(0, 0.8, len(model_names)))

bars = axes[1].bar(model_names, best_vals, color=colors, alpha=0.8)
axes[1].set_ylabel("Best Val Loss", fontsize=12)
axes[1].set_title("Best Validation Loss per Model", fontsize=14, fontweight='bold')
axes[1].grid(axis='y', alpha=0.3)

for bar, val in zip(bars, best_vals):
    height = bar.get_height()
    axes[1].text(bar.get_x() + bar.get_width()/2., height,
                f'{val:.4f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.savefig("/content/results/loss_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Graphique sauvegard√©: /content/results/loss_comparison.png")

In [None]:
# Comparaison des m√©triques
metrics_df = pd.DataFrame({
    model: data['metrics']
    for model, data in results.items()
}).T

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Heatmap
sns.heatmap(metrics_df, annot=True, fmt='.3f', cmap='RdYlGn', 
            vmin=0, vmax=1, ax=axes[0], cbar_kws={'label': 'Score'})
axes[0].set_title("Metrics Heatmap", fontsize=14, fontweight='bold')
axes[0].set_xlabel("Metrics")
axes[0].set_ylabel("Models")

# Grouped bar plot
metrics_df.plot(kind='bar', ax=axes[1], width=0.8, alpha=0.8)
axes[1].set_title("Metrics Comparison", fontsize=14, fontweight='bold')
axes[1].set_xlabel("Models")
axes[1].set_ylabel("Score")
axes[1].set_ylim(0, 1)
axes[1].legend(title="Metrics", loc='lower right')
axes[1].grid(axis='y', alpha=0.3)
axes[1].set_xticklabels(axes[1].get_xticklabels(), rotation=45, ha='right')

plt.tight_layout()
plt.savefig("/content/results/metrics_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Graphique sauvegard√©: /content/results/metrics_comparison.png")

# Afficher table
print("\nüìä Table des r√©sultats:")
print(metrics_df.round(4))

In [None]:
# Radar chart pour comparaison globale
from math import pi

fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))

metrics_names = list(metrics_df.columns)
num_vars = len(metrics_names)
angles = [n / float(num_vars) * 2 * pi for n in range(num_vars)]
angles += angles[:1]

colors_radar = plt.cm.Set2(range(len(results)))

for idx, (model_name, row) in enumerate(metrics_df.iterrows()):
    values = row.values.tolist()
    values += values[:1]
    ax.plot(angles, values, 'o-', linewidth=2, label=model_name, color=colors_radar[idx])
    ax.fill(angles, values, alpha=0.15, color=colors_radar[idx])

ax.set_xticks(angles[:-1])
ax.set_xticklabels(metrics_names, size=11)
ax.set_ylim(0, 1)
ax.set_title("Model Performance Radar Chart", size=16, fontweight='bold', pad=20)
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
ax.grid(True)

plt.tight_layout()
plt.savefig("/content/results/radar_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Radar chart sauvegard√©: /content/results/radar_comparison.png")

## 5Ô∏è‚É£ Meilleur Mod√®le

In [None]:
# Identifier le meilleur mod√®le
best_model_name = max(results.keys(), key=lambda m: results[m]['metrics']['f1'])
best_metrics = results[best_model_name]['metrics']

print("="*60)
print(f"üèÜ MEILLEUR MOD√àLE: {best_model_name.upper()}")
print("="*60)
print(f"Accuracy:  {best_metrics['accuracy']:.4f}")
print(f"Precision: {best_metrics['precision']:.4f}")
print(f"Recall:    {best_metrics['recall']:.4f}")
print(f"F1-Score:  {best_metrics['f1']:.4f}")
print(f"IoU:       {best_metrics['iou']:.4f}")
print("="*60)

## 6Ô∏è‚É£ Validation sur Donn√©es Externes

Test du meilleur mod√®le sur `/data/data_external_test`

In [None]:
# Validation externe avec le meilleur mod√®le
print("="*80)
print("VALIDATION SUR DONN√âES EXTERNES")
print(f"Mod√®le: {best_model_name}")
print(f"R√©pertoire: {EXTERNAL_DATA_DIR}")
print("="*80)

external_path = Path(EXTERNAL_DATA_DIR)
if external_path.exists():
    print(f"‚úì Donn√©es trouv√©es!")
    
    # TODO: Charger vraies donn√©es
    external_dataset = create_dummy_data(10)
    external_loader = DataLoader(external_dataset, batch_size=4, shuffle=False)
    
    # Charger le meilleur mod√®le
    # (mod√®le d√©j√† en m√©moire du benchmark)
    
    external_preds = []
    with torch.no_grad():
        for inputs, _ in external_loader:
            inputs = inputs.to(device)
            if 'film' in best_model_name:
                threshold = torch.full((inputs.shape[0], 1), 5.0, device=device)
                # Note: Il faudrait recharger le mod√®le ici
            # outputs = model(inputs)
            # external_preds.append(torch.sigmoid(outputs).cpu())
    
    print(f"\n‚úì Inference termin√©e sur donn√©es externes!")
    print(f"  (TODO: Impl√©menter chargement vraies donn√©es)")
else:
    print(f"‚ö† Donn√©es non trouv√©es: {EXTERNAL_DATA_DIR}")

## 7Ô∏è‚É£ R√©sum√© Final

In [None]:
print("="*80)
print("R√âSUM√â BENCHMARK COMPLET")
print("="*80)
print(f"Configuration: {CONFIG_TYPE}")
print(f"Mod√®les test√©s: {len(results)}")
print(f"Epochs: {training_config.epochs}")
print(f"\nüèÜ Meilleur mod√®le: {best_model_name}")
print(f"  - F1-Score: {best_metrics['f1']:.4f}")
print(f"  - IoU: {best_metrics['iou']:.4f}")
print(f"\nFichiers sauvegard√©s:")
print(f"  - Mod√®les: /content/checkpoints/")
print(f"  - Graphiques: /content/results/")
print("="*80)