In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import tensorflow_datasets as tfds
from blocks import Encoder, Decoder

<b>Load and Preprocess the dataset</b>

In [None]:
(ds_train, ds_test_), ds_info = tfds.load('celeb_a', 
                              split=['train', 'test'], 
                              shuffle_files=True,
                              with_info=True,
                             download=False, data_dir='/data/')

In [None]:
batch_size = 128

def preprocess(sample):
    image = sample['image']
    image = tf.image.resize(image, [112,112])
    image = tf.cast(image, tf.float32)/255.
    return image, image

ds_train = ds_train.map(preprocess)
ds_train = ds_train.shuffle(batch_size*4)
ds_train = ds_train.batch(batch_size).prefetch(batch_size)

ds_test = ds_test_.map(preprocess).batch(batch_size).prefetch(batch_size)

train_num = ds_info.splits['train'].num_examples
test_num = ds_info.splits['test'].num_examples

<b>Build the VAE</b>

In [None]:
class VAE(Model):
    def __init__(self, z_dim, name='VAE'):
        super(VAE, self).__init__(name=name)
        self.encoder = Encoder(z_dim)
        self.decoder = Decoder(z_dim)
        self.mean = None
        self.logvar = None
        
    def call(self, x):
        z, self.mean, self.logvar = self.encoder(x)
        out = self.decoder(z)
        return out

In [None]:
# define the losses

def vae_kl_loss(y_true, y_pred):
    kl_loss =  - 0.5 * tf.reduce_mean(1 + vae.logvar - tf.square(vae.mean) - tf.exp(vae.logvar))
    return kl_loss    

def vae_rc_loss(y_true, y_pred):
    rc_loss = tf.keras.losses.MSE(y_true, y_pred)
    return rc_loss

def vae_loss(y_true, y_pred):
    kl_loss = vae_kl_loss(y_true, y_pred)
    rc_loss = vae_rc_loss(y_true, y_pred)
    kl_weight_const = 0.01
    return kl_weight_const*kl_loss + rc_loss

<b>Instantiate and train the model</b>

In [None]:
vae = VAE(z_dim=200)

In [None]:
model_path = './models/celeb_a_vae.h5'

checkpoint = ModelCheckpoint(model_path, monitor="vae_rc_loss", verbose=1, save_best_only=True,
                             mode="auto", save_weights_only=True)

early = EarlyStopping(monitor="vae_rc_loss", mode="auto", patience=4)

callbacks_list = [checkpoint, early]

initial_lr = 1e-3

steps_per_epoch = int(np.round(train_num/batch_size))

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_lr, decay_steps=steps_per_epoch, decay_rate=0.96, staircase=True)

In [None]:
vae.compile(loss=[vae_loss], 
            optimizer=tf.keras.optimizers.RMSprop(learning_rate=3e-3),
            metrics=[vae_kl_loss, vae_rc_loss])

In [None]:
history = vae.fit(ds_train, validation_data=ds_test, epochs=2, callbacks = callbacks_list)