In [None]:
# All of the imports
import tensorflow as tf
from tensorflow.contrib.slim import fully_connected as fc
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

In [None]:
# Our data. Just load in the MNIST dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
num_sample = mnist.train.num_examples
# Flattened input 28*28=784
input_dim = 784

In [None]:
class VariationAutoencoder(object):
    def __init__(self, learning_rate=1e-4, batch_size=100, n_z=5):
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.n_z = n_z

        self.build()
        
        # Launch the session.
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())


    def build(self):
        self.x = tf.placeholder(name='x', dtype=tf.float32, shape=(None,
            input_dim))
        
        # Encoder
        # slim.fc(input, output_dim, scope, activation function)
        f1 = fc(self.x, 512, scope='enc_fullc1', activation_fn = tf.nn.elu)
        f2 = fc(f1, 384, scope='enc_fullc2', activation_fn = tf.nn.elu)
        f3 = fc(f2, 256, scope='enc_fullc3', activation_fn = tf.nn.elu)
        
        # Output dimension should be the latent dimension.
        self.z_mu = fc(f3, self.n_z, scope='enc_fc4_mu', activation_fn=None)
        # log(sigma^2)
        self.log_sigma_z_sq = fc(f3, self.n_z, scope='enc_fc4_sigma_sq', activation_fn=None)
        
        # N(z_mu, z_sigma)
        # Generate z from the normal distribution
        eps = tf.random_normal(shape=tf.shape(self.log_sigma_z_sq), mean=0, stddev=1, dtype=tf.float32)
        self.z = self.z_mu + tf.sqrt(tf.exp(self.log_sigma_z_sq)) * eps
        
        # Decoder
        d1 = fc(self.z, 256, scope='dec_fc1', activation_fn=tf.nn.elu)
        d2 = fc(self.z, 384, scope='dec_fc2', activation_fn=tf.nn.elu)
        d3 = fc(self.z, 512, scope='dec_fc3', activation_fn=tf.nn.elu)
        self.x_hat = fc(d3, input_dim, scope='dec_fc4', activation_fn=tf.sigmoid)
        
        # Losses
        # reconstruction loss between x and x hat 
        
        # H(x, x_hat) = - \Sigma x * log(x_hat) + (1-x) * log(1 - x_hat)
        epsilon = 1e-10
        recon_loss = -tf.reduce_sum(
            self.x * tf.log(self.x_hat + epsilon) + (1 - self.x) * tf.log(1 - self.x_hat + epsilon),
            axis=1
        )
        
        
        # latent distribution loss
        # Use the kale divergence to measure the difference between two distributions
        # The latent distribution and N(0, 1)
        latent_loss = -0.5 * tf.reduce_sum(
            1 + self.log_sigma_z_sq - tf.square(self.z_mu) - tf.exp(self.log_sigma_z_sq),
            axis=1
        )
        
        # total loss (that just combines the two losses together)
        self.total_loss = tf.reduce_mean(recon_loss + latent_loss)
        
        # Optimizer
        self.train_op = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate).minimize(self.total_loss)
        
        
    # Execute a forward and a backward pass
    # report the loss for monitoring
    def run_single_step(self, x):
        _, loss = self.sess.run([self.train_op, self.total_loss], feed_dict={self.x: x})
        return loss
        
        
    # Reconstruction
    # x -> x_hat
    def reconstructor(self, x):
        return self.sess.run(self.x_hat, feed_dict={self.x: x})
    
    
    # Generation
    def generator(self, z):
        return self.sess_run(self.x_hat, feed_dict={self.z: z})
    
    
    # Transformer
    # x -> z
    def transformer(self, x):
        return self.sess.run(self.z, feed_dict={self.x: x})

        
        

In [None]:
def trainer(learning_rate = 1e-4, batch_size = 100, num_epochs = 100, n_z = 10):
    # Create the model
    model = VariationAutoencoder(learning_rate = learning_rate, 
                                   batch_size = batch_size, n_z = n_z)
    
    
    # Training loop
    for epoch in range(num_epochs):
        for batch_i in range(num_sample // batch_size):
            # Obtain a mini-batch 
            batch = mnist.train.next_batch(batch_size)
            
            # train network, execute a forward and backward pass
            # We only care about the image not the label.
            loss = model.run_single_step(batch[0])
            
        print('[Epoch %i] Loss: %.6f' % (epoch, loss))
        
    print('Done')
    return model

In [None]:
model = trainer(learning_rate = 1e-4, batch_size = 100, num_epochs = 5, n_z = 5)