# Diffusion DDPM
Datasets : MSCOCO 2014, MNIST et CelebA

## Imports

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.models import save_model, load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
from tensorflow.keras.utils import plot_model, Sequence
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import random as r
from tqdm import tqdm # progress bar
import time
import datetime
from IPython.display import clear_output
from pycocotools.coco import COCO
from PIL import Image
import seaborn as sns

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"GPUs detected: {len(gpus)}")
        print(f"GPUs: {gpus}")
    except RuntimeError as e:
        print(e)
else:
    print("No GPUs detected")

### Constants

In [2]:
# Paths
ANNOTDIR = 'datasets/annotations_trainval2014'
DATADIR = 'datasets/train2014'
INSTANCEFILE = '{}/annotations/instances_{}.json'.format(ANNOTDIR, DATADIR)

# Choice of dataset
CHOSEN_DATASET = input('mnist, coco or face ? (default: mnist) ') 
if CHOSEN_DATASET == 'coco':
    # On utilise le dataset MS COCO 2014
    INPUT_SHAPE = (224, 224, 3)
    COCO_INSTANCES = COCO(INSTANCEFILE)
    IDS = COCO_INSTANCES.getImgIds()
    r.shuffle(IDS)
    print(f'Number of images in the dataset: {len(IDS)}')
    BATCH_SIZE = 16
    OPTIMIZER = tf.keras.optimizers.Adam(learning_rate=5e-5)
    PATIENCE = 2
elif CHOSEN_DATASET == 'face':
    # On utilise le dataset de visages
    INPUT_SHAPE = (64, 64, 3)
    BATCH_SIZE = 32*4
    OPTIMIZER = tf.keras.optimizers.Adam(learning_rate=1e-4)
    IDS = [ f'datasets/Humans_Face/{element}' for element in os.listdir('datasets/Humans_Face') ]
    IDS += [ f'datasets/celeba/{element}' for element in os.listdir('datasets/celeba') ]
    r.shuffle(IDS)
    PATIENCE = 2
else:
    # On utilise le dataset MNIST
    INPUT_SHAPE = (32, 32, 1) # On passe de 28x28 à 32x32 pour avoir une taille multiple de 2
    BATCH_SIZE = 32
    OPTIMIZER = tf.keras.optimizers.Adam(learning_rate=1e-4)
    IDS = []
    PATIENCE = 2

STEPS = 1000 # Nombre d'étapes de diffusion
BETA_1 = 1e-4
BETA_T = 1e-2
#PATH_DATASETS = ['stable-diffusion-face-dataset/512/man','stable-diffusion-face-dataset/512/woman']
TRAIN_SPLIT = 0.8
VALIDATION_SPLIT = 0.15
TEST_SPLIT = 0.05
EPOCHS = 1000 # L'early stopping est utilisé

# Configuration de l'augmentation
DATA_GEN = ImageDataGenerator(
    horizontal_flip=True,
)
BASE_GEN = ImageDataGenerator()

class LinearNoiceScheduler:
    def __init__(self):
        self.betas = np.linspace(BETA_1, BETA_T, STEPS) # on a donc b1, b2, ..., bT (T=STEPS), /!\ self.betas est en 0-index alors que les étapes sont en 1-index
        self.sqrt_betas = np.sqrt(self.betas)
        self.alphas = 1 - self.betas
        self.sqrt_alphas = np.sqrt(self.alphas)
        self.c_alphas = np.cumprod(self.alphas)
        self.sqrt_c_alphas = np.sqrt(self.c_alphas)
        self.sqrt_one_minus_c_alphas = np.sqrt(1 - self.c_alphas)
    
    def add_noise(self, original_images, steps, return_noise=False):
        c_step = steps - 1 # steps est en 1-indexed et c_step en 0-indexed
        original_shape = original_images.shape
        noise = np.random.normal(size=original_shape)

        sqrt_alpha_cumprod = self.sqrt_c_alphas[c_step]
        sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_c_alphas[c_step]
        sqrt_alpha_cumprod = sqrt_alpha_cumprod[ :, np.newaxis, np.newaxis, np.newaxis]
        sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[ :, np.newaxis, np.newaxis, np.newaxis]

        noisy_image = sqrt_alpha_cumprod*original_images + sqrt_one_minus_alpha_cumprod*noise

        #noisy_image = np.clip(noisy_image, -1, 1) # TEST : on clip les valeurs pour être sûr qu'elles soient dans [-1, 1]
        if return_noise:
            return noisy_image, noise
        return noisy_image

    def remove_noise(self, noisy_image, noise_pred, step):
        """
        step : l'étape actuel de la diffusion, (l'image bruité est donc à l'étape step)

        updates_image : l'image à l'étape step-1
        """
        c_step = step - 1 # steps est en 1-indexed et c_step en 0-indexed
        # Récupération des paramètres pour l'étape t
        alpha_t = self.alphas[c_step]
        sqrt_alpha_t = np.sqrt(alpha_t)
        sqrt_one_minus_cum_alpha_t = self.sqrt_one_minus_c_alphas[c_step]

        alpha_factor = ( 1 - alpha_t ) / sqrt_one_minus_cum_alpha_t
        
        # Calcul de l'image mise à jour x_{t-1}
        updated_image = (1 / sqrt_alpha_t) * (noisy_image - (alpha_factor * noise_pred))
        
        # Ajout du terme de stabilité pour les étapes non finales
        if step > 1:
            noise = np.random.normal(size=noisy_image.shape)
        else:
            noise = np.zeros(noisy_image.shape)
        
        sigma_t = self.sqrt_betas[c_step]
        updated_image += sigma_t * noise
        
        return updated_image

def inference(for_plot=False, plot_step=50, image=None, num_steps=STEPS):
    res = []
    batch_image = np.expand_dims(np.random.normal(size=INPUT_SHAPE), axis=0) if image is None else np.expand_dims(image, axis=0)
    with tqdm(total=num_steps, desc='Inference', unit='step') as pbar:
        for step in range(num_steps, 0, -1):
            if step % plot_step == 0:
                res.append((step, batch_image[0]))
            numpy_step = np.array([step])
            batch_step = np.array([[step]]) / STEPS # Normalisation de l'étape
            noise_pred = model.predict((batch_image, batch_step), verbose=0)
            batch_image = NoiceScheduler.remove_noise(batch_image, noise_pred, numpy_step)
            pbar.update()
    res.append((step, batch_image[0]))
    if for_plot:
        return res
    return batch_image[0]

def plot_inference_gif():
    batch_image = np.expand_dims(np.random.normal(size=INPUT_SHAPE), axis=0)
    for step in range(STEPS, 0, -1):
        clear_output(wait=True)
        batch_step = np.array([[step]]) / STEPS # Normalisation de l'étape
        noise_pred = model.predict((batch_image, batch_step), verbose=0)
        batch_image = NoiceScheduler.remove_noise(batch_image, noise_pred, np.array([step]))
        plot_image = (batch_image[0] + 1) / 2
        plt.imshow(np.clip(plot_image, 0, 1))
        plt.title(f'{step}')
        plt.axis('off')
        plt.show()

def plot_inference(image=None):
    result = inference(for_plot=True, image=image)
    fig, axes = plt.subplots(1, len(result), figsize=(20, 5))
    for i, res in enumerate(result):
        step, img = res
        img = (img + 1) / 2 # Convertion de -1, 1 à 0, 1
        axes[i].imshow(np.clip(img, 0, 1))
        axes[i].set_title(f'{int(step)}')
        axes[i].axis('off')
    plt.show()

