<a href="https://colab.research.google.com/github/TivoGatto/VAEs/blob/master/beta_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# LIBRARIES
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt

import keras
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, Conv2DTranspose, Lambda, Reshape, BatchNormalization, ReLU, Flatten
from keras.optimizers import Adam
import keras.backend as K

Using TensorFlow backend.


In [0]:
# Parameters

input_dim = (32, 32, 1)
intermediate_dim = 256
latent_dim = 4

batch_size = 200
epochs = 20
epsilon_std = 1.0

beta = 3
C = 3

# SAVING OPTIONS
SAVE_MODEL = True
PRINT_MODEL = True

In [0]:
# Functions

def vae_loss(x_true, x_pred):
	x_true = K.reshape(x_true, (-1, np.prod(input_dim)))
	x_pred = K.reshape(x_pred, (-1, np.prod(input_dim)))

	xent_loss = K.mean(0.5 * K.sum(K.square(x_true - x_pred), axis=-1))
	reg_loss = K.mean(0.5 * K.sum(K.square(z_mean) + K.exp(z_log_var) - 1 - z_log_var, axis=-1))

	return K.mean(xent_loss + beta * K.abs(reg_loss - C))
 
def reconstruction(x_true, x_pred):
	x_true = K.reshape(x_true, (-1, np.prod(input_dim)))
	x_pred = K.reshape(x_pred, (-1, np.prod(input_dim)))

	return K.mean(0.5 * K.sum(K.square(x_true - x_pred), axis=-1))

def regularizer(x_true, x_pred):
    reg_loss = K.mean(0.5 * K.sum(K.square(z_mean) + K.exp(z_log_var) - 1 - z_log_var, axis=-1))

    return K.mean(reg_loss)

def sampling(args):
	z_mean, z_log_var = args
	epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0, stddev=epsilon_std)

	return z_mean + K.exp(0.5 * z_log_var) * epsilon

def pad(x, n):
	d = x.shape[1]
	N = len(x)

	data = np.zeros(shape=(N, n, n))
	for i in range(N):
		data[i, :d, :d] = x[i]
	
	return data

In [0]:
# Model
# ENCODER
x = Input(shape=(input_dim))

h = Conv2D(32, 4, strides=(2, 2), padding='same')(x)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2D(64, 4, strides=(2, 2), padding='same')(h)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2D(128, 4, strides=(2, 2), padding='same')(h)
h = BatchNormalization()(h)
h = ReLU()(h)

shape_before_flattening = K.int_shape(h)
h = Flatten()(h)

h = Dense(intermediate_dim, activation='relu')(h)

z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)

z = Lambda(sampling)([z_mean, z_log_var])

encoder = Model(x, [z, z_mean, z_log_var])

# DECODER
z_in = Input(shape=(latent_dim, ))

h = Dense(np.prod(shape_before_flattening[1:]), activation='relu')(z_in)

h = Reshape(shape_before_flattening[1:])(h)

h = Conv2DTranspose(128, 4, strides=(2, 2), padding='same')(h)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2DTranspose(64, 4, strides=(2, 2), padding='same')(h)
h = BatchNormalization()(h)
h = ReLU()(h)

h = Conv2DTranspose(32, 4, strides=(2, 2), padding='same')(h)
h = BatchNormalization()(h)
h = ReLU()(h)

x_recon = Conv2DTranspose(1, 4, strides=(1, 1), padding='same', activation='sigmoid')(h)

decoder = Model(z_in, x_recon)

# VAE
x_pred = decoder(z)
vae = Model(x, x_pred)

optimizer = Adam(lr=5e-4)
vae.compile(optimizer=optimizer, loss=vae_loss, metrics=[reconstruction, regularizer])

In [0]:
# Dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255

x_train = pad(x_train, 32)
x_test = pad(x_test, 32)

x_train = np.reshape(x_train, (-1, ) + input_dim)
x_test = np.reshape(x_test, (-1, ) + input_dim)

print("X_train shape: ", x_train.shape)
print("X_test shape: ", x_test.shape)

X_train shape:  (60000, 32, 32, 1)
X_test shape:  (10000, 32, 32, 1)


In [0]:
# TRAIN
hist = vae.fit(x_train, x_train, batch_size=batch_size, epochs=epochs)

if SAVE_MODEL:
    from google.colab import drive
    drive.mount('/content/drive')

    vae.save('beta_vae.h5', overwrite=True)
    encoder.save('beta_vae_encoder.h5', overwrite=True)
    decoder.save('beta_vae_decoder.h5', overwrite=True)

if PRINT_MODEL:
    from keras.utils import plot_model

    plot_model(vae, 'beta_vae.png', show_shapes=True)
    plot_model(encoder, 'beta_vae_encoder.png', show_shapes=True)
    plot_model(decoder, 'beta_vae_decoder.png', show_shapes=True)

Epoch 1/20
 4800/60000 [=>............................] - ETA: 7:56 - loss: 110.6136 - reconstruction: 75.6019 - regularizer: 14.4786

KeyboardInterrupt: ignored

Ho trainato per 20 epochs, e noto che inizialmente scende il reconstruction error, e il regularizer rimane praticamente fermo (attorno a 4 come valore). Dopo un po' di epoche, inizia a scendere velocemente verso zero.

In [0]:
z_values = encoder.predict(x_train, batch_size=batch_size)[0]
z_values = np.array(z_values)

plt.scatter(z_values[:, 0], z_values[:, 1], c=y_train)
plt.colorbar()
plt.xlabel("z_0")
plt.ylabel("z_1")
plt.title("Latent Space Visualization MNIST")
plt.show()

In [0]:
x_recon = vae.predict(x_train, batch_size=batch_size)

x_train_temp = np.reshape(x_train, (-1, 32, 32))
x_recon_temp = np.reshape(x_recon, (-1, 32, 32))

n = 10
fig_size = 32
for i in range(n):
    figure = np.zeros(shape=(fig_size, fig_size * 2))

    figure[:, :fig_size] = x_train_temp[i]
    figure[:, fig_size:] = x_recon_temp[i]

    plt.imshow(figure, cmap='gray')
    plt.show()

In [0]:
# Generating samples from p(z) = N(O, I)
n = 10
digit_size = 32
stddev = 2

z_values = np.random.normal(size=(n ** 2, latent_dim), scale=stddev)
x_generated = decoder.predict(z_values, batch_size=n ** 2)

x_generated_temp = np.reshape(x_generated, (n ** 2, 32, 32))
figure = np.zeros(shape=(n * digit_size, n * digit_size))
for i in range(n):
    for j in range(n):
        figure[j * digit_size : (j+1) * digit_size, i * digit_size : (i+1) * digit_size] = x_generated_temp[i + j * n]

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='gray')
plt.show()