# ForestGaps - Training Complet avec Validation Externe

**Workflow complet**: Train ‚Üí Eval ‚Üí Inference sur donn√©es ind√©pendantes

Ce notebook permet de:
- Entra√Æner un mod√®le de d√©tection de trou√©es foresti√®res
- Valider sur donn√©es de test (train/val/test split)
- Valider sur donn√©es externes `/data/data_external_test`
- Visualiser avec TensorBoard
- Choisir entre config test (rapide) ou production (compl√®te)

---

## 1Ô∏è‚É£ S√©lection de Configuration

**Choisissez votre configuration:**
- `quick`: Test rapide (5 epochs, 50 tiles) - 2-5 minutes
- `production`: Training complet (50 epochs, toutes donn√©es) - plusieurs heures

In [None]:
# ========================================
# CONFIGURATION - Changez ici!
# ========================================

CONFIG_TYPE = "quick"  # Options: "quick" ou "production"

# Donn√©es externes (toujours dans ce dossier)
EXTERNAL_DATA_DIR = "/content/drive/MyDrive/forestgaps/data/data_external_test"

print(f"‚úì Configuration s√©lectionn√©e: {CONFIG_TYPE.upper()}")
print(f"‚úì Donn√©es externes: {EXTERNAL_DATA_DIR}")

## 2Ô∏è‚É£ Setup Colab

In [None]:
# V√©rifier GPU
!nvidia-smi

In [None]:
# Monter Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Installation des d√©pendances syst√®me
!apt-get update -qq
!apt-get install -y -qq gdal-bin libgdal-dev python3-gdal

# Installation du package ForestGaps
!pip install -q git+https://github.com/arthur048/forestgaps.git

print("‚úì Installation termin√©e!")

In [None]:
# Charger TensorBoard
%load_ext tensorboard

import os
os.makedirs("/content/logs", exist_ok=True)
os.makedirs("/content/checkpoints", exist_ok=True)

print("‚úì TensorBoard pr√™t!")

## 3Ô∏è‚É£ Training Complet

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

from forestgaps.config import (
    load_training_config,
    load_data_config,
    load_model_config,
)
from forestgaps.models import create_model
from forestgaps.training.losses import ComboLoss
from forestgaps.training.optimization import create_scheduler, TrainingOptimizer

# Style pour les graphiques
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

print("‚úì Imports OK")

In [None]:
# Charger configurations selon le type s√©lectionn√©
if CONFIG_TYPE == "quick":
    training_config = load_training_config("configs/test/quick.yaml")
    data_config = load_data_config("configs/test/data_quick.yaml")
    model_config = load_model_config("configs/test/model_quick.yaml")
else:  # production
    training_config = load_training_config("configs/production/default.yaml")
    data_config = load_data_config("configs/production/data_default.yaml")
    model_config = load_model_config("configs/defaults/model.yaml")

print(f"\n‚úì Configurations charg√©es: {CONFIG_TYPE}")
print(f"  - Training: {training_config.epochs} epochs")
print(f"  - Model: {model_config.model_type}")
print(f"  - Loss: {training_config.loss.type}")
print(f"  - Scheduler: {training_config.scheduler.type}")

In [None]:
# Cr√©er donn√©es dummy pour l'exemple
# (Remplacez par vos vraies donn√©es si disponibles)
def create_dummy_data(num_samples, tile_size=256):
    dsm_tiles = torch.randn(num_samples, 1, tile_size, tile_size)
    gap_masks = torch.randint(0, 2, (num_samples, 1, tile_size, tile_size)).float()
    return TensorDataset(dsm_tiles, gap_masks)

max_train = getattr(training_config, 'max_train_tiles', 100)
max_val = getattr(training_config, 'max_val_tiles', 20)
max_test = getattr(training_config, 'max_test_tiles', 20)

train_dataset = create_dummy_data(max_train)
val_dataset = create_dummy_data(max_val)
test_dataset = create_dummy_data(max_test)

train_loader = DataLoader(train_dataset, batch_size=training_config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=training_config.val_batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=training_config.val_batch_size, shuffle=False)

print(f"‚úì DataLoaders cr√©√©s:")
print(f"  - Train: {len(train_loader)} batches")
print(f"  - Val: {len(val_loader)} batches")
print(f"  - Test: {len(test_loader)} batches")

