# Understanding Variational Autoencoder

The notebook documents my explorations in undertsanding how to train a directed probabilistic models and deep latent variable models using variational inference. Here, I explore different aspects of the paper *Autoencoding Variational Bayes*, and the book *An Introduction to Variational AutoEncoders* both written/coauthored by ***Diederik P. Kingma and Max Welling***. 
<br><br>
My readings and notes on both materials are documented in this[note](here) and my [blog](blog_link). These spaces are where I write for personal recollection, and I hope any one who comes across them finds them useful for quick intuitive understanding of this topic.

This notebook is in two sections
1) An exploration and understanding of the concepts and objective of the Variational AutoEncoder.
   - Mathematically and visually look into understanding the model parameters, prior and posterior distributions
   - Explore ELBO - Evidence Lower Bound
   - Reparametrization Trick
   - Optimization methods for updating parameters of the model.  <br><br>
  
2) Understand how to train a Variational Autoencoder using two examples
   - Bernoulli VAE with a gaussian prior.
   - Gaussian VAE with a gaussian prior.

# 00 - Set up

In [None]:
# import dependencies
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
import numpy as np

## Helper Functions

In [None]:
def plot_samples(arr, num_rows = 2, num_cols = 6, title = "Sample Dataset", cmap = None, shuffle = True):
    fig, ax = plt.subplots(num_rows, num_cols, figsize = (num_cols * 2, num_rows * 2))
    if shuffle:
        seed = np.random.randint(200)
        np.random.shuffle(arr)
    selection = np.random.choice(np.arange(0, len(arr)), num_rows * num_cols, replace = False)
    ax = ax.flatten()
    for i in range(len(ax)):
        ax[i].imshow(arr[i], cmap = cmap)
        ax[i].axis('off')
    fig.suptitle(title)
    plt.tight_layout()

# 02 - Bernoulli VAE

## 2A - Bernoulli VAE

### Load and Preprocess Data

In [None]:
# use this for bernoulli data
mnist_train, mnist_val = tf.keras.datasets.mnist.load_data()

In [None]:
print('Train set:', mnist_train[0].shape[0])
print('Validation set:', mnist_val[0].shape[0])

In [None]:
def preprocess_mnist(data):
    data = np.expand_dims(data, axis = -1)
    data = data / 255
    data = np.where(data < 0.2, 0, 1)
    return data

In [None]:
# rescale images from 0 to 1
mnist_train_images = preprocess_mnist(mnist_train[0])
mnist_val_images = preprocess_mnist(mnist_val[0])

In [None]:
mnist_train_images.shape

In [None]:
mnist_train_y = (
    {'reconstruction' : mnist_train_images},
    {'latent' : [np.zeros(mnist_train_images.shape[0]), np.zeros(mnist_train_images.shape[0])]}, 
)
mnist_val_y = (
    {'reconstruction' : mnist_val_images}, 
    {'latent' : [np.zeros(mnist_val_images.shape[0]), np.zeros(mnist_val_images.shape[0])]}
)

In [None]:
plot_samples(mnist_train_images, num_cols = 8, num_rows = 2, cmap = 'Greys_r')

### Define VAE architecture

In [None]:
INPUT_SHAPE = mnist_train_images.shape[1:]
LATENT_DIM = 4
INPUT_SHAPE

