<a href="https://colab.research.google.com/github/TivoGatto/Thesis/blob/master/RAE/RAE_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# LIBRARIES
import numpy as np
import matplotlib.pyplot as plt

from keras.models import Model
from keras.layers import Input, Conv2D, BatchNormalization, ReLU, Dense, Flatten, Reshape, Conv2DTranspose, Lambda
from keras.datasets import mnist
from keras.regularizers import l2
import keras.backend as K

Using TensorFlow backend.


In [None]:
# Parameters
input_dim = (32, 32, 1)
latent_dim = 16

epochs = 100
batch_size = 100

lamb = 0.05
beta = 1

In [None]:
# Functions
def rae_loss(z):
    def loss(x_true, x_pred):
        x_true = K.reshape(x_true, (-1, np.prod(input_dim)))
        x_pred = K.reshape(x_pred, (-1, np.prod(input_dim)))

        L_rec = 0.5 * K.sum(K.square(x_true - x_pred), axis=-1)
        L_rae = 0.5 * K.sum(K.square(z), axis=-1)

        return K.mean(L_rec + beta * L_rae)
    return loss

def recon(x_true, x_pred):
    x_true = K.reshape(x_true, (-1, np.prod(input_dim)))
    x_pred = K.reshape(x_pred, (-1, np.prod(input_dim)))

    return K.mean(0.5 * K.sum(K.square(x_true - x_pred), axis=-1))

def RAE(z):
    def rae(x_true, x_pred):
        return K.mean(0.5 * K.sum(K.square(z), axis=-1))
    return rae

def pad(x, d):
    size = x.shape[0]
    h, w = x.shape[1:]

    x = np.reshape(x, (size, h, w, 1))

    x_padded = np.zeros(shape=(size, ) + d)
    x_padded[:, :h, :w] = x

    return x_padded

In [None]:
# Dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = pad(x_train, input_dim) / 255 # For MNIST, we pad x_train and x_test in 
x_test  = pad(x_test, input_dim) / 255 # shape (32, 32, 1)

x_train = x_train.astype('float32')
x_test  = x_test.astype('float32')

print('x_train shape: ' + str(x_train.shape))
print('x_test shape: ' + str(x_test.shape))

In [None]:
# Model Architecture
# ENCODER
x = Input(shape=input_dim) # Shape (32, 32, 1)

h = Conv2D(128, 4, strides=(2, 2), padding='same')(x) # Shape (16, 16, 128)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2D(256, 4, strides=(2, 2), padding='same')(h) # Shape (8, 8, 256)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2D(512, 4, strides=(2, 2), padding='same')(h) # Shape (4, 4, 512)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2D(1024, 4, strides=(2, 2), padding='same')(h) # Shape (2, 2, 1024)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Flatten()(h)

z = Dense(latent_dim)(h)

encoder = Model(x, z)

# DECODER
z_in = Input(shape=(latent_dim, ))

h = Dense(8 * 8 * 1024, kernel_regularizer=l2(lamb))(z_in)
h = Reshape((8, 8, 1024))(h)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2DTranspose(512, 4, strides=(2, 2), padding='same', kernel_regularizer=l2(lamb))(h) # Shape (16, 16, 512)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2DTranspose(256, 4, strides=(2, 2), padding='same', kernel_regularizer=l2(lamb))(h)
h = BatchNormalization()(h)
h = ReLU()(h)

x_decoded = Conv2DTranspose(1, 4, strides=(1, 1), padding='same', kernel_regularizer=l2(lamb))(h)

decoder = Model(z_in, x_decoded)

# VAE
x_recon = decoder(z)

vae = Model(x, x_recon)

# Compile the model
from tensorflow.keras.optimizers import Adam
optimizer = Adam(lr=0.001)

vae.compile(optimizer=optimizer, loss=rae_loss(z), metrics=[recon, RAE(z)])

In [None]:
# Fit model
hist = vae.fit(x_train, x_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_split=0.1)

In [None]:
# Learn latent space distribution
z_train = encoder.predict(x_train)

