# Diffusion
Dataset : MSCOCO 2014

## Imports

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.models import save_model, load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
from tensorflow.keras.utils import plot_model, Sequence
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import random as r
from tqdm import tqdm # progress bar
import time
import datetime
from IPython.display import clear_output
from pycocotools.coco import COCO
from PIL import Image

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"GPUs detected: {len(gpus)}")
        print(f"GPUs: {gpus}")
    except RuntimeError as e:
        print(e)
else:
    print("No GPUs detected")

### Constants

In [None]:
# Paths
ANNOTDIR = 'annotations_trainval2014'
DATADIR = 'train2014'
INSTANCEFILE = '{}/annotations/instances_{}.json'.format(ANNOTDIR, DATADIR)

INPUT_SHAPE = (224, 224, 3)
BATCH_SIZE = 32
STEPS = 1000 # Nombre d'étapes de diffusion
BETA_1 = 1e-4
BETA_T = 1e-2
PATH_DATASETS = ['stable-diffusion-face-dataset/512/man','stable-diffusion-face-dataset/512/woman']
TRAIN_SPLIT = 0.8
VALIDATION_SPLIT = 0.15
TEST_SPLIT = 0.05
EPOCHS = 1000 # L'early stopping est utilisé
PATIENCE = 2
COCO_INSTANCES = COCO(INSTANCEFILE)
model = None

# Configuration de l'augmentation
DATA_GEN = ImageDataGenerator(
    rotation_range=1,
    horizontal_flip=True,
)
BASE_GEN = ImageDataGenerator()

class LinearNoiceScheduler:
    def __init__(self):
        self.betas = np.linspace(BETA_1, BETA_T, STEPS)
        self.alphas = 1 - self.betas
        self.sqrt_alphas = np.sqrt(self.alphas)
        self.c_alphas = np.cumprod(self.alphas)
        self.sqrt_c_alphas = np.sqrt(self.c_alphas)
        self.sqrt_one_minus_c_alphas = np.sqrt(1 - self.c_alphas)
    
    def add_noise(self, original_images, steps):
        original_shape = original_images.shape
        noise = np.random.normal(size=original_shape)

        sqrt_alpha_cumprod = self.sqrt_c_alphas[steps]
        sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_c_alphas[steps]
        sqrt_alpha_cumprod = sqrt_alpha_cumprod[ :, np.newaxis, np.newaxis, np.newaxis]
        sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[ :, np.newaxis, np.newaxis, np.newaxis]

        noisy_image = sqrt_alpha_cumprod*original_images + sqrt_one_minus_alpha_cumprod*noise
        return noisy_image

    def remove_noise(self, noisy_image, pred_noise, step, stability_factor=True):
        # Récupération des paramètres pour l'étape t
        alpha_t = self.alphas[step]
        sqrt_alpha_t = np.sqrt(alpha_t)
        sqrt_one_minus_cum_alpha_t = self.sqrt_one_minus_c_alphas[step]

        alpha_factor = ( 1 - alpha_t ) / sqrt_one_minus_cum_alpha_t
        
        # Calcul de l'image mise à jour x_{t-1}
        updated_image = (1 / sqrt_alpha_t) * (noisy_image - (alpha_factor * pred_noise))
        
        # Ajout du terme de stabilité pour les étapes non finales
        if step > 0 and stability_factor:
            sigma_t = np.sqrt(self.betas[step])
            noise = np.random.normal(size=noisy_image.shape)
            updated_image += sigma_t * noise
        
        return updated_image

def inference():
    res = []
    image = np.random.normal(size=INPUT_SHAPE)
    image = np.expand_dims(image, axis=0)
    res.append((STEPS, image))
    with tqdm(total=STEPS, desc='Inference', unit='step') as pbar:
        for step in range(STEPS-1, -1, -1):
            noise_pred = model.predict(image, verbose=0)
            image = NoiceScheduler.remove_noise(image, noise_pred, step)
            pbar.update()
            if step % 100 == 0:
                res.append((step, image))
    return res

def plot_inference():
    result = inference()
    fig, axes = plt.subplots(1, len(result), figsize=(20, 5))
    for i, res in enumerate(result):
        step, img = res
        img = (img + 1) / 2 # Convertion de -1, 1 à 0, 1
        axes[i].imshow(np.clip(img[0], 0, 1))
        axes[i].set_title(f'{step}')
        axes[i].axis('off')
    plt.show()

NoiceScheduler = LinearNoiceScheduler()


## Data préprosessing

### Loading of dataset

