## CycleGAN

### Imports

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, losses, callbacks
from tensorflow.keras.models import save_model, load_model, Model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
import matplotlib.pyplot as plt
import os
from tensorflow.keras.utils import plot_model, Sequence
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import random as r
from tqdm import tqdm # progress bar
from IPython.display import clear_output
import seaborn as sns
import math
from tqdm import trange
from collections import defaultdict
import json
import pandas as pd      # si tu l’as déjà dans l’environnement

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"GPUs detected: {len(gpus)}")
        print(f"GPUs: {gpus}")
    except RuntimeError as e:
        print(e)
else:
    print("No GPUs detected")

tf.config.optimizer.set_jit(True)  # Active JIT (XLA) globalement

# Set the seed for reproducibility
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)
r.seed(seed)



### Constants

In [None]:
class LinearScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_lr, min_lr, total_iteration):
        self.initial_lr = initial_lr
        self.min_lr = min_lr
        self.iteration = 0
        self.total_iteration = total_iteration

    def __call__(self, step):
        lr = self.min_lr + (self.initial_lr - self.min_lr) * ( 1 - ( self.iteration / self.total_iteration ) )
        self.iteration += 1
        return lr


# Constants
IMAGE_SHAPE = (256, 256, 3)
BATCH_SIZE = 1
TOTAL_ITERATIONS = 200000
initial_lr = 2e-4
min_lr = 5e-6
COEF_AV, COEF_CYC, COEF_ID = 1, 10, 5 # Coefficients de la loss du générateur
RATIO_GEN, RATIO_DISC = 3 , 1 # Nombre de fois que le générateur est mis à jour par rapport au discriminateur
TRAINING_CYCLE = ['generator']*RATIO_GEN + ['discriminator']*RATIO_DISC

gen_lr_schedule = LinearScheduler(initial_lr, min_lr, TOTAL_ITERATIONS)
disc_lr_schedule = LinearScheduler(initial_lr, min_lr, TOTAL_ITERATIONS)

GENERATOR_OPTIMIZER = tf.keras.optimizers.Adam(learning_rate=gen_lr_schedule, beta_1=0.5)
DISCRIMINATOR_OPTIMIZER = tf.keras.optimizers.Adam(learning_rate=disc_lr_schedule, beta_1=0.5)
AUTOTUNE = tf.data.AUTOTUNE
PATIENCE = 2

def smooth(series, span=100):
    """EMA lissée (span≈ longueur de fenêtre)."""
    return pd.Series(series).ewm(span=span, adjust=False).mean().values

def plot_history(history, span=100):
    plt.figure(figsize=(20, 5))
    for key, value in history.items():
        value = smooth(value, span=span)
        plt.plot(value, label=key)
    plt.title('Training History')
    plt.xlabel('Iteration')
    plt.ylabel('Value')
    plt.legend()
    plt.show()

def plot_inference(image=None):
    result = inference(for_plot=True, image=image)
    fig, axes = plt.subplots(1, len(result), figsize=(20, 5))
    for i, res in enumerate(result):
        step, img = res
        img = (img + 1) / 2 # Convertion de -1, 1 à 0, 1
        axes[i].imshow(np.clip(img, 0, 1))
        axes[i].set_title(f'{int(step)}')
        axes[i].axis('off')
    plt.show()

def decode_jpeg(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)  # [0,1]
    return img

def augment(img):
    img = tf.image.resize(img, [286, 286])
    img = tf.image.random_crop(img, [256, 256, 3])
    img = tf.image.random_flip_left_right(img)
    return img


### Dataset

