## Importance Weighted Autoencoder

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import time

In [None]:
z_dim = 50
batch_size = 100
nb_steps = 400000
k = 5

In [None]:
def render_images(np_x):
    np_x = np_x.reshape((10,10,28,28))
    np_x = np.concatenate(np.split(np_x,10,axis=0),axis=3)
    np_x = np.concatenate(np.split(np_x,10,axis=1),axis=2)
    x_img = np.squeeze(np_x)
    plt.imshow(x_img, cmap='Greys_r')
    plt.title('Generation')
    plt.show()

In [None]:
def encoder(x, z_dim=20, reuse=False):
    with tf.variable_scope("encoder", reuse=reuse):
        l1 = tf.layers.dense(x, 200, activation=tf.nn.relu)
        l2 = tf.layers.dense(l1, 200, activation=tf.nn.relu)
        mu = tf.layers.dense(l2, z_dim, activation=None)
        sigma = 1e-6 + tf.nn.softplus(tf.layers.dense(l2, z_dim, activation=None))
        return mu, sigma

In [None]:
def decoder(z, z_dim=20, reuse=False):
    with tf.variable_scope("decoder", reuse=reuse):
        l1 = tf.layers.dense(z, 200, activation=tf.nn.relu)
        l2 = tf.layers.dense(l1, 200, activation=tf.nn.relu)
        x_hat = tf.layers.dense(l2, 784, activation=tf.nn.sigmoid)
        return x_hat

In [None]:
def objective(z, mu, sigma, x, x_hat, training=True):
    log2pi = tf.log(2 * np.pi)
    log_QzGx = (-(z_dim / 2)*log2pi 
                + tf.reduce_sum(- tf.log(sigma) - 0.5 * tf.squared_difference(z, mu) / (2 * tf.square(sigma)), -1))
    log_PxGz = tf.reduce_mean(tf.reduce_sum(x * tf.log(x_hat + 1e-8) + (1 - x) * tf.log(1 - x_hat + 1e-8), [1]))
    log_Pz = (-(z_dim / 2)*log2pi 
                + tf.reduce_sum(- 0.5 * tf.squared_difference(z, 0) / 2, -1))
    if training:
        log_weights = tf.reshape(log_PxGz + log_Pz - log_QzGx, [k, batch_size])
        weights = tf.exp(log_weights - tf.reduce_max(log_weights, 0))
        normalized_weights = weights / tf.reduce_sum(weights, 0)
        loss = -tf.reduce_mean(tf.reduce_sum(normalized_weights * log_weights, 0))
    else:
        log_weights = tf.reshape(log_PxGz + log_Pz - log_QzGx, [5000, 1])
        log_wmax = tf.reduce_max(log_weights, 0)
        weights = tf.exp(log_weights - log_wmax)
        loss = -tf.reduce_mean(tf.log(tf.reduce_mean(weights, 0))) -tf.reduce_mean(log_wmax)
    return loss

In [None]:
x = tf.placeholder(tf.float32, [batch_size, 784])
x_k = tf.tile(x, [k, 1])
mu, sigma = encoder(x_k, z_dim=z_dim)
z = mu + sigma * tf.random_normal([k * batch_size, z_dim], 0, 1, dtype=tf.float32)
x_hat = decoder(z)

In [None]:
loss = objective(z, mu, sigma, x_k, x_hat)

In [None]:
x_test = tf.placeholder(tf.float32, [1, 784])
x_k_test = tf.tile(x_test, [5000, 1])
mu_test, sigma_test = encoder(x_k_test, z_dim=z_dim, reuse=True)
z_test = mu_test + sigma_test * tf.random_normal([5000 * 1, z_dim], 0, 1, dtype=tf.float32)
x_hat_test = decoder(z_test, reuse=True)

In [None]:
test_loss = objective(z_test, mu_test, sigma_test, x_k_test, x_hat_test, False)

In [None]:
optim_op = tf.train.AdamOptimizer(1e-4).minimize(loss)
init_op = tf.global_variables_initializer()

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.InteractiveSession(config=config)

In [None]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True, reshape=True)

In [None]:
sess.run(init_op)
start_time = time.time()
for stp in range(1, nb_steps+1):
    x_np, _ = mnist.train.next_batch(batch_size)
    _, loss_np = sess.run([optim_op, loss], feed_dict={x: x_np})
    if stp % 5000 == 0:
        end_time = time.time()
        print('Step: {:d} in {:.2f}s :: Loss: {:.3f}'.format(stp, end_time - start_time, loss_np))
        start_time = end_time
        x_hat_np = sess.run(x_hat, feed_dict={x: mnist.train.next_batch(100)[0]})
        render_images(x_hat_np[:100])