<a href="https://colab.research.google.com/github/TivoGatto/Thesis/blob/master/DRAW/DRAW_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# LIBRARIES
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt

import keras
import tensorflow as tf
from keras.datasets import mnist
from keras.models import Model, Sequential
from keras.layers import Input, Dense, Lambda, LSTM, Activation, Add, Reshape, Concatenate, Multiply, Flatten
import keras.backend as K

In [None]:
# PARAMETERS
input_dim = 32 * 32
intermediate_dim = 256
latent_dim = 100

T = 20 # Number of cycles
N = 5  # Dimension of the attention window

total_latent_dimension = latent_dim * T # Number of effective latent variables

epochs = 20
batch_size = 100

# CONSTANT
ATTENTION  = False

In [None]:
# FUNCTIONS

def draw_loss(x_true, x_pred):
	# Compute the loss for DRAW Model
	x_pred = sigm(x_pred) 

	xent_loss = keras.losses.binary_crossentropy(x_true, x_pred) # Reconstruction loss (MSE between original image x and reconstructed x_bar

	# D_{KL} Loss, summed over cycles. For q(z|x) = N(z_mean, z_var), p(z) = N(0, I) we have:
	# D{KL} = 1/2 * (z_mean^2 + z_var^2 - z_log_var - 1)
	reg_loss = 0
	for t in range(1, T+1):
		reg_loss += 0.5 * K.sum(K.square(z_mean[t]) + K.exp(z_log_var[t]) - z_log_var[t], 1) - 0.5

	return K.mean(xent_loss + reg_loss / total_latent_dimension)

def Regularizer(x_true, x_pred):
	# I use this function as metric to check the behaviour of Regularizer D_{KL}(q(z|x) || p(z)) during training
	reg_loss = 0

	for t in range(1, T+1):
		reg_loss += 0.5 * K.sum(K.square(z_mean[t]) + K.exp(z_log_var[t]) - z_log_var[t], axis=-1) - 0.5

	return K.mean(reg_loss) / total_latent_dimension
 
def Reconstruction(x_true, x_pred):
	x_pred = sigm(x_pred) 
	# I use this function as metric to check the behaviour of Reconstruction loss during training
	return keras.losses.binary_crossentropy(x_true, x_pred)

def Conc(args):
	# Concatenate a list of layers.
	return K.concatenate(args, axis=-1)

def Sampling(args):
	# Sample z from q(z|x) = N(z_mean, z_var) with reparameterization trick.
	# eps <- N(O, I)
	# z = eps * z_var + z_mean
	z_mean, z_log_var = args
	epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim))

	return z_mean + K.exp(0.5 * z_log_var) * epsilon 

def Read(args):
	# Read with no attention. Just concatenate [x, x_hat]
    x, x_hat, h_dec = args

    if ATTENTION:
        pass
    else:
        return conc([x, x_hat])

def Write(args):
	# Write with no attention. Just a NN with linear activation function to "reshape" h_dec in input_dim.
    h_dec = args
    if ATTENTION:
        pass

def Add_time(args):
	# Add an additional dimension to a layer. This is needed for LSTM.
	return K.expand_dims(args, 1)
 
def sigmoid(x): 
	# Compute sigm(x) = 1 / 1 + exp(-x)
	return 1 / (1 + np.exp(-x))
 
def pad(x, d):
    size = x.shape[0]
    h, w = x.shape[1:]

    x = np.reshape(x, (size, h, w, 1))

    x_padded = np.zeros(shape=(size, ) + d)
    x_padded[:, :h, :w] = x

    return x_padded

In [None]:
# IMPORT DATA
"""
Import MNIST and preprocess it to have [0, 1] values.
"""
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = pad(x_train, (32, 32, 1)) / 255 # For MNIST, we pad x_train and x_test in 
x_test  = pad(x_test, (32, 32, 1)) / 255 # shape (32, 32, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

x_train = np.reshape(x_train, (-1, input_dim))
x_test = np.reshape(x_test, (-1, input_dim))

print('x_train shape: ' + str(x_train.shape))
print('x_test shape: ' + str(x_test.shape))

