<a href="https://colab.research.google.com/github/HSE-LAMBDA/MLDM-2021/blob/master/12-vae/MLDM_2021_seminar12_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Datasets

In [None]:
import numpy as np
import tensorflow as tf
print(tf.__version__)
import tensorflow_datasets as tfds
from tensorflow import keras
from tqdm import tqdm

from PIL import Image


In [None]:
lfw = tfds.image_classification.LFW()
lfw.download_and_prepare()
ds = lfw.as_dataset()

In [None]:
def get_img(x):
  return x['image'][80:-80,80:-80]

data = np.array([
  np.array(Image.fromarray(img.numpy()).resize((36, 36)))
  for img in tqdm(ds['train'].map(get_img))
])

In [None]:
data.shape

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(data[:25].reshape(5, 5, 36, 36, 3).transpose((0, 2, 1, 3, 4)).reshape(5 * 36, 5 * 36, 3));

In [None]:
# dgts = False
# X_train = data.astype('float32') / 255
# print(X_train.min(), X_train.max(), X_train.dtype)
# image_size = 36
# n_channels = 3

In order to perform the visualization we'd better use MNIST dataset, however you may pick any data you prefer.

In [None]:
dgts = True
(x_train, y_train), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
X_train = np.expand_dims(mnist_digits, -1).astype("float32") / 255
image_size = 28
n_channels = 1

# VAE

In [None]:
ll = tf.keras.layers

LATENT_DIM = 2 if dgts else 32 

decoder = tf.keras.Sequential([
  ll.Dense(128, input_shape=(LATENT_DIM,), activation='relu'),
  ll.Dense(128, activation='relu'),
  ll.Dense(image_size * image_size * n_channels, activation='sigmoid'),
  ll.Reshape((image_size, image_size, n_channels)),
])

encoder_base = tf.keras.Sequential([
  ll.Reshape((image_size * image_size * n_channels,), input_shape=(image_size, image_size, n_channels)),
  ll.Dense(128, activation='relu'),
  ll.Dense(128, activation='relu')
])
latent_mu = ll.Dense(LATENT_DIM, activation=None)(encoder_base.output)
latent_logsigma = ll.Dense(LATENT_DIM, activation=None)(encoder_base.output)
encoder = tf.keras.Model(inputs=encoder_base.inputs, outputs=[latent_mu, latent_logsigma])

decoder.summary()
encoder.summary()

In [None]:
def gen_images(mu, logsigma):
  return decoder(tf.random.normal(shape=mu.shape) * tf.exp(logsigma) + mu)

# @tf.function decorator below compiles the function
# it decorates into a static graph. This improves the performance
# but there are some pitfalls one should be aware of when using it,
# check out https://www.tensorflow.org/guide/function
# for more details
@tf.function
def forward(batch):
  real = batch

  mu, logsigma = encoder(real)
  fake = gen_images(mu, logsigma)

  loss_mse = tf.reduce_sum((real - fake)**2, axis=(1, 2, 3))
  loss_KL = tf.reduce_sum(-logsigma + 0.5 * (mu**2 + tf.exp(2 * logsigma) - 1), axis=1)
  return tf.reduce_mean(loss_mse + 0.3 * loss_KL)

opt_g = tf.optimizers.Adam()

@tf.function
def gen_step(batch):
  with tf.GradientTape() as t:
    g_loss = forward(batch)
  grads = t.gradient(g_loss, encoder.trainable_variables + decoder.trainable_variables)
  opt_g.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))
  return g_loss


In [None]:
from IPython.display import clear_output
from tqdm import trange

In [None]:
def plot_mn(images, m=5, n=5, shuffle=False):
    if shuffle:
        images = images[np.random.permutation(len(images))[:m * n]]
    _, h, w, _ = images.shape
    images = images[:m*n].reshape(m, n, *images.shape[1:])
    images = images.transpose(0, 2, 1, 3, 4).reshape(m * h, n * w)
    plt.imshow(images)

In [None]:
BATCH_SIZE = 256

N_EPOCHS = 10 if dgts else 50 

