In [1]:
import os
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as L
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from lfw_dataset import fetch_lfw_dataset
data, attrs = fetch_lfw_dataset()

In [None]:
X_train = data[:10000].reshape((10000, -1)).astype(np.float32) / 255.0
X_val = data[10000:].reshape((-1, X_train.shape[1])).astype(np.float32) / 255.0

In [None]:
image_h = data.shape[1]
image_w = data.shape[2]

In [None]:
input_shape = X_train.shape[1:]
latent_dim = 6075

In [None]:
encoder_inputs = tf.keras.Input(shape = input_shape)
encoder_hidden = L.Dense(256, activation = 'relu')(encoder_inputs)
z_mean = L.Dense(latent_dim)(encoder_hidden)
z_logvar = L.Dense(latent_dim)(encoder_hidden)
encoder = tf.keras.Model(encoder_inputs, [z_mean, z_logvar])

decoder_inputs = tf.keras.Input(shape = (latent_dim,))
decoder_hidden = L.Dense(256, activation = 'relu')(decoder_inputs)
decoder_outputs = L.Dense(np.prod(input_shape), activation = 'sigmoid')(decoder_hidden)
decoder = tf.keras.Model(decoder_inputs, decoder_outputs)

In [None]:
class VAE(tf.keras.Model):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def KL_divergence(self, mu, logsigma):
        kl_loss = -0.5 * tf.reduce_sum(1 + logsigma - tf.square(mu) - tf.exp(logsigma), axis=1)
        return tf.reduce_mean(kl_loss)

    def log_likelihood(self, x, z):
        recon_loss = tf.reduce_sum(tf.square(x - self.decoder(z)), axis=1)
        return tf.reduce_mean(recon_loss)

    def train_step(self, x):
        with tf.GradientTape() as tape:
            mu, logsigma = self.encoder(x)
            z = self.gaussian_sampler(mu, logsigma)
            recon_loss = self.log_likelihood(x, z)
            kl_loss = self.KL_divergence(mu, logsigma)
            total_loss = recon_loss + kl_loss

        gradients = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {"loss": total_loss, "kl_loss": kl_loss, "recon_loss": recon_loss}

    def call(self, x):
        mu, logsigma = self.encoder(x)
        z = self.gaussian_sampler(mu, logsigma)
        return self.decoder(z)

    def gaussian_sampler(self, mu, logsigma):
        epsilon = tf.random.normal(shape=tf.shape(mu))
        return mu + tf.exp(logsigma / 2) * epsilon

In [None]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=tf.keras.optimizers.Adam())

In [None]:
num_samples = 30
z_samples = np.random.normal(shape = (num_samples, latent_dim))
generated_images = vae.decoder(z_samples)

In [None]:
history = vae.fit(X_train, X_train,
                  epochs=10,
                  shuffle=True,
                  validation_data=(X_val, X_val))