# Course: Deep Learning
# Author: Sandro Camargo sandrocamargo@unipampa.edu.br
# Variational Autoencoder (VAE) for MNIST

Dataset: '/content/sample_data/' mnist

Dataset description: https://www.tensorflow.org/datasets/catalog/mnist

To open this code in your Google Colab environment, [click here](https://colab.research.google.com/github/Sandrocamargo/deep-learning/blob/master/dl_class09_variationalautoencoder.ipynb).


# Loading libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, Model

# Loading dataset

In [None]:
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_train = np.expand_dims(x_train, -1)

# Setting variational autoencoder

In [None]:
latent_dim = 2   # muito melhor para CNN

# ===========================================
# ENCODER
# ===========================================
inputs = layers.Input(shape=(28,28,1))

x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation="relu")(x)

z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)

def sampling(args):
    zm, zv = args
    eps = tf.random.normal(shape=(tf.shape(zm)[0], latent_dim))
    return zm + tf.exp(0.5 * zv) * eps

z = layers.Lambda(sampling)([z_mean, z_log_var])

encoder = Model(inputs, [z_mean, z_log_var, z])

# ===========================================
# DECODER
# ===========================================
latent_inputs = layers.Input(shape=(latent_dim,))
x = layers.Dense(7*7*64, activation="relu")(latent_inputs)
x = layers.Reshape((7,7,64))(x)

x = layers.Conv2DTranspose(64, 3, strides=2, activation="relu", padding="same")(x)
x = layers.Conv2DTranspose(32, 3, strides=2, activation="relu", padding="same")(x)
outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)

decoder = Model(latent_inputs, outputs)

# ===========================================
# VAE class
# ===========================================
class VAE(Model):
    def __init__(self, encoder, decoder, beta=1.0):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.beta = beta

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)

            # BCE por pixel (muito melhor para MNIST)
            recon_loss = tf.reduce_sum(
                tf.keras.losses.binary_crossentropy(data, reconstruction),
                axis=(1,2)
            )

            kl_loss = -0.5 * tf.reduce_sum(
                1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),
                axis=1
            )

            loss = tf.reduce_mean(recon_loss + self.beta * kl_loss)

        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

        return {"loss": loss, "recon": tf.reduce_mean(recon_loss), "kl": tf.reduce_mean(kl_loss)}

vae = VAE(encoder, decoder, beta=1.0)

# Compiling and training

In [None]:
vae.compile(optimizer=tf.keras.optimizers.Adam(1e-3))
vae.summary()

# ===========================================
# Treinamento
# ===========================================
history = vae.fit(x_train, epochs=50, batch_size=128)

# Viewing loss

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['recon'])
plt.title('Loss Function')
plt.ylabel('Categorical Cross Entropy')
plt.xlabel('Epoch')
plt.legend(['Training Loss','Reconstruction Loss'], loc='upper right')
plt.savefig("trainingerror.pdf")
plt.show()

In [None]:
plt.plot(history.history['kl'])
plt.title('KL Function')
plt.ylabel('Categorical Cross Entropy')
plt.xlabel('Epoch')
plt.legend(['KL'], loc='lower right')
plt.savefig("trainingkl.pdf")
plt.show()

# Viewing results

In [None]:
# ===========================================
# Generate 100 images
# ===========================================

import os
os.makedirs("synthetic_digits", exist_ok=True)

for digit in range(10):
    print(f"Gerando imagens ...")
    fig, axes = plt.subplots(1, 10, figsize=(15, 2))

    for i in range(10):
        # Sample from latent space
        z_sample = np.random.normal(size=(1, latent_dim))
        img = decoder(z_sample).numpy().reshape(28, 28)

        axes[i].imshow(img, cmap="gray")
        axes[i].axis("off")

        # Save file
        plt.imsave(f"synthetic_digits/{digit}_{i}.png", img, cmap="gray")

    plt.show()

print("Imagens salvas em synthetic_digits/")