In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from matplotlib import pyplot as plt

import numpy as np
import seaborn as sns
import math
import tensorflow as tf
import random
from scipy.misc import imsave
import os
import datetime as dt

# Change figure aesthetics
%matplotlib inline
sns.set_context('talk', font_scale=1.2, rc={'lines.linewidth': 1.5})

## dSprites Dataset

[dSprites](https://github.com/deepmind/dsprites-dataset) is a dataset of 2D shapes procedurally generated from 6 ground truth independent latent factors. These factors are color, shape, scale, rotation, x and y positions of a sprite.

All possible combinations of these latents are present exactly once, generating N = 737280 total images.

* Color: white
* Shape: square, ellipse, heart
* Scale: 6 values linearly spaced in [0.5, 1]
* Orientation: 40 values in [0, 2 pi]
* Position X: 32 values in [0, 1]
* Position Y: 32 values in [0, 1]

In [None]:
# Load dataset
dataset_zip = np.load('dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', encoding = 'latin1')

- `imgs` : (737280 x 64 x 64, uint8) Images in black and white.
- `latents_values` : (737280 x 6, float64) Values of the latent factors.
- `latents_classes` : (737280 x 6, int64) Integer index of the latent factor values. Useful as classification targets.
- `metadata` : some additional information, including the possible latent values.

In [3]:
print('Keys in the dataset:', dataset_zip.keys())
imgs = dataset_zip['imgs']                             # imgs: (737280 x 64 x 64, uint8) Images in black and white.
latents_values = dataset_zip['latents_values']         # latents_values : (737280 x 6, float64) Values of the latent factors.
latents_classes = dataset_zip['latents_classes']       # latents_classes: (737280 x 6, int64) Integer index of the latent factor values. Useful as classification targets.
metadata = dataset_zip['metadata'][()]                 # metadata: some additional information, including the possible latent values.

#print('Metadata: \n', metadata)

Keys in the dataset: ['metadata', 'imgs', 'latents_classes', 'latents_values']


In [0]:
# Define number of values per latents and functions to convert to indices
latents_sizes = metadata['latents_sizes'] # latents_sizes = [ 1  3  6 40 32 32]
latents_bases = np.concatenate((latents_sizes[::-1].cumprod()[::-1][1:], np.array([1,]))) # latents_bases = [737280 245760  40960   1024     32      1]
n_samples = latents_bases[0]

def latent_to_index(latents):
    return np.dot(latents, latents_bases).astype(int)


def sample_latent(size=1):
    samples = np.zeros((size, latents_sizes.size))
    for lat_i, lat_size in enumerate(latents_sizes):
        samples[:, lat_i] = np.random.randint(lat_size, size=size)

    return samples

In [0]:
# image getter methods
def sample_image(shape = 0, scale = 0, orientation = 0, x = 0, y = 0):
    latents = [0, shape, scale, orientation, x, y]
    index = np.dot(latents, self.latents_bases).astype(int)
    return get_images([index])[0]

def sample_images(indices):
    images = []
    for index in indices:
        img = imgs[index]
        img = img.reshape(4096)
        images.append(img)
    return images

def sample_random_images(size):
    indices = [np.random.randint(n_samples) for i in range(size)]
    return sample_images(indices)

In [0]:
epochs = 10000
batch_size = 64
beta = 1000
capacity_limit = 25.0
capacity_change_duration = 100000
learning_rate = 5e-4
checkpoint_dir = "./beta_checkpoints"
log_file = "log_files"

## Model Architecture
- Input(4096, flattened) -> FC(1200) -> FC(1200) -> FC(10) -> FC(1200) -> FC(1200) -> FC(1200) -> FC(4096)
- ReLU activations are used throughout.
- Adam optimiser with a learning rate of 5e-4 is used to train the network.

In [0]:
# model architecture
inputs_ = tf.layers.Flatten()(tf.placeholder(tf.float32, (None, 64,64,1), name="input")) # Input placeholder
capacity = tf.placeholder(tf.float32, shape=[]) # Encoding capcity

with tf.variable_scope("Encoder"):
    fc1 = tf.layers.dense(inputs_, 1200, activation = tf.nn.relu, kernel_initializer=tf.contrib.layers.xavier_initializer())
    fc2 = tf.layers.dense(fc1, 1200, activation = tf.nn.relu, kernel_initializer=tf.contrib.layers.xavier_initializer())
    fc3 = tf.layers.dense(fc2, 20, kernel_initializer=tf.contrib.layers.xavier_initializer())
    mean = fc3[:,:10]
    log_std_dev = tf.clip_by_value(fc3[:,10:],1e-8,5)

eps = tf.random_normal( tf.shape(mean), 0, 1, dtype=tf.float32 )
z = tf.add(mean, tf.multiply(tf.sqrt(tf.exp(log_std_dev)), eps)) # z = mu + sigma * epsilon

with tf.variable_scope("Decoder"):
    fc4 = tf.layers.dense(z, 1200, activation = tf.nn.relu, kernel_initializer=tf.contrib.layers.xavier_initializer())
    fc5 = tf.layers.dense(fc4, 1200, activation = tf.nn.relu, kernel_initializer=tf.contrib.layers.xavier_initializer())
    fc6 = tf.layers.dense(fc5, 1200, activation = tf.nn.relu, kernel_initializer=tf.contrib.layers.xavier_initializer())
    reconstruct_logit = tf.layers.dense(fc6, 4096, kernel_initializer=tf.contrib.layers.xavier_initializer())
    reconstruct = tf.nn.sigmoid(reconstruct_logit)
with tf.variable_scope("Loss"):
    reconstr_loss = tf.reduce_mean(tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels = inputs_, logits = reconstruct_logit),1)) # Reconstruction loss
    latent_loss = -0.5 * tf.reduce_mean(tf.reduce_sum(1 + 2*log_std_dev - tf.square(mean) - tf.square(tf.exp(log_std_dev)),1)) # Latent loss
    loss = reconstr_loss + beta * tf.abs(latent_loss - capacity)
    
    reconstr_loss_summary_op = tf.summary.scalar('reconstr_loss', reconstr_loss)
    latent_loss_summary_op   = tf.summary.scalar('latent_loss',   latent_loss)
    summary_op = tf.summary.merge([reconstr_loss_summary_op, latent_loss_summary_op])
    
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(loss)

