# 0727 VAE Tutorial - Assignment
### TA Taewook Nam (namsan@kaist.ac.kr)

Train two VAE with low z dimension (ex. 2) and high z dimension (ex. 10) using given code below, then  
1) Visualize several decoded images and compare their qualities.  
2) Describe the reason of observed difference.

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from tensorflow.keras.layers import Dense, Flatten, Conv2D, Reshape
from tensorflow.keras.models import Sequential
from tensorflow.keras import Model

from tqdm import trange

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
tf.config.list_physical_devices('GPU')

### 0) Prepare model & data (same as in tutorial)

In [None]:
class VAE(Model):
  h_dim = 500

  def __init__(self, x_shape, z_dim):
    super().__init__()

    x_dim = np.prod(x_shape)
    self.z_dim = z_dim

    self.encoder = Sequential([
        Flatten(),
        Dense(self.h_dim, activation='relu'),
        Dense(self.h_dim, activation='relu'),
    ])
    self.mu_dense = Dense(z_dim)
    self.sigma_dense = Dense(z_dim, activation='softplus')

    self.decoder = Sequential([
        Dense(self.h_dim, activation='relu'),
        Dense(self.h_dim, activation='relu'),
        Dense(x_dim, activation='sigmoid'),
        Reshape(x_shape)
    ])

  def encode(self, x):
    h = self.encoder(x)
    z_mu = self.mu_dense(h)
    z_sigma = self.sigma_dense(h)
    return z_mu, z_sigma

  def decode(self, z):
    return self.decoder(z)


@tf.function
def compute_elbo(x, x_reconst, z_mu, z_sigma):
  log_likelihood = tf.reduce_sum(
    x*tf.math.log(x_reconst+1e-6) + (1-x)*tf.math.log(1-x_reconst+1e-6),
    axis=(1,2)
  )
  kl = 0.5 * tf.reduce_sum(
    z_mu**2 + z_sigma**2 - tf.math.log(z_sigma**2) - 1,
    axis=1
  )
  elbo = tf.reduce_mean(log_likelihood - kl)
  return elbo

In [None]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = tf.cast(x_train / 255.0, tf.float32)
x_test = tf.cast(x_test / 255.0, tf.float32)

train_ds = tf.data.Dataset.from_tensor_slices(x_train).shuffle(10000).batch(100)
test_ds = tf.data.Dataset.from_tensor_slices(x_test).batch(100)### 3) Train VAE

### 1)Train VAE with low z dimension

In [None]:
vae2 = VAE((28, 28), 2)
optimizer = tf.keras.optimizers.Adam()

for epoch_i in trange(100):
  for x in train_ds:
    with tf.GradientTape() as tape:
      z_mu, z_sigma = vae2.encode(x)
      z_dist = tfp.distributions.Normal(z_mu, z_sigma)
      z = z_dist.sample()
      x_reconst = vae2.decode(z)
        
      elbo = compute_elbo(x, x_reconst, z_mu, z_sigma)
      loss = -elbo
    gradients = tape.gradient(loss, vae2.trainable_variables)
    optimizer.apply_gradients(zip(gradients, vae2.trainable_variables))

  if (epoch_i+1) % 5 == 0:
    vae2.save_weights(f'ckpt2/{epoch_i + 1}')

### 2)Train VAE with high z dimension

In [None]:
vae10 = VAE((28, 28), 10)
optimizer = tf.keras.optimizers.Adam()

for epoch_i in trange(100):
  for x in train_ds:
    with tf.GradientTape() as tape:
      z_mu, z_sigma = vae10.encode(x)
      z_dist = tfp.distributions.Normal(z_mu, z_sigma)
      z = z_dist.sample()
      x_reconst = vae10.decode(z)
        
      elbo = compute_elbo(x, x_reconst, z_mu, z_sigma)
      loss = -elbo
    gradients = tape.gradient(loss, vae10.trainable_variables)
    optimizer.apply_gradients(zip(gradients, vae10.trainable_variables))

  if (epoch_i+1) % 5 == 0:
    vae10.save_weights(f'ckpt10/{epoch_i + 1}')

### 3) Compare decoded images

In [None]:
idx = 3

vae2.load_weights(f'ckpt2/100')
vae10.load_weights(f'ckpt10/100')

fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
ax1.imshow(x_test[idx])

z_mu, _ = vae2.encode(x_test)
x_reconst = vae2.decode(z_mu)
ax2.imshow(x_reconst[idx])

z_mu, _ = vae10.encode(x_test)
print(z_mu.shape)
x_reconst = vae10.decode(z_mu)
ax3.imshow(x_reconst[idx])