<a href="https://colab.research.google.com/github/TivoGatto/Thesis/blob/master/Two_Stage/Two_Stage_VAE_CIFAR10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# LIBRARIES
import numpy as np
import matplotlib.pyplot as plt

from keras.models import Model
from keras.layers import Input, Conv2D, BatchNormalization, ReLU, Dense, Flatten, Reshape, Conv2DTranspose, Lambda
from keras.datasets import cifar10
import keras.backend as K

In [2]:
# Parameters
input_dim = (32, 32, 3)
latent_dim = 16

epochs = 5
batch_size = 100

beta = 1

In [3]:
# Functions
def vae_loss(z_mean, z_log_var):
    def loss(x_true, x_pred):
        x_true = K.reshape(x_true, (-1, np.prod(input_dim)))
        x_pred = K.reshape(x_pred, (-1, np.prod(input_dim)))

        L_rec = 0.5 * K.sum(K.square(x_true - x_pred), axis=-1)
        L_KL = 0.5 * K.sum(K.square(z_mean) + K.exp(z_log_var) - 1 - z_log_var, axis=-1)

        return K.mean(L_rec + beta * L_KL)
    return loss

def recon(x_true, x_pred):
    x_true = K.reshape(x_true, (-1, np.prod(input_dim)))
    x_pred = K.reshape(x_pred, (-1, np.prod(input_dim)))

    return K.mean(0.5 * K.sum(K.square(x_true - x_pred), axis=-1))

def KL(z_mean, z_log_var):
    def kl(x_true, x_pred):
        return K.mean(0.5 * K.sum(K.square(z_mean) + K.exp(z_log_var) - 1 - z_log_var, axis=-1))
    return kl

def sampling(args):
    z_mean, z_log_var = args
    eps = K.random_normal(shape=(100, latent_dim)) # 100 = batch_size

    return z_mean + K.exp(0.5 * z_log_var) * eps

def pad(x, d):
    size = x.shape[0]
    h, w = x.shape[1:]

    x = np.reshape(x, (size, h, w, 1))

    x_padded = np.zeros(shape=(size, ) + d)
    x_padded[:, :h, :w] = x

    return x_padded

In [None]:
# Dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

x_train = x_train.astype('float32')
x_test  = x_test.astype('float32')

print('x_train shape: ' + str(x_train.shape))
print('x_test shape: ' + str(x_test.shape))

In [6]:
# First stage VAE

# Model Architecture
# ENCODER
x = Input(shape=input_dim) # Shape (32, 32, 1)

h = Conv2D(128, 4, strides=(2, 2), padding='same')(x) # Shape (16, 16, 128)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2D(256, 4, strides=(2, 2), padding='same')(h) # Shape (8, 8, 256)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2D(512, 4, strides=(2, 2), padding='same')(h) # Shape (4, 4, 512)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2D(1024, 4, strides=(2, 2), padding='same')(h) # Shape (2, 2, 1024)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Flatten()(h)

z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)
z = Lambda(sampling)([z_mean, z_log_var])

encoder = Model(x, [z, z_mean, z_log_var])

# DECODER
z_in = Input(shape=(latent_dim, ))

h = Dense(8 * 8 * 1024)(z_in)
h = Reshape((8, 8, 1024))(h)
#h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2DTranspose(512, 4, strides=(2, 2), padding='same')(h) # Shape (16, 16, 512)
#h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2DTranspose(256, 4, strides=(2, 2), padding='same')(h)
#h = BatchNormalization()(h)
h = ReLU()(h)

x_decoded = Conv2DTranspose(3, 4, strides=(1, 1), padding='same', activation='sigmoid')(h)

decoder = Model(z_in, x_decoded)

# VAE
x_recon = decoder(z)

vae = Model(x, x_recon)

# Compile model
from tensorflow.keras.optimizers import Adam
optimizer = Adam(lr=0.001)

vae.compile(optimizer=optimizer, loss=vae_loss(z_mean, z_log_var), metrics=[recon, KL(z_mean, z_log_var)])

In [None]:
# Fit model
hist = vae.fit(x_train, x_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_split=0.1)

In [None]:
# Second Stage VAE

# Loss
def second_stage_loss(u_mean, u_log_var):
    def loss(x_true, x_pred):
        L_rec = 0.5 * K.sum(K.square(x_true - x_pred), axis=-1)
        L_KL = 0.5 * K.sum(K.square(u_mean) + K.exp(u_log_var) - 1 - u_log_var, axis=-1)

        return K.mean(L_rec + L_KL)
    return loss

def second_stage_sampling(args):
    u_mean, u_log_var = args
    eps = K.random_normal(shape=(batch_size, latent_dim))

    return u_mean + eps * K.exp(0.5 * u_log_var)

# Encoder
intermediate_dim = 256

z = Input(shape=(latent_dim, ))

h = Dense(intermediate_dim, activation='relu')(z)
h = Dense(intermediate_dim, activation='relu')(h)
h = Dense(intermediate_dim, activation='relu')(h)

u_mean = Dense(latent_dim)(h)
u_log_var = Dense(latent_dim)(h)

u = Lambda(second_stage_sampling)([u_mean, u_log_var])
second_encoder = Model(z, [u, u_mean, u_log_var])

# Decoder
u_in = Input(shape=(latent_dim, ))

h = Dense(intermediate_dim, activation='relu')(u_in)
h = Dense(intermediate_dim, activation='relu')(h)
h = Dense(intermediate_dim, activation='relu')(h)

z_decoded = Dense(latent_dim)(h)
second_decoder = Model(u_in, z_decoded)

