<a href="https://colab.research.google.com/github/TivoGatto/Thesis/blob/master/Naive_VAE/Naive_VAE_CIFAR10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Implementation of Naive VAE with architecture taken from Tolsikhin et al.

In [None]:
# LIBRARIES
import numpy as np
import matplotlib.pyplot as plt

from keras.models import Model
from keras.layers import Input, Conv2D, BatchNormalization, ReLU, Dense, Flatten, Reshape, Conv2DTranspose, Lambda
from keras.datasets import cifar10
import keras.backend as K

In [None]:
# Parameters
input_dim = (32, 32, 3)
latent_dim = 128

epochs = 100
batch_size = 100

In [None]:
# Functions
def vae_loss(z_mean, z_log_var):
    def loss(x_true, x_pred):
        x_true = K.reshape(x_true, (-1, np.prod(input_dim)))
        x_pred = K.reshape(x_pred, (-1, np.prod(input_dim)))

        L_rec = 0.5 * K.sum(K.square(x_true - x_pred), axis=-1)
        L_KL = 0.5 * K.sum(K.square(z_mean) + K.exp(z_log_var) - 1 - z_log_var, axis=-1)

        return K.mean(L_rec + L_KL)
    return loss

def sampling(args):
    z_mean, z_log_var = args
    eps = K.random_normal(shape=(100, latent_dim)) # 100 = batch_size

    return z_mean + K.exp(0.5 * z_log_var) * eps

In [None]:
# Dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

x_train = x_train.astype('float32')
x_test  = x_test.astype('float32')

print('x_train shape: ' + str(x_train.shape))
print('x_test shape: ' + str(x_test.shape))

In [None]:
# Model Architecture
# ENCODER
x = Input(shape=input_dim) # Shape (32, 32, 3)

h = Conv2D(128, 4, strides=(2, 2), padding='same')(x) # Shape (16, 16, 128)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2D(256, 4, strides=(2, 2), padding='same')(h) # Shape (8, 8, 256)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2D(512, 4, strides=(2, 2), padding='same')(h) # Shape (4, 4, 512)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2D(1024, 4, strides=(2, 2), padding='same')(h) # Shape (2, 2, 1024)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Flatten()(h)

z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)
z = Lambda(sampling)([z_mean, z_log_var])

encoder = Model(x, [z, z_mean, z_log_var])

# DECODER
z_in = Input(shape=(latent_dim, ))

h = Dense(8 * 8 * 1024)(z_in)
h = Reshape((8, 8, 1024))(h)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2DTranspose(512, 4, strides=(2, 2), padding='same')(h) # Shape (16, 16, 512)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2DTranspose(256, 4, strides=(2, 2), padding='same')(h) # Shape (32, 32, 256)
h = BatchNormalization()(h)
h = ReLU()(h)

x_decoded = Conv2DTranspose(3, 4, strides=(1, 1), padding='same')(h) # Shape (32, 32, 3)

decoder = Model(z_in, x_decoded)

# VAE
x_recon = decoder(z)

vae = Model(x, x_recon)

In [None]:
# Fit model
from tensorflow.keras.optimizers import Adam
optimizer = Adam(lr=0.001)

vae.compile(optimizer=optimizer, loss=vae_loss(z_mean, z_log_var))
hist = vae.fit(x_train, x_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_split=0.1)