# Variational Auto-encoder

In [1]:
import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
%matplotlib inline

np.random.seed(1234)
tf.set_random_seed(1234)

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
def xavier_init(fan_in, fan_out, constant=1): 
    low = -constant*np.sqrt(6.0/(fan_in + fan_out)) 
    high = constant*np.sqrt(6.0/(fan_in + fan_out))
    return tf.random_uniform((fan_in, fan_out), 
                             minval=low, maxval=high, 
                             dtype=tf.float32)

In [4]:
class VariationalAutoencoder(object):
    def __init__(self, n_input, n_z, network_architecture, 
                 learning_rate=0.001, decoder_distribution='gaussian', 
                 reparameterization_trick_for_train=True, 
                 reparameterization_trick_for_gradients=True):
        self.n_input = n_input
        self.n_z = n_z
        self.network_architecture = network_architecture
        self.learning_rate = learning_rate
        self.decoder_distribution = decoder_distribution
        self.reparameterization_trick_for_train = reparameterization_trick_for_train
        self.reparameterization_trick_for_gradients = reparameterization_trick_for_gradients
        
        self.x = tf.placeholder(tf.float32, [None, n_input])
        self.n_x = tf.cast(tf.shape(self.x)[1], tf.float32)
        
        self._create_network()
        
        self._create_loss_optimizer()
        
        init = tf.initialize_all_variables()
        self.sess = tf.Session()
        self.sess.run(init)
        
    def _create_network(self):
        self.weights = self._initialize_weights(**self.network_architecture)
        
        encoder_layer1 = tf.nn.softplus(tf.add(tf.matmul(self.x, self.weights['encoder']['h1']),
                                               self.weights['encoder']['b1']))
        encoder_layer2 = tf.nn.softplus(tf.add(tf.matmul(encoder_layer1, self.weights['encoder']['h2']), 
                                               self.weights['encoder']['b2']))
        self.z_mean = tf.add(tf.matmul(encoder_layer2, self.weights['encoder']['out_mean']), 
                             self.weights['encoder']['out_mean_b'])
        self.z_log_sigma_sq = tf.add(tf.matmul(encoder_layer2, self.weights['encoder']['out_log_sigma_sq']), 
                                     self.weights['encoder']['out_log_sigma_sq_b'])
        if self.reparameterization_trick_for_train == True:
            epsilon = tf.random_normal((tf.shape(self.x)[0], self.n_z), 0, 1, dtype=tf.float32)
            self.z = tf.add(self.z_mean, tf.mul(tf.sqrt(tf.exp(self.z_log_sigma_sq)), epsilon))
        else:
            self.z = tf.random_normal((tf.shape(self.x)[0], self.n_z), 
                                      self.z_mean, tf.sqrt(tf.exp(self.z_log_sigma_sq)))
        
        decoder_layer1 = tf.nn.softplus(tf.add(tf.matmul(self.z, self.weights['decoder']['h1']),
                                               self.weights['decoder']['b1']))
        decoder_layer2 = tf.nn.softplus(tf.add(tf.matmul(decoder_layer1, self.weights['decoder']['h2']), 
                                               self.weights['decoder']['b2']))
        self.x_reconstruction = tf.sigmoid(tf.add(tf.matmul(decoder_layer2, self.weights['decoder']['out_mean']),
                                                  self.weights['decoder']['out_mean_b']))
        
    def _initialize_weights(self, n_hidden_encoder_1, n_hidden_encoder_2, 
                           n_hidden_decoder_1, n_hidden_decoder_2):
        weights = dict()
        weights['encoder'] = {
            'h1': tf.Variable(xavier_init(self.n_input, n_hidden_encoder_1)), 
            'h2': tf.Variable(xavier_init(n_hidden_encoder_1, n_hidden_encoder_2)), 
            'out_mean': tf.Variable(xavier_init(n_hidden_encoder_2, self.n_z)),
            'out_log_sigma_sq': tf.Variable(xavier_init(n_hidden_encoder_2, self.n_z)),
            
            'b1': tf.Variable(tf.zeros([n_hidden_encoder_1], dtype=tf.float32)), 
            'b2': tf.Variable(tf.zeros([n_hidden_encoder_2], dtype=tf.float32)), 
            'out_mean_b': tf.Variable(tf.zeros([self.n_z], dtype=tf.float32)),
            'out_log_sigma_sq_b': tf.Variable(tf.zeros([self.n_z], dtype=tf.float32))
        }
        weights['decoder'] = {
            'h1': tf.Variable(xavier_init(self.n_z, n_hidden_decoder_1)), 
            'h2': tf.Variable(xavier_init(n_hidden_decoder_1, n_hidden_decoder_2)), 
            'out_mean': tf.Variable(xavier_init(n_hidden_encoder_2, self.n_input)),
            
            'b1': tf.Variable(tf.zeros([n_hidden_decoder_1], dtype=tf.float32)), 
            'b2': tf.Variable(tf.zeros([n_hidden_decoder_2], dtype=tf.float32)), 
            'out_mean_b': tf.Variable(tf.zeros([self.n_input], dtype=tf.float32))
        }
        return weights
    
    def _create_loss_optimizer(self):
        if self.decoder_distribution == 'gaussian':
            self.inference_log_density = -0.5*self.n_x*tf.log(2*np.pi) \
                                         - 0.5*tf.reduce_sum(tf.square(tf.sub(self.x_reconstruction, self.x)), 1)
            self.recognition_log_density = -0.5*self.n_z*tf.log(2*np.pi) -0.5*tf.reduce_sum(self.z_log_sigma_sq, 1) \
                                           -0.5*tf.reduce_sum(tf.mul(tf.exp(-self.z_log_sigma_sq), 
                                                                     tf.square(tf.sub(self.z, self.z_mean))), 1)
            if self.reparameterization_trick_for_train == True:
                self.decoder_cost = 0.5*tf.reduce_sum(tf.square(tf.sub(self.x_reconstruction, self.x)), 1)
            else:
                self.decoder_cost = -tf.mul(self.inference_log_density, self.recognition_log_density)
            
            if self.reparameterization_trick_for_gradients == True:
                self.decoder_cost_for_gradients = 0.5 \
                                                  * tf.reduce_sum(tf.square(tf.sub(self.x_reconstruction, self.x)), 1)
            else:
                self.decoder_cost_for_gradients = -tf.mul(self.inference_log_density, self.recognition_log_density)
        elif self.decoder_distribution == 'bernoulli':
            self.decoder_cost = -tf.reduce_sum(self.x * tf.log(1e-10 + self.x_reconstruction)
                                               + (1-self.x) * tf.log(1e-10 + 1 - self.x_reconstruction), 1)
        else:
            raise ValueError('Unsupported decoder distribution!')
        self.encoder_cost = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq 
                                            - tf.square(self.z_mean) 
                                            - tf.exp(self.z_log_sigma_sq), 1)
        self.cost = tf.reduce_mean(self.decoder_cost + self.encoder_cost)
        self.cost_for_gradients = tf.reduce_mean(self.decoder_cost_for_gradients + self.encoder_cost)
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
        self.minimizer = self.optimizer.minimize(self.cost)
        
        self.compute_gradients = self.optimizer.compute_gradients(self.cost_for_gradients)
        
    def partial_fit(self, X):
        self.sess.run(self.minimizer, feed_dict={self.x: X})
    
    def get_gradients(self, X):
        gradients = self.sess.run(self.compute_gradients, feed_dict={self.x: X})
        flat_gradients = np.array([])
        for grad in gradients:
            flat_gradients = np.append(flat_gradients, grad[0].flatten())
        return flat_gradients
    
    def transform(self, X):
        return self.sess.run(self.z_mean, feed_dict={self.x: X})
    
    def generate(self, z=None):
        if z is None:
            z = np.random.normal(size=self.n_z)
        return self.sess.run(self.x_reconstruction, feed_dict={self.z: z})
    
    def reconstruct(self, X):
        return self.sess.run(self.x_reconstruction, feed_dict={self.x: X})
    
    def loss(self, X):
        return self.sess.run(self.cost, feed_dict={self.x: X})
    
    def decoder_loss(self, X):
        return self.sess.run(self.decoder_cost, feed_dict={self.x: X})
    
    def encoder_loss(self, X):
        return self.sess.run(self.encoder_cost, feed_dict={self.x: X})