x_train shape: (60000, 1024)
x_test shape: (10000, 1024)


In [None]:
# MODEL

# Convert previous function in Lambda layer to be used in the Model.
sigm = Activation("sigmoid")
conc = Lambda(Conc)
sampling = Lambda(Sampling)
add_time = Lambda(Add_time)

if ATTENTION:
    read = Lambda(Read)
    write = Lambda(Write)
else:
    read = Lambda(Read)
    write = Dense(input_dim)


# Initialize Encoder and Decoder as LSTM
encoder = LSTM(intermediate_dim, stateful=False, return_state=True, name='Encoder')
decoder = LSTM(intermediate_dim, stateful=False, return_state=True, name='Decoder')


# Initialize a T-length list where we can store the values of the feature during cycles
x_hat = [0] * (T + 1)
r     = [0] * (T + 1)
z     = [0] * (T + 1)

z_mean = [0] * (T + 1)
z_log_var = [0] * (T + 1)

h_enc = [0] * (T + 1)
h_dec = [0] * (T + 1)
C     = [0] * (T + 1)


c_enc = [0] * (T + 1)
c_dec = [0] * (T + 1)

c_enc_init = Input(tensor=K.zeros(shape=(batch_size,  intermediate_dim)), name='c_enc')
c_dec_init = Input(tensor=K.zeros(shape=(batch_size,  intermediate_dim)), name='c_dec')

c_enc[0] = c_enc_init
c_dec[0] = c_dec_init

# Inizialize Input layers
x = Input(shape=(input_dim, ), batch_shape=(batch_size, input_dim), name='Input_img')
h_enc_init = Input(tensor=K.zeros(shape=(batch_size,  intermediate_dim)), name='h_enc')
h_dec_init = Input(tensor=K.zeros(shape=(batch_size,  intermediate_dim)), name='h_dec')
C_0 = Input(tensor=K.zeros(shape=(batch_size, input_dim)), name='C_input')

# And assign them to the first element of our list
h_enc[0] = h_enc_init
h_dec[0] = h_dec_init
C[0]     = Dense(input_dim, name='C_0')(C_0) # C[0] = C_0 (we used a Dense layer which do nothing to coverge C_0 into a Dense Layer)

for t in range(1, T+1):
    x_hat[t] = keras.layers.Subtract(name='x_hat_'+str(t))([x, sigm(C[t-1])]) # Error Image
    r[t]     = read([x, x_hat[t], h_dec[t-1]])            # Information about x, x_hat, h_dec[t-1]

    h_enc[t], _, c_enc[t] = encoder(add_time(conc([r[t], h_dec[t-1]])), initial_state=[h_enc[t-1], c_enc[t-1]]) # Encoded x

    z_mean[t] = Dense(latent_dim, name='z_mean_'+str(t))(h_enc[t]) 
    z_log_var[t] = Dense(latent_dim, name='z_log_var_'+str(t))(h_enc[t])

    z[t] = sampling([z_mean[t], z_log_var[t]])

    h_dec[t], _, c_dec[t] = decoder(add_time(z[t]), initial_state=[h_dec[t-1], c_dec[t-1]]) # Decoded z

    C[t] = Add(name='C_'+str(t))([C[t-1], write(h_dec[t])]) # Update the canvas

vae = Model([x, C_0, h_dec_init, h_enc_init, c_enc_init, c_dec_init], C[T])

optmizers = keras.optimizers.Adam(learning_rate=0.001, beta_1=0.5)
vae.compile(optimizer=optmizers, loss=draw_loss, metrics=[Reconstruction, Regularizer])

In [None]:
# Fit model
hist = vae.fit(x_train, x_train, batch_size=batch_size, epochs=50, verbose=1, validation_split=0.1)

# Generation and Reconstruction

In [None]:
# TEST RECONSTRUCTION
x_recon = vae.predict(x_train, batch_size=batch_size) # Reconstruct MNIST Digit from x_train
x_recon = sigmoid(x_recon) # Apply sigmoid = 1 / 1 + exp(-x) to reconstructed canvas.

digit_size = 32 # Size of the digit. It is supposed to be a digit_size x digit_size image
n = 5 # Number of showed images