def array_stats(array, plot=False):
    print(f'Shape: {array.shape} \n \
        Mean: {np.mean(array)} \n \
        Min: {np.min(array)} \n \
        Max: {np.max(array)} \n \
        Std: {np.std(array)}')
    if plot:
        sns.histplot(array.flatten())
        plt.show()

# Loss function
@tf.function
def loss_fn(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))

NoiceScheduler = LinearNoiceScheduler()

## Data préprosessing

### Loading of dataset

In [None]:
# COCO Dataset
class DatasetGeneratorCOCO(Sequence):
    def _getsplit(self, ensemble):
        if ensemble == 'train':
            start = 0
            stop = int(TRAIN_SPLIT * len(IDS))
        elif ensemble == 'val':
            start = int(TRAIN_SPLIT * len(IDS))
            stop = int((TRAIN_SPLIT + VALIDATION_SPLIT) * len(IDS))
        elif ensemble == 'test':
            start = int((TRAIN_SPLIT + VALIDATION_SPLIT) * len(IDS))
            stop = len(IDS)
        return start, stop

    def __init__(self, ensemble, **kwargs):
        super().__init__(**kwargs)
        self.ensemble = ensemble
        
        # Créer une liste de tous les IDs d'images
        start, stop = self._getsplit(ensemble)
        self.ids = IDS[start:stop]

    def __len__(self):
        return int(np.ceil(len(self.ids) / BATCH_SIZE))

    def __getitem__(self, index):
        batch_ids = self.ids[index * BATCH_SIZE : (index + 1) * BATCH_SIZE]
        batch_noisy_images = []
        batch_steps = []
        batch_noises = []
        for id in batch_ids:
            # Charger l'image
            file_name = COCO_INSTANCES.imgs[id]['file_name']
            image = Image.open(f'{DATADIR}/{file_name}')
            image = image.resize((224, 224))
            image = image.convert('RGB')
            image = img_to_array(image)
            image = (image / 255.0) * 2 - 1 # Normalisation entre -1 et 1
            if self.ensemble == 'train':
                image = DATA_GEN.random_transform(image)
            else :
                image = BASE_GEN.random_transform(image)
            # Créer le label en rajoutant du bruit
            step = np.array([np.random.randint(1, STEPS)])
            noisy_image, noise_used = NoiceScheduler.add_noise(np.expand_dims(image, axis=0), step, return_noise=True)
            batch_noisy_images.append(noisy_image[0])
            batch_steps.append(step)
            batch_noises.append(noise_used[0])

        batch_noisy_images = np.array(batch_noisy_images, dtype='float64')
        batch_steps = np.array(batch_steps, dtype='float64') / STEPS
        batch_noises = np.array(batch_noises, dtype='float64')

        return (batch_noisy_images, batch_steps, batch_noises)

    def on_epoch_end(self):
        if self.ensemble == 'train':
            np.random.shuffle(self.ids)

# Face Dataset
class DatasetGeneratorFace(Sequence):
    def _getsplit(self, ensemble):
        if ensemble == 'train':
            start = 0
            stop = int(TRAIN_SPLIT * len(IDS))
        elif ensemble == 'val':
            start = int(TRAIN_SPLIT * len(IDS))
            stop = int((TRAIN_SPLIT + VALIDATION_SPLIT) * len(IDS))
        elif ensemble == 'test':
            start = int((TRAIN_SPLIT + VALIDATION_SPLIT) * len(IDS))
            stop = len(IDS)
        return start, stop
    
    def __init__(self, ensemble, **kwargs):
        super().__init__(**kwargs)
        self.ensemble = ensemble
        
        # Créer une liste de tous les IDs d'images
        start, stop = self._getsplit(ensemble)
        self.ids = IDS[start:stop]
    
    def __len__(self):
        return int(np.ceil(len(self.ids) / BATCH_SIZE))
    
    def __getitem__(self, index):
        batch_ids = self.ids[index * BATCH_SIZE : (index + 1) * BATCH_SIZE]
        batch_noisy_images = []
        batch_steps = []
        batch_noises = []
        for id in batch_ids:
            # Charger l'image
            image = load_img(id, target_size=(INPUT_SHAPE[0], INPUT_SHAPE[1]), color_mode='rgb')
            image = img_to_array(image)
            image = (image / 255.0) * 2 - 1 # Normalisation entre -1 et 1
            if self.ensemble == 'train':
                image = DATA_GEN.random_transform(image)
            else :
                image = BASE_GEN.random_transform(image)
            # Créer le label en rajoutant du bruit
            step = np.array([np.random.randint(1, STEPS)])
            noisy_image, noise_used = NoiceScheduler.add_noise(np.expand_dims(image, axis=0), step, return_noise=True)
            batch_noisy_images.append(noisy_image[0])
            batch_steps.append(step)
            batch_noises.append(noise_used[0])
        
        batch_noisy_images = np.array(batch_noisy_images, dtype='float64')
        batch_steps = np.array(batch_steps, dtype='float64') / STEPS
        batch_noises = np.array(batch_noises, dtype='float64')

        return (batch_noisy_images, batch_steps, batch_noises)

    def on_epoch_end(self):
        if self.ensemble == 'train': # Sert à rien de shuffle les données de validation et de test
            np.random.shuffle(self.ids)
    
    def get_random_images(self, n):
        images = []
        for _ in range(n):
            id = r.choice(self.ids)
            image = load_img(f'Humans_Face/{id}', target_size=(INPUT_SHAPE[0], INPUT_SHAPE[1]))
            image = img_to_array(image)
            image = (image / 255.0) * 2 - 1 # Normalisation entre -1 et 1
            images.append(image)
        return images
    
