In [None]:
from keras.models import Sequential, Model
from keras.metrics import mean_squared_error
from keras.layers import InputLayer, BatchNormalization, Conv2D, Conv2DTranspose, Dense, Layer, UpSampling2D, MaxPooling2D, Flatten, Reshape, Input, Activation, Lambda
from keras.callbacks import ModelCheckpoint
import keras.backend as K
from skimage.transform import resize
from skimage.io import imread
from zipfile import ZipFile
from io import BytesIO
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf

# 1. Generating Faces

In [None]:
%%time
with ZipFile('datasets/img_align_celeba.zip', 'r') as zf:
    keys = [name for name in zf.namelist() if name.endswith('.jpg')]
    dataset = np.memmap('datasets/blob', dtype='float32', mode='w+', shape=(len(keys), 64, 64, 3))
    for i, k in enumerate(keys):
        if i % 1000 == 0:
            print('{0:.2f}%'.format(100 * i / len(keys)))
        img = imread(BytesIO(zf.read(k))) 
        img = img + np.random.uniform(0, 1, size=img.shape)
        dataset[i] = resize(img, output_shape=(64, 64, 3), mode='constant', preserve_range=True)

In [None]:
dataset = np.memmap('datasets/blob', dtype='float32', mode='r', shape=(202599, 64, 64, 3))

In [None]:
plt.hist(dataset[:100].flatten(), bins=500, density=True)
plt.savefig('figures/preprocessing-color-distribution-histogram', dpi=300)

In [None]:
train_X = dataset
train_y = train_X

# 2. Model

In [None]:
from keras.layers import LeakyReLU, Activation