figure = np.zeros(shape=(digit_size * n, digit_size * 2))
for i in range(n):
    X_true  = x_train[i]
    X_true  = np.reshape(X_true, (digit_size, digit_size))

    X_recon = x_recon[i]
    X_recon = np.reshape(X_recon, (digit_size, digit_size))

    figure[i * digit_size : (i + 1) * digit_size, : digit_size] = X_true
    figure[i * digit_size : (i + 1) * digit_size, digit_size :] = X_recon

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='gray')
plt.show()

In [None]:
# RE-DEFINE MODELS NEEDED FOR GENERATION 
WEIGHTS = vae.layers[16].get_weights() # Save Weights of LSTM Decoder
LSTM_decoder = LSTM(intermediate_dim, stateful=False, return_state=True)    # Instantiate another copy of LSTM Layer

# DEFINE DECODER MODEL
z_in = Input(shape=(latent_dim), batch_shape=(batch_size, latent_dim))
h_in = Input(shape=(intermediate_dim), batch_shape=(batch_size, intermediate_dim))
c_in = Input(shape=(intermediate_dim), batch_shape=(batch_size, intermediate_dim))

h_out, _, c_out = LSTM_decoder(add_time(z_in), initial_state=[h_in, c_in])
DECODER = Model([z_in, h_in, c_in], h_out)

# Load WEIGHTS into the new Model
DECODER.layers[-1].set_weights(WEIGHTS)

# DEFINE WRITE MODEL
h_in = Input(shape=(intermediate_dim, ), batch_shape=(batch_size, intermediate_dim))
x_out = Dense(input_dim)(h_in)
WRITE = Model(h_in, x_out)

WRITE.layers[-1].set_weights(vae.layers[17].get_weights()) # Load write models

In [None]:
# GENERATE NEW IMAGES WITH PRIOR p(z) = N(0, I)
h_dec_gen = np.zeros(shape=(batch_size, intermediate_dim)) # h_dec_0 for generation
c_dec_gen = np.zeros(shape=(batch_size, intermediate_dim))

C = [0] * (T + 1)                  # Canvas
C[0] = np.zeros(shape=(input_dim)) # Initialize Canvas with a zero-valued images
for t in range(1, T + 1):
    Z = np.random.normal(size=(batch_size, latent_dim))
    h_dec_gen = DECODER.predict([Z, h_dec_gen, c_dec_gen], batch_size=batch_size)

    C[t] = C[t-1] + WRITE.predict(h_dec_gen)

"""
Now we have a tensor C of dimension (T, batch_size, input_dim) where:
    C[t] is the canvas at time t, t = 1, ..., T + 1
    C[t] is a matrix of dimension (batch_size, input_dim) for each t such that
         it contains a batch_size number of generated images at time t

    We want to show these results.
"""

digit_size = 32
n = 10

figure = np.zeros((digit_size * n, digit_size * (T + 1)))
for t in range(1, T + 1):
    C[t] = sigmoid(C[t])
    C[t] = np.reshape(C[t], (-1, digit_size, digit_size))

    for i in range(n):
        figure[i * digit_size : (i+1) * digit_size, t * digit_size : (t+1)*digit_size] = C[t][i]

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='gray')
plt.show()

# Metrics Evaluation

First of all, we want to evaluate the ability of the model of generate high quality samples.

In [None]:
# FID Score
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf

from scipy.linalg import sqrtm
from skimage.transform import resize

from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from keras.datasets.mnist import load_data

# Functions needed to compute FID score
def scale_images(images, new_shape): # Scale an image in a new shape using NN Interpolation
	images_list = list()
	for image in images:
		# resize with nearest neighbor interpolation
		new_image = resize(image, new_shape, 0)
		# store
		images_list.append(new_image)
	return np.asarray(images_list)


def calculate_fid(model, images1, images2): # Calculate Frechet Inception Distance between images1, images2
	# calculate activations
	act1 = model.predict(images1)
	act2 = model.predict(images2)

	# calculate mean and covariance statistics
	mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
	mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)

	ssdiff = np.sum((mu1 - mu2)**2.0)
	covmean = sqrtm(sigma1.dot(sigma2))

	if np.iscomplexobj(covmean): # Check if the sqrtm is complex
		covmean = covmean.real

	# calculate score
	fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
	return fid