In [0]:
# helping functions
def _calc_encoding_capacity(step):
    if step > capacity_change_duration:
        c = capacity_limit
    else:
        c = capacity_limit * (step / capacity_change_duration)
    return c

def batch_train(sess, xs, step):
    c = _calc_encoding_capacity(step)
    _, reconstruction_loss, latent_z_loss, summary_str = sess.run((optimizer, reconstr_loss, latent_loss, summary_op), feed_dict={inputs_ : xs, capacity : c})
    return reconstruction_loss, latent_z_loss, summary_str
  
def input_to_output(sess, xs):
    # Original VAE output
    return sess.run(reconstruct, feed_dict={inputs_: xs})

def input_to_latent(sess, xs):
    return sess.run([mean, log_std_dev], feed_dict={inputs_: xs})

def latent_to_output(sess, zs):
    """ Generate data by sampling from latent space. """
    return sess.run(reconstruct, feed_dict={z: zs})
saver = tf.train.Saver()

In [9]:
total_batch = n_samples // batch_size
print(total_batch)

11520


In [10]:
# training step
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    summary_writer = tf.summary.FileWriter(log_file, sess.graph)
    reconstruct_check_images = sample_random_images(10)
    indices = list(range(n_samples))
    step = 0
    n1=dt.datetime.now()
    for epoch in range(epochs):
        random.shuffle(indices)
        r_loss_term = 0
        l_loss_term = 0
        for i in range(total_batch):
            batch_indices = indices[batch_size*i : batch_size*(i+1)]
            batch_xs = sample_images(batch_indices)

            # Fit training using batch data
            reconstruction_loss, latent_z_loss, summary_str = batch_train(sess, batch_xs, step)
            summary_writer.add_summary(summary_str, step)
            if step%1000 == 0:
                n2=dt.datetime.now()
                print("step count: "+str(step)+" time: "+str((n2-n1).seconds)+"sec reconstruction loss: "+str(reconstruction_loss)+" latent loss: "+str(latent_z_loss)+" capacity: "+str(_calc_encoding_capacity(step)))
                n1 = n2
            step += 1
            r_loss_term += reconstruction_loss
            l_loss_term += latent_z_loss
        
        print("------------------------------epoch: "+str(epoch)+" reconstruction loss: "+str(r_loss_term/total_batch)+" latent loss: "+str(l_loss_term/total_batch)+"------------------------------")
        # Save checkpoint
        save_path = saver.save(sess, checkpoint_dir + '/' + 'checkpoint', global_step = epoch)

