__[Keras Example](https://keras.io/examples/generative/vae/)__

__WIP__

In [1]:
import os
import tensorflow as tf
from tensorflow import keras
from keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, ReLU, Dropout, Layer
from keras.models import Model
from keras import backend as K
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint,LearningRateScheduler, EarlyStopping
from keras.utils import plot_model
import numpy as np
import pickle
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist

In [2]:
def load_mnist():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    x_train = x_train.astype('float32') / 255.
    #x_train = np.squeeze(x_train,3)
    x_train = x_train.reshape(x_train.shape + (1,))
    x_test = x_test.astype('float32') / 255.
    x_test = x_test.reshape(x_test.shape + (1,))

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_mnist()

In [3]:
x_train.shape

(60000, 28, 28, 1)

In [4]:
K.clear_session()

In [5]:

# load dataset
(trainX, trainy), (testX, testy) = mnist.load_data()

trainX=trainX/255
testX=testX/255

# summarize loaded dataset
print('Train: X=%s, y=%s' % (trainX.shape, trainy.shape))
print('Test: X=%s, y=%s' % (testX.shape, testy.shape))

Train: X=(60000, 28, 28), y=(60000,)
Test: X=(10000, 28, 28), y=(10000,)


In [6]:
class Sampling(Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.mean_squared_error(data, reconstruction),
                    axis=(1, 2),
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

class VariationalAutoEncoder():

    _ENC_FILTERS = [32,64,64,64]
    _ENC_STRIDES = [1,2,2,1]
    _FILTER_SIZE = [3,3,3,3]
    _DEC_FILTERS = [64,64,32,1]
    _DEC_STRIDES = [1,2,2,1]
    _BATCH_NORM = True
    _DROPOUT = 0.2

    def __init__(self,input_dim, latent_space, r_loss_factor = 10, learning_rate = 0.001):

        self.input_dim = input_dim
        self.latent = latent_space
        self.r_loss_factor = r_loss_factor = 10
        self.learning_rate = learning_rate
        self._build()

    def _build(self):

        ## Encoder Stack
        
        encoder_input = Input(shape = self.input_dim,name = 'encoder_input')
        x = encoder_input

        for lyr in range(len(self._ENC_FILTERS)):
            conv_layer = Conv2D(filters = self._ENC_FILTERS[lyr],
                              kernel_size = self._FILTER_SIZE[lyr],
                              strides = self._ENC_STRIDES[lyr],
                              name = 'encoder_conv'+str(lyr),
                              padding ='same')
            x = conv_layer(x)
            x = ReLU()(x)
            if self._BATCH_NORM == True:
                x = BatchNormalization()(x)

        shape_before_flattening = K.int_shape(x)[1:]
        x = Flatten()(x)

        mu = Dense(self.latent, name='mu')(x)
        log_var = Dense(self.latent, name='log_var')(x)
        z = Sampling(name='encoder_output')([mu, log_var])

        self.encoder = Model(encoder_input, [mu, log_var, z], name = 'encoder')

        ## Decoder Stack

        decoder_input = Input(shape=self.latent,name = 'decoder_input')
        x = Dense(np.prod(shape_before_flattening))(decoder_input)
        x = Reshape(shape_before_flattening)(x)

        for lyr in range(len(self._DEC_FILTERS)):

            conv_t_layer = Conv2DTranspose(filters = self._DEC_FILTERS[lyr],
                                         kernel_size = self._FILTER_SIZE[lyr],
                                         strides = self._DEC_STRIDES[lyr],
                                         name = 'decoder_conv_t'+str(lyr),
                                         padding ='same')
            x = conv_t_layer(x)

            if lyr<len(self._DEC_FILTERS)-1:

                x = ReLU()(x)

                if self._BATCH_NORM == True:
                    x = BatchNormalization()(x)
            else:
                x = Activation('sigmoid')(x)

        decoder_output = x
        self.decoder = Model(decoder_input,decoder_output, name = 'Decoder')

In [10]:
vae = VariationalAutoEncoder(input_dim=(28,28,1), latent_space = 2)

(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

In [11]:
model = VAE(vae.encoder, vae.decoder)
model.compile(optimizer=keras.optimizers.Adam())
model.fit(mnist_digits, epochs=3, batch_size=16)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x2691f503610>