In [None]:
# Cr√©er mod√®le
model_kwargs = {
    "in_channels": model_config.in_channels,
    "out_channels": model_config.out_channels,
}

# Mapping config ‚Üí registry
model_type = model_config.model_type
if model_type == "unet_film":
    model_type = "film_unet"

# Param√®tres sp√©cifiques au mod√®le
if model_config.model_type == "unet":
    model_kwargs["init_features"] = model_config.base_channels
elif model_config.model_type in ["film_unet", "unet_film"]:
    model_kwargs["init_features"] = model_config.base_channels
    model_kwargs["condition_size"] = model_config.num_conditions
else:
    model_kwargs["base_channels"] = model_config.base_channels

model = create_model(model_type, **model_kwargs)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"‚úì Model: {model_config.model_type} sur {device}")
print(f"  - Param√®tres: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Loss function
if training_config.loss.type == "combo":
    criterion = ComboLoss(
        bce_weight=training_config.loss.bce_weight,
        dice_weight=training_config.loss.dice_weight,
        focal_weight=training_config.loss.focal_weight,
    )
    print(f"‚úì Combo Loss (BCE={training_config.loss.bce_weight}, Dice={training_config.loss.dice_weight}, Focal={training_config.loss.focal_weight})")
else:
    criterion = nn.BCEWithLogitsLoss()
    print("‚úì BCE Loss")

# Optimizer
if training_config.optimizer.type == "adamw":
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=training_config.optimizer.lr,
        weight_decay=training_config.optimizer.weight_decay,
    )
else:
    optimizer = torch.optim.Adam(model.parameters(), lr=training_config.optimizer.lr)

print(f"‚úì Optimizer: {training_config.optimizer.type}")

# Scheduler
scheduler_dict = (training_config.scheduler.dict() 
                 if hasattr(training_config.scheduler, 'dict')
                 else training_config.scheduler.model_dump())
scheduler = create_scheduler(
    optimizer,
    scheduler_dict,
    steps_per_epoch=len(train_loader),
    epochs=training_config.epochs,
)
print(f"‚úì Scheduler: {training_config.scheduler.type}")

# Training optimizer (AMP + gradient clipping)
training_opt = TrainingOptimizer(
    gradient_clip_value=training_config.optimization.gradient_clip_value,
    gradient_clip_norm=training_config.optimization.gradient_clip_norm,
    use_amp=training_config.optimization.use_amp,
    accumulate_grad_batches=training_config.optimization.accumulate_grad_batches,
    device=str(device),
)
print(f"‚úì AMP: {training_config.optimization.use_amp}, Grad clip: {training_config.optimization.gradient_clip_norm}")

In [None]:
# Lancer TensorBoard
%tensorboard --logdir /content/logs

In [None]:
# PHASE 1: TRAINING
print("="*80)
print("PHASE 1: TRAINING")
print("="*80)

history = {'train_loss': [], 'val_loss': [], 'lr': []}
best_val_loss = float('inf')

for epoch in range(training_config.epochs):
    # Training
    model.train()
    train_loss = 0.0
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        with training_opt.forward_context():
            # Support FiLM models
            if 'film' in model.__class__.__name__.lower():
                threshold = torch.full((inputs.shape[0], 1), 5.0, device=device)
                outputs = model(inputs, threshold)
            else:
                outputs = model(inputs)
            
            if isinstance(criterion, ComboLoss):
                loss, _ = criterion(outputs, targets)
            else:
                loss = criterion(outputs, targets)
        
        step_info = training_opt.backward_step(loss, optimizer, model.parameters())
        train_loss += loss.item()
        
        if hasattr(scheduler, 'step') and training_opt.accumulator.should_step():
            if 'OneCycleLR' in scheduler.__class__.__name__:
                scheduler.step()
    
    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            if 'film' in model.__class__.__name__.lower():
                threshold = torch.full((inputs.shape[0], 1), 5.0, device=device)
                outputs = model(inputs, threshold)
            else:
                outputs = model(inputs)
            if isinstance(criterion, ComboLoss):
                loss, _ = criterion(outputs, targets)
            else:
                loss = criterion(outputs, targets)
            val_loss += loss.item()
    
    train_loss /= len(train_loader)
    val_loss /= len(val_loader)
    current_lr = optimizer.param_groups[0]['lr']
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['lr'].append(current_lr)
    
    # Scheduler step (if not OneCycleLR)
    if hasattr(scheduler, 'step') and 'OneCycleLR' not in scheduler.__class__.__name__:
        if 'ReduceLROnPlateau' in scheduler.__class__.__name__:
            scheduler.step(val_loss)
        else:
            scheduler.step()
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch
        torch.save(model.state_dict(), "/content/checkpoints/best_model.pt")
    
    print(f"Epoch {epoch+1}/{training_config.epochs} - Train: {train_loss:.4f}, Val: {val_loss:.4f}, LR: {current_lr:.6f}")