step count: 0 time: 1sec reconstruction loss: 2838.9702 latent loss: 0.4028828 capacity: 0.0
step count: 1000 time: 13sec reconstruction loss: 543.4658 latent loss: 0.3408327 capacity: 0.25
step count: 2000 time: 14sec reconstruction loss: 493.11334 latent loss: 0.54778117 capacity: 0.5
step count: 3000 time: 13sec reconstruction loss: 483.64294 latent loss: 0.90990674 capacity: 0.75
step count: 4000 time: 13sec reconstruction loss: 487.61887 latent loss: 1.1022894 capacity: 1.0
step count: 5000 time: 13sec reconstruction loss: 409.75787 latent loss: 1.1762537 capacity: 1.25
step count: 6000 time: 14sec reconstruction loss: 386.47943 latent loss: 1.4615921 capacity: 1.5
step count: 7000 time: 13sec reconstruction loss: 405.71588 latent loss: 1.768379 capacity: 1.7500000000000002
step count: 8000 time: 13sec reconstruction loss: 448.90472 latent loss: 2.0521615 capacity: 2.0
step count: 9000 time: 13sec reconstruction loss: 391.42023 latent loss: 2.3645582 capacity: 2.25
step count: 100

KeyboardInterrupt: ignored

In [0]:
# Helper function to show images
def show_images_grid(imgs_, num_images=25):
  ncols = int(np.ceil(num_images**0.5))
  nrows = int(np.ceil(num_images / ncols))
  _, axes = plt.subplots(ncols, nrows, figsize=(nrows * 3, ncols * 3))
  axes = axes.flatten()

  for ax_i, ax in enumerate(axes):
    if ax_i < num_images:
      ax.imshow(imgs_[ax_i], cmap='Greys_r',  interpolation='nearest')
      ax.set_xticks([])
      ax.set_yticks([])
    else:
      ax.axis('off')

In [None]:
# VAE reconstruction check
with tf.Session() as sess:
    saver.restore(sess, checkpoint_dir + '/' + 'checkpoint-40')
    #sess.run(tf.global_variables_initializer())
    rand_imgs = sample_random_images(2)
    rand_reconst_imgs = input_to_output(sess, rand_imgs).reshape(-1,64,64)
    side_by_side = np.array(rand_imgs).reshape(-1,64,64)
    show_images_grid([side_by_side[0],rand_reconst_imgs[0],side_by_side[1],rand_reconst_imgs[1]], 4)

In [None]:
# Disentanglement check: sweep over the latent space
with tf.Session() as sess:
    saver.restore(sess, checkpoint_dir + '/' + 'checkpoint-40')
    rand_imgs = sample_random_images(1)
    latent_mean, latent_log_std = input_to_latent(sess, rand_imgs)
    z_sigma_sq = np.exp(latent_log_std)[0]
    z_mean = latent_mean[0]
    print("Variance: ", end="")
    for x in z_sigma_sq:
        print(x, end=', ')
    
    appended_list = []
    for target_z_index in range(10):
        for ri in range(10):
            value = -3.0 + (6.0 / 9.0) * ri
            z_mean2 = np.zeros((1, 10))
            for i in range(10):
                if( i == target_z_index ):
                    z_mean2[0][i] = value
                else:
                    z_mean2[0][i] = z_mean[i]
            reconstr_img = latent_to_output(sess, z_mean2)
            rimg = reconstr_img[0].reshape(64, 64)
            appended_list.append(rimg)
      #imsave("disentangle_img/check_z{0}_{1}.png".format(target_z_index,ri), rimg)
    
    show_images_grid(appended_list, 100)