In [None]:
class DatasetGenerator(Sequence):
    def _getsplit(self, ensemble):
        if ensemble == 'train':
            start = 0
            stop = int(TRAIN_SPLIT * len(self.imgIds))
        elif ensemble == 'val':
            start = int(TRAIN_SPLIT * len(self.imgIds))
            stop = int((TRAIN_SPLIT + VALIDATION_SPLIT) * len(self.imgIds))
        elif ensemble == 'test':
            start = int((TRAIN_SPLIT + VALIDATION_SPLIT) * len(self.imgIds))
            stop = len(self.imgIds)
        return start, stop

    def __init__(self, ensemble, **kwargs):
        super().__init__(**kwargs)
        self.ensemble = ensemble
        
        # Créer une liste de tous les IDs d'images
        self.imgIds = COCO_INSTANCES.getImgIds()
        start, stop = self._getsplit(ensemble)
        self.ids = self.imgIds[start:stop]

    def __len__(self):
        return int(np.ceil(len(self.ids) / BATCH_SIZE))

    def __getitem__(self, index):
        batch_ids = self.ids[index * BATCH_SIZE : (index + 1) * BATCH_SIZE]
        batch_images = []
        batch_labels = []
        for id in batch_ids:
            # Charger l'image
            file_name = COCO_INSTANCES.imgs[id]['file_name']
            image = Image.open(f'{DATADIR}/{file_name}')
            image = image.resize((224, 224))
            image = image.convert('RGB')
            image = img_to_array(image)
            image = (image / 255.0) * 2 - 1 # Normalisation entre -1 et 1
            if self.ensemble == 'train':
                image = DATA_GEN.random_transform(image)
            else :
                image = BASE_GEN.random_transform(image)
            batch_images.append(image)
            # Créer le label en rajoutant du bruit
            step = np.array([np.random.randint(1, STEPS)])
            label = NoiceScheduler.add_noise(image, step)
            label = label[0]
            batch_labels.append(label)

        batch_labels = np.array(batch_labels, dtype='float64')
        batch_images = np.array(batch_images, dtype='float64')

        return (batch_images, batch_labels)

    def on_epoch_end(self):
        if self.ensemble == 'train':
            np.random.shuffle(self.ids)

train_generator = DatasetGenerator('train')
val_generator = DatasetGenerator('val')
test_generator = DatasetGenerator('test')

print(f'Taille du dataset d\'entrainement: {len(train_generator)} batches, {len(train_generator.ids)} items')
print(f'Taille du dataset de validation: {len(val_generator)} batches, {len(val_generator.ids)} items')
print(f'Taille du dataset de test: {len(test_generator)} batches, {len(test_generator.ids)} items')

### Exemple

In [None]:
# Sélectionner un batch aléatoire
batch_index = np.random.randint(0, len(train_generator))
images, labels = train_generator.__getitem__(batch_index)
print(f'Images shape: {images.shape}')
print(f'Labels shape: {labels.shape}')

# Sélectionner une image aléatoire dans le batch
image_index = np.random.randint(0, len(images))
image = images[image_index] # On prend le premier élément de l'axe des batchs
image = (image + 1) / 2 # On remet les valeurs de pixels entre 0 et 1
label = labels[image_index] # On prend le premier élément de l'axe des batchs
label = (label + 1) / 2 # On remet les valeurs de pixels entre 0 et 1
print(f'Image shape: {image.shape}')
print(f'Label shape: {label.shape}')

# Afficher l'image et son label
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(np.clip(image, 0, 1))
axes[0].set_title('Image')
axes[0].axis('off')

axes[1].imshow(np.clip(label, 0, 1))
axes[1].set_title('Label (Noisy Image)')
axes[1].axis('off')

plt.show()

## Model

### Gaussian noice tests

#### Curves

In [None]:
nb_images = 10

# Sélectionner nb_images aléatoires du dataset
selected_images = data[np.random.choice(len(data), nb_images, replace=False)] / 255.0
noise = np.random.normal(size=selected_images.shape[1:])

# Initialiser les listes pour stocker les moyennes et les écarts types
means = []
stds = []

step_jump = 10

# Calculer la moyenne et l'écart type à chaque étape
for step in range(0, STEPS, step_jump):
    noisy_images = NoiceScheduler.add_noise(selected_images, step)
    means.append(np.mean(noisy_images))
    stds.append(np.std(noisy_images))

