In [None]:
from keras.models import Sequential, Model
from keras.metrics import mean_squared_error
from keras.layers import InputLayer, Conv2D, Conv2DTranspose, Dense, Layer, UpSampling2D, MaxPooling2D, Flatten, Reshape, Input, Activation, LeakyReLU, Lambda
from keras.callbacks import ModelCheckpoint
import keras.backend as K
from skimage.transform import resize
from skimage.io import imread
from zipfile import ZipFile
from io import BytesIO
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
from keras.layers import BatchNormalization

In [None]:
with ZipFile('datasets/img_align_celeba.zip', 'r') as zf:
    keys = [name for name in zf.namelist() if name.endswith('.jpg')]
#    dataset = np.memmap('datasets/blob', dtype='uint8', mode='w+', shape=(len(keys), 64, 64, 3))
#    for i, k in enumerate(keys):
#        dataset[i] = resize(imread(BytesIO(zf.read(k))), output_shape=(64, 64, 3), mode='constant', preserve_range=True)

In [None]:
dataset = np.memmap('datasets/blob', dtype='uint8', mode='r', shape=(len(keys), 64, 64, 3))

In [None]:
train_X = dataset
train_y = train_X

In [None]:
def build_encoder():
    model = Sequential()
    model.add(InputLayer(input_shape=(64,64,3)))
    
    model.add(BatchNormalization())
    model.add(Conv2D(32, kernel_size=(3,3), padding='same'))
    model.add(Conv2D(32, kernel_size=(3,3), padding='same'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Activation('relu'))
    
    model.add(BatchNormalization())
    model.add(Conv2D(64, kernel_size=(3,3), padding='same'))
    model.add(Conv2D(64, kernel_size=(3,3), padding='same'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Activation('relu'))

    # flatten into 16x16x64 components
    model.add(BatchNormalization())
    model.add(Flatten())
    
    return model

In [None]:
def build_strided_deconv_decoder(input_shape):
    model = Sequential()
    model.add(InputLayer(input_shape))
    
    # unpack into 8*8
    # input is already normalized since it follows a N(0,1)
    model.add(Dense(16*16*64))
    model.add(Reshape((16, 16, 64)))
    
    model.add(Conv2DTranspose(32, kernel_size=(3,3), strides=2, padding='same'))
    model.add(Conv2DTranspose(32, kernel_size=(3,3), padding='same'))
    model.add(Activation('relu'))
    
    model.add(Conv2DTranspose(3, kernel_size=(3,3), strides=2, padding='same'))
    model.add(Conv2DTranspose(3, kernel_size=(3,3), padding='same'))
    model.add(Activation('tanh'))
    model.add(Lambda(lambda x: 255*(x+1))) # rescale to RGB
        
    return model

In [None]:
class NearestNeighborUpsampling2D(Layer):
    def __init__(self, size=(2,2), **kwargs):
        self.size = size
        super(NearestNeighborUpsampling2D, self).__init__(**kwargs)
    def build(self, input_shape):
        super(NearestNeighborUpsampling2D, self).build(input_shape)
    def call(self, x):
        print()
        w = K.shape(x)[1]
        h = K.shape(x)[2]
        return tf.image.resize_nearest_neighbor(x, (self.size[0] * w, self.size[1] * h))

def build_nearest_upsampling_decoder(input_shape):
    model = Sequential()
    model.add(InputLayer(input_shape))
    
    # unpack into 8*8
    model.add(Dense(16*16*64))
    model.add(Reshape((16, 16, 64)))
    
    model.add(NearestNeighborUpsampling2D(size=(2,2)))
    model.add(Conv2D(32, kernel_size=(3,3), padding='same'))
    model.add(Conv2D(32, kernel_size=(3,3), padding='same'))
    model.add(Activation('relu'))
    
    model.add(NearestNeighborUpsampling2D(size=(2,2)))
    model.add(Conv2D(3, kernel_size=(3,3), padding='same'))
    model.add(Conv2D(3, kernel_size=(3,3), padding='same'))
    model.add(Activation('relu'))
        
    return model

In [None]:
class BilinearUpSampling2D(Layer):
    def __init__(self, size=(2,2), **kwargs):
        self.size = size
        super(BilinearUpSampling2D, self).__init__(**kwargs)
    def build(self, input_shape):
        super(BilinearUpSampling2D, self).build(input_shape)
    def call(self, x):
        print()
        w = K.shape(x)[1]
        h = K.shape(x)[2]
        return tf.image.resize_bilinear(x, (self.size[0] * w, self.size[1] * h))

def build_bilinear_upsampling_decoder(input_shape):
    model = Sequential()
    model.add(InputLayer(input_shape))
    
    # unpack into 8*8
    model.add(Dense(16*16*64))
    model.add(Reshape((16, 16, 64)))
    
    model.add(BilinearUpSampling2D(size=(2,2)))
    model.add(Conv2D(32, kernel_size=(3,3), padding='same'))
    model.add(Conv2D(32, kernel_size=(3,3), padding='same'))
    model.add(Activation('relu'))
    
    model.add(BilinearUpSampling2D(size=(2,2)))
    model.add(Conv2D(3, kernel_size=(3,3), padding='same'))
    model.add(Conv2D(3, kernel_size=(3,3), padding='same'))
    model.add(Activation('relu'))
        
    return model

In [None]:
def build_model(encoder, decoder):
    x = Input((64, 64, 3))
    
    latent_dims = 100
    
    enc = encoder()
    h = enc(x)
    
    z_mean = Dense(100)(h)
    z_log_std = Dense(100, activation='tanh')(h) # bound the log stdev in [-1, 1], which corresponds to [-e, e]
    
    def sample(inputs):
        mean, log_std = inputs
        # inspired from: https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder.py#L34
        # but here we let the model approximate the stdev instead of the variance
        epsilon = K.random_normal(shape=(K.shape(mean)[0], latent_dims), mean=0, stddev=1)
        # mean + std * epsilon ~ N(mean, std) since epsilon ~ N(0, 1)
        return mean + (K.exp(log_std) * epsilon)
    
    z = Lambda(sample)([z_mean, z_log_std])
   
    dec = decoder(input_shape=(100,))
    
    y = dec(z)
   
    # Kullback-Leibler divergence
    kl_loss = - 0.5 * K.sum(1 + 2*z_log_std - K.square(z_mean) - K.exp(2*z_log_std), axis=-1)
    
    vae_loss = mean_squared_error(K.flatten(x), K.flatten(y)) + K.mean(kl_loss)

    model = Model(inputs=[x], outputs=[y, z])
    model.add_loss(vae_loss)
    model.compile(optimizer='adam', loss=[None, None])

    return model

In [None]:
vae_strided_model = build_model(build_encoder, build_strided_deconv_decoder)
vae_strided_model_history = vae_strided_model.fit(train_X[:1000], epochs=30, validation_split=0.33, callbacks=[ModelCheckpoint('models/vae-strided-deconv-decoder.h5')])

In [None]:
vae_nearest_model = build_model(build_encoder, build_nearest_upsampling_decoder)
vae_nearest_model_history = vae_nearest_model.fit(train_X[:10000], epochs=30, validation_split=0.33, callbacks=[ModelCheckpoint('models/vae-nearest-neighbor-upsampling-decoder.h5')])

In [None]:
vae_bilinear_model = build_model(build_encoder, build_bilinear_upscaling_decoder)
vae_bilinear_model_history = vae_bilinear_model.fit(train_X[:10000], epochs=30, validation_split=0.33, callbacks=[ModelCheckpoint('models/vae-bilinear-upsampling-decoder.h5')])

In [None]:
fig, axs = plt.subplots(4, 8, figsize=(8*2, 4*2))
reconstructions = vae_strided_model.predict(train_X[:32].reshape(32, 64, 64, 3))
fig.subplots_adjust(wspace=0, hspace=0)
for i in range(4):
    for j in range(0, 8, 2):
        axs[i][j].imshow(train_X[i*4+j])
        axs[i][j+1].imshow(reconstructions[0][i*4+j].astype('uint8'))
        axs[i][j].axis('off')
        axs[i][j+1].axis('off')
plt.savefig('figures/examples-of-reconstructions', dpi=300)

In [None]:
z_dist = model.predict(train_X)[1]
for i in range(10):
    plt.hist(z_dist[:,i], bins=100, histtype='step')
plt.savefig('figures/latent-space-distribution', dpi=300)

In [None]:
fig, axs = plt.subplots(20, 10, figsize=(2*10, 2*20))
fig.subplots_adjust(wspace=0, hspace=0)
mu_space = np.linspace(0, 20, num=10)
for j in range(10):
    latent_repr = K.eval(model.layers[-1](K.variable(mu_space[j]*np.eye(100)))).astype('uint8')
    for i in range(20):
        axs[i][j].imshow(latent_repr[i])
        axs[i][j].axis('off')
plt.savefig('figures/progression', dpi=300)

In [None]:
from keras.layers import GaussianNoise
denoiser = Sequential()

denoiser.add(Conv2D(32, kernel_size=(3,3), activation='relu', input_shape=(64, 64, 3), padding='same'))
denoiser.add(Conv2D(64, kernel_size=(3,3), activation='relu', padding='same'))

denoiser.add(Conv2DTranspose(32, kernel_size=(3,3), padding='same', activation='relu'))
denoiser.add(Conv2DTranspose(3, kernel_size=(3,3), padding='same', activation='relu'))

denoiser.compile(optimizer='adam', loss='mean_squared_error')
denoiser.fit(train_X[:1000] + np.random.normal(0, 1, train_X[:1000].shape), train_y[:1000], epochs=100, validation_split=0.33, callbacks=[ModelCheckpoint('models/denoiser.h5')])

In [None]:
from keras.models import load_model
denoiser = load_model('models/denoiser.h5')

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
axs[0].imshow(model.predict(train_X[[0]])[0][0].astype('uint8'))
axs[1].imshow(denoiser.predict(model.predict(train_X[[0]])[0])[0].astype('uint8'))