In [1]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
%matplotlib inline
from skimage import color, transform

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

from IPython.display import clear_output

## Parameters

In [2]:
batch_size = 64

## Load the dataset

In [3]:
def normalize(x, max_value):
    """ If x takes its values between 0 and max_value, normalize it between -1 and 1"""
    return (x / float(max_value)) * 2 - 1

In [4]:
mnist = input_data.read_data_sets("/datasets/mnist/", one_hot=False)

Extracting /datasets/mnist/train-images-idx3-ubyte.gz
Extracting /datasets/mnist/train-labels-idx1-ubyte.gz
Extracting /datasets/mnist/t10k-images-idx3-ubyte.gz
Extracting /datasets/mnist/t10k-labels-idx1-ubyte.gz


In [5]:
def transform_mnist(X):
    X = X.reshape(len(X), 28, 28)
    X = np.array([transform.resize(im, [32,32]) for im in X])
    X = normalize(X, 1)
    X = X.reshape(len(X), 32, 32, 1)
    
    return X

In [6]:
X_mnist = transform_mnist(mnist.train.images)

  warn("The default mode, 'constant', will be changed to 'reflect' in "


## Create the model

### Useful functions

In [7]:
def leaky_relu(x):
    alpha = 0.05
    return tf.maximum(x, alpha * x)

In [8]:
def instance_normalization(x, name):
    with tf.variable_scope("instance_norm"):
        with tf.variable_scope(name):
            epsilon = 1e-5
            mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
            scale = tf.get_variable('scale',[x.get_shape()[-1]], 
                                    initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02))
            offset = tf.get_variable('offset',[x.get_shape()[-1]],initializer=tf.constant_initializer(0.0))
            out = scale*tf.div(x-mean, tf.sqrt(var+epsilon)) + offset

            return out

### Placeholders

In [9]:
ipt_mnist = tf.placeholder(tf.float32, shape=[None, 32, 32, 1], name="ipt_mnist")

### Encoder

In [15]:
def encoder(x):
    """Encoder for the VAE. Map the input to a latent embedding space
    
    Parameters
    ----------
    x : tensor of shape = [?, 32, 32, 1]
        Normally takes a real image (except if you use cycle-consistency)
    scope : {'source', 'target'}
        Corresponds to the domain of x

    Returns
    -------
    mu : tensor of shape = [?, 8, 8, 1024]
        Mean of the embedding space conditionned on x
    log_sigma_sq : tensor of shape = [?, 8, 8, 1024]
        log of the variance of the embedding space conditionned on x
    z : tensor of shape = [?, 8, 8, 1024]
        Random sample generated from mu(x) and sigma(x)
        
    """
    
    initializer = tf.contrib.layers.xavier_initializer()
    
    with tf.variable_scope("encoder", reuse=None):
        # Layer 1: 32x32x1 --> 16x16x64
        conv1 = tf.layers.conv2d(x, 64, [5, 5], strides=2, padding='SAME', 
                                 kernel_initializer=initializer, activation=leaky_relu)
        conv1 = instance_normalization(conv1, "conv1")
    
        # Layer 2: 16x16x64 --> 8x8x128
        conv2 = tf.layers.conv2d(conv1, 128, [5, 5], strides=2, padding='SAME', 
                                 kernel_initializer=initializer, activation=leaky_relu)
        conv2 = instance_normalization(conv2, "conv2")
        
        # Layer 3: 8x8x128 --> 8x8x256
        conv3 = tf.layers.conv2d(conv2, 256, [8, 8], strides=1, padding='SAME', 
                                 kernel_initializer=initializer, activation=leaky_relu)
        conv3 = instance_normalization(conv3, "conv3")
        
        # Layer 4: 8x8x256 --> 8x8x512
        conv4 = tf.layers.conv2d(conv3, 512, [1, 1], strides=1, padding='SAME', 
                                 kernel_initializer=initializer, activation=leaky_relu)
        conv4 = instance_normalization(conv4, "conv4")

        # Layer 5 : 8x8x512 --> 8x8x1024
        mu = tf.layers.conv2d(conv4, 1024, [1, 1], strides=1, padding='SAME', 
                              kernel_initializer=initializer, activation=None)
        log_sigma_sq = tf.layers.conv2d(conv4, 1024, [1, 1], strides=1, padding='SAME', 
                              kernel_initializer=initializer, activation=None)
        
        z = mu + tf.multiply(tf.exp(log_sigma_sq / 2), tf.random_normal([tf.shape(x)[0],8,8,1024],0,1,dtype=tf.float32)) # latent space
        
    return mu, log_sigma_sq, z

