### Imports

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras import backend as K
from tensorflow.keras.losses import binary_crossentropy
import numpy as np
import matplotlib.pyplot as plt
import random as r
from sklearn.cluster import KMeans


### Dataset

In [2]:
# Chargement du jeu de données MNIST
mnist = tf.keras.datasets.mnist
split = 0.2
kl_factor = 1e-3

# Séparation des données en ensembles d'entraînement et de test
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalisation des valeurs des pixels pour qu'elles soient entre 0 et 1
x_train, x_test = x_train / 255.0, x_test / 255.0

x_train, x_val = x_train[int(len(x_train)*split):], x_train[:int(len(x_train)*split)]
y_train, y_val = y_train[int(len(y_train)*split):], y_train[:int(len(y_train)*split)]


### Modèle

In [None]:
# Encodeur
def encoder_model():
    # Échantillonnage de l'espace latent
    def sampling(args):
        z_mean, z_log_var = args
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        sample = z_mean + tf.exp(0.5 * z_log_var) * epsilon
        return sample
    
    inputs = layers.Input(shape=(28, 28, 1))
    x = layers.Conv2D(32, 3, activation='leaky_relu', strides=1, padding='same')(inputs)
    x = layers.Conv2D(64, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2D(128, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2D(128, 3, activation='leaky_relu', strides=1, padding='same')(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation='leaky_relu')(x)
    x = layers.BatchNormalization()(x)
    z_mean = layers.Dense(2, name='z_mean')(x)
    z_log_var = layers.Dense(2, name='z_log_var')(x)
    z = layers.Lambda(sampling, output_shape=(2,), name='z')([z_mean, z_log_var])
    return models.Model(inputs, [z_mean, z_log_var,z], name='encoder'), z_mean, z_log_var

encoder, z_mean, z_log_var = encoder_model()

# Décodeur
def decoder_model():
    latent_inputs = layers.Input(shape=(2,))
    x = layers.Dense(7*7*64, activation='leaky_relu')(latent_inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Reshape((7, 7, 64))(x)
    x = layers.Conv2DTranspose(128, 3, activation='leaky_relu', strides=1, padding='same')(x)
    x = layers.Conv2DTranspose(128, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2DTranspose(64, 3, activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.Conv2DTranspose(64, 3, activation='leaky_relu', strides=1, padding='same')(x)
    x = layers.Conv2DTranspose(1, 3, activation='sigmoid', padding='same')(x)
    # On supprime la dimension en trop a la fin
    output = layers.Reshape((28, 28))(x)
    return models.Model(latent_inputs, output, name='decoder')

decoder = decoder_model()

class VAE(tf.keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        # Définir les métriques
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")    

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
    
        # Calcule la perte
        reconstruction_loss, kl_loss, total_loss = vae_loss(inputs, reconstructed, z_mean, z_log_var)
        
        # Mise à jour des trackers de métrique
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return reconstructed

    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            # Forward pass
            z_mean, z_log_var, z = self.encoder(x, training=True)
            reconstructed = self.decoder(z, training=True)
            # Calcule la perte
            reconstruction_loss, kl_loss, total_loss = vae_loss(x, reconstructed, z_mean, z_log_var)
            # Backward pass
            gradients = tape.gradient(total_loss, self.trainable_variables)
            self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
            # Mise à jour des trackers de métrique
            self.total_loss_tracker.update_state(total_loss)
            self.reconstruction_loss_tracker.update_state(reconstruction_loss)
            self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result()
        }
    
    @property
    def metrics(self):
        return [self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker]

def vae_loss(y, vae_output, z_mean, z_log_var):
    reconstruction_loss = binary_crossentropy(y, vae_output)
    reconstruction_loss = tf.reduce_mean(reconstruction_loss)
    
    # Divergence KL
    kl_loss = -0.5 * tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1)
    kl_loss = tf.reduce_mean(kl_loss) * kl_factor
    return reconstruction_loss, kl_loss, reconstruction_loss + kl_loss

vae = VAE(encoder, decoder)
# Astuce pour satisfaire l'API de Keras
def zero_loss(y_true, y_pred):
    return tf.constant(0.0)

vae.compile(optimizer='adam', loss=zero_loss)
vae.summary()

### Entraînement

In [None]:
# Callback early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_total_loss', patience=1, mode='min', restore_best_weights=True, verbose=2)

# Entraînement du VAE
history = vae.fit(x_train, x_train, epochs=30, batch_size=64, callbacks=[early_stopping], validation_data=(x_val, x_val))

### Résultats

In [None]:
# Affichage des courbes d'apprentissage, avec les différentes pertes
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['reconstruction_loss'], label='reconstruction_loss')
plt.plot(history.history['kl_loss'], label='kl_loss')
plt.legend()
plt.title('Training loss')
plt.subplot(1, 2, 2)
plt.plot(history.history['val_total_loss'], label='val_total_loss')
plt.plot(history.history['val_reconstruction_loss'], label='val_reconstruction_loss')
plt.plot(history.history['val_kl_loss'], label='val_kl_loss')
plt.legend()
plt.title('Validation loss')
plt.show()

### Test du VAE

In [None]:
random_input = False
input_type = input("Quel ensemble d'image voulez-vous visualiser ? (train/val/test) : ")
if input_type == "val":
    # On choisit un échantillon aléatoire de l'ensemble de validation
    idx = r.randint(0, len(x_val)) if random_input else 0
    x_sample = x_val[idx]
    x_sample = np.expand_dims(x_sample, axis=0)
elif input_type == "train" or input_type == "":
    # On choisit un échantillon aléatoire de l'ensemble d'entraînement
    idx = r.randint(0, len(x_train)) if random_input else 0
    x_sample = x_train[idx]
    x_sample = np.expand_dims(x_sample, axis=0)
elif input_type == "test":
    # On choisit un échantillon aléatoire de l'ensemble de test
    idx = r.randint(0, len(x_test)) if random_input else 0
    x_sample = x_test[idx]
    x_sample = np.expand_dims(x_sample, axis=0)
else:
    print("Choix invalide")
    exit()

# On encode l'image
z_mean, z_log_var, z = encoder.predict(x_sample, verbose=0)
print("z_mean: ", z_mean)
print("z_log_var: ", z_log_var)
print("z: ", z)

# On décode l'image
x_reconstructed = decoder.predict(z, verbose=0)

# Affichage de l'image originale et de l'image reconstruite
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Image originale")
plt.imshow(x_sample[0], cmap='gray')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title("Image reconstruite")
plt.imshow(x_reconstructed[0], cmap='gray')
plt.axis('off')
plt.show()


### Map du dataset
Carte qui montre la distribution du dataset dans l'espace latent.

In [None]:
x_train_encoded = encoder.predict(x_train)
x_train_z_mean = x_train_encoded[0]
x_train_z_log_var = x_train_encoded[1]
x_train_z = x_train_encoded[2]
# Plot de z_mean
plt.figure(figsize=(10, 10))
plt.scatter(x_train_z_mean[:, 0], x_train_z_mean[:, 1], c=y_train, cmap='viridis')
plt.title('Projection des données MNIST dans l\'espace latent (z_mean)')
plt.colorbar()
# Plot de z_log_var
plt.figure(figsize=(10, 10))
plt.scatter(x_train_z_log_var[:, 0], x_train_z_log_var[:, 1], c=y_train, cmap='viridis')
plt.title('Projection des données MNIST dans l\'espace latent (z_log_var)')
plt.colorbar()
# Plot de z
plt.figure(figsize=(10, 10))
plt.scatter(x_train_z[:, 0], x_train_z[:, 1], c=y_train, cmap='viridis')
plt.title('Projection des données MNIST dans l\'espace latent (z)')
plt.colorbar()
plt.show()


### Identification de clusters
Le but étant d'obtenir les means et les stds des clusters (classes) afin de les utiliser pour la génération de données.

#### K-means

In [None]:

kmeans = KMeans(n_clusters=10)
kmeans.fit(x_train_z)

# Récupérer les labels des clusters
labels = kmeans.labels_

# Afficher les clusters
plt.figure(figsize=(10, 10))
plt.scatter(x_train_z[:, 0], x_train_z[:, 1], c=labels, cmap='tab10')
plt.title('K-means clustering des données MNIST dans l\'espace latent (z)')
plt.colorbar()
plt.show()

#### Moyenne des prédictions

In [None]:
# Initialiser des dictionnaires pour stocker les moyennes
mean_z_mean = {}
mean_z_log_var = {}

# Calculer la moyenne pour chaque classe
for digit in range(10):
    indices = np.where(y_train == digit)
    mean_z_mean[digit] = np.mean(x_train_z_mean[indices], axis=0)
    mean_z_log_var[digit] = np.mean(x_train_z_log_var[indices], axis=0)

# Afficher les moyennes
print("Moyenne de z_mean pour chaque chiffre:")
for digit, mean in mean_z_mean.items():
    print(f"Chiffre {digit}: {mean}")

print("\nMoyenne de z_log_var pour chaque chiffre:")
for digit, mean in mean_z_log_var.items():
    print(f"Chiffre {digit}: {mean}")