# MNIST Dataset
class DatasetGeneratorMNIST(Sequence):
    def _getsplit(self, ensemble):
        if ensemble == 'train':
            start = 0
            stop = int(TRAIN_SPLIT * len(self.ids))
        elif ensemble == 'val':
            start = int(TRAIN_SPLIT * len(self.ids))
            stop = int((TRAIN_SPLIT + VALIDATION_SPLIT) * len(self.ids))
        elif ensemble == 'test':
            start = int((TRAIN_SPLIT + VALIDATION_SPLIT) * len(self.ids))
            stop = len(self.ids)
        return start, stop

    def __init__(self, ensemble, **kwargs):
        super().__init__(**kwargs)
        self.ensemble = ensemble
        
        # Créer une liste de tous les IDs d'images
        self.dict_training = { i : [] for i in range(10) }
        for i in range(10):
            # On récupère les images de chaque classe
            self.dict_training[i] = [ f'MNIST/training/{i}/' + element for element in os.listdir(f'MNIST/training/{i}') ]
        self.dict_validation = { i : [] for i in range(10) }
        for i in range(10):
            # On récupère les images de chaque classe
            self.dict_validation[i] = [ f'MNIST/validation/{i}/' + element for element in os.listdir(f'MNIST/validation/{i}') ]
        # On fusionne les deux dictionnaires
        self.dict = { i : [] for i in range(10) }
        for i in range(10):
            self.dict[i] = self.dict_training[i] + self.dict_validation[i]
        self.ids = []
        if IDS == []:
            for i in range(10):
                self.ids += self.dict[i]
            # Shuffle
            r.shuffle(self.ids)
        else:
            self.ids = IDS
        start, stop = self._getsplit(ensemble)
        self.ids = self.ids[start:stop]

    def __len__(self):
        return int(np.ceil(len(self.ids) / BATCH_SIZE))

    def __getitem__(self, index):
        batch_ids = self.ids[index * BATCH_SIZE : (index + 1) * BATCH_SIZE]
        batch_noise = []
        batch_noisy_image = []
        batch_steps = []
        for id in batch_ids:
            # Charger l'image
            image = load_img(id, color_mode='grayscale', target_size=(INPUT_SHAPE[0], INPUT_SHAPE[1]))
            image = img_to_array(image)
            image = ((image / 255.0) * 2) - 1 # Normalisation entre -1 et 1
            # Créer le label en rajoutant du bruit
            step = np.array([np.random.randint(1, STEPS)])
            batch_steps.append(step)
            noisy_image, noise_used = NoiceScheduler.add_noise(np.expand_dims(image, axis=0), step, return_noise=True)
            batch_noise.append(noise_used[0])
            batch_noisy_image.append(noisy_image[0])

        batch_noise = np.array(batch_noise, dtype='float64')
        batch_noisy_image = np.array(batch_noisy_image, dtype='float64')
        batch_steps = np.array(batch_steps, dtype='float64') / STEPS # Normalisation entre 0 et 1

        return (batch_noisy_image, batch_steps, batch_noise)

    def on_epoch_end(self):
        if self.ensemble == 'train':
            np.random.shuffle(self.ids)
    
    def get_random_images(self, n):
        images = []
        for _ in range(n):
            id = r.choice(self.ids)
            image = load_img(id, color_mode='grayscale', target_size=(INPUT_SHAPE[0], INPUT_SHAPE[1]))
            image = img_to_array(image)
            image = ((image / 255.0) * 2) - 1 # Normalisation entre -1 et 1
            images.append(image)
        return images

if CHOSEN_DATASET == 'coco':
    train_generator = DatasetGeneratorCOCO('train')
    val_generator = DatasetGeneratorCOCO('val')
    test_generator = DatasetGeneratorCOCO('test')
elif CHOSEN_DATASET == 'face':
    train_generator = DatasetGeneratorFace('train')
    val_generator = DatasetGeneratorFace('val')
    test_generator = DatasetGeneratorFace('test')
else:
    train_generator = DatasetGeneratorMNIST('train')
    val_generator = DatasetGeneratorMNIST('val')
    test_generator = DatasetGeneratorMNIST('test')

print(f'Taille du dataset d\'entrainement: {len(train_generator)} batches, {len(train_generator.ids)} items')
print(f'Taille du dataset de validation: {len(val_generator)} batches, {len(val_generator.ids)} items')
print(f'Taille du dataset de test: {len(test_generator)} batches, {len(test_generator.ids)} items')

### Exemple

In [None]:
# Sélectionner un batch aléatoire
batch_index = np.random.randint(0, len(train_generator))
noisy_images, steps, noises = train_generator.__getitem__(batch_index)
print(f'Noisy_images shape: {noisy_images.shape}')
print(f'Noises shape: {noises.shape}')

# Sélectionner une image aléatoire dans le batch
image_index = np.random.randint(0, len(noisy_images))
noisy_image = noisy_images[image_index] # On prend le premier élément de l'axe des batchs
noisy_image = (noisy_image + 1) / 2 # On remet les valeurs de pixels entre 0 et 1
noise = noises[image_index] # On prend le premier élément de l'axe des batchs
print(f'Noisy_image shape: {noisy_images.shape}')
print(f'Noise shape: {noises.shape}')

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(np.clip(noisy_image, 0, 1))
axes[0].set_title('Image')
axes[0].axis('off')

axes[1].imshow(np.clip(noise, 0, 1))
axes[1].set_title(f'Noise')
axes[1].axis('off')

plt.show()

## Model

### Gaussian noice tests

#### Curves

In [None]:

# Sélectionne un batch aléatoire
r_batch = r.randint(0, len(train_generator))
batch = train_generator[r_batch]
selected_images, selected_steps, selected_labels = batch
noise = np.random.normal(size=selected_images.shape[1:])

# Initialiser les listes pour stocker les moyennes et les écarts types
means = []
stds = []

step_jump = 10

# Calculer la moyenne et l'écart type à chaque étape
for step in range(0, STEPS, step_jump):
    step = np.array([step])
    noisy_images = NoiceScheduler.add_noise(selected_images, step)
    means.append(np.mean(noisy_images))
    stds.append(np.std(noisy_images))

# Plot des résultats
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(range(0, STEPS, step_jump), means, label='Mean')
plt.xlabel('Step')
plt.ylabel('Mean')
plt.title('Mean of Noisy Images')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(0, STEPS, step_jump), stds, label='Standard Deviation')
plt.xlabel('Step')
plt.ylabel('Standard Deviation')
plt.title('Standard Deviation of Noisy Images')
plt.legend()
plt.tight_layout()

# On prend quelques images aléatoires
idx = r.randint(1, len(selected_images)-1)
test_images = selected_images[:idx]

# Show noise
plt.figure(figsize=(6, 6))
noise_img = np.clip(noise, 0, 1)
plt.imshow(noise_img)
plt.title('Noise')
plt.axis('off')
# Show noise distribution
plt.figure(figsize=(6, 6))
plt.hist(noise.flatten(), bins=150)
plt.title('Noise Distribution')
plt.xlabel('Intensity')
plt.ylabel('Frequency')

plt.show()



In [None]:
# plot of cumprod_alphas
plt.figure(figsize=(6, 6))
plt.plot(NoiceScheduler.c_alphas)
plt.xlabel('Step')
plt.ylabel('Cumprod alpha')

#### Test on one sample

In [None]:

# Sélectionne une image aléatoire
r_image = r.randint(0, len(val_generator)-1)
batch = val_generator[r_image]
test_images, test_step,test_labels = batch # A PATCH
test_image = test_images[0]
jump = STEPS // 10
fig, axes = plt.subplots(1, 11, figsize=(20, 4))
for slot in range(11):
    step = slot * jump
    step = np.array([step])
    if step == 0:
        step = np.array([1])
    c_alpha = NoiceScheduler.c_alphas[step-2] # step est 1-indexed
    noisy_images = NoiceScheduler.add_noise(test_image, step)
    noisy_images = (noisy_images + 1) / 2
    noisy_image_plt = np.clip(noisy_images, 0, 1)
    axes[slot].imshow(noisy_image_plt[0])
    axes[slot].axis('off')
    axes[slot].set_title(f'{int(step)} - {float(c_alpha):.2f}')

plt.show()

### Definition du modèle

In [None]:
# Fonctions