In [None]:
# Summer <-> Winter
def build_dataset_wintersummer(ensemble, cache_to_ram=True):
    if ensemble == 'train':
        summer_dir = 'datasets/summer_winter/train_summer'
        winter_dir = 'datasets/summer_winter/train_winter'
    elif ensemble == 'test':
        summer_dir = 'datasets/summer_winter/test_summer'
        winter_dir = 'datasets/summer_winter/test_winter'
    else:
        raise ValueError("set must be either 'train' or 'test'")
    summer_files = os.listdir(summer_dir)
    winter_files = os.listdir(winter_dir)

    summer_files = [os.path.join(summer_dir, file) for file in summer_files]
    winter_files = [os.path.join(winter_dir, file) for file in winter_files]
    # Summer
    summer_ds = tf.data.Dataset.from_tensor_slices(summer_files)
    summer_ds = summer_ds.shuffle(len(summer_files))                      # mélange global
    summer_ds = summer_ds.map(decode_jpeg, num_parallel_calls=AUTOTUNE)
    summer_ds = summer_ds.map(augment,       num_parallel_calls=AUTOTUNE)

    # Winter
    winter_ds = tf.data.Dataset.from_tensor_slices(winter_files)
    winter_ds = winter_ds.shuffle(len(winter_files))                      # mélange global
    winter_ds = winter_ds.map(decode_jpeg, num_parallel_calls=AUTOTUNE)
    winter_ds = winter_ds.map(augment,       num_parallel_calls=AUTOTUNE)

    if cache_to_ram:
        summer_ds = summer_ds.cache()
        winter_ds = winter_ds.cache()
    else:
        summer_ds = summer_ds.cache("summer_winter.cache")             # ou sur disque
        winter_ds = winter_ds.cache("winter_summer.cache")             # ou sur disque
    
    summer_ds = summer_ds.repeat()                                     # boucle infinie
    summer_ds = summer_ds.batch(BATCH_SIZE, drop_remainder=True)
    summer_ds = summer_ds.prefetch(AUTOTUNE)
    
    winter_ds = winter_ds.repeat()                                     # boucle infinie
    winter_ds = winter_ds.batch(BATCH_SIZE, drop_remainder=True)
    winter_ds = winter_ds.prefetch(AUTOTUNE)

    ds = tf.data.Dataset.zip((summer_ds, winter_ds))
    return iter(ds)

TrainGen = build_dataset_wintersummer('train', cache_to_ram=True)
TestGen = build_dataset_wintersummer('test', cache_to_ram=True)

# Visualisation des données
def plot_data_gen(gen, n=5):
    fig, axes = plt.subplots(2, n, figsize=(20, 5))
    for i in range(n):
        image_A, image_B = next(gen)
        axes[0][i].imshow(image_A[0])
        axes[0][i].axis('off')
        axes[0, i].set_title("A")
        axes[1][i].imshow(image_B[0])
        axes[1][i].axis('off')
        axes[1, i].set_title("B")
    plt.tight_layout()
    plt.show()

plot_data_gen(TrainGen, n=5)

    

### Model

In [None]:
def conv_block(x, filters, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True, dropout=0.0):
    x = layers.Conv2D(filters, kernel_size=kernel_size, strides=strides, padding=padding)(x)
    if norm:
        x = layers.BatchNormalization()(x) # InstanceNormalization n'est pas dispo avec tf 2.10
    if activation:
        x = layers.Activation(activation)(x)
    if dropout > 0.0:
        x = layers.Dropout(dropout)(x)
    return x

def residual_block(x, nb, filters, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True, dropout=0.0):
    # ResNet block
    skip = x
    for i in range(nb):
        x = conv_block(x, filters, kernel_size, strides, padding, activation, norm, dropout)
    # Add skip connection
    x = layers.Add()([x, skip])
    if norm:
        x = layers.BatchNormalization()(x) # InstanceNormalization n'est pas dispo avec tf 2.10
    if activation:
        x = layers.Activation(activation)(x)
    if dropout > 0.0:
        x = layers.Dropout(dropout)(x)
    return x