In [5]:
n_samples = mnist.train.num_examples
n_input = 784
n_z = 20
batch_size = 100
learning_rate = 0.001

network_architecture = {
    'n_hidden_encoder_1': 100,
    'n_hidden_encoder_2': 100,
    'n_hidden_decoder_1': 100,
    'n_hidden_decoder_2': 100
}

In [6]:
def train(data, n_samples, n_input, n_z, batch_size, 
          network_architecture, learning_rate, decoder_distribution, 
          reparameterization_trick_for_train, reparameterization_trick_for_gradients,
          training_epochs=10, display_step=5, variance_display_step=5):
    vae = VariationalAutoencoder(n_input, n_z, network_architecture, learning_rate, 
                                 decoder_distribution, reparameterization_trick_for_train, 
                                 reparameterization_trick_for_gradients)
    for epoch in xrange(training_epochs):
        avg_cost = 0.
        total_batch = int(n_samples / batch_size)
        for i in xrange(total_batch):
            batch_xs, _ = data.train.next_batch(batch_size)
            vae.partial_fit(batch_xs)
            cost = vae.loss(batch_xs)
            avg_cost += cost / n_samples * batch_size
        
        if epoch % display_step == 0:
            print('Epoch: {:04d}, cost = {:.9f}, test cost = {:.9f}' \
                  .format(epoch+1, avg_cost, vae.loss(data.test.images)))
            
        if epoch % variance_display_step == 0:
            gradients = []
            for _ in xrange(100):
                gradient = vae.get_gradients(batch_xs)
                gradients.append(gradient)
            gradients = np.array(gradients)
            gradient_variance = np.linalg.norm(gradients - gradients.mean(axis=0)) / 100
            print('Epoch: {:04d}, gradient variance = {:.9f}'.format(epoch+1, gradient_variance))
    return vae

