In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

In [2]:
def weight_variable(shape):
    return tf.Variable(tf.random_normal(shape, mean=0.0, stddev=0.01))

def bias_variable(shape):
    return tf.Variable(tf.random_normal(shape, mean=0.0, stddev=0.01))

def layer(x, shape, activation):
    W = weight_variable(shape)
    b = bias_variable([shape[1]])
    return activation(tf.matmul(x, W) + b)

In [3]:
def get_batch(images, batch_size, index):
    start_idx = index * batch_size
    end_idx = (index + 1) * batch_size        
    batch  = images[start_idx: end_idx]
    return batch

In [4]:
image_size = 28 * 28

input_shape=[None, image_size]
encoder_internal_dim=2048
decoder_internal_dim=2048
latent_dim=2

x = tf.placeholder(tf.float32, input_shape)

softplus = tf.nn.softplus

h_enc1 = layer(x, [image_size, encoder_internal_dim], activation=softplus)
h_enc2 = layer(h_enc1, [encoder_internal_dim, encoder_internal_dim], activation=softplus)
h_enc3 = layer(h_enc2, [encoder_internal_dim, encoder_internal_dim], activation=softplus)

W_mu = weight_variable([encoder_internal_dim, latent_dim])
b_mu = bias_variable([latent_dim])

W_log_sigma = weight_variable([encoder_internal_dim, latent_dim])
b_log_sigma = bias_variable([latent_dim])

    
z_mu = tf.matmul(h_enc3, W_mu) + b_mu
z_log_sigma = tf.matmul(h_enc3, W_log_sigma) + b_log_sigma

# reparametarization trick

# noise gaussian ε ~ N(0, 1)
epsilon = tf.random_normal(tf.stack([tf.shape(x)[0], latent_dim]))

# z = μ+σ^(1/2)*ε
z = z_mu + tf.exp(z_log_sigma/2) * epsilon


h_dec1 = layer(z, [latent_dim, decoder_internal_dim], activation=softplus)
h_dec2 = layer(h_dec1, [decoder_internal_dim, decoder_internal_dim], activation=softplus)
h_dec3 = layer(h_dec2, [decoder_internal_dim, decoder_internal_dim], activation=softplus)

# log(p(x|z)) (p is Bernoulli) reconstruction loss
y = layer(h_dec3, [decoder_internal_dim, image_size], activation=tf.nn.sigmoid)
log_px_given_z = -tf.reduce_sum(x * tf.log(y + 1e-10) + (1 - x) * tf.log(1 - y + 1e-10), 1)


# KLD(q(z|x)||p(z)) ~ -(1/2) * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl_div = - (1/2)* tf.reduce_sum(1.0 + 2.0 * z_log_sigma - tf.square(z_mu) - tf.exp(2.0 * z_log_sigma),1)

    
cost = tf.reduce_mean(log_px_given_z + kl_div)



In [5]:
% rm -rf img/
% mkdir img/

def create_images(i, sess, test_images, num_examples=20, image_size=28*28):
    h = w = int(np.sqrt(image_size))
    
    original = get_batch(test_images, num_examples, 0)
    reconstruction = sess.run(y, feed_dict={x: original})
    
    fig, axs = plt.subplots(2, num_examples, figsize=(10, 2))
    for example_i in range(num_examples):
        axs[0][example_i].imshow(np.reshape(original[example_i, :], (h, w)),cmap='gray')
        axs[1][example_i].imshow(np.reshape(np.reshape(reconstruction[example_i, ...], (image_size,)),(h, w)),cmap='gray')
        axs[0][example_i].axis('off')
        axs[1][example_i].axis('off')
    fig.savefig('img/reconstruction_%08d.png' % i)

    
    

def create_latent_scatter_images(i, sess, test_images,test_labels):
    zs = sess.run(z, feed_dict={x: test_images})
    fig, ax = plt.subplots(1, 1)
    ax.clear()
    ax.scatter(zs[:, 0], zs[:, 1], c=np.argmax(test_labels, 1), alpha=0.2)
    ax.set_xlim([-6, 6])
    ax.set_ylim([-6, 6])
    ax.axis("off")
    fig.savefig("img/latent_scatter_%08d.png"% i)

In [6]:
def train(train_images, validation_images, test_images, test_labels, image_size, learning_rate = 0.001):

    sess = tf.Session()
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    sess.run(tf.global_variables_initializer())

    
    batch_size = 100
    num_epochs = 100

    num_train_batches = len(train_images) // batch_size
    num_validation_batches = len(validation_images) // batch_size

    print("num of train batches: ", num_train_batches)
    print("num of validation batches: ", num_validation_batches)

    i = 0
    for epoch in range(num_epochs):
        print("epoch No.", epoch)
        
        for batch_idx in range(num_train_batches):            
            batch  = get_batch(train_images, batch_size, batch_idx)
            sess.run(optimizer, feed_dict={x: batch})
         
        train_cost = sess.run(cost, feed_dict={x: batch})
        i += 1
        create_images(i, sess, test_images, num_examples=20, image_size=image_size)
        create_latent_scatter_images(i, sess, test_images, test_labels)
        print('train cost per a batch: ', train_cost)
      

        valid_cost = 0
        for i in range(num_validation_batches):
            batch  = get_batch(validation_images, batch_size, i)
            valid_cost += sess.run(cost, feed_dict={x: batch})
        print('validation cost per a batch:', valid_cost / num_validation_batches)


In [None]:
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets('MNIST_DATA', one_hot=True)

Extracting MNIST_DATA/train-images-idx3-ubyte.gz
Extracting MNIST_DATA/train-labels-idx1-ubyte.gz
Extracting MNIST_DATA/t10k-images-idx3-ubyte.gz
Extracting MNIST_DATA/t10k-labels-idx1-ubyte.gz


In [None]:
train_images = mnist.train.images
validation_images = mnist.validation.images
test_images = mnist.test.images
test_labels = mnist.test.labels

train(train_images, validation_images, test_images, test_labels, image_size)

num of train batches:  550
num of validation batches:  50
epoch No. 0
train cost per a batch:  207.705
validation cost per a batch: 210.610833435
epoch No. 1
train cost per a batch:  208.258
validation cost per a batch: 208.829506836
epoch No. 2