def build_GeneratorModel(name="Generator"):
    image_input = layers.Input(shape=(IMAGE_SHAPE))

    # DOWN
    d1 = conv_block(image_input, 64, kernel_size=3, strides=2, padding='same', activation='leaky_relu', norm=True)
    d1 = residual_block(d1, 2, 64, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)
    d1 = residual_block(d1, 2, 64, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)
    
    d2 = conv_block(d1, 128, kernel_size=3, strides=2, padding='same', activation='leaky_relu', norm=True)
    d2 = residual_block(d2, 2, 128, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)
    d2 = residual_block(d2, 2, 128, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)
    
    # BOTTLENECK
    d3 = conv_block(d2, 256, kernel_size=3, strides=2, padding='same', activation='leaky_relu', norm=True)
    for i in range(9):
        d3 = residual_block(d3, 2, 256, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)
    
    # UP
    u1 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3)
    u1 = layers.Concatenate()([u1, d2])
    u1 = conv_block(u1, 128, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)
    u1 = residual_block(u1, 2, 128, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)
    u1 = residual_block(u1, 2, 128, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)

    u2 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(u1)
    u2 = layers.Concatenate()([u2, d1])
    u2 = conv_block(u2, 64, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)
    u2 = residual_block(u2, 2, 64, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)
    u2 = residual_block(u2, 2, 64, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)

    u0 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(u2)
    u0 = layers.Concatenate()([u0, image_input])
    u0 = conv_block(u0, 64, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)
    u0 = residual_block(u0, 2, 64, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)
    u0 = residual_block(u0, 2, 64, kernel_size=3, strides=1, padding='same', activation='leaky_relu', norm=True)

    output = conv_block(u0, 3, kernel_size=3, strides=1, padding='same', activation='sigmoid', norm=False)
    model = Model(inputs=image_input, outputs=output, name=name)
    return model

def build_DiscriminatorModel(name="Discriminator"):
    """
    Discriminateur de type PatchGAN 70x70
    """
    image_input = layers.Input(shape=IMAGE_SHAPE)
    x = conv_block(image_input, 64, kernel_size=7, strides=2)

    x = conv_block(x, 128, kernel_size=4, strides=2)

    x = conv_block(x, 256, kernel_size=4, strides=2)

    x = conv_block(x, 512, kernel_size=4)

    x = conv_block(x, 1, kernel_size=4, activation='sigmoid', norm=False) # Pas de BatchNorm à la fin sinon on tend vers 0.5 (loss_d à 1.38)
    
    return Model(inputs=image_input, outputs=x, name=name)

GeneratorModel_A, DiscriminatorModel_A = build_GeneratorModel('GeneratorB2A'), build_DiscriminatorModel('DiscriminatorA')
GeneratorModel_B, DiscriminatorModel_B = build_GeneratorModel('GeneratorA2B'), build_DiscriminatorModel('DiscriminatorB')
print(GeneratorModel_A.summary())
print(DiscriminatorModel_A.summary())
print(GeneratorModel_B.summary())
print(DiscriminatorModel_B.summary())

### Training

In [None]:
loss_bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)  # car output= sigmoid