print(f"\n‚úì Training termin√©! Meilleur val loss: {best_val_loss:.4f} (epoch {best_epoch+1})")

### üìä Visualisation des Courbes de Training

In [None]:
# Graphiques de training
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
epochs_range = range(1, len(history['train_loss']) + 1)
axes[0].plot(epochs_range, history['train_loss'], 'b-o', label='Train Loss', linewidth=2, markersize=6)
axes[0].plot(epochs_range, history['val_loss'], 'r-s', label='Val Loss', linewidth=2, markersize=6)
axes[0].axvline(x=best_epoch+1, color='g', linestyle='--', linewidth=2, label=f'Best Epoch ({best_epoch+1})')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title(f'Training & Validation Loss - {model_config.model_type.upper()}', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Learning rate schedule
axes[1].plot(epochs_range, history['lr'], 'g-o', linewidth=2, markersize=6)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Learning Rate', fontsize=12)
axes[1].set_title(f'Learning Rate Schedule - {training_config.scheduler.type.upper()}', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].set_yscale('log')

plt.tight_layout()
plt.savefig('/content/training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úì Graphiques sauvegard√©s: /content/training_curves.png")

## 4Ô∏è‚É£ Evaluation sur Test Set

In [None]:
# PHASE 2: EVALUATION sur donn√©es de test
print("="*80)
print("PHASE 2: EVALUATION SUR TEST SET")
print("="*80)

# Charger le meilleur mod√®le
model.load_state_dict(torch.load("/content/checkpoints/best_model.pt"))
model.eval()

test_loss = 0.0
all_preds = []
all_targets = []

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        if 'film' in model.__class__.__name__.lower():
            threshold = torch.full((inputs.shape[0], 1), 5.0, device=device)
            outputs = model(inputs, threshold)
        else:
            outputs = model(inputs)
        
        if isinstance(criterion, ComboLoss):
            loss, _ = criterion(outputs, targets)
        else:
            loss = criterion(outputs, targets)
        test_loss += loss.item()
        
        preds = torch.sigmoid(outputs) > 0.5
        all_preds.append(preds.cpu())
        all_targets.append(targets.cpu())

test_loss /= len(test_loader)
all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()

# M√©triques
tp = np.sum((all_preds == 1) & (all_targets == 1))
fp = np.sum((all_preds == 1) & (all_targets == 0))
fn = np.sum((all_preds == 0) & (all_targets == 1))
tn = np.sum((all_preds == 0) & (all_targets == 0))

precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
accuracy = (tp + tn) / (tp + tn + fp + fn)

# Stocker pour visualisation
test_metrics = {
    'Accuracy': accuracy,
    'Precision': precision,
    'Recall': recall,
    'F1-Score': f1,
    'IoU': iou
}

print(f"\nR√©sultats Test Set:")
print(f"  Test Loss: {test_loss:.4f}")
for metric_name, metric_value in test_metrics.items():
    print(f"  {metric_name}: {metric_value:.4f}")

### üìä Visualisation des M√©triques et Matrice de Confusion

In [None]:
# Graphiques de m√©triques
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Barplot des m√©triques
metrics_names = list(test_metrics.keys())
metrics_values = list(test_metrics.values())
colors = plt.cm.RdYlGn([x for x in metrics_values])  # Couleurs selon la valeur

bars = axes[0].barh(metrics_names, metrics_values, color=colors, edgecolor='black', linewidth=1.5)
axes[0].set_xlabel('Score', fontsize=12)
axes[0].set_title('Test Set Metrics', fontsize=14, fontweight='bold')
axes[0].set_xlim(0, 1)
axes[0].grid(axis='x', alpha=0.3)

# Ajouter valeurs sur les barres
for i, (bar, value) in enumerate(zip(bars, metrics_values)):
    axes[0].text(value + 0.02, i, f'{value:.3f}', va='center', fontsize=10, fontweight='bold')

# Matrice de confusion
confusion_matrix = np.array([[tn, fp], [fn, tp]])
sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Predicted Negative', 'Predicted Positive'],
            yticklabels=['Actual Negative', 'Actual Positive'],
            ax=axes[1], cbar_kws={'label': 'Count'},
            linewidths=2, linecolor='black')
axes[1].set_title('Confusion Matrix', fontsize=14, fontweight='bold')
axes[1].set_ylabel('True Label', fontsize=12)
axes[1].set_xlabel('Predicted Label', fontsize=12)

plt.tight_layout()
plt.savefig('/content/test_metrics.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úì Graphiques sauvegard√©s: /content/test_metrics.png")

## 5Ô∏è‚É£ Validation sur Donn√©es Externes

Validation sur `/data/data_external_test` (donn√©es jamais vues pendant training)

In [None]:
# PHASE 3: VALIDATION sur donn√©es externes
print("="*80)
print("PHASE 3: VALIDATION DONN√âES EXTERNES")
print(f"R√©pertoire: {EXTERNAL_DATA_DIR}")
print("="*80)

# V√©rifier si les donn√©es existent
external_path = Path(EXTERNAL_DATA_DIR)
if external_path.exists():
    print(f"‚úì Donn√©es trouv√©es: {EXTERNAL_DATA_DIR}")
    
    # TODO: Charger vraies donn√©es externes
    # Pour l'instant, utiliser donn√©es dummy
    external_dataset = create_dummy_data(10)
    external_loader = DataLoader(external_dataset, batch_size=4, shuffle=False)
    
    model.eval()
    external_preds = []
    
    with torch.no_grad():
        for batch_idx, (inputs, _) in enumerate(external_loader):
            inputs = inputs.to(device)
            if 'film' in model.__class__.__name__.lower():
                threshold = torch.full((inputs.shape[0], 1), 5.0, device=device)
                outputs = model(inputs, threshold)
            else:
                outputs = model(inputs)
            preds = torch.sigmoid(outputs)
            external_preds.append(preds.cpu())
    
    external_preds = torch.cat(external_preds)
    
    print(f"\n‚úì Inference termin√©e sur donn√©es externes!")
    print(f"  Nombre de pr√©dictions: {external_preds.shape[0]}")
    print(f"  Shape: {external_preds.shape}")
    print(f"  Pr√©diction moyenne: {external_preds.mean():.4f}")
    print(f"  Pr√©diction min/max: {external_preds.min():.4f} / {external_preds.max():.4f}")
else:
    print(f"‚ö† Donn√©es externes non trouv√©es: {EXTERNAL_DATA_DIR}")
    print("  Assurez-vous que le dossier existe dans Google Drive")

## 6Ô∏è‚É£ R√©sum√© Final

In [None]:
print("="*80)
print("R√âSUM√â COMPLET")
print("="*80)
print(f"Configuration: {CONFIG_TYPE}")
print(f"Model: {model_config.model_type}")
print(f"Epochs: {training_config.epochs}")
print(f"\nTraining:")
print(f"  - Final train loss: {history['train_loss'][-1]:.4f}")
print(f"  - Final val loss: {history['val_loss'][-1]:.4f}")
print(f"  - Best val loss: {best_val_loss:.4f} (epoch {best_epoch+1})")
print(f"\nTest Set:")
for metric_name, metric_value in test_metrics.items():
    print(f"  - {metric_name}: {metric_value:.4f}")
print(f"\nFichiers g√©n√©r√©s:")
print(f"  - Mod√®le: /content/checkpoints/best_model.pt")
print(f"  - Graphiques: /content/training_curves.png")
print(f"  - M√©triques: /content/test_metrics.png")
print("="*80)

---

## üì• T√©l√©charger les R√©sultats

Pour t√©l√©charger les fichiers g√©n√©r√©s:

In [None]:
# T√©l√©charger les fichiers
from google.colab import files

print("T√©l√©chargement des fichiers...")
files.download('/content/checkpoints/best_model.pt')
files.download('/content/training_curves.png')
files.download('/content/test_metrics.png')
print("‚úì T√©l√©chargements termin√©s!")