# 1-2 Variational AutoEncoder

<img src="./img/vae.png" alt="variationalautoencoder" width="500" align="left"/>

In [None]:
import tensorflow as tf
import numpy as np
import os
from matplotlib import pyplot as plt
from matplotlib import gridspec as gridspec

In [None]:
CKPT_DIR = "../generated_output/VAE"

In [None]:
LEARNING_RATE = 0.0002
EPOCHS = 5
STEPS_PER_EPOCH = 469
BATCH_SIZE = 128

In [None]:
IMAGE_DIM = 784
LATENT_DIM = 10
ENDOCER_HIDDEN_DIM = [256]
DECODER_HIDDEN_DIM = [256]

<img src="./img/vae_recloss.png" alt="vae_recloss" width="400" align="top"/>
<img src="./img/vae_regloss.png" alt="vae_regloss" width="450" align="top"/>

In [None]:
class VAE():
     
    def __init__(self, image_dim=IMAGE_DIM, latent_dim=LATENT_DIM, encoder_hidden_dim=ENDOCER_HIDDEN_DIM, decoder_hidden_dim=DECODER_HIDDEN_DIM):
        self.image_dim = image_dim
        self.latent_dim = latent_dim
        self.encoder_hidden_dim = encoder_hidden_dim
        self.decoder_hidden_dim = decoder_hidden_dim

    def _encoder_model(self, feature):
        with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
            net = feature
            for units in self.encoder_hidden_dim:
                net = tf.layers.Dense(units, activation=tf.nn.relu, kernel_initializer=tf.initializers.he_normal())(net)
            latent_mean = tf.layers.Dense(self.latent_dim, kernel_initializer=tf.initializers.he_normal())(net)
            latent_log_var = tf.layers.Dense(self.latent_dim, kernel_initializer=tf.initializers.he_normal())(net)
            return latent_mean, latent_log_var

    def _sampler_model(self, args):
        latent_mean, latent_log_var = args
        with tf.variable_scope('sampler', reuse=tf.AUTO_REUSE):
            snd_sample = tf.random_normal(tf.shape(latent_log_var), dtype=tf.float32, mean=0., stddev=1.0)
            latent_std = tf.exp(latent_log_var / 2)
            latent = latent_mean + latent_std * snd_sample
            return latent

    def _decoder_model(self, feature):
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            net = feature
            for units in self.decoder_hidden_dim:
                net = tf.layers.Dense(units, activation=tf.nn.relu, kernel_initializer=tf.initializers.he_normal())(net)
            recon = tf.layers.Dense(self.image_dim, activation=tf.nn.sigmoid, kernel_initializer=tf.initializers.he_normal())(net)
            return recon
   
    def _vae_loss(self, inputs, outputs, latent_mean, latent_log_var):
        def __vae_loss(x,y):
            rec_loss = tf.reduce_sum(tf.keras.backend.binary_crossentropy(inputs, outputs), 1)
            reg_loss = -0.5 * tf.reduce_sum(1 + latent_log_var - tf.square(latent_mean) - tf.exp(latent_log_var), 1)
            return tf.reduce_mean(rec_loss + reg_loss)
        return __vae_loss

    def _set_model(self):
        inputs = tf.keras.Input(shape=[self.image_dim])
        latent_mean, latent_log_var = self._encoder_model(inputs)
        sample = tf.keras.layers.Lambda(self._sampler_model)([latent_mean, latent_log_var])
        outputs = self._decoder_model(sample)
        self.model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
        self.vae_loss = self._vae_loss(inputs, outputs, latent_mean, latent_log_var)

    def fit(self, x, y, learning_rate, epochs, steps_per_epoch, ckpt_dir):
        self.learning_rate = learning_rate
        if not os.path.exists(os.path.dirname(ckpt_dir)):
            os.makedirs(os.path.dirname(ckpt_dir))
        self._set_model()
        self.model.compile(optimizer=tf.keras.optimizers.Adam(lr=self.learning_rate), loss=self.vae_loss)
        self.model.summary()
        cp_callback = tf.keras.callbacks.ModelCheckpoint(ckpt_dir+'/cp-{epoch:04d}.ckpt', verbose=1, period=1, save_weights_only=True)
        tb_callback = tf.keras.callbacks.TensorBoard(log_dir=ckpt_dir+'/Graph', histogram_freq=0, write_graph=True, write_images=True)
        self.model.fit(x, y, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=[cp_callback, tb_callback])

    def load_weights(self, ckpt_dir):
        self._set_model()
        self.model.load_weights(tf.train.latest_checkpoint(ckpt_dir))

    def predict(self, features):
        self._set_model()
        return self.model.predict(features)

    def batch(self, features, batch_size, is_training):
        self.batch_size=batch_size
        if is_training == True:
            count = None
        else:
            count = 1
        dataset = tf.data.Dataset.from_tensor_slices((features, features))
        batch_dataset = dataset.shuffle(features.shape[0]).repeat(count=count).batch(self.batch_size)
        return batch_dataset.make_one_shot_iterator().get_next()

In [None]:
def train(features, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, ckpt_dir=CKPT_DIR):
    vae_model = VAE()
    batch_x, batch_y = vae_model.batch(features, batch_size, is_training=True)
    vae_model.fit(batch_x, batch_y, learning_rate, epochs, steps_per_epoch, ckpt_dir)

In [None]:
def predict(features):
    features = np.expand_dims(features, axis=0)
    vae_model = VAE()
    vae_model.load_weights(CKPT_DIR)
    return vae_model.predict(features)

In [None]:
def image_plot(true, recon):
    fig = plt.figure(figsize=(6, 3))
    gs = gridspec.GridSpec(1, 2)
    gs.update(wspace=0.05)
    plt.subplot(gs[0])
    plt.axis('off')
    plt.imshow(true.reshape([28, 28]), cmap = 'gray_r')
    plt.subplot(gs[1])
    plt.axis('off')
    plt.imshow(recon.reshape([28, 28]), cmap = 'gray_r')
    plt.show()

In [None]:
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.
x_test = x_test / 255.
x_train = x_train.reshape([-1, IMAGE_DIM]).astype(np.float32)
x_test = x_test.reshape([-1, IMAGE_DIM]).astype(np.float32)

In [None]:
train(x_train)

In [None]:
for i in range(10):
    j = np.random.randint(0,9999)
    image_plot(x_test[j], predict(x_test[j]))