@tf.function(jit_compile=True)
def train_step(images, Generators, Discriminators, turn):
    gen_a, gen_b = Generators
    disc_a, disc_b = Discriminators
    images_a, images_b = images
    total_g_loss, total_d_loss = None, None
    loss_data = {}
    # === Générateur ===
    if 'generator' in turn :
        with tf.GradientTape() as gen_tape:
            # Transition dans les deux sens
            fake_images_a = gen_a(images_b, training=True)
            fake_images_b = gen_b(images_a, training=True)
            recov_images_a = gen_a(fake_images_b, training=True)
            recov_images_b = gen_b(fake_images_a, training=True)
            # Discriminateur
            disc_output_a = disc_a(fake_images_a, training=True)
            disc_output_b = disc_b(fake_images_b, training=True)
            # Identity
            id_a = gen_a(images_a, training=True)
            id_b = gen_b(images_b, training=True)

            # 1. Cycle consistency loss
            cc_loss = loss_bce(images_a, recov_images_a) + loss_bce(images_b, recov_images_b)
            # 2. Discriminateur loss
            real_labels_a = tf.ones_like(disc_output_a)
            real_labels_b = tf.ones_like(disc_output_b)
            disc_loss = loss_bce(real_labels_a, disc_output_a) + loss_bce(real_labels_b, disc_output_b) # On veux tromper le discriminateur
            # 3. Identity loss
            id_loss = loss_bce(images_a, id_a) + loss_bce(images_b, id_b)


            total_g_loss = COEF_CYC * cc_loss + COEF_AV * disc_loss + COEF_ID * id_loss
            loss_data.update({
                "cc_loss": COEF_CYC * cc_loss,
                "disc_loss": COEF_AV * disc_loss,
                "id_loss": COEF_ID * id_loss,
            })

            # Backward
            vars_g = gen_a.trainable_variables + gen_b.trainable_variables
            grads_g = gen_tape.gradient(total_g_loss, vars_g)
            GENERATOR_OPTIMIZER.apply_gradients([(g, v) for g, v in zip(grads_g, vars_g) if g is not None])
    
    # === Discriminateur ===
    if 'discriminator' in turn :
        fake_images_a = gen_a(images_b, training=True) # Calcul hors tape pour éviter de calculer le gradient
        fake_images_b = gen_b(images_a, training=True) # Calcul hors tape pour éviter de calculer le gradient
        with tf.GradientTape() as disc_tape:
            # Forward dans les discriminateurs
            real_output_a = disc_a(images_a, training=True)
            fake_output_a = disc_a(fake_images_a, training=True)

            real_output_b = disc_b(images_b, training=True)
            fake_output_b = disc_b(fake_images_b, training=True)
            
            # Discriminateur loss
            real_labels_a = tf.ones_like(real_output_a)
            fake_labels_a = tf.zeros_like(fake_output_a)
            real_labels_b = tf.ones_like(real_output_b)
            fake_labels_b = tf.zeros_like(fake_output_b)
            disc_loss_a_real = loss_bce(real_labels_a, real_output_a) 
            disc_loss_a_fake = loss_bce(fake_labels_a, fake_output_a)
            disc_loss_b_real = loss_bce(real_labels_b, real_output_b)
            disc_loss_b_fake = loss_bce(fake_labels_b, fake_output_b)
            
            total_d_loss =  disc_loss_a_real + disc_loss_a_fake + disc_loss_b_real + disc_loss_b_fake
            loss_data.update({
                "disc_loss_a_real": disc_loss_a_real,
                "disc_loss_a_fake": disc_loss_a_fake,
                "disc_loss_b_real": disc_loss_b_real,
                "disc_loss_b_fake": disc_loss_b_fake,
            })

            # Backward
            vars_d = disc_a.trainable_variables + disc_b.trainable_variables
            grads_d = disc_tape.gradient(total_d_loss, vars_d)
            DISCRIMINATOR_OPTIMIZER.apply_gradients([(g, v) for g, v in zip(grads_d, vars_d) if g is not None])
    
    # --- Debug ---
    for v in disc_a.trainable_variables + disc_b.trainable_variables:
        tf.debugging.assert_all_finite(v, "NaN/Inf in D weights")

    return total_g_loss, total_d_loss, loss_data

best_val_loss = float("inf")
wait = 0