### Train with reparameterization trick and compute gradients with reparameterization trick

In [7]:
vae = train(mnist, n_samples, n_input, n_z, 
            batch_size, network_architecture, 
            learning_rate, decoder_distribution='gaussian', 
            reparameterization_trick_for_train=True, 
            reparameterization_trick_for_gradients=True, 
            training_epochs=76, variance_display_step=25)

Epoch: 0001, cost = 27.249713381, test cost = 25.881767273
Epoch: 0001, gradient variance = 0.123692751
Epoch: 0006, cost = 21.156966511, test cost = 20.803092957
Epoch: 0011, cost = 20.022824222, test cost = 19.922349930
Epoch: 0016, cost = 19.610593064, test cost = 19.531904221
Epoch: 0021, cost = 19.393669718, test cost = 19.426069260
Epoch: 0026, cost = 19.171127066, test cost = 19.149215698
Epoch: 0026, gradient variance = 0.509512763
Epoch: 0031, cost = 19.035913082, test cost = 18.996868134
Epoch: 0036, cost = 18.920415157, test cost = 18.909206390
Epoch: 0041, cost = 18.828198045, test cost = 18.834566116
Epoch: 0046, cost = 18.740361096, test cost = 18.779804230
Epoch: 0051, cost = 18.682113856, test cost = 18.713949203
Epoch: 0051, gradient variance = 0.573750052
Epoch: 0056, cost = 18.628590383, test cost = 18.683971405
Epoch: 0061, cost = 18.579649159, test cost = 18.593568802
Epoch: 0066, cost = 18.533772153, test cost = 18.635084152
Epoch: 0071, cost = 18.495735203, test 

### Train with reparameterization trick, but compute gradients with log-derivative trick

In [7]:
vae = train(mnist, n_samples, n_input, n_z, 
            batch_size, network_architecture, 
            learning_rate, decoder_distribution='gaussian', 
            reparameterization_trick_for_train=True, 
            reparameterization_trick_for_gradients=False, 
            training_epochs=76, variance_display_step=25)

Epoch: 0001, cost = 27.249713381, test cost = 25.881767273
Epoch: 0001, gradient variance = 3.698353952
Epoch: 0006, cost = 21.156966511, test cost = 20.803092957
Epoch: 0011, cost = 20.022824222, test cost = 19.922349930
Epoch: 0016, cost = 19.610593064, test cost = 19.531904221
Epoch: 0021, cost = 19.393669718, test cost = 19.426069260
Epoch: 0026, cost = 19.171127066, test cost = 19.149215698
Epoch: 0026, gradient variance = 12.217070787
Epoch: 0031, cost = 19.035913082, test cost = 18.996868134
Epoch: 0036, cost = 18.920415157, test cost = 18.909206390
Epoch: 0041, cost = 18.828198045, test cost = 18.834566116
Epoch: 0046, cost = 18.740361096, test cost = 18.779804230
Epoch: 0051, cost = 18.682113856, test cost = 18.713949203
Epoch: 0051, gradient variance = 13.589430122
Epoch: 0056, cost = 18.628590383, test cost = 18.683971405
Epoch: 0061, cost = 18.579649159, test cost = 18.593568802
Epoch: 0066, cost = 18.533772153, test cost = 18.635084152
Epoch: 0071, cost = 18.495735203, tes

### Train with log-derivative trick and compute gradients with log-derivative trick

In [7]:
vae = train(mnist, n_samples, n_input, n_z, 
            batch_size, network_architecture, 
            learning_rate, decoder_distribution='gaussian', 
            reparameterization_trick_for_train=False, 
            reparameterization_trick_for_gradients=False, 
            training_epochs=76, variance_display_step=25)

Epoch: 0001, cost = -90745.160958807, test cost = -92366.554687500
Epoch: 0001, gradient variance = 0.008425934
Epoch: 0006, cost = -92418.387428977, test cost = -92386.843750000
Epoch: 0011, cost = -92433.870312500, test cost = -92427.234375000
Epoch: 0016, cost = -92422.569133523, test cost = -92429.914062500
Epoch: 0021, cost = -92436.126491477, test cost = -92467.539062500
Epoch: 0026, cost = -92414.301704545, test cost = -92460.054687500
Epoch: 0026, gradient variance = 0.000082183
Epoch: 0031, cost = -92435.803494318, test cost = -92393.507812500
Epoch: 0036, cost = -92421.430298295, test cost = -92424.023437500
Epoch: 0041, cost = -92420.640397727, test cost = -92450.015625000
Epoch: 0046, cost = -92444.738536932, test cost = -92429.054687500
Epoch: 0051, cost = -92472.283664773, test cost = -92447.039062500
Epoch: 0051, gradient variance = 0.000041690
Epoch: 0056, cost = -92449.946107955, test cost = -92453.953125000
Epoch: 0061, cost = -92416.251889205, test cost = -92393.3281