In [16]:
def decoder(z):
    """Decoder for the VAE. Map the latent space to the image space
    
    Parameters
    ----------
    x : tensor of shape = [?, 8, 8, 1024]
        Normally takes an encoded image (point in the embedding space)
    scope : {'source', 'target'}
        Corresponds to the domain of x

    Returns
    -------
    deconv5 : tensor of shape = [?, 32, 32, 3]
        Generated image
    """
    initializer = tf.contrib.layers.xavier_initializer()
    
    with tf.variable_scope("decoder", reuse=None): # shared weights
        # Layer 1: 8x8x1024 --> 8x8x512
        deconv1 = tf.layers.conv2d_transpose(z, 512, [4, 4], strides=2, padding='SAME', kernel_initializer=initializer, activation=leaky_relu)
        deconv1 = instance_normalization(deconv1, "deconv1")
        
        # Layer 2: 8x8x512 --> 16x16x256
        deconv2 = tf.layers.conv2d_transpose(deconv1, 256, [4, 4], strides=2, padding='SAME', kernel_initializer=initializer, activation=leaky_relu)
        deconv2 = instance_normalization(deconv2, "deconv2")
        
        # Layer 3: 16x16x256 --> 32x32x512
        deconv3 = tf.layers.conv2d_transpose(deconv2, 128, [4, 4], strides=1, padding='SAME', kernel_initializer=initializer, activation=leaky_relu)
        deconv3 = instance_normalization(deconv3, "deconv3")
        
        # Layer 3: 16x16x256 --> 32x32x512
        deconv4 = tf.layers.conv2d_transpose(deconv3, 64, [4, 4], strides=1, padding='SAME', kernel_initializer=initializer, activation=leaky_relu)
        deconv4 = instance_normalization(deconv4, "deconv4")

        # Layer 6: 16x16x64 --> 32x32x3
        deconv5 = tf.layers.conv2d_transpose(deconv4, 1, [1, 1], strides=1, padding='SAME', kernel_initializer=initializer, activation=tf.nn.tanh)
        
    return deconv5

### Define the graph

In [17]:
E_mean, E_log_sigma_sq, E_space = encoder(ipt_mnist)

In [18]:
D_rec = decoder(E_space)

### Losses

In [20]:
lambda_rec = 1
lambda_kl = 10

In [25]:
def log(tensor):
    return tf.log(tensor + 1e-7)

In [21]:
def normalize(x):
    return (x + 1) / 2

In [22]:
def reconstruction_loss(x, x_rec):
    x, x_rec = normalize(x), normalize(x_rec)
    return - tf.reduce_mean(x * log(x_rec) + (1 - x) * log(1 - x_rec))

In [23]:
def latent_loss(mean, log_std_sq):
    return 0.5 * tf.reduce_mean(tf.square(mean) + tf.exp(log_std_sq) - log_std_sq - 1.)

In [26]:
vae_loss = lambda_rec * reconstruction_loss(ipt_mnist, D_rec) \
           + lambda_kl * latent_loss(E_mean, E_log_sigma_sq)

### Solvers

In [28]:
model_vars = tf.trainable_variables()

In [30]:
with tf.variable_scope("optim", reuse=None):
    vae_solver = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(vae_loss, var_list=model_vars)

## Run the model

In [31]:
sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=False))

In [32]:
sess.run(tf.global_variables_initializer())
vae_loss_list = []
iter_list = []
i = 0

In [36]:
nb_iter = 100000
nb_iter_epoch = 1
verbose = True
i_init = i

for i in range(i, nb_iter+i):
    for j in range(nb_iter_epoch):
        sample_mnist = X_mnist[np.random.choice(len(X_mnist), batch_size)]

        _, vae_loss_curr = sess.run([vae_solver, vae_loss], feed_dict={ipt_mnist: sample_mnist})

    iter_list.append(i)

    vae_loss_list.append(vae_loss_curr)
    
    if verbose:
        clear_output(wait=True)
        print('Iter: {} / {}'.format(i, i_init + nb_iter - 1))
        print('VAE loss: {:.4}'.format(vae_loss_curr))
        print()

Iter: 8 / 99999
VAE loss: 1.569



KeyboardInterrupt: 

## Display the results

In [None]:
def unnormalize(x):
    return np.array((x + 1)/2)

In [None]:
X_rec = unnormalize(sess.run(D_rec, feed_dict={ipt_mnist: X_mnist[:batch_size]}))

In [None]:
plt.rcParams['figure.figsize'] = (15, 10)
# plt.axes().set_aspect('equal', 'datalim')
#plt.axis([-3,3,-3,3])

index = 7

plt.subplot(1,1,2)
plt.imshow(unnormalize(X_mnist[index]))
plt.axis('off')


plt.subplot(2,1,2)
# plt.imshow(np.transpose(X_target_trans[index].reshape(3,32,32), (1,2,0)))
plt.imshow(X_rec[index].reshape(32,32,3))
plt.axis('off')