# --- Entraînement ---
progress_bar = trange(TOTAL_ITERATIONS, desc="Training", leave=True)
history = defaultdict(list)
turn = ['generator', 'discriminator']
for index in progress_bar:
    # --- Chargement des images ---
    images = next(TrainGen)
    # --- Entraînement ---
    Generators = (GeneratorModel_A, GeneratorModel_B)
    Discriminators = (DiscriminatorModel_A, DiscriminatorModel_B)
    g_loss, d_loss, loss_data = train_step(images, Generators, Discriminators, turn)
    # --- Affichage ---
    current_lr_gen = GENERATOR_OPTIMIZER._decayed_lr(tf.float32).numpy()
    current_lr_disc = DISCRIMINATOR_OPTIMIZER._decayed_lr(tf.float32).numpy()
    postfix = {"g_loss": f"{g_loss:.4e}", "d_loss": f"{d_loss:.4e}", "lr_gen": f"{current_lr_gen:.4e}", "lr_disc": f"{current_lr_disc:.4e}"}
    loss_data =  {key: f"{value:.4e}" for key, value in loss_data.items()}
    postfix.update(loss_data)
    progress_bar.set_postfix(postfix)
    # --- Sauvegarde des pertes ---
    for key, value in loss_data.items() :
        history[key].append(float(value))
    # --- Changement de tour ---
    turn = TRAINING_CYCLE[index % len(TRAINING_CYCLE)]

    # --- Test ---
    if index % 1000 == 0:
        images_a, images_b = images
        fake_images_a = GeneratorModel_A(images_b, training=False)
        fake_images_b = GeneratorModel_B(images_a, training=False)
        recov_images_a = GeneratorModel_A(fake_images_b, training=False)
        recov_images_b = GeneratorModel_B(fake_images_a, training=False)
        fig, axes = plt.subplots(2,3, figsize=(18, 5))
        axes[0][0].imshow(images_a[0])
        axes[0][0].axis('off')
        axes[0][0].set_title("A original")
        axes[0][1].imshow(fake_images_b[0])
        axes[0][1].axis('off')
        axes[0][1].set_title("fake B")
        axes[1][0].imshow(images_b[0])
        axes[1][0].axis('off')
        axes[1][0].set_title("B original")
        axes[1][1].imshow(fake_images_a[0])
        axes[1][1].axis('off')
        axes[1][1].set_title("fake A")
        axes[0][2].imshow(recov_images_a[0])
        axes[0][2].axis('off')
        axes[0][2].set_title("recov A")
        axes[1][2].imshow(recov_images_b[0])
        axes[1][2].axis('off')
        axes[1][2].set_title("recov B")
        plt.tight_layout()
        plt.show()

    # --- Sauvegarde du modèle ---
    if index % 10000 == 0 and index != 0:
        save_model(GeneratorModel_A, f"models/cyclegan/GeneratorAModel.h5")
        save_model(GeneratorModel_B, f"models/cyclegan/GeneratorBModel.h5")
        save_model(DiscriminatorModel_A, f"models/cyclegan/DiscriminatorAModel.h5")
        save_model(DiscriminatorModel_B, f"models/cyclegan/DiscriminatorBModel.h5")
        # Save history
        history_path = f"models/cyclegan/history.json"
        with open(history_path, 'w') as f:
            json.dump(history, f, indent=4)
        print(f"Models saved at iteration {index}")
        # Plot history
        plot_history(history, span=100)


In [None]:
# History
plot_history(history)

### Inférence

In [None]:
images = next(iter(TrainGen))
images_a, images_b = images
fake_images_a = GeneratorModel_A(images_b, training=False)
fake_images_b = GeneratorModel_B(images_a, training=False)
recov_images_a = GeneratorModel_A(fake_images_b, training=False)
recov_images_b = GeneratorModel_B(fake_images_a, training=False)
fig, axes = plt.subplots(2,3, figsize=(18, 5))
axes[0][0].imshow(images_a[0])
axes[0][0].axis('off')
axes[0][0].set_title("A original")
axes[0][1].imshow(fake_images_b[0])
axes[0][1].axis('off')
axes[0][1].set_title("fake B")
axes[1][0].imshow(images_b[0])
axes[1][0].axis('off')
axes[1][0].set_title("B original")
axes[1][1].imshow(fake_images_a[0])
axes[1][1].axis('off')
axes[1][1].set_title("fake A")
axes[0][2].imshow(recov_images_a[0])
axes[0][2].axis('off')
axes[0][2].set_title("recov A")
axes[1][2].imshow(recov_images_b[0])
axes[1][2].axis('off')
axes[1][2].set_title("recov B")
plt.tight_layout()
plt.show()