# Plot des résultats
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(range(0, STEPS, step_jump), means, label='Mean')
plt.xlabel('Step')
plt.ylabel('Mean')
plt.title('Mean of Noisy Images')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(0, STEPS, step_jump), stds, label='Standard Deviation')
plt.xlabel('Step')
plt.ylabel('Standard Deviation')
plt.title('Standard Deviation of Noisy Images')
plt.legend()
plt.tight_layout()

# On prend quelques images aléatoires
idx = r.randint(1, len(selected_images)-1)
test_images = selected_images[:idx]

# Show noise
plt.figure(figsize=(6, 6))
noise_img = np.clip(noise, 0, 1)
plt.imshow(noise_img)
plt.title('Noise')
plt.axis('off')
# Show noise distribution
plt.figure(figsize=(6, 6))
plt.hist(noise.flatten(), bins=150)
plt.title('Noise Distribution')
plt.xlabel('Intensity')
plt.ylabel('Frequency')

plt.show()



In [None]:
# plot of cumprod_alphas
plt.figure(figsize=(6, 6))
plt.plot(NoiceScheduler.c_alphas)
plt.xlabel('Step')
plt.ylabel('Cumprod alpha')

#### Test on one sample

In [None]:
# Nombre d'images bruitées à afficher
bruitage = 14

test_image = data[0]
test_image = test_image / 255.0
jump = STEPS // bruitage
fig, axes = plt.subplots(1, bruitage+1, figsize=(20, 4))
for slot in range(bruitage+1):
    step = min((slot+1) * jump, STEPS-1)
    step = np.array([step])
    c_alpha = NoiceScheduler.c_alphas[step]
    noisy_images = NoiceScheduler.add_noise(test_image, step)
    noisy_image_plt = np.clip(noisy_images, 0, 1)
    axes[slot].imshow(noisy_image_plt[0])
    axes[slot].axis('off')
    axes[slot].set_title(f'{int(step)} - {float(c_alpha):.2f}')

plt.show()

### Definition du modèle