def build_encoder():
    model = Sequential()
    model.add(InputLayer(input_shape=(64,64,3)))
    
    model.add(Conv2D(128, kernel_size=(5,5), strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    
    model.add(Conv2D(256, kernel_size=(5,5), strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Conv2D(512, kernel_size=(5,5), strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    
    model.add(Conv2D(1024, kernel_size=(5,5), strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    
    # flatten into 16x16x64 components
    model.add(Flatten())
    
    return model

# 3. Architectures

In [None]:
def build_strided_deconv_decoder(input_shape):
    model = Sequential()
    model.add(InputLayer(input_shape))
    
    # unpack into 8*8
    # input is already normalized since it follows a N(0,1)
    model.add(Dense(4*4*1024))
    model.add(Reshape((4, 4, 1024)))
        
    model.add(Conv2DTranspose(512, kernel_size=(5,5), strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    
    model.add(Conv2DTranspose(256, kernel_size=(5,5), strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))

    
    model.add(Conv2DTranspose(128, kernel_size=(5,5), strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
            
    model.add(Conv2DTranspose(3, kernel_size=(3,3), strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('sigmoid'))
        
    return model

In [None]:
simple_ae_model = Sequential()
simple_ae_model.add(build_encoder())
simple_ae_model.add(Dense(100))
simple_ae_model.add(build_strided_deconv_decoder((100,)))
simple_ae_model.compile(optimizer='adam', loss='mean_squared_error')
simple_ae_model.fit(train_X[:100], train_y[:100], epochs=10)

In [None]:
plt.imshow(simple_ae_model.predict(train_X[[0]])[0])

In [None]:
class NearestNeighborUpsampling2D(Layer):
    def __init__(self, size=(2,2), **kwargs):
        self.size = size
        super(NearestNeighborUpsampling2D, self).__init__(**kwargs)
    def build(self, input_shape):
        super(NearestNeighborUpsampling2D, self).build(input_shape)
    def call(self, x):
        w = K.shape(x)[1]
        h = K.shape(x)[2]
        return tf.image.resize_nearest_neighbor(x, (self.size[0] * w, self.size[1] * h))
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.size[0]*input_shape[1], self.size[1]*input_shape[2], input_shape[3])

def build_nearest_upsampling_decoder(input_shape):
    model = Sequential()
    model.add(InputLayer(input_shape))
    
    # unpack into 8*8
    model.add(Dense(16*16*64))
    model.add(Reshape((16, 16, 64)))
    
    model.add(NearestNeighborUpsampling2D(size=(2,2)))
    model.add(Conv2D(32, kernel_size=(3,3), padding='same', activation='relu'))
    model.add(Conv2D(32, kernel_size=(3,3), padding='same', activation='relu'))
    
    model.add(NearestNeighborUpsampling2D(size=(2,2)))
    model.add(Conv2D(3, kernel_size=(3,3), padding='same', activation='relu'))
    model.add(Conv2D(3, kernel_size=(3,3), padding='same', activation='sigmoid'))
        
    return model

In [None]:
class BilinearUpSampling2D(Layer):
    def __init__(self, size=(2,2), **kwargs):
        self.size = size
        super(BilinearUpSampling2D, self).__init__(**kwargs)
    def build(self, input_shape):
        super(BilinearUpSampling2D, self).build(input_shape)
    def call(self, x):
        w = K.shape(x)[1]
        h = K.shape(x)[2]
        return tf.image.resize_bilinear(x, (self.size[0] * w, self.size[1] * h))
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.size[0]*input_shape[1], self.size[1]*input_shape[2], input_shape[3])

def build_bilinear_upsampling_decoder(input_shape):
    model = Sequential()
    model.add(InputLayer(input_shape))
    
    # unpack into 8*8
    model.add(Dense(16*16*64))
    model.add(Reshape((16, 16, 64)))
    
    model.add(BilinearUpSampling2D(size=(2,2)))
    model.add(Conv2D(32, kernel_size=(3,3), padding='same', activation='relu'))
    model.add(Conv2D(32, kernel_size=(3,3), padding='same', activation='relu'))
    model.add(Activation('relu'))
    
    model.add(BilinearUpSampling2D(size=(2,2)))
    model.add(Conv2D(3, kernel_size=(3,3), padding='same', activation='relu'))
    model.add(Conv2D(3, kernel_size=(3,3), padding='same', activation='sigmoid'))
        
    return model

In [None]:
class GaussianSample(Layer):
    def call(self, inputs):
        mu, log_sigma = inputs
        # inspired from: https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder.py#L34
        # but here we let the model approximate the stdev instead of the variance
        epsilon = K.random_normal(shape=K.shape(mu), mean=0, stddev=1)
        # mean + std * epsilon ~ N(mean, std) since epsilon ~ N(0, 1)
           
        # Kullback-Leibler divergence to N(0,1)
        kl_loss = -0.5 * K.sum(1 + 2*log_sigma - K.square(mu) - K.exp(2*log_sigma), axis=-1)
        self.add_loss(K.mean(kl_loss), inputs=[mu, log_sigma])
        
        return mu + (K.exp(log_sigma) * epsilon)
    
class GaussianImportanceWeightedSample(GaussianSample):
    def __init__(self, samples=5, **kwargs):
        self.samples = samples
        super(GaussianImportanceWeightedSample, self).__init__(**kwargs)
    def call(self, inputs):
        z = super(GaussianImportanceWeightedSample, self).call(inputs)
        for _ in range(self.samples-1):
            z = z + super(GaussianImportanceWeightedSample, self).call(inputs)
        return z / self.samples

In [None]:
def build_model(encoder, decoder):
    x = Input((64, 64, 3))
    
    enc = encoder()
    h = enc(x)
    
    # keep values in reasonable intervals
    h = BatchNormalization()(h)
    #h = Dense(100, activation='tanh')(h)
    
    z_mean = Dense(100)(h)
    z_log_std = Dense(100)(h)
    
    z = GaussianSample()([z_mean, z_log_std])
   
    dec = decoder(input_shape=(100,))
    
    y = dec(z)

    model = Model(inputs=[x], outputs=[y, z])
    def mse(y_true, y_pred):
        return mean_squared_error(256*y_true, 256*y_pred)
    model.compile(optimizer='adam', loss=[mse, None])

    return model

In [None]:
vae_strided_model = build_model(build_encoder, build_strided_deconv_decoder)
vae_strided_model_history = vae_strided_model.fit(train_X, train_y, validation_split=0.1, epochs=10, callbacks=[ModelCheckpoint('models/vae-strided-deconv-decoder.h5')])

In [None]:
vae_nearest_model = build_model(build_encoder, build_nearest_upsampling_decoder)
vae_nearest_model_history = vae_nearest_model.fit(train_X[:10000], train_y[:10000], epochs=10, callbacks=[ModelCheckpoint('models/vae-nearest-neighbor-upsampling-decoder.h5')])

In [None]:
vae_bilinear_model = build_model(build_encoder, build_bilinear_upsampling_decoder)
vae_bilinear_model_history = vae_bilinear_model.fit(train_X[:1000], train_y[:1000], epochs=10, callbacks=[ModelCheckpoint('models/vae-bilinear-upsampling-decoder.h5')])

In [None]:
from keras.models import load_model
vae_strided_model = load_model('models/vae-strided-deconv-decoder.h5', custom_objects={'GaussianSample': GaussianSample, 'NearestNeighborUpsampling2D': NearestNeighborUpsampling2D})
vae_nearest_model = load_model('models/vae-nearest-neighbor-upsampling-decoder.h5', custom_objects={'GaussianSample': GaussianSample, 'NearestNeighborUpsampling2D': NearestNeighborUpsampling2D})
vae_bilinear_model = load_model('models/vae-bilinear-upsampling-decoder.h5', custom_objects={'GaussianSample': GaussianSample, 'BilinearUpSampling2D': BilinearUpSampling2D})

In [None]:
fig, axs = plt.subplots(4, 8, figsize=(8*2, 4*2))
fig.subplots_adjust(wspace=0, hspace=0)

for j in range(8):
    axs[0][j].imshow(train_X[j])
    axs[0][j].axis('off')

for i, m in enumerate([vae_strided_model]): #, vae_nearest_model, vae_bilinear_model]):
    r = m.predict(train_X[:8].reshape(8, 64, 64, 3))
    for j in range(8):
        axs[i+1][j].imshow(r[0][j])
        axs[i+1][j].axis('off')
plt.savefig('figures/examples-of-reconstructions', dpi=300)

In [None]:
from scipy.stats import norm
z_dist = vae_strided_model.predict(train_X[:1000])[1]
plt.hist(z_dist[:,:10], bins=50, histtype='step', density=True)
plt.plot(np.linspace(-6, 6), norm.pdf(np.linspace(-6, 6)))
plt.xlabel('z')
plt.ylabel('Densité')
plt.savefig('figures/latent-space-distribution', dpi=300)

# 4. Variants

In [None]:
def build_wae_model(encoder, decoder):
    x = Input((64, 64, 3))
    
    enc = encoder()
    h = Dense(100, activation='relu')(enc(x))
    
    z_mean = Dense(100)(h)
    z_log_std = Dense(100, activation='tanh')(h)
    
    z = GaussianImportanceWeightedSample(samples=5)([z_mean, z_log_std])
   
    dec = decoder(input_shape=(100,))
    
    y = dec(z)

    model = Model(inputs=[x], outputs=[y, z])
    model.compile(optimizer='adam', loss=['mean_squared_error', None])

    return model

In [None]:
wvae_nearest_model = build_wae_model(build_encoder, build_nearest_upsampling_decoder)
wvae_nearest_model_history = wvae_strided_model.fit(train_X[:1000], train_y[:1000], epochs=10, callbacks=[ModelCheckpoint('models/vae-weighted-strided-deconv-decoder.h5')])

In [None]:
fig, axs = plt.subplots(2, 8, figsize=(8*2, 2*2))
fig.subplots_adjust(wspace=0, hspace=0)

for j in range(8):
    axs[0][j].imshow(train_X[j])
    axs[0][j].axis('off')

for i, m in enumerate([vae_nearest_model, wvae_nearest_model]):
    r = m.predict(train_X[:8].reshape(8, 64, 64, 3))
    for j in range(8):
        axs[i+1][j].imshow(r[0][j].astype('uint8'))
        axs[i+1][j].axis('off')
plt.savefig('figures/wvae-vs-vae', dpi=300)

# Qualitative Evaluations

From now on, we'll use the VAE with nearest-neighbor upsampling since it shown the best results so far. We will also train it on a larger subset of the dataset.

In [None]:
#vae_nearest_model = build_model(build_encoder, build_nearest_upsampling_decoder)
vae_nearest_model_history = vae_nearest_model.fit(train_X[:10000], train_y[:10000], epochs=10, callbacks=[ModelCheckpoint('models/vae-best-nearest-neighbor-upsampling-decoder.h5')])

In [None]:
plt.imshow(vae_nearest_model.predict(train_X[[1]])[0][0])

## a)

In [None]:
ls

In [None]:
fig, axs = plt.subplots(2, 10, figsize=(2*5, 2*10))
fig.subplots_adjust(wspace=0, hspace=0)
mu_space = np.linspace(0, 5, num=5) # 0 to 5 standard deviation
for j in range(5):
    latent_repr = K.eval(vae_nearest_model.layers[-1](K.variable(mu_space[j]*np.eye(100))))
    for i in range(10):
        axs[i][j].imshow(latent_repr[j].astype('uint8'))
        axs[i][j].axis('off')
plt.savefig('figures/vae-nearest-progression-in-latent-space', dpi=300)

## b)