<a href="https://colab.research.google.com/github/Daankrol/DL-VAE/blob/main/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Convolutional Variational Autoencoder
sources: 
- https://www.tensorflow.org/tutorials/generative/cvae
- https://www.topbots.com/generating-new-faces-with-variational-autoencoders/
- https://keras.io/examples/generative/vae/

to prevent Google Colab from disconnecting due to a timeout issue, execute the following JS function in the Google Chrome console.
```js
function ClickConnect(){
  console.log(“Working”);document.querySelector(“colab-toolbar-button#connect”).click()
}  
setInterval(ClickConnect,60000)
```

In [None]:
from IPython import display

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
from tensorflow.keras.layers import Input, InputLayer, Conv2D, BatchNormalization, LeakyReLU, Flatten, Dense, Lambda, Reshape, Conv2DTranspose, Activation
import tensorflow_probability as tfp
import time
IMAGE_DIM = (28,28,1)
BATCH_SIZE = 32
TRAIN_SIZE = 60000
TEST_SIZE = 10000
Z_DIM = 100

##Dataset loading and preprocessing

In [None]:
(train_ds, _), (test_ds, _) = tf.keras.datasets.mnist.load_data()
train_ds = np.concatenate([train_ds, test_ds], axis=0)
# train_ds = train_ds.reshape((train_ds.shape[0], IMAGE_DIM[0], IMAGE_DIM[1], 1)) /255
# train_ds = train_ds.astype('float32')
# test_ds = test_ds.reshape((test_ds.shape[0], IMAGE_DIM[0], IMAGE_DIM[1], 1)) / 255
# test_ds = test_ds.astype('float32')

def preprocess_images(images):
  images = images.reshape((images.shape[0], 28, 28, 1)) / 255.
  return np.where(images > .5, 1.0, 0.0).astype('float32')

train_ds = preprocess_images(train_ds)
test_ds = preprocess_images(test_ds)

train_ds = tf.data.Dataset.from_tensor_slices(train_ds).shuffle(TRAIN_SIZE).batch(BATCH_SIZE)
test_ds = tf.data.Dataset.from_tensor_slices(test_ds).shuffle(TEST_SIZE).batch(BATCH_SIZE)

##Encoder

In [None]:
tf.keras.backend.clear_session()
filters = [32,64]
num_layers = len(filters)

print(IMAGE_DIM)
encoder_input = Input(shape = IMAGE_DIM, name='encoder_input')
x = encoder_input 
for l in range(num_layers):
  x = Conv2D(filters[l], kernel_size=3, strides=2, padding='same', name='encoder_conv_' +str(l))(x)
  x = BatchNormalization()(x)
  x = LeakyReLU()(x) # alpha=0.3 by default

pre_flatten_shape = tf.keras.backend.int_shape(x)[1:]
print('Pre flatten shape:', pre_flatten_shape)
x = Flatten()(x)
z_mu = Dense(Z_DIM, name='mu')(x)
z_log_var = Dense(Z_DIM, name='log_var')(x)

def sampling(args):
  mu, log_var = args
  eps = tf.keras.backend.random_normal(shape=tf.keras.backend.shape(mu), mean=0., stddev=1.)
  return mu + tf.exp(log_var * 0.5) * eps

encoder_output = Lambda(sampling, name='encoder_output')([z_mu, z_log_var])

encoder = tf.keras.Model(encoder_input, encoder_output)
encoder.summary()
# encoder = tf.keras.Sequential(
#     [
#      tf.keras.layers.InputLayer(input_shape=IMAGE_DIM),
#      tf.keras.layers.Conv2D(32, kernel_size=4, padding='same', strides=2),
#     #  tf.keras.layers.BatchNormalization(),
#      tf.keras.layers.LeakyReLU(),
#      tf.keras.layers.Conv2D(64, kernel_size=4,  padding='same', strides=2),
#     #  tf.keras.layers.BatchNormalization(),
#      tf.keras.layers.LeakyReLU(),
#      tf.keras.layers.Conv2D(128, kernel_size=4, padding='same', strides=2),
#     #  tf.keras.layers.BatchNormalization(),
#      tf.keras.layers.LeakyReLU(),
#      tf.keras.layers.Conv2D(256, kernel_size=4, padding='same', strides=2),
#     #  tf.keras.layers.BatchNormalization(),
#      tf.keras.layers.LeakyReLU(),
#      tf.keras.layers.Flatten(),
#      tf.keras.layers.Dense(Z_DIM + Z_DIM)
#     ]
# )

(28, 28, 1)
Pre flatten shape: (7, 7, 64)
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
encoder_conv_0 (Conv2D)         (None, 14, 14, 32)   320         encoder_input[0][0]              
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 14, 14, 32)   128         encoder_conv_0[0][0]             
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 14, 14, 32)   0           batch_normalization[0][0]        
____________________________________________________

## Decoder

In [None]:
dec_filters = [64,32]
dec_num_layers = len(dec_filters)

decoder_input = Input(shape = Z_DIM, name='decoder_input')
x = Dense(np.prod(pre_flatten_shape))(decoder_input)
x = Reshape(pre_flatten_shape)(x)
for l in range(num_layers):
  x = Conv2DTranspose(dec_filters[l], kernel_size=3, strides=2, padding='same', name='decoder_conv_' +str(l))(x)
  x = BatchNormalization()(x)
  x = LeakyReLU()(x) # alpha=0.3 by default

x = Conv2DTranspose(1, kernel_size=3, strides=1, padding='same', name='decoder_conv_' +str(dec_num_layers))(x)
# x = Activation('sigmoid')(x)

decoder_output = x 
decoder = tf.keras.Model(decoder_input, decoder_output)
decoder.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
decoder_input (InputLayer)   [(None, 100)]             0         
_________________________________________________________________
dense (Dense)                (None, 3136)              316736    
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 64)          0         
_________________________________________________________________
decoder_conv_0 (Conv2DTransp (None, 14, 14, 64)        36928     
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
decoder_conv_1 (Conv2DTransp (None, 28, 28, 32)        1846

## Compilation and training

In [None]:
epochs = 10
optimizer = tf.keras.optimizers.Adam(lr=0.0005)
mse = tf.keras.losses.MeanSquaredError()
def r_loss(y_true, y_pred):
    return tf.mean(tf.square(y_true - y_pred), axis = [1,2,3])
 
simple_ae = tf.keras.Model(encoder_input, decoder(encoder_output))
simple_ae.summary()
simple_ae.compile(optimizer=optimizer, loss=mse)

simple_ae.fit(train_ds, epochs=epochs, batch_size=BATCH_SIZE)

Model: "model_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
encoder_conv_0 (Conv2D)         (None, 14, 14, 32)   320         encoder_input[0][0]              
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 14, 14, 32)   128         encoder_conv_0[0][0]             
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 14, 14, 32)   0           batch_normalization[0][0]        
____________________________________________________________________________________________

ValueError: ignored

In [None]:
class VAE(tf.keras.Model):
  def __init__(self, Z_DIM):
    super(VAE, self).__init__()
    self.Z_DIM = Z_DIM
    self.encoder = encoder
    self.decoder = decoder
  
  @tf.function
  def sample(self, eps=None):
    if eps is None:
      eps = tf.random.normal(shape=(100, ))