sample_size = 10000

z_sample = np.random.normal(0, 1, size=(sample_size, latent_dim))
sample = np.random.randint(0, len(x_test), size=sample_size)
x_gen = decoder.predict(z_sample)
x_real = x_test[sample]

x_gen = evaluate.scale_images(x_gen, (299, 299, 1))
x_real = evaluate.scale_images(x_real, (299, 299, 1))
print('Scaled', x_gen.shape, x_real.shape)

x_gen_t = preprocess_input(x_gen)
x_real_t = preprocess_input(x_real)

x_gen = np.zeros(shape=(sample_size, 299, 299, 3))
x_real = np.zeros(shape=(sample_size, 299, 299, 3))
for i in range(3):
    x_gen[:, :, :, i] = x_gen_t[:, :, :, 0]
    x_real[:, :, :, i] = x_real_t[:, :, :, 0]
print('Final', x_gen.shape, x_real.shape)

# prepare the inception v3 model
model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))

# fid between images1 and images2
fid = evaluate.calculate_fid(model, x_real, x_gen)
print('FID (different): %.3f' % fid)

### Deactivated Latent Variables, Variance Loss and Variance Law


In [None]:
def count_deactivated_variables(z_var, treshold = 0.8):
    z_var = np.mean(z_var, axis=0)

    return np.sum(z_var > treshold)

def loss_variance(x_true, x_recon):
    x_true = np.reshape(x_true, (-1, np.prod(x_true.shape[1:])))
    x_recon = np.reshape(x_recon, (-1, np.prod(x_recon.shape[1:])))

    var_true = np.mean(np.var(x_true, axis=1), axis=0)
    var_recon = np.mean(np.var(x_recon, axis=1), axis=0)

    return np.abs(var_true - var_recon)

########################################################################################################################
# SHOW THE RESULTS
########################################################################################################################

_, z_mean, z_log_var = encoder.predict(x_test, batch_size=batch_size)
z_var = np.exp(z_log_var)
n_deact = count_deactivated_variables(z_var)
print('We have a total of ', latent_dim, ' latent variables. ', count_deactivated_variables(z_var), ' of them are deactivated')

var_law = np.mean(np.var(z_mean, axis=0) + np.mean(z_var, axis=0))
print('Variance law has a value of: ', var_law)

x_recon = vae.predict(x_train, batch_size=batch_size)
print('We lost ', loss_variance(x_test, x_recon), 'Variance of the original data')

### Latent space matching

In [None]:
# We want to verify if q(z) = p(z).

# Moments Matching
# Generate samples from q(z) and for p(z)
# p(z) = N(0, I)
# q(z) = E_q(x)[q(z|x)]
#
# For every moment we compare the log-moments
n = len(x_test)

p_samples = np.random.normal(size=(n, latent_dim))
q_samples = encoder.predict(x_test, batch_size=batch_size)


from scipy.stats import moment
# First moment matching:
p_first_moment = np.log(np.mean(moment(p_samples, moment=1, axis=0)))
q_first_moment = np.log(np.mean(moment(q_samples, moment=1, axis=0)))

print("\n")
print("First log-moment of p(z): " + str(p_first_moment))
print("First log-moment of q(z): " + str(q_first_moment))
print("\n")

# Second moment matching:
p_second_moment = np.log(np.mean(moment(p_samples, moment=2, axis=0)))
q_second_moment = np.log(np.mean(moment(q_samples, moment=2, axis=0)))

print("\n")
print("Second log-moment of p(z): " + str(p_second_moment))
print("Second log-moment of q(z): " + str(q_second_moment))
print("\n")

# Thid moment matching:
p_third_moment = np.log(np.mean(moment(p_samples, moment=3, axis=0)))
q_third_moment = np.log(np.mean(moment(q_samples, moment=3, axis=0)))

print("\n")
print("Third log-moment of p(z): " + str(p_third_moment))
print("Third log-moment of q(z): " + str(q_third_moment))
print("\n")