<a href="https://colab.research.google.com/github/AtSourav/CatsnDogs_VAE/blob/main/CatsnDogs_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

We're going to build a VAE in keras with the encoder and decoder based on CNNs.

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [2]:
input_size = (160,160,3)
batch_size = 256
latent_dim = 16

In [4]:
encoder_input = keras.Input(shape=input_size)

x = layers.Conv2D(16, 3, padding="valid")(encoder_input)
x = layers.MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid")(x)    # with strides=None this defaults to pool_size
x = layers.BatchNormalization(axis=-1)(x) # the default data_format in the conv2d is "channels last", we want to normalize across the channels, hence we set axis=-1
x = layers.LeakyReLU()(x)

x = layers.Conv2D(32, 3, padding="valid")(x)
x = layers.MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid")(x)
x = layers.BatchNormalization(axis=-1)(x)
x = layers.LeakyReLU()(x)

x = layers.Conv2D(64, 3, padding="valid")(x)
x = layers.MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid")(x)
x = layers.BatchNormalization(axis=-1)(x)
x = layers.LeakyReLU()(x)

x = layers.Conv2D(128, 3, padding="valid")(x)
x = layers.MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid")(x)
x = layers.BatchNormalization(axis=-1)(x)
x = layers.LeakyReLU()(x)

x = layers.Flatten()(x)       # dimension of vector produced 8*8*256

# we won't add any extra dense layers for now

z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)

def sampling(arg):
  z_m, z_log_v = arg
  batch = tf.shape(z_m)[0]
  dim = tf.shape(z_m)[1]
  eps = tf.random.normal(shape=(batch,dim))
  return z_m + tf.exp(0.5*z_log_v)*eps

z = layers.Lambda(sampling)([z_mean,z_log_var])   # we feed the sampling function in to a Lambda layer to build form a layer for the architecture as keras needs

encoder = keras.Model(encoder_input, [z_mean, z_log_var, z], name='encoder')           # the second argument specifies that the encoder outputs [z_mean, z_log_var, z] for each input.
encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 160, 160, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_1 (Conv2D)              (None, 158, 158, 16  448         ['input_2[0][0]']                
                                )                                                                 
                                                                                                  
 max_pooling2d_1 (MaxPooling2D)  (None, 79, 79, 16)  0           ['conv2d_1[0][0]']               
                                                                                            

The number of trainable parameters might just be too big for the dataset, possibility of overtraining, we'll see how it goes. Let's make the decoder now.

In [5]:
latent_input = keras.Input(shape=(latent_dim,))

x = layers.Dense(10*10*128)(latent_input)
x = layers.LeakyReLU()(x)
x = layers.Reshape((10,10,128))(x)

x = layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
                                                        # axis=-1 is set by default
x = layers.Conv2DTranspose(32, 3, strides=2, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)

x = layers.Conv2DTranspose(16, 3, strides=2, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)

decoder_output = layers.Conv2DTranspose(3, 3, activation='sigmoid', strides=2, padding='same')(x)

decoder = keras.Model(latent_input, decoder_output, name="decoder")
decoder.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 16)]              0         
                                                                 
 dense (Dense)               (None, 12800)             217600    
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 12800)             0         
                                                                 
 reshape (Reshape)           (None, 10, 10, 128)       0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 20, 20, 64)       73792     
 nspose)                                                         
                                                                 
 batch_normalization_5 (Batc  (None, 20, 20, 64)       256       
 hNormalization)                                           

Now let's instantiate the VAE by combininb the encoder and the decoder layers.

In [6]:
decoder_out = decoder(encoder(encoder_input)[2])          # we feed in z from the encoder output
VAE = keras.Model(encoder_input, decoder_out, name='VAE')

VAE.summary()

Model: "VAE"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 160, 160, 3)]     0         
                                                                 
 encoder (Functional)        [(None, 16),              360576    
                              (None, 16),                        
                              (None, 16)]                        
                                                                 
 decoder (Functional)        (None, 160, 160, 3)       315363    
                                                                 
Total params: 675,939
Trainable params: 675,235
Non-trainable params: 704
_________________________________________________________________


We shall define the loss function now, it's made of the reconstruction loss that tries to ensure repoducibility of the data, and the KL divergence loss that tries to ensure that the learned (approximate) posterior is close to the true posterior. Since we have continuous values for the pixels, we cannot use the Bernoulli log loss (binary cross entropy) for the reconstruction loss, we shall use mean squared error. In a different version, we shall try to implement a continuous version of the Bernoulli log loss based on 1907.06845.

We shall keep a relative weight between the two terms in the total loss as a hyperparameter. We may want to investigate the effects of varying this hyperparameter eventually.