# VAE
z_reconstructed = second_decoder(u)
second_vae = Model(z, z_reconstructed)
second_vae.compile(optimizer=optimizer, loss=second_stage_loss(u_mean, u_log_var))

In [None]:
# Generate second stage Dataset
z_train = encoder.predict(x_train, batch_size=batch_size)[0]

In [None]:
# Fit second stage VAE
second_vae.fit(z_train, z_train, batch_size=batch_size, epochs=epochs, verbose=1)

# Generation and Reconstruction

Now we have two VAE network: \\
1) u <- p(u) = N(0, I) \\
2) z <- q(z) = E_p(u)[q(z|u)] \\
3) x <- p(x) = E_q(z)[p(x|z)] \\

And we can use them to generate new samples.

In [None]:
# Reconstruction

"""
Here we don't need to use the second stage VAE, 'cause we don't use q(z)
"""
n = 10
digit_size = input_dim[0]

x_recon = vae.predict(x_test, batch_size=batch_size)
x_recon = np.reshape(x_recon, (-1, digit_size, digit_size))
x_test = np.reshape(x_test, (-1, digit_size, digit_size))
figure = np.zeros((2 * digit_size, n * digit_size))

for i in range(n):
    sample = np.random.randint(0, len(x_recon))
    figure[:digit_size, i * digit_size: (i+1) * digit_size] = x_test[sample]
    figure[digit_size:, i * digit_size: (i+1) * digit_size] = x_recon[sample]

x_test = np.reshape(x_test, (-1, ) + input_dim)

plt.style.use('default')
plt.imshow(figure, cmap='gray')
plt.show()

In [None]:
# Generation
n = 10 #figure with n x n digits
digit_size = 32
figure = np.zeros((digit_size * n, digit_size * n))
# we will sample n points randomly sampled


"""
We want to sample z from q(z) = E_p(u)[q(z|u)]
p(u) = N(0, I)
"""
u_sample = np.random.normal(size=(n**2, latent_dim), scale=1)
z_sample = second_decoder.predict(u_sample, batch_size=batch_size)
for i in range(n):
    for j in range(n):
        x_decoded = decoder.predict(np.array([z_sample[i + n * j]]))
        digit = x_decoded.reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
            j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='gray')
plt.show()

# Metrics Evaluation

In [None]:
# FID Score
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf

from scipy.linalg import sqrtm
from skimage.transform import resize

from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from keras.datasets.mnist import load_data

# Functions needed to compute FID score
def scale_images(images, new_shape): # Scale an image in a new shape using NN Interpolation
	images_list = list()
	for image in images:
		# resize with nearest neighbor interpolation
		new_image = resize(image, new_shape, 0)
		# store
		images_list.append(new_image)
	return np.asarray(images_list)


def calculate_fid(model, images1, images2): # Calculate Frechet Inception Distance between images1, images2
	# calculate activations
	act1 = model.predict(images1)
	act2 = model.predict(images2)

	# calculate mean and covariance statistics
	mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
	mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)

	ssdiff = np.sum((mu1 - mu2)**2.0)
	covmean = sqrtm(sigma1.dot(sigma2))

	if np.iscomplexobj(covmean): # Check if the sqrtm is complex
		covmean = covmean.real

	# calculate score
	fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
	return fid

sample_size = 10000

u_sample = np.random.normal(0, 1, size=(sample_size, latent_dim))
z_sample = second_decoder.predict(u_sample, batch_size=batch_size)
sample = np.random.randint(0, len(x_test), size=sample_size)
x_gen = decoder.predict(z_sample)
x_real = x_test[sample]

x_gen = evaluate.scale_images(x_gen, (299, 299, 1))
x_real = evaluate.scale_images(x_real, (299, 299, 1))
print('Scaled', x_gen.shape, x_real.shape)

x_gen_t = preprocess_input(x_gen)
x_real_t = preprocess_input(x_real)

x_gen = np.zeros(shape=(sample_size, 299, 299, 3))
x_real = np.zeros(shape=(sample_size, 299, 299, 3))
for i in range(3):
    x_gen[:, :, :, i] = x_gen_t[:, :, :, 0]
    x_real[:, :, :, i] = x_real_t[:, :, :, 0]
print('Final', x_gen.shape, x_real.shape)

# prepare the inception v3 model
model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))

# fid between images1 and images2
fid = evaluate.calculate_fid(model, x_real, x_gen)
print('FID (different): %.3f' % fid)

### Deactivated Latent Variables, Variance Loss and Variance Law


In [None]:
def count_deactivated_variables(z_var, treshold = 0.8):
    z_var = np.mean(z_var, axis=0)

    return np.sum(z_var > treshold)

def loss_variance(x_true, x_recon):
    x_true = np.reshape(x_true, (-1, np.prod(x_true.shape[1:])))
    x_recon = np.reshape(x_recon, (-1, np.prod(x_recon.shape[1:])))

    var_true = np.mean(np.var(x_true, axis=1), axis=0)
    var_recon = np.mean(np.var(x_recon, axis=1), axis=0)

    return np.abs(var_true - var_recon)

########################################################################################################################
# SHOW THE RESULTS
########################################################################################################################

_, z_mean, z_log_var = encoder.predict(x_test, batch_size=batch_size)
z_var = np.exp(z_log_var)
n_deact = count_deactivated_variables(z_var)
print('We have a total of ', latent_dim, ' latent variables. ', count_deactivated_variables(z_var), ' of them are deactivated')

var_law = np.mean(np.var(z_mean, axis=0) + np.mean(z_var, axis=0))
print('Variance law has a value of: ', var_law)

x_recon = vae.predict(x_train, batch_size=batch_size)
print('We lost ', loss_variance(x_test, x_recon), 'Variance of the original data')