losses = []
for i_ep in range(N_EPOCHS):
  shuffle_ids = np.random.choice(len(X_train), len(X_train), replace=False)
  epoch_loss = 0
  for i_img in trange(0, len(X_train), BATCH_SIZE):
    batch = X_train[shuffle_ids][i_img:i_img + BATCH_SIZE]
    epoch_loss += gen_step(batch).numpy() * len(batch)

  epoch_loss /= len(X_train)
  losses.append(epoch_loss)

  opt_g.learning_rate.assign(opt_g.learning_rate * 0.99)

  imgs = (gen_images(tf.zeros(shape=(25, LATENT_DIM)),
                    tf.zeros(shape=(25, LATENT_DIM))).numpy() * 255).astype('uint8')
  clear_output(wait=True)
  plt.figure(figsize=(12, 7))
  plt.subplot(1, 2, 1)
  if dgts:
    plot_mn(imgs)
  else:
    plt.imshow(imgs.reshape((5, 5, 36, 36, 3)).transpose(0, 2, 1, 3, 4).reshape(36 * 5, 36 * 5, 3))
  plt.subplot(1, 2, 2)
  plt.plot(losses)
  plt.yscale('log')
  plt.xlabel('epoch')
  plt.ylabel('loss')
  plt.show()
  print("Done with epoch #", i_ep)

# Visualization

In [None]:
codes = encoder.predict(X_train)
reco = decoder.predict(codes[0])

In [None]:
shuffle_ids = np.random.choice(len(X_train), len(X_train), replace=False)

plt.figure(figsize=(12, 6), dpi=100)
if dgts:
  plt.subplot(1, 2, 1)
  plt.imshow(X_train[shuffle_ids][:25].reshape(5, 5, 28, 28, 1).transpose((0, 2, 1, 3, 4)).reshape(5 * 28, 5 * 28));
  plt.title('Train')

  plt.subplot(1, 2, 2)
  plt.imshow(reco[shuffle_ids][:25].reshape(5, 5, 28, 28, 1).transpose((0, 2, 1, 3, 4)).reshape(5 * 28, 5 * 28));
  plt.title('Reconstructed');

else:
  plt.subplot(1, 2, 1)
  plt.imshow(data[shuffle_ids][:25].reshape(5, 5, 36, 36, 3).transpose((0, 2, 1, 3, 4)).reshape(5 * 36, 5 * 36, 3));
  plt.title('Train')

  plt.subplot(1, 2, 2)
  plt.imshow(reco[shuffle_ids][:25].reshape(5, 5, 36, 36, 3).transpose((0, 2, 1, 3, 4)).reshape(5 * 36, 5 * 36, 3));
  plt.title('Reconstructed');

In [None]:
if dgts:
  z_mean, _ = encoder.predict(np.expand_dims(x_train, -1).astype("float32") / 255)
  plt.figure(figsize=(10, 10))
  plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_train, cmap=plt.get_cmap('jet', 10))
  plt.colorbar()
  plt.xlabel("z[0]")
  plt.ylabel("z[1]")
  plt.show()


In [None]:
# display a n*n 2D manifold of digits
coord = 1.
n = 10
figure = np.zeros((image_size * n, image_size * n))
# linearly spaced coordinates corresponding to the 2D plot
# of digit classes in the latent space
grid_x = np.linspace(-coord, coord, n)
grid_y = np.linspace(-coord, coord, n)[::-1]

for i, yi in enumerate(grid_y):
    for j, xi in enumerate(grid_x):
        z_sample = np.array([[xi, yi]])
        x_decoded = decoder.predict(z_sample)
        digit = x_decoded[0].reshape(image_size, image_size)
        figure[
            i * image_size : (i + 1) * image_size,
            j * image_size : (j + 1) * image_size,
        ] = digit

plt.figure(figsize=(15, 15))
start_range = image_size // 2
end_range = n * image_size + start_range
pixel_range = np.arange(start_range, end_range, image_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap="Greys_r")
plt.show()

Repeat the training procedure, but try to decrease the weight of the KL-loss.

In [None]:
if dgts:
  low_w_zmean, _ = encoder.predict(np.expand_dims(x_train, -1).astype("float32") / 255)
  plt.figure(figsize=(10, 10))
  plt.scatter(low_w_zmean[:, 0], low_w_zmean[:, 1], c=y_train, cmap=plt.get_cmap('jet', 10))
  plt.colorbar()
  plt.xlabel("z[0]")
  plt.ylabel("z[1]")
  plt.show()
