# 02 - Entrainement du CycleGAN

Ce notebook entraine un CycleGAN pour apprendre la transformation
bidirectionnelle entre images agricoles normales et en conditions de secheresse.

**Pipeline :**
1. Preparer les donnees (domaine A = normal, domaine B = secheresse simulee)
2. Initialiser le CycleGAN avec sauvegarde Drive
3. Entrainer (~3-5h sur GPU T4 Colab) avec auto-resume
4. Visualiser les resultats
5. Sauvegarder le modele final dans Drive

In [1]:
"""
Configuration Google Drive pour sauvegarde outputs.
Les notebooks restent sur GitHub, seuls les checkpoints/resultats vont dans Drive.
"""

from google.colab import drive
import os
import sys

# Montage Drive
drive.mount('/content/drive', force_remount=True)

# Chemin racine pour les outputs uniquement
DRIVE_OUTPUTS = "/content/drive/MyDrive/SatelliteGAN-Outputs"

# Creation structure outputs
for subdir in [
    'data/eurosat', 'data/processed_drought',
    'cyclegan/checkpoints', 'cyclegan/generated_images', 'cyclegan/losses',
    'diffusion/checkpoints', 'diffusion/samples', 'diffusion/losses',
    'evaluation/metrics', 'evaluation/comparisons', 'evaluation/figures',
]:
    os.makedirs(f"{DRIVE_OUTPUTS}/{subdir}", exist_ok=True)

print(f"Drive monte : {DRIVE_OUTPUTS}")
print(f"Structure outputs creee")

# Clone du repo GitHub (code source)
if not os.path.exists('/content/SatelliteGAN-Climate-Agriculture'):
    !git clone https://github.com/aymenssf/SatelliteGAN-Climate-Agriculture.git /content/SatelliteGAN-Climate-Agriculture
    !pip install -q -r /content/SatelliteGAN-Climate-Agriculture/requirements.txt

%cd /content/SatelliteGAN-Climate-Agriculture
sys.path.insert(0, '/content/SatelliteGAN-Climate-Agriculture')

print("Code source charge depuis GitHub")
print("Outputs seront sauvegardes dans Drive")

