## VAE


Useful ressources :
https://www.youtube.com/watch?v=qJeaCHQ1k2w

## Imports, Constants and Global Variables

In [9]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras import backend as K
from tensorflow.keras.losses import BinaryCrossentropy, MeanSquaredError
import numpy as np
import matplotlib.pyplot as plt
import random as r
from sklearn.cluster import KMeans
import os
import cv2
import absl.logging
import logging
from PIL import Image
from tensorflow.keras.utils import plot_model, Sequence
from tqdm import tqdm
np.random.seed(42)

absl.logging.set_verbosity(absl.logging.ERROR)
logging.getLogger('tensorflow').setLevel(logging.WARNING)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(gpus, "Physical GPUs,", logical_gpus, "Logical GPUs")
    except RuntimeError as e:
        print(e)

'''
CONSTANTS
'''

# CONSTANTES
DATASET_FOLDER = 'datasets/vae' # Dataset folder with images
MODEL_FOLDER = 'models/vae' # Folder to save and load models
bce = BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
mse = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)

chosen_dataset = input('Choose the dataset (mnist, celeba, default : mnist) : ') or 'mnist'
if chosen_dataset == 'mnist' :
    DATASET_FOLDER = 'datasets/mnist'
    DATASET_SPLIT = (0.80, 0.15, 0.05) # Split between train, val and test
    IMAGE_SIZE = (28, 28) # Heigth, Width
    MODEL_EPOCHS = 1000 # Number of epoch
    MODEL_BATCH_SIZE = 64 # batch size
    MODEL_LEARNING_RATE = 1e-4
    KL_FACTOR = 1e-5
    MODEL_PATIENCE = 10

elif chosen_dataset == 'celeba' :
    DATASET_FOLDER = 'datasets/celeba'
    DATASET_SPLIT = (0.75,0.20,0.05) # Split between train, val and test
    IMAGE_SIZE = (576, 1024) # Heigth, Width
    MODEL_EPOCHS = 1000 # Number of epoch
    MODEL_BATCH_SIZE = 4 # batch size
    MODEL_LEARNING_RATE = 1e-4
    KL_FACTOR = 1e-5
    MODEL_PATIENCE = 10

else :
    raise ValueError(f'Unknown dataset : {chosen_dataset}')


# Select generator
def select_generator():
    selected_input = input('Choose the generator (train, val, test, default : train) : ')
    if selected_input == '' :
        return train_generator
    elif selected_input == 'train' :
        return train_generator
    elif selected_input == 'val' :
        return val_generator
    elif selected_input == 'test' :
        return test_generator
    else :
        raise ValueError('Unknown subset (train, val, test)')

# Select Model learning rate
def select_learning_rate():
    global MODEL_LEARNING_RATE
    learning_rate = input(f'Choose the learning rate (default : {MODEL_LEARNING_RATE}) : ')
    if learning_rate != '' :
        MODEL_LEARNING_RATE = float(learning_rate)
    return MODEL_LEARNING_RATE


## Dataset

### Load the dataset

