## Encoder and Decoder networks

## 4. Build the encoder and decoder networks
* You should now create the encoder and decoder for the variational autoencoder algorithm.
* You should design these networks yourself, subject to the following constraints:
   * The encoder and decoder networks should be built using the `Sequential` class.
   * The encoder and decoder networks should use probabilistic layers where necessary to represent distributions.
   * The prior distribution should be a zero-mean, isotropic Gaussian (identity covariance matrix).
   * The encoder network should add the KL divergence loss to the model.
* Print the model summary for the encoder and decoder networks.

In [23]:
import math
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Flatten, Conv2D, UpSampling2D, BatchNormalization, Reshape

tfb = tfp.bijectors
tfd = tfp.distributions
tfpl = tfp.layers

%matplotlib inline

In [24]:
latent_dim = 50

prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_dim))

In [25]:
def get_kl_regularizer(prior_distribution):
    """
    This function should create an instance of the KLDivergenceRegularizer 
    according to the above specification. 
    The function takes the prior_distribution, which should be used to define 
    the distribution.
    Your function should then return the KLDivergenceRegularizer instance.
    """
    return tfpl.KLDivergenceRegularizer(prior_distribution,
                                        weight=1.0,
                                        use_exact_kl=False,
                                        test_points_fn=lambda q: q.sample(3),
                                        test_points_reduce_axis=None) 

In [26]:
kl_regularizer = get_kl_regularizer(prior)

In [27]:
event_shape = (28, 28)

encoder = Sequential([
    Conv2D(filters=32, padding='SAME', kernel_size=(4,4), strides=(2,2), activation='relu', input_shape=(32,32,3)),
    BatchNormalization(),
    Conv2D(filters=64, padding='SAME', kernel_size=(4,4), strides=(2,2), activation='relu'),
    BatchNormalization(),
    Conv2D(filters=128, padding='SAME', kernel_size=(4,4), strides=(2,2), activation='relu'),
    BatchNormalization(),
    Conv2D(filters=256, padding='SAME', kernel_size=(4,4), strides=(2,2), activation='relu'),
    BatchNormalization(),
    Flatten(),
    Dense(tfpl.MultivariateNormalTriL.params_size(latent_dim)),
    tfpl.MultivariateNormalTriL(event_size=latent_dim, activity_regularizer=kl_regularizer)
    ])

In [28]:
decoder = Sequential([
    Dense(4096, activation='relu', input_shape=(latent_dim,)),
    Reshape((4,4,256)),
    UpSampling2D(size=(2,2)),
    Conv2D(filters=128, padding='SAME', kernel_size=(3,3), activation='relu'),
    UpSampling2D(size=(2,2)),
    Conv2D(filters=64, padding='SAME', kernel_size=(3,3), activation='relu'),
    UpSampling2D(size=(2,2)),
    Conv2D(filters=32, padding='SAME', kernel_size=(3,3), activation='relu'),
    UpSampling2D(size=(2,2)),
    Conv2D(filters=128, padding='SAME', kernel_size=(3,3), activation='relu'),
    Conv2D(filters=3, padding='SAME', kernel_size=(3,3)),
    Flatten(),
    tfpl.IndependentBernoulli(event_shape=(64,64,3))
])

In [29]:
encoder.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_18 (Conv2D)           (None, 16, 16, 32)        1568      
_________________________________________________________________
batch_normalization_12 (Batc (None, 16, 16, 32)        128       
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 8, 8, 64)          32832     
_________________________________________________________________
batch_normalization_13 (Batc (None, 8, 8, 64)          256       
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 4, 4, 128)         131200    
_________________________________________________________________
batch_normalization_14 (Batc (None, 4, 4, 128)         512       
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 2, 2, 256)        

In [30]:
decoder.summary()

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_6 (Dense)              (None, 4096)              208896    
_________________________________________________________________
reshape_2 (Reshape)          (None, 4, 4, 256)         0         
_________________________________________________________________
up_sampling2d_4 (UpSampling2 (None, 8, 8, 256)         0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 8, 8, 128)         295040    
_________________________________________________________________
up_sampling2d_5 (UpSampling2 (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 16, 16, 64)        73792     
_________________________________________________________________
up_sampling2d_6 (UpSampling2 (None, 32, 32, 64)       

In [31]:
vae = Model(inputs=encoder.inputs, outputs=decoder(encoder.outputs))