Mounted at /content/drive
Drive monte : /content/drive/MyDrive/SatelliteGAN-Outputs
Structure outputs creee
Cloning into '/content/SatelliteGAN-Climate-Agriculture'...
remote: Enumerating objects: 58, done.[K
remote: Counting objects: 100% (58/58), done.[K
remote: Compressing objects: 100% (45/45), done.[K
Receiving objects: 100% (58/58), 1.19 MiB | 4.11 MiB/s, done.
remote: Total 58 (delta 19), reused 42 (delta 12), pack-reused 0 (from 0)[K
Resolving deltas: 100% (19/19), done.
/content/SatelliteGAN-Climate-Agriculture
Code source charge depuis GitHub
Outputs seront sauvegardes dans Drive


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

from src.config import DEVICE, CYCLEGAN, IMAGE_SIZE
from src.dataset import get_agricultural_dataset, split_dataset
from src.preprocessing import (
    get_cyclegan_transform, get_eval_transform,
    simulate_drought, denormalize, tensor_to_numpy
)
from src.cyclegan.train import CycleGANTrainer
from src.evaluation.visualization import (
    show_cyclegan_results, show_comparison, plot_training_losses
)

print(f"Device : {DEVICE}")
print(f"Configuration CycleGAN : {CYCLEGAN}")

Device : cuda
Configuration CycleGAN : {'n_residual_blocks': 6, 'batch_size': 4, 'lr': 0.0002, 'betas': (0.5, 0.999), 'lambda_cycle': 10.0, 'lambda_identity': 5.0, 'n_epochs': 100, 'decay_epoch': 50, 'save_every': 10, 'replay_buffer_size': 50}


## 1. Preparation des donnees

On cree deux domaines :
- **Domaine A** : images agricoles normales (EuroSAT)
- **Domaine B** : images agricoles avec secheresse simulee

Le CycleGAN n'a pas besoin de paires appariees, mais on cree
les deux domaines a partir du meme dataset pour comparer.

In [3]:
from torchvision import transforms
from PIL import Image

# Charger le dataset brut (sans normalisation) pour appliquer la simulation
raw_transform = transforms.Resize(IMAGE_SIZE)
raw_dataset = get_agricultural_dataset(transform=raw_transform)

# Split
train_set, val_set, _ = split_dataset(raw_dataset)

print(f"Images d'entrainement : {len(train_set)}")

100%|██████████| 94.3M/94.3M [00:00<00:00, 95.4MB/s]


EuroSAT agricole : 10500 images, 4 classes ['AnnualCrop', 'PermanentCrop', 'Pasture', 'HerbaceousVegetation']
Split : train=8400, val=1050, test=1050
Images d'entrainement : 8400


In [4]:
# Creer les tensors pour les deux domaines
cyclegan_transform = get_cyclegan_transform()
eval_transform = get_eval_transform()

def prepare_paired_batch(dataset, transform, n_max=None):
    """
    Prepare les batches pour le CycleGAN.
    Domaine A = images normales normalisees
    Domaine B = images secheresse normalisees
    """
    images_a = []
    images_b = []
    n = min(len(dataset), n_max) if n_max else len(dataset)

    for i in range(n):
        img, _ = dataset[i]

        # Convertir en PIL si c'est un tensor
        if isinstance(img, torch.Tensor):
            img_pil = transforms.ToPILImage()(img)
        else:
            img_pil = img

        # Domaine A : normal
        img_a = transform(img_pil)

        # Domaine B : secheresse simulee
        drought_pil = simulate_drought(img_pil, severity=0.6)
        img_b = transform(drought_pil)

        images_a.append(img_a)
        images_b.append(img_b)

    return torch.stack(images_a), torch.stack(images_b)

print("Preparation des donnees d'entrainement...")
train_a, train_b = prepare_paired_batch(train_set, cyclegan_transform)
print(f"Domaine A : {train_a.shape}")
print(f"Domaine B : {train_b.shape}")

Preparation des donnees d'entrainement...
Domaine A : torch.Size([8400, 3, 64, 64])
Domaine B : torch.Size([8400, 3, 64, 64])


In [8]:
# Creer le DataLoader (optimise pour GPU)
train_dataset = TensorDataset(train_a, train_b)
train_loader = DataLoader(
    train_dataset,
    batch_size=CYCLEGAN['batch_size'],
    shuffle=True,
    drop_last=True,
    num_workers=2,            # Chargement parallele
    pin_memory=True,          # Transfert GPU plus rapide
    prefetch_factor=2,        # Pre-charge 2 batches
    persistent_workers=True   # Garde les workers actifs entre epochs
)

print(f"Nombre de batches par epoch : {len(train_loader)}")
print("DataLoader optimise (num_workers=2, pin_memory=True)")

Nombre de batches par epoch : 2100
DataLoader optimise (num_workers=2, pin_memory=True)


## 2. Initialisation du CycleGAN

Le `save_dir` pointe vers Google Drive pour que les checkpoints,
images generees et historique des pertes survivent aux deconnexions Colab.

Si un checkpoint existe deja dans Drive, l'entrainement reprend
automatiquement depuis le dernier checkpoint (auto-resume).

In [6]:
# Creer le trainer avec sauvegarde Drive
trainer = CycleGANTrainer(
    save_dir=f"{DRIVE_OUTPUTS}/cyclegan"  # Checkpoints, images, pertes -> Drive
)

# Compter les parametres
n_params_G = sum(p.numel() for p in trainer.G_A2B.parameters())
n_params_D = sum(p.numel() for p in trainer.D_A.parameters())
print(f"Parametres Generateur : {n_params_G:,}")
print(f"Parametres Discriminateur : {n_params_D:,}")
print(f"Total : {2*n_params_G + 2*n_params_D:,}")

Parametres Generateur : 7,833,987
Parametres Discriminateur : 2,763,841
Total : 21,195,656


In [9]:
# ============================================================
# Mixed Precision (AMP) - Accelere l'entrainement de ~30-40%
# Utilise FP16 pour les calculs, FP32 pour les mises a jour
# Sans perte de qualite pour les GANs a cette resolution
# ============================================================
from torch.cuda.amp import autocast, GradScaler
from src.config import DEVICE

scaler_G = GradScaler()
scaler_D_A = GradScaler()
scaler_D_B = GradScaler()

_original_train_step = trainer.train_step

def _amp_train_step(self, real_a, real_b):
    real_a = real_a.to(DEVICE)
    real_b = real_b.to(DEVICE)

    # ---- Generateurs avec AMP ----
    self.opt_G.zero_grad()
    with autocast():
        fake_b = self.G_A2B(real_a)
        fake_a = self.G_B2A(real_b)
        cycle_a = self.G_B2A(fake_b)
        cycle_b = self.G_A2B(fake_a)
        idt_a = self.G_B2A(real_a)
        idt_b = self.G_A2B(real_b)
        pred_fake_a = self.D_A(fake_a)
        pred_fake_b = self.D_B(fake_b)
        loss_G, components = self.loss_fn.generator_loss(
            pred_fake_a, pred_fake_b,
            cycle_a, real_a, cycle_b, real_b,
            idt_a, idt_b,
        )

    scaler_G.scale(loss_G).backward()
    scaler_G.step(self.opt_G)
    scaler_G.update()

    # ---- Discriminateur A avec AMP ----
    self.opt_D_A.zero_grad()
    fake_a_buffer = self.buffer_A.push_and_pop(fake_a.detach())
    with autocast():
        pred_real_a = self.D_A(real_a)
        pred_fake_a = self.D_A(fake_a_buffer)
        loss_D_A = self.loss_fn.discriminator_loss(pred_real_a, pred_fake_a)

    scaler_D_A.scale(loss_D_A).backward()
    scaler_D_A.step(self.opt_D_A)
    scaler_D_A.update()

    # ---- Discriminateur B avec AMP ----
    self.opt_D_B.zero_grad()
    fake_b_buffer = self.buffer_B.push_and_pop(fake_b.detach())
    with autocast():
        pred_real_b = self.D_B(real_b)
        pred_fake_b = self.D_B(fake_b_buffer)
        loss_D_B = self.loss_fn.discriminator_loss(pred_real_b, pred_fake_b)

    scaler_D_B.scale(loss_D_B).backward()
    scaler_D_B.step(self.opt_D_B)
    scaler_D_B.update()

    return {
        'loss_G': loss_G.item(),
        'loss_D_A': loss_D_A.item(),
        'loss_D_B': loss_D_B.item(),
        'loss_cycle': components['cycle_a'] + components['cycle_b'],
        'loss_identity': components['identity'],
    }

# Remplacer train_step par la version AMP
import types
trainer.train_step = types.MethodType(_amp_train_step, trainer)

print("Mixed Precision (AMP) active")
print("Gain estime : ~30-40% plus rapide, meme qualite")

Mixed Precision (AMP) active
Gain estime : ~30-40% plus rapide, meme qualite


  scaler_G = GradScaler()
  scaler_D_A = GradScaler()
  scaler_D_B = GradScaler()


## 3. Entrainement

L'entrainement alterne entre les generateurs et les discriminateurs.
Le replay buffer stabilise l'entrainement.

**Sauvegarde automatique :**
- Checkpoints tous les 10 epochs dans `Drive/SatelliteGAN-Outputs/cyclegan/checkpoints/`
- Images A->B et B->A dans `Drive/SatelliteGAN-Outputs/cyclegan/generated_images/`
- Historique pertes (JSON) dans `Drive/SatelliteGAN-Outputs/cyclegan/losses/`

**Auto-resume :** si l'entrainement est interrompu, relancer cette cellule.
Le trainer detecte le dernier checkpoint et reprend automatiquement.

In [10]:
# Entrainer le modele
# Reduire n_epochs pour un test rapide (ex: 5-10 epochs)
# Pour l'entrainement complet, utiliser CYCLEGAN['n_epochs'] (100)
N_EPOCHS = CYCLEGAN['n_epochs']  # mettre 5 pour un test rapide

# L'auto-resume detecte les checkpoints existants dans Drive.
# Pour forcer une reprise depuis un checkpoint specifique :
#   history = trainer.train(train_loader, n_epochs=N_EPOCHS,
#                           resume_from=f"{DRIVE_OUTPUTS}/cyclegan/checkpoints/epoch_40.pth")

history = trainer.train(train_loader, n_epochs=N_EPOCHS)

Checkpoint charge : epoch 30
Reprise de l'entrainement a l'epoch 30/100


  with autocast():
  with autocast():
  with autocast():
Epoch 31/100: 100%|██████████| 2100/2100 [03:14<00:00, 10.80it/s, G=3.624, D_A=0.209, D_B=0.089]


[Epoch 31] G: 3.4758 | D_A: 0.0780 | D_B: 0.0807 | Cycle: 1.5054


Epoch 32/100: 100%|██████████| 2100/2100 [03:11<00:00, 10.97it/s, G=3.501, D_A=0.087, D_B=0.069]


[Epoch 32] G: 3.4406 | D_A: 0.0742 | D_B: 0.0763 | Cycle: 1.4726


Epoch 33/100: 100%|██████████| 2100/2100 [03:10<00:00, 11.00it/s, G=3.171, D_A=0.071, D_B=0.091]


[Epoch 33] G: 3.4130 | D_A: 0.0733 | D_B: 0.0754 | Cycle: 1.4529


Epoch 34/100: 100%|██████████| 2100/2100 [03:10<00:00, 11.03it/s, G=3.012, D_A=0.052, D_B=0.063]


[Epoch 34] G: 3.4209 | D_A: 0.0712 | D_B: 0.0732 | Cycle: 1.4407


Epoch 35/100: 100%|██████████| 2100/2100 [03:10<00:00, 11.03it/s, G=3.049, D_A=0.092, D_B=0.044]


[Epoch 35] G: 3.3941 | D_A: 0.0705 | D_B: 0.0719 | Cycle: 1.4200


Epoch 36/100: 100%|██████████| 2100/2100 [03:10<00:00, 11.05it/s, G=3.051, D_A=0.188, D_B=0.096]


[Epoch 36] G: 3.3643 | D_A: 0.0706 | D_B: 0.0712 | Cycle: 1.4009


Epoch 37/100: 100%|██████████| 2100/2100 [03:10<00:00, 11.02it/s, G=4.043, D_A=0.029, D_B=0.128]


[Epoch 37] G: 3.3643 | D_A: 0.0667 | D_B: 0.0727 | Cycle: 1.4014


Epoch 38/100: 100%|██████████| 2100/2100 [03:11<00:00, 10.97it/s, G=3.418, D_A=0.042, D_B=0.046]


[Epoch 38] G: 3.3461 | D_A: 0.0673 | D_B: 0.0721 | Cycle: 1.3788


Epoch 39/100: 100%|██████████| 2100/2100 [03:13<00:00, 10.85it/s, G=2.977, D_A=0.152, D_B=0.046]


[Epoch 39] G: 3.3379 | D_A: 0.0640 | D_B: 0.0717 | Cycle: 1.3727


Epoch 40/100: 100%|██████████| 2100/2100 [03:11<00:00, 10.96it/s, G=3.481, D_A=0.044, D_B=0.078]


[Epoch 40] G: 3.3387 | D_A: 0.0646 | D_B: 0.0727 | Cycle: 1.3582
Checkpoint sauvegarde : /content/drive/MyDrive/SatelliteGAN-Outputs/cyclegan/checkpoints/epoch_40.pth


Epoch 41/100: 100%|██████████| 2100/2100 [03:11<00:00, 10.99it/s, G=2.599, D_A=0.083, D_B=0.080]


[Epoch 41] G: 3.3453 | D_A: 0.0645 | D_B: 0.0699 | Cycle: 1.3604


Epoch 42/100: 100%|██████████| 2100/2100 [03:10<00:00, 11.00it/s, G=3.132, D_A=0.105, D_B=0.073]


[Epoch 42] G: 3.3460 | D_A: 0.0631 | D_B: 0.0637 | Cycle: 1.3421


Epoch 43/100: 100%|██████████| 2100/2100 [03:11<00:00, 10.94it/s, G=3.419, D_A=0.038, D_B=0.085]


[Epoch 43] G: 3.3388 | D_A: 0.0625 | D_B: 0.0659 | Cycle: 1.3397


Epoch 44/100: 100%|██████████| 2100/2100 [03:14<00:00, 10.82it/s, G=3.740, D_A=0.093, D_B=0.080]


[Epoch 44] G: 3.3559 | D_A: 0.0634 | D_B: 0.0619 | Cycle: 1.3389


Epoch 45/100: 100%|██████████| 2100/2100 [03:12<00:00, 10.90it/s, G=3.561, D_A=0.058, D_B=0.044]


[Epoch 45] G: 3.3590 | D_A: 0.0626 | D_B: 0.0603 | Cycle: 1.3366


Epoch 46/100: 100%|██████████| 2100/2100 [03:13<00:00, 10.87it/s, G=2.839, D_A=0.129, D_B=0.032]


[Epoch 46] G: 3.3556 | D_A: 0.0607 | D_B: 0.0607 | Cycle: 1.3293


Epoch 47/100: 100%|██████████| 2100/2100 [03:12<00:00, 10.92it/s, G=3.513, D_A=0.078, D_B=0.092]


[Epoch 47] G: 3.3441 | D_A: 0.0605 | D_B: 0.0628 | Cycle: 1.3236


Epoch 48/100: 100%|██████████| 2100/2100 [03:12<00:00, 10.92it/s, G=3.278, D_A=0.043, D_B=0.037]


[Epoch 48] G: 3.3725 | D_A: 0.0574 | D_B: 0.0577 | Cycle: 1.3138


Epoch 49/100: 100%|██████████| 2100/2100 [03:12<00:00, 10.90it/s, G=2.150, D_A=0.054, D_B=0.099]


[Epoch 49] G: 3.3439 | D_A: 0.0590 | D_B: 0.0577 | Cycle: 1.3063


Epoch 50/100: 100%|██████████| 2100/2100 [03:13<00:00, 10.83it/s, G=2.710, D_A=0.083, D_B=0.087]


[Epoch 50] G: 3.3320 | D_A: 0.0544 | D_B: 0.0593 | Cycle: 1.2983
Checkpoint sauvegarde : /content/drive/MyDrive/SatelliteGAN-Outputs/cyclegan/checkpoints/epoch_50.pth


Epoch 51/100: 100%|██████████| 2100/2100 [03:15<00:00, 10.76it/s, G=3.571, D_A=0.044, D_B=0.025]


[Epoch 51] G: 3.3530 | D_A: 0.0557 | D_B: 0.0563 | Cycle: 1.2955


Epoch 52/100: 100%|██████████| 2100/2100 [03:13<00:00, 10.83it/s, G=3.233, D_A=0.042, D_B=0.085]


[Epoch 52] G: 3.3341 | D_A: 0.0535 | D_B: 0.0558 | Cycle: 1.2740


Epoch 53/100: 100%|██████████| 2100/2100 [03:13<00:00, 10.87it/s, G=2.580, D_A=0.031, D_B=0.048]


[Epoch 53] G: 3.3349 | D_A: 0.0519 | D_B: 0.0522 | Cycle: 1.2700


Epoch 54/100: 100%|██████████| 2100/2100 [03:13<00:00, 10.83it/s, G=3.653, D_A=0.047, D_B=0.041]


[Epoch 54] G: 3.3236 | D_A: 0.0497 | D_B: 0.0530 | Cycle: 1.2574


Epoch 55/100: 100%|██████████| 2100/2100 [03:14<00:00, 10.81it/s, G=2.824, D_A=0.028, D_B=0.025]


[Epoch 55] G: 3.3002 | D_A: 0.0508 | D_B: 0.0487 | Cycle: 1.2366


Epoch 56/100: 100%|██████████| 2100/2100 [03:14<00:00, 10.82it/s, G=3.810, D_A=0.113, D_B=0.047]


[Epoch 56] G: 3.3002 | D_A: 0.0494 | D_B: 0.0492 | Cycle: 1.2348


Epoch 57/100: 100%|██████████| 2100/2100 [03:15<00:00, 10.73it/s, G=2.722, D_A=0.027, D_B=0.137]


[Epoch 57] G: 3.3134 | D_A: 0.0451 | D_B: 0.0504 | Cycle: 1.2205


Epoch 58/100: 100%|██████████| 2100/2100 [03:13<00:00, 10.83it/s, G=2.650, D_A=0.039, D_B=0.041]


[Epoch 58] G: 3.3051 | D_A: 0.0471 | D_B: 0.0456 | Cycle: 1.2130


Epoch 59/100: 100%|██████████| 2100/2100 [03:13<00:00, 10.84it/s, G=3.484, D_A=0.039, D_B=0.026]


[Epoch 59] G: 3.3005 | D_A: 0.0437 | D_B: 0.0410 | Cycle: 1.1947


Epoch 60/100: 100%|██████████| 2100/2100 [03:13<00:00, 10.86it/s, G=2.435, D_A=0.063, D_B=0.060]


[Epoch 60] G: 3.2904 | D_A: 0.0423 | D_B: 0.0443 | Cycle: 1.1842
Checkpoint sauvegarde : /content/drive/MyDrive/SatelliteGAN-Outputs/cyclegan/checkpoints/epoch_60.pth


Epoch 61/100:  66%|██████▌   | 1388/2100 [02:09<01:06, 10.73it/s, G=3.278, D_A=0.037, D_B=0.039]


KeyboardInterrupt: 

## 4. Courbes de perte

In [None]:
# Afficher les courbes de perte (sauvegardees dans Drive)
plot_training_losses(
    history,
    title='CycleGAN - Courbes de perte',
    save_path=f"{DRIVE_OUTPUTS}/evaluation/figures/cyclegan_losses.png"
)

## 5. Visualisation des resultats

In [None]:
# Preparer quelques images de validation
print("Preparation des donnees de validation...")
val_a, val_b = prepare_paired_batch(val_set, eval_transform, n_max=8)

# Generer les transformations
with torch.no_grad():
    val_a_dev = val_a.to(DEVICE)
    val_b_dev = val_b.to(DEVICE)

    fake_b = trainer.G_A2B(val_a_dev)    # Normal -> Secheresse
    fake_a = trainer.G_B2A(val_b_dev)    # Secheresse -> Normal
    cycle_a = trainer.G_B2A(fake_b)      # Normal -> Secheresse -> Normal
    cycle_b = trainer.G_A2B(fake_a)      # Secheresse -> Normal -> Secheresse

# Afficher les resultats complets (sauvegarde dans Drive)
show_cyclegan_results(
    val_a_dev.cpu(), fake_b.cpu(), cycle_a.cpu(),
    val_b_dev.cpu(), fake_a.cpu(), cycle_b.cpu(),
    n_samples=4,
    save_path=f"{DRIVE_OUTPUTS}/evaluation/comparisons/cyclegan_cycle_results.png"
)

In [None]:
# Comparaison directe Normal vs CycleGAN secheresse
show_comparison(
    val_a.cpu(), fake_b.cpu(),
    n_samples=6,
    labels=('Normal (reel)', 'Secheresse (CycleGAN)'),
    title='Transformation Normal -> Secheresse par CycleGAN',
    save_path=f"{DRIVE_OUTPUTS}/evaluation/comparisons/cyclegan_normal_vs_drought.png"
)

## 6. Resume

**Observations :**
- Le CycleGAN apprend a transformer les images normales en images seches
- La coherence cyclique preserve la structure spatiale (routes, limites de parcelles)
- La perte d'identite aide a preserver les teintes de couleur

**Sauvegardes dans Drive :**
- `cyclegan/checkpoints/` : checkpoints toutes les 10 epochs + final.pth
- `cyclegan/generated_images/` : images A->B et B->A a chaque sauvegarde
- `cyclegan/losses/loss_history.json` : historique complet des pertes

**Prochaine etape :** Entrainer le modele de diffusion (DDPM) sur les images
de secheresse pour generer de nouveaux echantillons synthetiques.