In [6]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from lfw_dataset import fetch_lfw_dataset

data, attrs = fetch_lfw_dataset()

In [7]:
X_train = data[:10000].reshape((10000, -1))
print(X_train.shape)
X_val = data[10000:].reshape((-1, X_train.shape[1]))
print(X_val.shape)

image_h = data.shape[1]
image_w = data.shape[2]

(10000, 6075)
(3143, 6075)


In [8]:
X_train = np.float32(X_train)
X_train = X_train/255
X_val = np.float32(X_val)
X_val = X_val/255

In [10]:
def plot_gallery(images, h, w, n_row=3, n_col=6):
    """Helper function to plot a gallery of portraits"""
    plt.figure(figsize=(1.5 * n_col, 1.7 * n_row))
    plt.subplots_adjust(bottom=0, left=.01, right=.99, top=.90, hspace=.35)
    for i in range(n_row * n_col):
        plt.subplot(n_row, n_col, i + 1)
        plt.imshow(images[i].reshape((h, w, 3)), cmap=plt.cm.gray, vmin=-1, vmax=1, interpolation='nearest')
        plt.xticks(())
        plt.yticks(())

In [None]:
plot_gallery(X_train, image_h, image_w)

In [23]:
class VAE(tf.keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    @tf.function
    def KL_divergence(self, mu, logsigma):
        kl_loss = -0.5 * tf.reduce_sum(1 + logsigma - tf.square(mu) - tf.exp(logsigma), axis=1)
        return tf.reduce_mean(kl_loss)

    @tf.function
    def log_likelihood(self, x, z):
        x_decoded = self.decoder(z)
        log_likelihood = tf.reduce_sum(tf.keras.losses.binary_crossentropy(x, x_decoded), axis=1)
        return tf.reduce_mean(log_likelihood)


    def train_step(self, data):
        x = data

        with tf.GradientTape() as tape:
            tape.watch(x)
            z_mean, z_logsigma = self.encoder(x)
            z = self.reparameterize(z_mean, z_logsigma)
            x_decoded_mean = self.decoder(z)
            kl_loss = self.KL_divergence(z_mean, z_logsigma)
            log_likelihood = self.log_likelihood(x, x_decoded_mean)
            total_loss = kl_loss + log_likelihood

        gradients = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {
            "total_loss": total_loss,
            "kl_loss": kl_loss,
            "log_likelihood": log_likelihood
        }

    def call(self, inputs):
        return self.encoder(inputs)


In [12]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def __init__(self, model, rate):
        super(DisplayCallback, self).__init__()
        self.model = model
        self.rate = rate

    def on_epoch_end(self, epoch, logs=None):
        model = self.model
        if epoch % self.rate == 0:
            idx = np.random.choice(X_train.shape[0])
            plt.subplot(221)
            plt.imshow(X_train[idx].reshape(
                (image_h, image_w, 3)
            ))
            plt.subplot(222)
            plt.imshow(tf.reshape(
                model(X_train[tf.newaxis, idx]), (image_h, image_w, 3)
            ))
            idx = np.random.choice(X_val.shape[0])
            plt.subplot(223)
            plt.imshow(X_val[idx].reshape(
                (image_h, image_w, 3)
            ))
            plt.subplot(224)
            plt.imshow(tf.reshape(
                model(X_val[tf.newaxis, idx]), (image_h, image_w, 3)
            ))
            plt.show()

In [13]:
encoder = tf.keras.Sequential()
decoder = tf.keras.Sequential()
encoder.add(tf.keras.layers.Dense(64, activation='relu', kernel_initializer= tf.keras.initializers.GlorotUniform()))
encoder.add(tf.keras.layers.Dense(128, activation='relu', kernel_initializer= tf.keras.initializers.GlorotUniform()))
decoder.add(tf.keras.layers.Dense(128, activation='relu', kernel_initializer= tf.keras.initializers.GlorotUniform()))
decoder.add(tf.keras.layers.Dense(64, activation='sigmoid', kernel_initializer= tf.keras.initializers.GlorotUniform()))

In [14]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=tf.keras.optimizers.Adam())
callback = DisplayCallback(vae, 2)

In [24]:
history = vae.fit(
    X_train, X_train,
    epochs=30,
    validation_data = (X_val, X_val),
    callbacks=[callback]
                  )

Epoch 1/30


TypeError: ignored