prior_for_qz = "Gaussian" # Choose between GMM or Gaussian
if prior_for_qz == "GMM":
    from sklearn.mixture import GaussianMixture

    z_density = GaussianMixture(n_components=10, max_iter=100)
    z_density.fit(z_train)

    print("Learned Gaussian")
elif prior_for_qz == "Gaussian":
    from scipy.stats import norm

    mean, std = norm.fit(z_train) # z_train is fitted to a gaussian N(mean, std)
    print("Learned GMM")
else:
    print("Distribution not found")

# Generation and Reconstruction

In [None]:
# Reconstruction
n = 10
digit_size = input_dim[0]

x_recon = vae.predict(x_test, batch_size=batch_size)
x_recon = np.reshape(x_recon, (-1, digit_size, digit_size))
x_test = np.reshape(x_test, (-1, digit_size, digit_size))
figure = np.zeros((2 * digit_size, n * digit_size))

for i in range(n):
    sample = np.random.randint(0, len(x_recon))
    figure[:digit_size, i * digit_size: (i+1) * digit_size] = x_test[sample]
    figure[digit_size:, i * digit_size: (i+1) * digit_size] = x_recon[sample]

x_test = np.reshape(x_test, (-1, ) + input_dim)

plt.style.use('default')
plt.imshow(figure, cmap='gray')
plt.show()

In [None]:
# Generation
n = 10 #figure with n x n digits
digit_size = 32
figure = np.zeros((digit_size * n, digit_size * n))
# we will sample n points randomly sampled

if prior_for_qz == "GMM":
    z_sample = z_density.sample(n**2)
elif prior_for_qz == "Gaussian":
    z_sample = np.random.normal(size=(n**2, latent_dim), loc=mean, scale=std)
else:
    print("Distribution not found")

for i in range(n):
    for j in range(n):
        x_decoded = decoder.predict(np.array([z_sample[i + n * j]]))
        digit = x_decoded.reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
            j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='gray')
plt.show()

# Metrics Evaluation

In [None]:
# FID Score
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf

from scipy.linalg import sqrtm
from skimage.transform import resize

from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from keras.datasets.mnist import load_data

# Functions needed to compute FID score
def scale_images(images, new_shape): # Scale an image in a new shape using NN Interpolation
	images_list = list()
	for image in images:
		# resize with nearest neighbor interpolation
		new_image = resize(image, new_shape, 0)
		# store
		images_list.append(new_image)
	return np.asarray(images_list)


def calculate_fid(model, images1, images2): # Calculate Frechet Inception Distance between images1, images2
	# calculate activations
	act1 = model.predict(images1)
	act2 = model.predict(images2)

	# calculate mean and covariance statistics
	mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
	mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)

	ssdiff = np.sum((mu1 - mu2)**2.0)
	covmean = sqrtm(sigma1.dot(sigma2))

	if np.iscomplexobj(covmean): # Check if the sqrtm is complex
		covmean = covmean.real

	# calculate score
	fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
	return fid

sample_size = 10000

if prior_for_qz == "GMM":
    z_sample = z_density.sample(sample_size)
elif prior_for_qz == "Gaussian":
    z_sample = np.random.normal(size=(sample_size, latent_dim), mean=mean, scale=std)
else:
    print("Distribution not found")
sample = np.random.randint(0, len(x_test), size=sample_size)
x_gen = decoder.predict(z_sample)
x_real = x_test[sample]

x_gen = evaluate.scale_images(x_gen, (299, 299, 1))
x_real = evaluate.scale_images(x_real, (299, 299, 1))
print('Scaled', x_gen.shape, x_real.shape)

x_gen_t = preprocess_input(x_gen)
x_real_t = preprocess_input(x_real)

x_gen = np.zeros(shape=(sample_size, 299, 299, 3))
x_real = np.zeros(shape=(sample_size, 299, 299, 3))
for i in range(3):
    x_gen[:, :, :, i] = x_gen_t[:, :, :, 0]
    x_real[:, :, :, i] = x_real_t[:, :, :, 0]
print('Final', x_gen.shape, x_real.shape)

# prepare the inception v3 model
model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))

# fid between images1 and images2
fid = evaluate.calculate_fid(model, x_real, x_gen)
print('FID (different): %.3f' % fid)