In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import (Reshape, Conv2DTranspose, Add, Conv2D, Dense,
                                     Flatten, InputLayer, BatchNormalization, Input)
from tensorflow.keras.optimizers import Adam

## Data Preperation

In [3]:
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis = 0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32")/255

In [4]:
dataset = tf.data.Dataset.from_tensor_slices(mnist_digits)

In [5]:
BATCH_SIZE = 128
LATENT_DIM = 2

In [7]:
train_dataset = (
    dataset
    .shuffle(buffer_size = 1024, reshuffle_each_iteration = True)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

In [8]:
train_dataset

<PrefetchDataset shapes: (None, 28, 28, 1), types: tf.float32>

## Model

### Sampling

In [9]:
class Sampling(Layer):
    def call(sel, inputs):
        mean, log_var = inputs
        return mean + tf.math.exp(0.5*log_var)*tf.random.normal(shape = (tf.shape(mean)[0], tf.shape(mean)[1]))

### Encoder

In [11]:
encoder_inputs = Input(shape = (28,28,1))

x = Conv2D(32, 3, activation='relu', strides=2, padding='same')(encoder_inputs)
x = Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)

x = Flatten()(x)
x = Dense(16, activation='relu')(x)

mean = Dense(LATENT_DIM)(x)
log_var = Dense(LATENT_DIM)(x)

z = Sampling()([mean, log_var])

encoder_model = Model(encoder_inputs, [z,mean,log_var], name='encoder')
encoder_model.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 14, 14, 32)   320         ['input_2[0][0]']                
                                                                                                  
 conv2d_1 (Conv2D)              (None, 7, 7, 64)     18496       ['conv2d[0][0]']                 
                                                                                                  
 flatten (Flatten)              (None, 3136)         0           ['conv2d_1[0][0]']               
                                                                                            

### Decoder

In [12]:
latent_inputs = Input(shape=(LATENT_DIM))

x = Dense(7*7*64, activation='relu')(latent_inputs)
x = Reshape((7,7,64))(x)

x = Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same')(x)
x = Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)

decoder_output = Conv2DTranspose(1, 3, activation='sigmoid', padding='same')(x)
decoder_model = Model(latent_inputs, decoder_output, name='decoder')
decoder_model.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 2)]               0         
                                                                 
 dense_3 (Dense)             (None, 3136)              9408      
                                                                 
 reshape (Reshape)           (None, 7, 7, 64)          0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 14, 14, 64)       36928     
 nspose)                                                         
                                                                 
 conv2d_transpose_1 (Conv2DT  (None, 28, 28, 32)       18464     
 ranspose)                                                       
                                                                 
 conv2d_transpose_2 (Conv2DT  (None, 28, 28, 1)        289 

### Overall VAE Model

In [13]:
vae_input = Input(shape=(28,28,1), name="vae_input" )
z,_,_ = encoder_model(vae_input)
output = decoder_model(z)
vae = Model(vae_input, output, name='vae')
vae.summary()

Model: "vae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 vae_input (InputLayer)      [(None, 28, 28, 1)]       0         
                                                                 
 encoder (Functional)        [(None, 2),               69076     
                              (None, 2),                         
                              (None, 2)]                         
                                                                 
 decoder (Functional)        (None, 28, 28, 1)         65089     
                                                                 
Total params: 134,165
Trainable params: 134,165
Non-trainable params: 0
_________________________________________________________________


In [15]:
for i in range(3):
    print(vae.layers[i])

<keras.engine.input_layer.InputLayer object at 0x0000024581EEF160>
<keras.engine.functional.Functional object at 0x0000024581E91FA0>
<keras.engine.functional.Functional object at 0x0000024581EA34F0>


## Training

In [16]:
OPTIMIZER = Adam(learning_rate = 1e-3)
EPOCHS = 20

In [17]:
def custom_loss(y_true, y_pred, mean, log_var):

    loss_rec = tf.reduce_mean(tf.reduce_sum(tf.keras.losses.binary_crossentropy(y_true, y_pred), axis = (1,2)))

    loss_reg = -0.5*(1+log_var-tf.square(mean)-tf.exp(log_var))

    return loss_rec+tf.reduce_mean(tf.reduce_sum(loss_reg, axis=1))

In [22]:
@tf.function
def training_block(x_batch):
    with tf.GradientTape() as recorder:
        z, mean, log_var= encoder_model(x_batch)
        y_pred = decoder_model(z)
        y_true = x_batch
        loss = custom_loss(y_true, y_pred, mean, log_var)

    partial_derivatives = recorder.gradient(loss, vae.trainable_weights)
    OPTIMIZER.apply_gradients(zip(partial_derivatives, vae.trainable_weights))
    return loss

In [23]:
def custom_learn(epochs):
    for epoch in range(1, epochs+1):
        print('Training starts for epoch number {}'.format(epoch))

        for step, x_batch in enumerate(train_dataset):
            loss = training_block(x_batch)
        print('Training Loss is : ', loss)
    print('Trining Complete!!!')

In [24]:
custom_learn(EPOCHS)

Training starts for epoch number 1
Training Loss is :  tf.Tensor(180.26564, shape=(), dtype=float32)
Training starts for epoch number 2
Training Loss is :  tf.Tensor(169.36844, shape=(), dtype=float32)
Training starts for epoch number 3
Training Loss is :  tf.Tensor(159.76198, shape=(), dtype=float32)
Training starts for epoch number 4
Training Loss is :  tf.Tensor(157.54752, shape=(), dtype=float32)
Training starts for epoch number 5
Training Loss is :  tf.Tensor(155.40848, shape=(), dtype=float32)
Training starts for epoch number 6
Training Loss is :  tf.Tensor(148.90096, shape=(), dtype=float32)
Training starts for epoch number 7
Training Loss is :  tf.Tensor(146.37741, shape=(), dtype=float32)
Training starts for epoch number 8
Training Loss is :  tf.Tensor(152.75414, shape=(), dtype=float32)
Training starts for epoch number 9
Training Loss is :  tf.Tensor(153.19708, shape=(), dtype=float32)
Training starts for epoch number 10
Training Loss is :  tf.Tensor(152.50475, shape=(), dtyp

## Overriding train_step method