In [None]:
def conv_block(input_tensor, num_filters):
    x = layers.Conv2D(num_filters, (3, 3), padding="same")(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    
    x = layers.Conv2D(num_filters, (3, 3), padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    
    return x

def encoder_block(input_tensor, num_filters):
    x = conv_block(input_tensor, num_filters)
    p = layers.MaxPooling2D((2, 2))(x)
    return x, p

def decoder_block(input_tensor, skip_features, num_filters):
    x = layers.Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input_tensor)
    x = layers.Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

def build_unet(input_shape):
    inputs = layers.Input(input_shape)

    # Encoder
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    # Bottleneck
    b = conv_block(p4, 1024)

    # Decoder
    d1 = decoder_block(b, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    d5 = layers.Concatenate()([d4, inputs])
    outputs = layers.Conv2D(3, (1, 1))(d5)  # output is a 3-channel image (RGB)

    model = tf.keras.Model(inputs, outputs, name='Diffusion_UNet')
    return model

# Construire le modèle
model = build_unet(INPUT_SHAPE)
model.summary()

### Plot du modèle

In [None]:
plot_model(model, show_shapes=True, show_layer_names=True, to_file=model.name+'.png')

### Training

In [None]:
# Loss function
@tf.function
def loss_fn(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))

# Training function
@tf.function
def train_step(model, labels, noises, optimizer):
    with tf.GradientTape() as tape:
        predictions = model(noises, training=True)
        predictions = tf.cast(predictions, tf.float64)
        loss = loss_fn(noises, predictions) # y_true = noises, y_pred = predictions
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

# Training loop
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
epochs_losses = []
val_epochs_losses = []
min_loss = np.inf
counter = 0
model_save_name = f"{model.name}_{datetime.datetime.now().strftime('%Y%m%d%H%M')}.keras"
for epoch in range(EPOCHS):
    # Entrainement du modèle
    with tqdm(total=len(train_generator), desc=f'Epoch {epoch+1}/{EPOCHS}', unit='batch') as pbar:
        epoch_losses = []
        for step in range(len(train_generator)) :
            # Récupérer les images et les bruits
            images, noises = train_generator.__getitem__(step)
            # Training step
            loss = train_step(model, images, noises, optimizer)
            # Métriques
            epoch_losses.append(loss.numpy())
            epoch_loss = np.mean(epoch_losses)
            pbar.set_postfix(Loss=f"{epoch_loss:.6f}")
            pbar.update()
    train_generator.on_epoch_end()
    epochs_losses.append(np.mean(epoch_losses))
    # Validation
    with tqdm(total=len(val_generator), desc=f'Validation {epoch+1}/{EPOCHS}', unit='batch') as pbar:
        val_losses = []
        for step in range(len(val_generator)):
            images, noises = val_generator.__getitem__(step)
            predictions = model.predict(noises, verbose=0)
            predictions = tf.cast(predictions, tf.float64)
            loss = loss_fn(noises, predictions)
            val_losses.append(loss.numpy())
            pbar.set_postfix(Loss=f"{np.mean(val_losses):.6f}")
            pbar.update()
    val_epochs_losses.append(np.mean(val_losses))
    # Early stopping
    if val_epochs_losses[-1] >= min_loss and epoch > 0:
        counter += 1
        if counter >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break
    else:
        counter = 0
        min_loss = epochs_losses[-1]
        save_model(model, model_save_name)
    # Affichage d'un test d'inférence
    plot_inference()


### Loss curves

## Evaluation

## Tests

### Chargement du modèle

In [None]:
model = load_model('Diffusion_UNet_202411130839.keras', compile=False)

In [None]:
# Test the model
plot_inference()

### Test sur un bruit gaussien

In [None]:
image = np.random.normal(size=INPUT_SHAPE)
image = np.expand_dims(image, axis=0)
for step in range(STEPS-1, -1, -1):
    noise_pred = model.predict(image, verbose=0)
    image = NoiceScheduler.remove_noise(image, noise_pred, step)
    image = (image + 1) / 2
    plt.imshow(np.clip(image[0], 0, 1))
    plt.title(f'{step}')
    plt.axis('off')
    plt.show()
    clear_output(wait=True)
    


### Test itératif sur une image bruité (image fourni par le dataset)

In [None]:
r_index = np.random.randint(0, len(data), 1)
print(r_index)
images = data[r_index] / 255.0
steps = np.random.randint(0, STEPS, size=len(images))
noises = NoiceScheduler.add_noise(images, steps)
for i in range(int(steps)-1, -1, -1):
    noise_pred = model.predict(noises, verbose=0)
    noises = NoiceScheduler.remove_noise(noises, noise_pred, i)
    plt.imshow(np.clip(noises[0], 0, 1))
    plt.title(f'{i}')
    plt.axis('off')
    plt.show()
    clear_output(wait=True)


### Test sur une image bruité sur dataset

In [None]:
r_index = np.random.randint(0, len(data), 1)
images = data[r_index] / 255.0
chosen_steps = np.random.randint(0, STEPS, size=len(images))
noises = NoiceScheduler.add_noise(images, chosen_steps)
noise_pred = model.predict(noises, verbose=0)
denoised_images = images - (noise_pred * ( 1 - NoiceScheduler.c_alphas[chosen_steps] ))
print(f"Chosen steps: {chosen_steps}")
print(f"Cumprod alphas: {NoiceScheduler.c_alphas[chosen_steps]}")
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
axes[0].imshow(np.clip(images[0], 0, 1))
axes[0].set_title('Original')
axes[0].axis('off')
axes[1].imshow(np.clip(noises[0], 0, 1))
axes[1].set_title('Noisy')
axes[1].axis('off')
axes[2].imshow(np.clip(denoised_images[0], 0, 1))
axes[2].set_title('Denoised')
axes[2].axis('off')
axes[3].imshow(np.clip(noise_pred[0], 0, 1))
axes[3].set_title('Predicted Noise')
axes[3].axis('off')
plt.show()


### Test sur une image du dataset

In [None]:
r_index = np.random.randint(0, len(data), 1)
images = np.zeros((1, 512, 512, 3))
chosen_steps = np.random.randint(0, STEPS, size=len(images))
noises = NoiceScheduler.add_noise(images, chosen_steps)
noise_pred = model.predict(noises, verbose=0)
denoised_images = images - (noise_pred * ( 1 - NoiceScheduler.c_alphas[chosen_steps] ))
print(f"Chosen steps: {chosen_steps}")
print(f"Cumprod alphas: {NoiceScheduler.c_alphas[chosen_steps]}")
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
axes[0].imshow(np.clip(images[0], 0, 1))
axes[0].set_title('Original')
axes[0].axis('off')
axes[1].imshow(np.clip(noises[0], 0, 1))
axes[1].set_title('Noisy')
axes[1].axis('off')
axes[2].imshow(np.clip(denoised_images[0], 0, 1))
axes[2].set_title('Denoised')
axes[2].axis('off')
axes[3].imshow(np.clip(noise_pred[0], 0, 1))
axes[3].set_title('Predicted Noise')
axes[3].axis('off')
