# Variational Autoencoder with MNIST

In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from PIL import Image

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def create_encoder(x, zdim):
    e = tf.layers.dense(x, 250, activation=tf.nn.relu)
    mu = tf.layers.dense(e, zdim)
    log_sigma = tf.layers.dense(e, zdim)
    return mu, log_sigma

In [None]:
def sampling_op(distribution):
    mu, log_sigma = distribution
    epsilon = tf.random_normal(shape=tf.shape(mu))
    z = mu + tf.exp(log_sigma) * epsilon
    return z

In [None]:
def create_decoder(z):
    d = tf.layers.dense(z, 250, activation=tf.nn.relu)
    out = tf.layers.dense(d, 784, activation=tf.sigmoid)
    return out

In [None]:
# Saving image samples
def save_sample(images_array, filename, shape):
    
    img_width = images_array.shape[1]
    img_height = images_array.shape[2]
    
    final_width = img_width * shape[0]
    final_height = img_width * shape[1]
    
    final_arr = np.zeros((final_width, final_height))
    
    for i in range(len(images_array)):
        x = int(i % shape[0]) * img_width
        y = int(i / shape[0]) * img_height
        
        final_arr[x:x + img_width, y:y + img_height] = images_array[i].reshape(img_height, img_width)
        
    final_img = Image.fromarray((final_arr * 255).astype(np.uint8), mode="L")
    final_img.save(filename)

In [None]:
mnist = input_data.read_data_sets("MNIST-data/")

In [None]:
# Build Model
tf.reset_default_graph()

zdim = 2
beta = 1

learning_rate = 0.001

X = tf.placeholder(dtype=tf.float32, shape=[None, 784])

encoder = create_encoder(X, zdim)
sampling = sampling_op(encoder)
Xh = create_decoder(sampling)

# Binary Crossentropy
r_loss = -tf.reduce_sum(X * tf.log(1e-8 + Xh) + (1 - X) * tf.log(1e-8 + 1 - Xh), 1)

mu, log_sigma = encoder
sigma = tf.exp(log_sigma)
kl_div = beta * -tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)), 1)

loss = tf.reduce_mean(r_loss) - tf.reduce_mean(kl_div)
train = tf.train.RMSPropOptimizer(learning_rate).minimize(loss)

In [None]:
batch_size = 100
epochs = 10000
display_step = 10

test_xs, _ = mnist.test.next_batch(25)
save_sample(test_xs.reshape(-1, 28, 28), "Generated/target.bmp", [5, 5])

sess = tf.Session()

sess.run(tf.global_variables_initializer())
    
for epoch in range(epochs):
    xs, _ = mnist.train.next_batch(batch_size)
    
    sess.run(train, feed_dict={X: xs})
    
    if epoch % display_step == 0:
        print("Epoch", epoch, "Loss", sess.run(loss, feed_dict={X: xs}))
        a = sess.run(Xh, feed_dict={X: test_xs}).reshape(-1, 28, 28)
        save_sample(a, "Generated/" + str(epoch) + ".bmp", [5, 5])

In [None]:
# Show latent z vector for different MNIST digit classes
plot_xs, plot_ys = mnist.test.next_batch(3000)


sampled_zvector = sess.run(sampling, feed_dict={X: plot_xs})
    
plt.figure(figsize=(7, 7))
plt.scatter(sampled_zvector[:, 0], sampled_zvector[:, 1], c=plot_ys)
plt.colorbar()
plt.show()