<a href="https://colab.research.google.com/github/HSE-LAMBDA/MLDM-2020/blob/master/day-12/MLDM_2020_seminar12_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import tensorflow as tf
print(tf.__version__)
import tensorflow_datasets as tfds
from tqdm import tqdm

from PIL import Image

In [None]:
lfw = tfds.image_classification.LFW()
lfw.download_and_prepare()
ds = lfw.as_dataset()

In [None]:
def get_img(x):
  return x['image'][80:-80,80:-80]

data = np.array([
  np.array(Image.fromarray(img.numpy()).resize((36, 36)))
  for img in tqdm(ds['train'].map(get_img))
])

In [None]:
data.shape

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(data[:25].reshape(5, 5, 36, 36, 3).transpose((0, 2, 1, 3, 4)).reshape(5 * 36, 5 * 36, 3));

In [None]:
X_train = data.astype('float32') / 255
print(X_train.min(), X_train.max(), X_train.dtype)

In [None]:
ll = tf.keras.layers

LATENT_DIM = 32

decoder = tf.keras.Sequential([
  ll.Dense(128, input_shape=(LATENT_DIM,), activation='relu'),
  ll.Dense(128, activation='relu'),
  ll.Dense(36 * 36 * 3, activation='sigmoid'),
  ll.Reshape((36, 36, 3)),
])

encoder_base = tf.keras.Sequential([
  ll.Reshape((36 * 36 * 3,), input_shape=(36, 36, 3)),
  ll.Dense(128, activation='relu'),
  ll.Dense(128, activation='relu')
])
latent_mu = ll.Dense(LATENT_DIM, activation=None)(encoder_base.output)
latent_logsigma = ll.Dense(LATENT_DIM, activation=None)(encoder_base.output)
encoder = tf.keras.Model(inputs=encoder_base.inputs, outputs=[latent_mu, latent_logsigma])

decoder.summary()
encoder.summary()

In [None]:
def gen_images(mu, logsigma):
  return decoder(tf.random.normal(shape=mu.shape) * tf.exp(logsigma) + mu)

# @tf.function decorator below compiles the function
# it decorates into a static graph. This improves the performance
# but there are some pitfalls one should be aware of when using it,
# check out https://www.tensorflow.org/guide/function
# for more details
@tf.function
def forward(batch):
  real = batch

  mu, logsigma = encoder(real)
  fake = gen_images(mu, logsigma)

  loss_mse = tf.reduce_sum((real - fake)**2, axis=(1, 2, 3))
  loss_KL = tf.reduce_sum(-logsigma + 0.5 * (mu**2 + tf.exp(2 * logsigma) - 1), axis=1)
  return tf.reduce_mean(loss_mse + 0.2 * loss_KL)

opt_g = tf.optimizers.Adam()

@tf.function
def gen_step(batch):
  with tf.GradientTape() as t:
    g_loss = forward(batch)
  grads = t.gradient(g_loss, encoder.trainable_variables + decoder.trainable_variables)
  opt_g.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))
  return g_loss


In [None]:
from IPython.display import clear_output
from tqdm import trange

In [None]:
BATCH_SIZE = 256

N_EPOCHS = 100

losses = []
for i_ep in range(N_EPOCHS):
  shuffle_ids = np.random.choice(len(X_train), len(X_train), replace=False)
  epoch_loss = 0
  for i_img in trange(0, len(X_train), BATCH_SIZE):
    batch = X_train[shuffle_ids][i_img:i_img + BATCH_SIZE]
    epoch_loss += gen_step(batch).numpy() * len(batch)

  epoch_loss /= len(X_train)
  losses.append(epoch_loss)

  opt_g.learning_rate.assign(opt_g.learning_rate * 0.99)

  imgs = (gen_images(tf.zeros(shape=(25, LATENT_DIM)),
                    tf.zeros(shape=(25, LATENT_DIM))).numpy() * 255).astype('uint8')
  clear_output(wait=True)
  plt.figure(figsize=(12, 7))
  plt.subplot(1, 2, 1)
  plt.imshow(imgs.reshape((5, 5, 36, 36, 3)).transpose(0, 2, 1, 3, 4).reshape(36 * 5, 36 * 5, 3))
  plt.subplot(1, 2, 2)
  plt.plot(losses)
  plt.yscale('log')
  plt.xlabel('epoch')
  plt.ylabel('loss')
  plt.show()
  print("Done with epoch #", i_ep)

In [None]:
codes = encoder.predict(X_train)
reco = decoder.predict(codes[0])

In [None]:
shuffle_ids = np.random.choice(len(X_train), len(X_train), replace=False)

plt.figure(figsize=(12, 6), dpi=100)

plt.subplot(1, 2, 1)
plt.imshow(data[shuffle_ids][:25].reshape(5, 5, 36, 36, 3).transpose((0, 2, 1, 3, 4)).reshape(5 * 36, 5 * 36, 3));
plt.title('Train')

plt.subplot(1, 2, 2)
plt.imshow(reco[shuffle_ids][:25].reshape(5, 5, 36, 36, 3).transpose((0, 2, 1, 3, 4)).reshape(5 * 36, 5 * 36, 3));
plt.title('Reconstructed');