In [None]:
def build_encoder(input_shape, latent_dim = 2, downsample = 2, filter_size = 32, kernel_size = 3, padding = 'same'): 
    input = tf.keras.layers.Input(shape = input_shape, name = 'encoder_input')
    x = input
    for i in range(downsample):
        x = tf.keras.layers.Conv2D(filter_size * 2**i, kernel_size = kernel_size, padding = padding, strides = 2, use_bias = False)(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(latent_dim * 2)(x)
    mu, logvar = tf.keras.layers.Lambda(lambda x: tf.split(x, 2, axis = -1))(x)
    model = tf.keras.models.Model(inputs = input, outputs = [mu, logvar], name = 'encoder')
    return model
        

In [None]:
def build_decoder(latent_dim = 3, upsample = 2, base_size = 7, filter_size = 32, padding = 'same', kernel_size = 3):
    input = tf.keras.layers.Input(shape = (latent_dim,))
    x = input
    x = tf.keras.layers.Dense(base_size * base_size * filter_size, activation = 'relu')(x)
    x = tf.keras.layers.Reshape([base_size, base_size, filter_size])(x)
    for i in range(upsample, 0, -1):
        x = tf.keras.layers.Conv2DTranspose(filter_size * i, kernel_size = kernel_size, padding = padding, strides = 2, use_bias = False)(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Conv2DTranspose(1, kernel_size = kernel_size, padding = padding, strides = 1)(x)
    model = tf.keras.models.Model(inputs = input, outputs = x,  name = 'decoder')
    return model

In [None]:
encoder = build_encoder(INPUT_SHAPE, LATENT_DIM)
encoder.summary()

In [None]:
decoder = build_decoder(LATENT_DIM)
decoder.summary()

In [None]:
class VAEncoder(tf.keras.models.Model):
    def __init__(self, input_shape, latent_dim, encoder_params = {}, decoder_params = {}, apply_sigmoid = False, **kwargs):
        super().__init__(**kwargs)
        self.input_dim = input_shape
        self.latent_dim = latent_dim
        assert 'input_shape' not in encoder_params, 'input_shape must not be part of encoder params'
        assert 'latent_dim' not in encoder_params, 'latent_dim must not be part of encoder params'
        self.encoder = build_encoder(self.input_dim, self.latent_dim, **encoder_params)
        self.decoder = build_decoder(self.latent_dim, **decoder_params)
        self.apply_sigmoid = apply_sigmoid

    def encode(self, inputs):
        mu, logvar = self.encoder(inputs)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        eps = tf.random.normal(shape=tf.shape(mu))
        return mu + tf.exp(logvar * 0.5) * eps

    def decode(self, inputs, apply_sigmoid = False):
        out = self.decoder(inputs)
        if apply_sigmoid or self.apply_sigmoid:
            out = tf.sigmoid(out)
        return out

    def call(self, inputs):
        mu, logvar = self.encode(inputs)
        z = self.reparameterize(mu, logvar)
        x = self.decode(z)
        # return x, [mu, logvar]
        return ({'reconstruction' : x}, {'latent' : [mu, logvar]})

    @tf.function
    def sample(self,  eps = None, num_samples = 20):
        if eps is None:
            eps = tf.random.normal(shape=(num_samples, self.latent_dim))
        return self.decode(eps, apply_sigmoid=True)


In [None]:
def kl_divergence():
    def inner_func(y_true, y_pred):
        return 0.5 * tf.reduce_sum(tf.exp(y_pred[1]) + tf.square(y_pred[0]) - 1 - y_pred[1])
    return inner_func

In [None]:
vae_model = VAEncoder(INPUT_SHAPE, LATENT_DIM, apply_sigmoid = True)
vae_model.build([None] + list(INPUT_SHAPE))
vae_model.compile(
    loss = (
        {'reconstruction': tf.keras.losses.BinaryCrossentropy(from_logits = False, reduction = 'sum')},
        {'latent': kl_divergence()}
    ),
    loss_weights = ({'reconstruction': 1}, {'latent': 3}),
    metrics = ({'reconstruction' : ['mae']}, {'latent': [None, None]}),
    optimizer = tf.keras.optimizers.Adam(learning_rate = 1e-3)
)

In [None]:
history = vae_model.fit(mnist_train_images, mnist_train_y, validation_data = (mnist_val_images, mnist_val_y), epochs = 50, batch_size = 32)

In [None]:
plt.plot(history.history['output_2_1_loss'])
plt.plot(history.history['val_output_2_1_loss'])

In [None]:
def generate_and_plot_samples(model, hard_sigmoid = 0):
    np.random.shuffle(mnist_val_images)
    arr = mnist_val_images[:20]
    mu, logvar = model.encode(arr)
    z = model.reparameterize(mu, logvar)
    samples = model.sample(eps = z).numpy()
    if hard_sigmoid:
        samples = np.where(samples <= hard_sigmoid, 0, 1)
    plot_samples(arr, num_cols = 10, num_rows = 2, cmap = 'Greys_r', shuffle = False, title = 'Original Images')
    plot_samples(samples, num_cols = 10, num_rows = 2, cmap = 'Greys_r', shuffle = False, title = 'Generated Images')

In [None]:
generate_and_plot_samples(vae_model, hard_sigmoid = 0.5)

In [None]:
samples = vae_model.sample(num_samples = 30).numpy()
samples = np.where(samples <= 0.5, 0, 1)
plot_samples(samples, num_cols = 10, num_rows = 3, cmap = 'Greys_r', title = 'Sampled Images')

## 02 - Gaussian VAE

In [None]:
# data 2, use the cartoons dataset for this instead. 
# to demonstrate gaussian output
# in another notebook to do gaussian covariate prior 
# to show that there is a spatial relationship to pixels really and pixels that are one level apart can be related