#**Variational Autoencoder**
<font color='grey' size='1.5'> Created by Kevin Harnden for *Machine learning for proteins*, Spring 2022. 
This notebook is adapted from [François Chollet](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/generative/ipynb/vae.ipynb)

In today's in-class activity, we will be building a VAE to generate digits 0-9 using the MNIST dataset.

###Step 1. Setup

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

###Step 2. Create a sampling layer



In [None]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

###Step 3. Build the encoder

Create an encoder with two convolutional layers, followed by a fully connected layer, and then two separate fully connected layers for the mean and the variance. Finally, the latent layers should be sampled using the Sampling class defined above.

The encoder inputs should have dimensions of 28x28x1. The convolutional layers should have 32 and 64 filters, respectively, kernal sizes of 3, strides of 2, "same" padding, and use the ReLU activation function. The first fully connected layer should have 16 units and use the ReLU activation funciton. The latent layers should have 2 units each and be named "z_mean" and "z_log_var".

###Q1. Encoder

This encoder was designed for the MNIST dataset. For more complex datasets, such as images of faces or protein features, what about the encoder would need to be changed?

In [None]:
# Your code here

latent_dim = None

encoder_inputs = None
x = None
x = None
x = None
x = None
z_mean = None
z_log_var = None
z = None

encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

In [None]:
#@markdown Sample solution

latent_dim = 2

encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])

encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

###Step 4. Build the decoder

Create an decoder with a fully connected layer followed by three transposed convolutional layers.

The decoder inputs should have the same dimensions as the latent space. The fully connected layer should have dimensions of 7x7x64 and use the ReLU activation funciton. The first two convolutional layers should have 64 and 32 filters, respectively, kernal sizes of 3, strides of 2, "same" padding, and use the ReLU activation function. The final convolutional layer should have 1 filter, kernal size of 3, "same" padding, and use the sigmoid activation function.

###Q2. Decoder

What do you notice about the types of layers and their dimensions of the decoder compared to the encoder?

In [None]:
# Your code here

latent_inputs = None
x = None
x = None
x = None
x = None
decoder_outputs = None

decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

In [None]:
#@markdown Sample solution

latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)

decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

###Step 5. Define the VAE as a `Model` with a custom `train_step`

In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

###Q3. Loss function


*   What variables are used to calculate the reconstruction loss?
*   What variables are used to calculate the KL divergence loss?
*   How will these two losses affect the training?



###Step 6. Train the VAE

In [None]:
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(mnist_digits, epochs=10, batch_size=128)

###Step 7. Display a grid of sampled digits


In [None]:
import matplotlib.pyplot as plt


def plot_latent_space(vae, n=30, figsize=15):
    # display a n*n 2D manifold of digits
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent_space(vae)

###Q4. Latent space visualization

Are all of the numbers present in the latent space visualization? Do you see interpolations between different numbers?

###Step 8. Display how the latent space clusters different digit classes


In [None]:
def plot_label_clusters(vae, data, labels):
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = vae.encoder.predict(data)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()


(x_train, y_train), _ = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype("float32") / 255

plot_label_clusters(vae, x_train, y_train)

###Q5. Latent space plot

Based on the plot above, is the latent space regular or irregular? What is your reasoning?

###Q6. Sampling the latent space

Write a function that manually samples the latent space from user defined values and displays the output as an image (similar to the function defined above). What happends when values far outside of the distribution from the plot above?

In [None]:
# Your code here

In [None]:
#@markdown Sample solution

def sample_latent_space(vae, z0, z1, figsize=5):
    digit_size = 28
    digit = np.zeros((digit_size, digit_size))

    x_decoded = vae.decoder.predict([[z0,z1]])
    digit = x_decoded[0].reshape(digit_size, digit_size)

    plt.figure(figsize=(figsize, figsize))
    plt.imshow(digit, cmap="Greys_r")
    plt.show()


sample_latent_space(vae,0,0)