def conv_block(input_tensor, num_filters):
    x = layers.Conv2D(num_filters, (3, 3), padding="same")(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    
    x = layers.Conv2D(num_filters, (3, 3), padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    
    return x

def resnet_block(input_tensor, num_filters, block_num):
    input_tensor = layers.Conv2D(num_filters, (1, 1), padding="same")(input_tensor)
    for _ in range(block_num):
        x = layers.Conv2D(num_filters, (3, 3), padding="same")(input_tensor)
        x = layers.BatchNormalization()(x)
        x = layers.Activation("leaky_relu")(x)
    x = layers.Add()([x, input_tensor])  
    return x

def encoder_block(input_tensor, num_filters):
    x = conv_block(input_tensor, num_filters)
    p = layers.MaxPooling2D((2, 2))(x)
    return x, p

def decoder_block(input_tensor, skip_features, num_filters):
    x = layers.Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input_tensor)
    x = layers.Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

def conv_block_time(input_tensor, num_filters, time_tensor):
    # Convolution
    x_params = layers.Conv2D(num_filters, (3, 3), padding="same")(input_tensor)
    x_params = layers.BatchNormalization()(x_params)
    x_params = layers.Activation("leaky_relu")(x_params)

    # Time 
    x_time = layers.Dense(num_filters)(time_tensor)
    x_time = layers.BatchNormalization()(x_time)
    x_time = layers.Activation("leaky_relu")(x_time)
    x_time = layers.Reshape((1, 1, num_filters))(x_time)
    x_params = x_params * x_time

    # Convolution
    x_out = layers.Conv2D(num_filters, (3, 3), padding="same")(input_tensor)
    x_out = layers.Add()([x_params, x_out])
    x_out = layers.BatchNormalization()(x_out)
    x_out = layers.Activation("leaky_relu")(x_out)
    
    return x_out

def self_attention_block(input_tensor, num_dims, num_heads=8, ff_dim=2):
    #input_tensor = layers.Dense(num_dims)(input_tensor)
    x = layers.MultiHeadAttention(num_heads=num_heads, key_dim=num_dims, value_dim=num_dims)(input_tensor, input_tensor)
    x = layers.Add()([x, input_tensor])
    x2 = layers.LayerNormalization()(x)
    x = layers.Dense(num_dims*ff_dim, activation='leaky_relu')(x2)
    x = layers.Dense(num_dims, activation='leaky_relu')(x)
    x = layers.LayerNormalization()(x)
    x = layers.Add()([x, x2])
    return x

def film_layer(input_tensor, step, num_filters):
    gamma = layers.Dense(num_filters, activation='tanh')(step)
    beta = layers.Dense(num_filters, activation='tanh')(step)

    gamma = layers.Reshape((1, 1, num_filters))(gamma)  # New shape: [batch_size, 1, 1, num_filters]
    beta = layers.Reshape((1, 1, num_filters))(beta)    # New shape: [batch_size, 1, 1, num_filters]
    return input_tensor * gamma + beta

def resnet_block_film(input_tensor, num_filters, block_num, step):
    input_tensor = layers.Conv2D(num_filters, (1, 1), padding="same")(input_tensor)
    x = input_tensor
    for _ in range(block_num):
        x = layers.Conv2D(num_filters, (3, 3), padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = film_layer(x, step, num_filters) # Comme décrit dans le papier FiLM
        x = layers.Activation("leaky_relu")(x)
    x = layers.Add()([x, input_tensor])
    return x

def attention_with_step(input_tensor, step, num_dims, num_heads=8, ff_dim=2):
    step_emb = layers.Dense(num_dims)(step)
    step_emb = layers.Reshape((int(num_dims**(1/2)), int(num_dims**(1/2),), 1))(step_emb)
    query = layers.Add()([input_tensor, step_emb])
    x = layers.MultiHeadAttention(num_heads=num_heads, key_dim=num_dims, value_dim=num_dims)(query, input_tensor)
    x = layers.Add()([x, input_tensor])
    x2 = layers.LayerNormalization()(x)
    x = layers.Dense(num_dims*ff_dim, activation='leaky_relu')(x2)
    x = layers.Dense(num_dims, activation='leaky_relu')(x)
    x = layers.LayerNormalization()(x)
    x = layers.Add()([x, x2])
    return x


'''
MODELS
'''

# COCO

def build_unetCOCO(input_shape):
    inputs = layers.Input(input_shape)

    # Encoder
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    # Bottleneck
    b = conv_block(p4, 1024)

    # Decoder
    d1 = decoder_block(b, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    d5 = layers.Concatenate()([d4, inputs])
    outputs = layers.Conv2D(3, (1, 1))(d5)  # output is a 3-channel image (RGB)

    model = tf.keras.Model(inputs, outputs, name='Diffusion_UNetCOCO')
    return model

def build_unetCOCOv2(input_shape):

    input_image = layers.Input(input_shape)

    input_step = layers.Input((1,))
    step = layers.Dense(128, activation='leaky_relu')(input_step)

    # Encoder

    s11 = conv_block_time(input_image, 64, step)
    s12 = conv_block_time(s11, 64, step)
    p1 = layers.MaxPooling2D((2, 2))(s12)

    s21 = conv_block_time(p1, 128, step)
    s22 = conv_block_time(s21, 128, step)
    p2 = layers.MaxPooling2D((2, 2))(s22)

    s31 = conv_block_time(p2, 256, step)
    s32 = conv_block_time(s31, 256, step)
    p3 = layers.MaxPooling2D((2, 2))(s32)

    # Bottleneck

    b= conv_block_time(p3, 512, step)
    b = conv_block_time(b, 512, step)
    b = layers.UpSampling2D((2, 2))(b)

    # Decoder

    d30 = layers.Concatenate()([b, s32])
    d31 = conv_block_time(d30, 256, step)
    d32 = conv_block_time(d31, 256, step)
    d3 = layers.UpSampling2D((2, 2))(d32)

    d20 = layers.Concatenate()([d3, s22])
    d21 = conv_block_time(d20, 128, step)
    d22 = conv_block_time(d21, 128, step)
    d2 = layers.UpSampling2D((2, 2))(d22)

    d10 = layers.Concatenate()([d2, s12])
    d11 = conv_block_time(d10, 64, step)
    d1 = conv_block_time(d11, 64, step)

    d0 = layers.Concatenate()([d1, input_image])
    outputs = layers.Conv2D(64, (3, 3), padding="same")(d0)
    outputs = layers.Conv2D(3, (1, 1))(outputs)  # output is a 3-channel image (RGB)

    model = tf.keras.Model((input_image, input_step), outputs, name='Diffusion_UNetCOCOv2')
    return model

# FACE

def build_unetFACEv2(input_shape):
    input_image = layers.Input(input_shape) # 128x128x3
    step = layers.Input((1,))

    # Encoder
    p1 = layers.AveragePooling2D((2, 2))(input_image) # 64x64x3
    s11 = conv_block_time(p1, 128, step)
    s12 = conv_block_time(s11, 128, step)

    p2 = layers.AveragePooling2D((2, 2))(s12) # 32x32x64
    s21 = conv_block_time(p2, 256, step)
    s22 = conv_block_time(s21, 256, step)

    p3 = layers.AveragePooling2D((2, 2))(s22) # 16x16x128
    s31 = conv_block_time(p3, 512, step)
    s32 = conv_block_time(s31, 512, step)

    # Bottleneck
    p4 = layers.AveragePooling2D((2, 2))(s32) # 8x8x256
    
    b = conv_block_time(p4, 1024, step)
    b = conv_block_time(b, 1024, step)

    # Decoder
    d3 = layers.UpSampling2D((2, 2))(b)
    d3 = layers.Concatenate()([d3, s32])
    d31 = conv_block_time(d3, 512, step)
    d32 = conv_block_time(d31, 512, step)

    d2 = layers.UpSampling2D((2, 2))(d32)
    d2 = layers.Concatenate()([d2, s22])
    d21 = conv_block_time(d2, 256, step)
    d22 = conv_block_time(d21, 256, step)

    d1 = layers.UpSampling2D((2, 2))(d22)
    d1 = layers.Concatenate()([d1, s12])
    d11 = conv_block_time(d1, 128, step)
    d12 = conv_block_time(d11, 128, step)

    d0 = layers.UpSampling2D((2, 2))(d12)
    d0 = layers.Concatenate()([d0, input_image])
    outputs = layers.Conv2D(128, (3, 3), padding="same")(d0)
    outputs = layers.Conv2D(3, (1, 1))(outputs)  # output is a 3-channel image (RGB)

    model = tf.keras.Model((input_image, step), outputs, name='Diffusion_UNetFACEv2')
    return model

def build_unetFACEv3(input_shape):
    input_image = layers.Input(input_shape) # 128x128x3
    step = layers.Input((1,))

    # Encoder
    p1 = layers.MaxPooling2D((2, 2))(input_image) # 64x64x3
    s11 = conv_block_time(p1, 128, step)
    s12 = conv_block_time(s11, 128, step)

    p2 = layers.MaxPooling2D((2, 2))(s12) # 32x32x64
    s21 = conv_block_time(p2, 256, step)
    s22 = conv_block_time(s21, 256, step)

    p3 = layers.MaxPooling2D((2, 2))(s22) # 16x16x128
    s31 = conv_block_time(p3, 512, step)
    s32 = conv_block_time(s31, 512, step)

    # Bottleneck MLP
    p4 = layers.MaxPooling2D((2, 2))(s32) # 8x8x256
    
    x = layers.Flatten()(p4)
    x = layers.Concatenate()([x, step])
    x = layers.Dense(2048, activation='leaky_relu')(x)
    x = layers.BatchNormalization()(x)
    
    x = layers.Dense(8*8*128, activation='leaky_relu')(x)
    x = layers.Reshape((8, 8, 128))(x)
    b = layers.BatchNormalization()(x) # Bottleneck
    
    # Decoder
    d3 = layers.UpSampling2D((2, 2))(b)
    d3 = layers.Concatenate()([d3, s32])
    d31 = conv_block_time(d3, 512, step)
    d32 = conv_block_time(d31, 512, step)

    d2 = layers.UpSampling2D((2, 2))(d32)
    d2 = layers.Concatenate()([d2, s22])
    d21 = conv_block_time(d2, 256, step)
    d22 = conv_block_time(d21, 256, step)

    d1 = layers.UpSampling2D((2, 2))(d22)
    d1 = layers.Concatenate()([d1, s12])
    d11 = conv_block_time(d1, 128, step)
    d12 = conv_block_time(d11, 128, step)

    d0 = layers.UpSampling2D((2, 2))(d12)
    d0 = layers.Concatenate()([d0, input_image])
    outputs = layers.Conv2D(128, (3, 3), padding="same")(d0)
    outputs = layers.Conv2D(3, (1, 1))(outputs)  # output is a 3-channel image (RGB)

    model = tf.keras.Model((input_image, step), outputs, name='Diffusion_UNetFACEv2')
    return model

def build_unetFACEv4(input_shape):
    '''
    Best val loss : 0.023802
    '''
    input_image = layers.Input(input_shape) # 128x128x3
    input_image = layers.Dropout(0.10)(input_image)
    step = layers.Input((1,))

    # Encoder
    p1 = layers.MaxPooling2D((2, 2))(input_image) # 64x64x3
    c1 = p1
    for _ in range(4):
        c1 = resnet_block_film(c1, 64, 2, step)

    p2 = layers.MaxPooling2D((2, 2))(c1) # 32x32x64
    c2 = p2
    for _ in range(4):
        c2 = resnet_block_film(p2, 128, 2, step)


    p3 = layers.MaxPooling2D((2, 2))(c2) # 16x16x128
    c3 = p3
    for _ in range(4):
        c3 = resnet_block_film(c3, 256, 2, step)
        #c3 = attention_with_step(c3, step, 16*16, 8, 2)

    p4 = layers.MaxPooling2D((2, 2))(c3) # 8x8x256
    c4 = p4
    for _ in range(4):
        c4 = resnet_block_film(c4, 512, 2, step)
    

    # Bottleneck
    b_num_filters = 256
    b_step = layers.Dense(512, activation='leaky_relu')(step)
    b = layers.Conv2D(b_num_filters, (1, 1))(c4)
    b = layers.Flatten()(b)
    b = layers.Concatenate()([b, b_step])
    b = layers.Dropout(0.45)(b)

    b = layers.Dense(2024, activation='leaky_relu')(b) # True bottleneck
    b = layers.BatchNormalization()(b)
    b = layers.Dropout(0.45)(b)

    b = layers.Dense(2024*4, activation='leaky_relu')(b)
    b = layers.BatchNormalization()(b)
    b = layers.Dropout(0.45)(b)
    
    b = layers.Dense(2024, activation='leaky_relu')(b) # True bottleneck
    b = layers.BatchNormalization()(b)
    b = layers.Dropout(0.45)(b)

    b = layers.Dense(8*8*b_num_filters, activation='leaky_relu')(b)
    b = layers.Reshape((8, 8, b_num_filters))(b)
    b = layers.BatchNormalization()(b)
    b = layers.Dropout(0.45)(b)

    # Decoder
    d4 = b
    d4 = layers.Concatenate()([d4, c4])
    for _ in range(4):
        d4 = resnet_block_film(d4, 512, 2, step)
    
    d3 = layers.UpSampling2D((2, 2))(d4)
    d3 = layers.Concatenate()([d3, c3])
    for _ in range(4):
        d3 = resnet_block_film(d3, 256, 2, step)
        #d3 = attention_with_step(d3, step, 16*16, 8, 2)

    d2 = layers.UpSampling2D((2, 2))(d3)
    d2 = layers.Concatenate()([d2, c2])
    for _ in range(4):
        d2 = resnet_block_film(d2, 128, 2, step)

    d1 = layers.UpSampling2D((2, 2))(d2)
    d1 = layers.Concatenate()([d1, c1])
    for _ in range(4):
        d1 = resnet_block_film(d1, 64, 2, step)

    d0 = layers.UpSampling2D((2, 2))(d1)
    outputs = layers.Concatenate()([d0, input_image])
    for _ in range(4):
        outputs = resnet_block_film(outputs, 64, 2, step)
    outputs = layers.Conv2D(3, (1, 1))(outputs)  # output is a 3-channel image (RGB)

    model = tf.keras.Model((input_image, step), outputs, name='Diffusion_UNetFACEv4')
    description = 'CNN U-Net with FiLM layers and ResNet blocks and a bottleneck MLP'
    ModelManager.save_model(model, description)
    return model

def build_unetFACEv5(input_shape):
    """
    Best val loss : 0.025702
    """
    input_image = layers.Input(input_shape) # 128x128x3 
    step = layers.Input((1,))

    # Encoder
    p1 = layers.MaxPooling2D((2, 2))(input_image) # 64x64x3
    c1 = p1
    for _ in range(4):
        c1 = resnet_block_film(c1, 64, 2, step)

    p2 = layers.MaxPooling2D((2, 2))(c1) # 32x32x64
    c2 = p2
    for _ in range(4):
        c2 = resnet_block_film(p2, 128, 2, step)


    positions_256 = tf.expand_dims(tf.range(0, 256, 1), axis=0)
    positionnal_encoding_256 = layers.Embedding(
    input_dim=256,  # Maximum index + 1
    output_dim=16*16,  # Each position index will be embedded into a n-dimensional vector
)(positions_256)
    positionnal_encoding_256 = layers.Reshape((16, 16, 256))(positionnal_encoding_256)

    p3 = layers.MaxPooling2D((2, 2))(c2) # 16x16x128
    c3 = p3
    for _ in range(5):
        c3 = resnet_block_film(c3, 256, 2, step)
        c3 = layers.Add()([c3, positionnal_encoding_256])
        c3 = layers.Permute((3, 1, 2))(c3) # Permute to match the shape of the attention mechanism (nums_dims, h, w)
        c3 = layers.Reshape((c3.shape[1], c3.shape[2] * c3.shape[3]))(c3) # Reshape to match the shape of the attention mechanism (num_dims, h*w)
        c3 = self_attention_block(c3, 16*16, 8, 2)
        c3 = layers.Permute((2, 1))(c3)
        c3 = layers.Reshape((16, 16, 256))(c3) # Reshape to match the shape of the input tensor

    positions_512 = tf.expand_dims(tf.range(0, 512, 1), axis=0)
    positionnal_encoding_512 = layers.Embedding(
    input_dim=512,  # Maximum index + 1
    output_dim=8*8,  # Each position index will be embedded into a 1-dimensional vector
)(positions_512)
    positionnal_encoding_512 = layers.Reshape((8, 8, 512))(positionnal_encoding_512)

    p4 = layers.MaxPooling2D((2, 2))(c3) # 8x8x256
    c4 = p4
    for _ in range(6):
        c4 = resnet_block_film(c4, 512, 2, step)
        c4 = layers.Add()([c4, positionnal_encoding_512])
        c4 = layers.Permute((3, 1, 2))(c4) 
        c4 = layers.Reshape((c4.shape[1], c4.shape[2] * c4.shape[3]))(c4)
        c4 = self_attention_block(c4, 8*8, 8, 2)
        c4 = layers.Permute((2, 1))(c4)
        c4 = layers.Reshape((8, 8, 512))(c4)
    

    # Bottleneck
    b_num_filters = 256
    b_step = layers.Dense(512, activation='leaky_relu')(step)
    b = layers.Conv2D(b_num_filters, (1, 1))(c4)
    b = layers.Flatten()(b)
    b = layers.Concatenate()([b, b_step])
    b = layers.Dropout(0.35)(b)

    b = layers.Dense(1024, activation='leaky_relu')(b) # True bottleneck
    b = layers.BatchNormalization()(b)
    b = layers.Dropout(0.35)(b)

    b = layers.Dense(1024*4, activation='leaky_relu')(b)
    b = layers.BatchNormalization()(b)
    b = layers.Dropout(0.35)(b)
    
    b = layers.Dense(1024, activation='leaky_relu')(b) # True bottleneck
    b = layers.BatchNormalization()(b)
    b = layers.Dropout(0.35)(b)

    b = layers.Dense(8*8*b_num_filters, activation='leaky_relu')(b)
    b = layers.Reshape((8, 8, b_num_filters))(b)
    b = layers.BatchNormalization()(b)
    b = layers.Dropout(0.35)(b)

    # Decoder
    d4 = b
    d4 = layers.Concatenate()([d4, c4])
    for _ in range(6):
        d4 = resnet_block_film(d4, 512, 2, step)
        d4 = layers.Add()([d4, positionnal_encoding_512])
        d4 = layers.Permute((3, 1, 2))(d4)
        d4 = layers.Reshape((d4.shape[1], d4.shape[2] * d4.shape[3]))(d4)
        d4 = self_attention_block(d4, 8*8, 8, 2)
        d4 = layers.Permute((2, 1))(d4)
        d4 = layers.Reshape((8, 8, 512))(d4)
    
    d3 = layers.UpSampling2D((2, 2))(d4)
    d3 = layers.Concatenate()([d3, c3])
    for _ in range(5):
        d3 = resnet_block_film(d3, 256, 2, step)
        d3 = layers.Add()([d3, positionnal_encoding_256])
        d3 = layers.Permute((3, 1, 2))(d3)
        d3 = layers.Reshape((d3.shape[1], d3.shape[2] * d3.shape[3]))(d3)
        d3 = self_attention_block(d3, 16*16, 8, 2)
        d3 = layers.Permute((2, 1))(d3)
        d3 = layers.Reshape((16, 16, 256))(d3)

    d2 = layers.UpSampling2D((2, 2))(d3)
    d2 = layers.Concatenate()([d2, c2])
    for _ in range(4):
        d2 = resnet_block_film(d2, 128, 2, step)

    d1 = layers.UpSampling2D((2, 2))(d2)
    d1 = layers.Concatenate()([d1, c1])
    for _ in range(4):
        d1 = resnet_block_film(d1, 64, 2, step)

    d0 = layers.UpSampling2D((2, 2))(d1)
    outputs = layers.Concatenate()([d0, input_image])
    for _ in range(4):
        outputs = resnet_block_film(outputs, 64, 2, step)
    outputs = layers.Conv2D(3, (1, 1))(outputs)  # output is a 3-channel image (RGB)

    model = tf.keras.Model((input_image, step), outputs, name='Diffusion_UNetFACEv5')
    description = 'CNN U-Net with FiLM layers and ResNet blocks and a attention mechanism bottleneck MLP'
    ModelManager.save_model(model, description)
    return model

def build_unetFACEv6(input_shape):
    input_image = layers.Input(input_shape)
    input_image = layers.Dropout(0.00)(input_image)
    step = layers.Input((1,))
    step_expanded = tf.reduce_sum(step)

    # Encoder
    p0 = input_image
    for _ in range(4):
        p0 = resnet_block(p0, 64, 3)
    p1 = layers.Conv2D(128, (3, 3), padding="same", strides=2)(p0)
    p1 = layers.BatchNormalization()(p1)
    p1 = layers.Activation("leaky_relu")(p1)

    for _ in range(4):
        p1 = resnet_block(p1, 128, 3)
    p2 = layers.Conv2D(256, (3, 3), padding="same", strides=2)(p1)
    p2 = layers.BatchNormalization()(p2)
    p2 = layers.Activation("leaky_relu")(p2)

    for _ in range(4):
        p2 = resnet_block(p2, 256, 3)
    p3 = layers.Conv2D(512, (3, 3), padding="same", strides=2)(p2)
    p3 = layers.BatchNormalization()(p3)
    p3 = layers.Activation("leaky_relu")(p3)

    for _ in range(4):
        p3 = resnet_block(p3, 512, 3)
    
    # Bottleneck
    b = layers.Conv2D(1024, (3, 3), padding="same", strides=2)(p3)
    b = layers.BatchNormalization()(b)
    b = layers.Activation("leaky_relu")(b)
    for _ in range(2):
        b = resnet_block(b, 1024, 3)
    b = layers.Conv2D(1024, (3, 3), padding="same")(b)
    b = layers.BatchNormalization()(b)
    b = layers.Activation("leaky_relu")(b)

    # Decoder

    d4 = layers.Conv2DTranspose(512, (3, 3), strides=2, padding="same")(b)
    d4 = layers.Concatenate()([d4, p3])
    for _ in range(4):
        d4 = resnet_block(d4, 512, 3)

    d3 = layers.Conv2DTranspose(256, (3, 3), strides=2, padding="same")(d4)
    d3 = layers.Concatenate()([d3, p2])
    for _ in range(4):
        d3 = resnet_block(d3, 256, 3)

    d2 = layers.Conv2DTranspose(128, (3, 3), strides=2, padding="same")(d3)
    d2 = layers.Concatenate()([d2, p1])
    for _ in range(4):
        d2 = resnet_block(d2, 128, 3)
    
    d1 = layers.Conv2DTranspose(64, (3, 3), strides=2, padding="same")(d2)
    d1 = layers.Concatenate()([d1, p0])
    for _ in range(4):
        d1 = resnet_block(d1, 64, 3)
    

    outputs = layers.Conv2D(3, (1, 1))(d1)  # output is a 3-channel image (RGB)
    outputs = outputs + 0 * step_expanded

    model = tf.keras.Model((input_image, step), outputs, name='Diffusion_UNetFACEv6')
    return model

# MNIST

def build_unetMNIST(input_shape):
    inputs = layers.Input(input_shape)

    # Encoder
    s1, p1 = encoder_block(inputs, 16)
    s2, p2 = encoder_block(p1, 32)
    s3, p3 = encoder_block(p2, 64)

    # Bottleneck
    b = conv_block(p3, 128)

    # Decoder
    d1 = decoder_block(b, s3, 64)
    d2 = decoder_block(d1, s2, 32)
    d3 = decoder_block(d2, s1, 16)

    d4 = layers.Concatenate()([d3, inputs])
    outputs = layers.Conv2D(1, (1, 1))(d4)  # output is a 1-channel image (grayscale)

    model = tf.keras.Model(inputs, outputs, name='Diffusion_UNetMNIST')
    return model

def build_unetMNISTv2(input_shape):
    input_image = layers.Input(input_shape)

    # Encoder step
    input_step = layers.Input((1,))
    step = layers.Dense(512, activation='leaky_relu')(input_step)

    # Encoder image
    s11 = conv_block_time(input_image, 64, step)
    s12 = conv_block_time(s11, 64, step)
    p1 = layers.MaxPooling2D((2, 2))(s12)

    s21 = conv_block_time(p1, 128, step)
    s22 = conv_block_time(s21, 128, step)
    p2 = layers.MaxPooling2D((2, 2))(s22)

    s31 = conv_block_time(p2, 256, step)
    s32 = conv_block_time(s31, 256, step)
    p3 = layers.MaxPooling2D((2, 2))(s32)

    s41 = conv_block_time(p3, 512, step)
    s42 = conv_block_time(s41, 512, step)
    p4 = layers.MaxPooling2D((2, 2))(s42)

    # Bottleneck MLP
    x = layers.Flatten()(p4)
    x = layers.Concatenate()([x, step])
    x = layers.Dense(512, activation='leaky_relu')(x)
    x = layers.BatchNormalization()(x)

    x = layers.Dense(2*2*512, activation='leaky_relu')(x)
    x = layers.Reshape((2, 2, 512))(x)
    b = layers.BatchNormalization()(x) # Bottleneck

    # Decoder
    d41 = conv_block_time(b, 512, step)
    d42 = conv_block_time(d41, 512, step)
    d4 = layers.UpSampling2D((2, 2))(d42)
    d4 = layers.Concatenate()([d4, s42])

    d31 = conv_block_time(d4, 256, step)
    d32 = conv_block_time(d31, 256, step)
    d3 = layers.UpSampling2D((2, 2))(d32)
    d3 = layers.Concatenate()([d3, s32])

    d21 = conv_block_time(d3, 128, step)
    d22 = conv_block_time(d21, 128, step)
    d2 = layers.UpSampling2D((2, 2))(d22)
    d2 = layers.Concatenate()([d2, s22])

    d11 = conv_block_time(d2, 64, step)
    d12 = conv_block_time(d11, 64, step)
    d1 = layers.UpSampling2D((2, 2))(d12)
    d1 = layers.Concatenate()([d1, s12])

    d0 = layers.Concatenate()([d1, input_image])
    outputs = layers.Conv2D(64, (3, 3), padding="same")(d0)
    outputs = layers.Conv2D(1, (1, 1))(outputs)  # output is a 1-channel image (grayscale)

    model = tf.keras.Model((input_image,input_step), outputs, name='Diffusion_UNetMNISTv2')
    return model

def build_unetMNISTv2_nostep(input_shape):
    input_image = layers.Input(input_shape)

    # Encoder step
    input_step = layers.Input((1,))

    # Encoder image
    s11 = conv_block(input_image, 64)
    s12 = conv_block(s11, 64)
    p1 = layers.MaxPooling2D((2, 2))(s12)

    s21 = conv_block(p1, 128)
    s22 = conv_block(s21, 128)
    p2 = layers.MaxPooling2D((2, 2))(s22)

    s31 = conv_block(p2, 256)
    s32 = conv_block(s31, 256)
    p3 = layers.MaxPooling2D((2, 2))(s32)

    s41 = conv_block(p3, 512)
    s42 = conv_block(s41, 512)
    p4 = layers.MaxPooling2D((2, 2))(s42)

    # Bottleneck MLP
    x = layers.Flatten()(p4)
    x = layers.Dense(512, activation='leaky_relu')(x)
    x = layers.BatchNormalization()(x)

    x = layers.Dense(2*2*512, activation='leaky_relu')(x)
    x = layers.Reshape((2, 2, 512))(x)
    b = layers.BatchNormalization()(x) # Bottleneck

    # Decoder
    d41 = conv_block(b, 512)
    d42 = conv_block(d41, 512)
    d4 = layers.UpSampling2D((2, 2))(d42)
    d4 = layers.Concatenate()([d4, s42])

    d31 = conv_block(d4, 256)
    d32 = conv_block(d31, 256)
    d3 = layers.UpSampling2D((2, 2))(d32)
    d3 = layers.Concatenate()([d3, s32])

    d21 = conv_block(d3, 128)
    d22 = conv_block(d21, 128)
    d2 = layers.UpSampling2D((2, 2))(d22)
    d2 = layers.Concatenate()([d2, s22])

    d11 = conv_block(d2, 64)
    d12 = conv_block(d11, 64)
    d1 = layers.UpSampling2D((2, 2))(d12)

    d0 = layers.Concatenate()([d1, input_image])
    outputs = layers.Conv2D(32, (3, 3), padding="same")(d0)
    outputs = layers.Conv2D(1, (1, 1))(outputs)  # output is a 1-channel image (grayscale)

    model = tf.keras.Model((input_image,input_step), outputs, name='build_unetMNISTv2_nostep')
    return model

def build_unetMNISTv3(input_shape):
    input_image = layers.Input(input_shape)

    # Encoder step
    input_step = layers.Input((1,))
    step = layers.Dense(256, activation='leaky_relu')(input_step)

    # Encoder
    s11 = conv_block(input_image, 64)
    s12 = conv_block(s11, 64)
    p1 = layers.MaxPooling2D((2, 2))(s12)

    s21 = conv_block(p1, 128)
    s22 = conv_block(s21, 128)
    p2 = layers.MaxPooling2D((2, 2))(s22)

    # Self-attention Bottleneck
    s31 = self_attention_block(p2, 256)
    s32 = self_attention_block(s31, 256)

    # Decoder
    d11 = conv_block(s32, 128)
    d12 = conv_block(d11, 128)
    d1 = layers.Concatenate()([d12, p2])
    d1 = layers.UpSampling2D((2, 2))(d1)

    d21 = conv_block(d1, 64)
    d22 = conv_block(d21, 64)
    d2 = layers.Concatenate()([d22, p1])
    d2 = layers.UpSampling2D((2, 2))(d2)

    d3 = layers.Concatenate()([d2, input_image])
    output = layers.Conv2D(64, (3, 3), padding="same")(d3)
    output = layers.Conv2D(1, (1, 1))(output)  # output is a 1-channel image (grayscale)

    model = tf.keras.Model((input_image, input_step), output, name='Diffusion_UNetMNISTv3')
    return model


# Construire le modèle
if CHOSEN_DATASET == 'coco':
    model = build_unetCOCOv2(INPUT_SHAPE)
elif CHOSEN_DATASET == 'face':
    model = build_unetFACEv6(INPUT_SHAPE)
else:
    model = build_unetMNISTv2(INPUT_SHAPE)
model.summary()

### Plot du modèle

In [None]:
plot_model(model, show_shapes=True, show_layer_names=True, to_file=model.name+'.png')

### Training

In [None]:
# Training function
@tf.function
def train_step(model, noisy_images, steps, noises, optimizer):
    with tf.GradientTape() as tape:
        predictions = model((noisy_images, steps), training=True)
        predictions = tf.cast(predictions, tf.float64)
        loss = loss_fn(noises, predictions) # y_true = noises, y_pred = predictions
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

# Training loop
epochs_losses = []
val_epochs_losses = []
min_loss = np.inf
counter = 0
model_save_name = f"{model.name}_{datetime.datetime.now().strftime('%Y%m%d-%H%M')}.keras"
for epoch in range(EPOCHS):
    # Entrainement du modèle
    with tqdm(total=len(train_generator), desc=f'Epoch {epoch+1}/{EPOCHS}', unit='batch') as pbar:
        epoch_losses = []
        for step in range(len(train_generator)) :
            # Récupérer les images et les bruits
            noisy_images, steps, noises = train_generator[step]
            # Training step
            loss = train_step(model, noisy_images, steps, noises, OPTIMIZER)
            # Métriques
            epoch_losses.append(loss.numpy())
            epoch_loss = np.mean(epoch_losses)
            pbar.set_postfix(Loss=f"{epoch_loss:.6f}")
            pbar.update()
    train_generator.on_epoch_end()
    epochs_losses.append(np.mean(epoch_losses))
    # Validation
    with tqdm(total=len(val_generator), desc=f'Validation {epoch+1}/{EPOCHS}', unit='batch') as pbar:
        val_losses = []
        for step in range(len(val_generator)):
            noisy_images, steps, noises = val_generator[step]
            predictions = model.predict((noisy_images, steps), verbose=0)
            predictions = tf.cast(predictions, tf.float64)
            loss = loss_fn(noises, predictions)
            # Métriques
            val_losses.append(loss.numpy())
            pbar.set_postfix(Loss=f"{np.mean(val_losses):.6f}")
            pbar.update()
    val_epochs_losses.append(np.mean(val_losses))
    # Early stopping
    if val_epochs_losses[-1] >= min_loss and epoch > 0:
        counter += 1
        if counter >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break
    else:
        counter = 0
        min_loss = val_epochs_losses[-1]
    # Affichage d'un test d'inférence
    plot_inference()


### Loss curves

## Evaluation

In [None]:
# Fonction pour calculer les pertes pour chaque batch d'un générateur
def calculate_losses(generator):
    losses = []
    with tqdm(total=len(generator), desc=f'Calculating Losses {generator.ensemble}', unit='batch') as pbar:
        for step in range(len(generator)):
            noisy_images, step, noises = generator[step]
            pred_noises = model.predict((noisy_images, step), verbose=0)
            pred_noises = tf.cast(pred_noises, tf.float64)
            loss = loss_fn(noises, pred_noises)
            losses.append(loss.numpy())
            pbar.update()
    return losses

# Calcul des pertes pour chaque générateur
train_losses = calculate_losses(train_generator)
val_losses = calculate_losses(val_generator)
test_losses = calculate_losses(test_generator)

# Tracer les pertes en densité
plt.figure(figsize=(12, 6))
sns.kdeplot(train_losses, label='Train Losses', fill=True)
sns.kdeplot(val_losses, label='Validation Losses', fill=True)
sns.kdeplot(test_losses, label='Test Losses', fill=True)
plt.xlabel('Loss')
plt.ylabel('Density')
plt.title('Density Plot of Losses')
plt.legend()
plt.show()

# Calculer la moyenne et l'écart type des pertes
train_mean, train_std = np.mean(train_losses), np.std(train_losses)
val_mean, val_std = np.mean(val_losses), np.std(val_losses)
test_mean, test_std = np.mean(test_losses), np.std(test_losses)

print(f'Train Loss - Mean: {train_mean:.6f}, Std: {train_std:.6f}')
print(f'Validation Loss - Mean: {val_mean:.6f}, Std: {val_std:.6f}')
print(f'Test Loss - Mean: {test_mean:.6f}, Std: {test_std:.6f}')

## Tests

### Chargement du modèle

In [None]:
model = load_model('models\diffusion\diffusion_face_152M_24-11-24\diffusion_face_152M_24-11-24.keras', compile=False)
model.summary()

### Test itératif sur un bruit gaussien

In [None]:
plot_inference()

### Test itératif sur un bruit gaussien (version fluide)

In [None]:
plot_inference_gif()

### Test Interpolation on 2 images

In [None]:
# Settings
lambda_ = 0.5 # ratio de mélange entre les deux images
chosen_step = 0.55 # étape de diffusion choisie (entre 0 => 0 et 1 => T)

# Récupération des images
images = test_generator.get_random_images(2)
image1, image2 = images
image1 = np.expand_dims(image1, axis=0)
image2 = np.expand_dims(image2, axis=0)
chosen_step = np.array([int(chosen_step * STEPS)])
# Mixage des deux images
mixed_image = lambda_*image1 + (1-lambda_)*image2
# Ajout de bruit à l'image mixée
noisy_mixed_image = NoiceScheduler.add_noise(mixed_image, chosen_step)
# Diffusion de l'image mixée
diffused_image = inference(image=noisy_mixed_image[0], num_steps=int(chosen_step))   
# Plotting the original images, noisy images, and the diffused image
fig, axes = plt.subplots(1, 5, figsize=(20, 5))

# Original images
axes[0].imshow((image1[0] + 1) / 2, cmap='gray')
axes[0].set_title('Original Image 1')
axes[0].axis('off')
axes[1].imshow((image2[0] + 1) / 2, cmap='gray')
axes[1].set_title('Original Image 2')
axes[1].axis('off')

# Mixed image
axes[2].imshow((mixed_image[0] + 1) / 2, cmap='gray')
axes[2].set_title('Mixed Image')
axes[2].axis('off')

# Mixed noisy image
axes[3].imshow(np.clip((noisy_mixed_image[0] + 1) / 2, 0, 1), cmap='gray')
axes[3].set_title('Mixed Noisy Image')
axes[3].axis('off')

# Plotting the diffused image
axes[4].imshow(np.clip((diffused_image + 1) / 2, 0, 1), cmap='gray')
axes[4].set_title('Diffused Image')
axes[4].axis('off')

plt.show()


### Test Interpolation on >2 images

In [None]:
num_images = 5
images = test_generator.get_random_images(num_images)