In [2]:
# Generator for MNIST dataset
class Generator_MNIST(Sequence):
    mnist = None

    def __init__(self, ensemble):
        super(Generator_MNIST, self).__init__()
        self.ensemble = ensemble
        if Generator_MNIST.mnist is None:
            (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
            # Combine all data to split it according to DATASET_SPLIT
            x_data = np.concatenate([x_train, x_test])
            y_data = np.concatenate([y_train, y_test])

            num_samples = len(x_data)
            num_train = int(num_samples * DATASET_SPLIT[0])
            num_val = int(num_samples * DATASET_SPLIT[1])

            # Assign training, validation, and test sets
            self.train_images, self.train_labels = x_data[:num_train], y_data[:num_train]
            self.val_images, self.val_labels = x_data[num_train:num_train+num_val], y_data[num_train:num_train+num_val]
            self.test_images, self.test_labels = x_data[num_train+num_val:], y_data[num_train+num_val:]

            # Store in mnist attribute
            Generator_MNIST.mnist = {
                'train': (self.train_images, self.train_labels),
                'val': (self.val_images, self.val_labels),
                'test': (self.test_images, self.test_labels)
            }
        
        # Assign the dataset based on the ensemble parameter
        self.images, self.labels = Generator_MNIST.mnist[ensemble]
        self.indices = np.arange(len(self.images))
    
    def __len__(self):
        return int(np.ceil(len(self.images) / MODEL_BATCH_SIZE))

    def __getitem__(self, index):
        start = index * MODEL_BATCH_SIZE
        end = min((index + 1) * MODEL_BATCH_SIZE, len(self.indices))
        batch_x = (self.images[self.indices[start:end]] / 255.0) #* 2 - 1
        #batch_y = self.labels[self.indices[start:end]] 
        return batch_x, batch_x  # x and y are the same image
    
    def on_epoch_end(self):
        # Shuffle indices for the next epoch
        np.random.shuffle(self.indices)
    
# Generator for CelebA dataset
class Generator_CelebA(Sequence):
    celeba = None
    
    def __init__(self, ensemble):
        super(Generator_CelebA, self).__init__()
        self.ensemble = ensemble
        if Generator_CelebA.celeba is None:
            self.image_filenames = [f for f in os.listdir(DATASET_FOLDER) if f.endswith('.jpg') or f.endswith('.png')]
            self.indices = np.arange(len(self.image_filenames))
            
            num_samples = len(self.image_filenames)
            num_train = int(num_samples * DATASET_SPLIT[0])
            num_val = int(num_samples * DATASET_SPLIT[1])
            num_test = int(num_samples * DATASET_SPLIT[2])
            
            # Assign training, validation, and test sets
            self.train_images = self.image_filenames[:num_train]
            self.val_images = self.image_filenames[num_train:num_train+num_val]
            self.test_images = self.image_filenames[num_train+num_val:]
            
            # Store in celeba attribute
            Generator_CelebA.celeba = {
                'train': self.train_images,
                'val': self.val_images,
                'test': self.test_images
            }
        
        # Assign the dataset based on the ensemble parameter
        self.image_filenames = Generator_CelebA.celeba[ensemble]
        self.indices = np.arange(len(self.image_filenames))
        
    def __len__(self):
        return int(np.ceil(len(self.image_filenames) / MODEL_BATCH_SIZE))

    def __getitem__(self, index):
        start = index * MODEL_BATCH_SIZE
        end = min((index + 1) * MODEL_BATCH_SIZE, len(self.image_filenames))
        batch_filenames = self.image_filenames[start:end]
        batch_images = []
        for filename in batch_filenames:
            image_path = os.path.join(self.directory_path, filename)
            with Image.open(image_path) as img:
                img = img.resize(IMAGE_SIZE)
                img = np.array(img, dtype=np.float32)
                img = (img / 255.0) #* 2 - 1
            batch_images.append(img)
        
        batch_images = np.array(batch_images)
        return batch_images, batch_images

    def on_epoch_end(self):
        np.random.shuffle(self.indices)  # Shuffle indices to mix up batches

if chosen_dataset == 'mnist' :
    train_generator = Generator_MNIST('train')
    val_generator = Generator_MNIST('val')
    test_generator = Generator_MNIST('test')
elif chosen_dataset == 'celeba' :
    train_generator = Generator_CelebA('train')
    val_generator = Generator_CelebA('val')
    test_generator = Generator_CelebA('test')
else :
    raise ValueError(f'Unknown dataset : {chosen_dataset}')

### Test the dataset

In [None]:
generator = select_generator()
idx = r.randint(0, len(generator) - 1)
batch = generator[idx][0]
print(f'Batch shape : {batch.shape}')
print(f'Batch type : {type(batch)}')
idx = r.randint(0, len(batch) - 1)
image = batch[idx]
# Plot X 
plt.figure()
plt.imshow(image)
plt.axis('off')



## Modèle

In [None]:
# Sampling function
def sampling(args):
    z_mean, z_log_var = args
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

# Encodeur MNIST
def encoder_MNIST(): 
    inputs = layers.Input(shape=IMAGE_SIZE + (1,))
    x = layers.Conv2D(32, 3, activation='leaky_relu', strides=1, padding='same')(inputs)
    x = layers.Conv2D(64, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2D(128, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2D(128, 3, activation='leaky_relu', strides=1, padding='same')(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation='leaky_relu')(x)
    x = layers.BatchNormalization()(x)
    z_mean = layers.Dense(5, name='z_mean')(x)
    z_log_var = layers.Dense(5, name='z_log_var')(x)
    z = layers.Lambda(sampling, output_shape=(2,), name='z')([z_mean, z_log_var])
    return models.Model(inputs, [z_mean, z_log_var,z], name='encoder'), z_mean, z_log_var

# Encodeur CelebA
def encoder_celebA():
    inputs = layers.Input(shape=IMAGE_SIZE + (3,))
    x = layers.Conv2D(32, 3, activation='leaky_relu', strides=2, padding='same')(inputs)
    x = layers.Conv2D(64, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2D(128, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2D(256, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Flatten()(x)
    x = layers.Dense(256, activation='leaky_relu')(x)
    z_mean = layers.Dense(50, name='z_mean')(x)
    z_log_var = layers.Dense(50, name='z_log_var')(x)
    z = layers.Lambda(sampling, output_shape=(50,), name='z')([z_mean, z_log_var])
    return models.Model(inputs, [z_mean, z_log_var, z], name='encoder_celeba'), z_mean, z_log_var

# Décodeur MNIST
def decoder_MNIST():
    latent_inputs = layers.Input(shape=(5,))
    x = layers.Dense(7*7*64, activation='leaky_relu')(latent_inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Reshape((7, 7, 64))(x)
    x = layers.Conv2DTranspose(128, 3, activation='leaky_relu', strides=1, padding='same')(x)
    x = layers.Conv2DTranspose(128, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2DTranspose(64, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2DTranspose(64, 3, activation='leaky_relu', strides=1, padding='same')(x)
    x = layers.Conv2DTranspose(1, 3, activation='sigmoid', padding='same')(x)
    # On supprime la dimension en trop a la fin
    output = layers.Reshape((28, 28))(x)
    return models.Model(latent_inputs, output, name='decoder')

# Décodeur CelebA
def decoder_celebA():
    latent_inputs = layers.Input(shape=(50,))
    x = layers.Dense(8*16*256, activation='leaky_relu')(latent_inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Reshape((8, 16, 256))(x)
    x = layers.Conv2DTranspose(128, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2DTranspose(64, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2DTranspose(32, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2DTranspose(3, 3, activation='sigmoid', padding='same')(x)
    return models.Model(latent_inputs, x, name='decoder_celeba')

# VAE model
class VAE(tf.keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        # Définir les métriques
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")    

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
    
        # Calcule la perte
        reconstruction_loss, kl_loss, total_loss = vae_loss(inputs, reconstructed, z_mean, z_log_var)
        
        # Mise à jour des trackers de métrique
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return reconstructed

    def train_step(self, data):
        x = data[0] # x and y is the batch of images

        with tf.GradientTape() as tape:
            # Forward pass
            z_mean, z_log_var, z = self.encoder(x, training=True)
            reconstructed = self.decoder(z, training=True)
            # Calcule la perte
            reconstruction_loss, kl_loss, total_loss = vae_loss(x, reconstructed, z_mean, z_log_var)
            # Backward pass
            gradients = tape.gradient(total_loss, self.trainable_variables)
            self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
            # Mise à jour des trackers de métrique
            self.total_loss_tracker.update_state(total_loss)
            self.reconstruction_loss_tracker.update_state(reconstruction_loss)
            self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result()
        }
    
    @property
    def metrics(self):
        return [self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker]

def vae_loss(y, vae_output, z_mean, z_log_var):
    reconstruction_loss = mse(y, vae_output)
    
    # Divergence KL
    kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1)
    kl_loss = tf.reduce_mean(kl_loss) * KL_FACTOR
    return reconstruction_loss, kl_loss, reconstruction_loss + kl_loss

# For Keras API
def zero_loss(y_true, y_pred):
    return tf.constant(0.0)


if chosen_dataset == 'mnist' :
    encoder, z_mean, z_log_var = encoder_MNIST()
    decoder = decoder_MNIST()
elif chosen_dataset == 'celeba' :
    encoder, z_mean, z_log_var = encoder_celebA()
    decoder = decoder_celebA()

model = VAE(encoder, decoder)
model.compile(optimizer='adam', loss=zero_loss)
model.build((None, *IMAGE_SIZE, 1 if chosen_dataset == 'mnist' else 3))
model.summary()

### Plot model

In [None]:
plot_model(model, show_dtype=True, show_shapes=True)

### Entraînement

In [None]:
# Callback early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_total_loss', patience=MODEL_PATIENCE, mode='min', restore_best_weights=True, verbose=2)

# Entraînement du VAE
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=MODEL_EPOCHS,
    callbacks=[early_stopping],
)

### Résultats

In [None]:
# Affichage des courbes d'apprentissage, avec les différentes pertes
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['reconstruction_loss'], label='reconstruction_loss')
plt.plot(history.history['kl_loss'], label='kl_loss')
plt.legend()
plt.title('Training loss')
plt.subplot(1, 2, 2)
plt.plot(history.history['val_total_loss'], label='val_total_loss')
plt.plot(history.history['val_reconstruction_loss'], label='val_reconstruction_loss')
plt.plot(history.history['val_kl_loss'], label='val_kl_loss')
plt.legend()
plt.title('Validation loss')
plt.show()

### Test du VAE

In [None]:
generator = select_generator()
idx = r.randint(0, len(generator) - 1)
batch = generator[idx][0]
idx = r.randint(0, len(batch) - 1)
sample = batch[np.newaxis, idx]

# On encode l'image
z_mean, z_log_var, z = encoder.predict(sample, verbose=0)
print("z_mean: ", z_mean)
print("z_log_var: ", z_log_var)
print("z: ", z)

# On décode l'image
x_reconstructed = decoder.predict(z, verbose=0)

# Affichage de l'image originale et de l'image reconstruite
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Image originale")
plt.imshow(sample[0], cmap='gray')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title("Image reconstruite")
plt.imshow(x_reconstructed[0], cmap='gray')
plt.axis('off')
plt.show()


#### Moyenne des prédictions

### Génération de données

In [None]:
input_cluster = input('Choose a cluster number : ')
n_samples = 9
z_samples = np.random.normal(
    mean_z_mean[input_cluster],
    np.sqrt(np.exp(mean_z_log_var[input_cluster])),
    (n_samples, 2)
)
x_samples = decoder.predict(z_samples)
plt.figure(figsize=(10, 10))
plt.title(f'Génération du cluster {input_cluster}')
for i in range(n_samples):
    plt.subplot(3, 3, i + 1)
    plt.imshow(x_samples[i].reshape(28, 28), cmap='gray')
    plt.axis('off')
plt.show()