In [None]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import math, os

In [None]:
# Define some handy network layers
def lrelu(x, rate=0.1):
    return tf.maximum(tf.minimum(x * rate, 0), x)

def conv2d_lrelu(inputs, num_outputs, kernel_size, stride):
    conv = tf.contrib.layers.convolution2d(inputs, num_outputs, kernel_size, stride, 
                                           weights_initializer=tf.contrib.layers.xavier_initializer(),
                                           activation_fn=tf.identity)
    conv = lrelu(conv)
    return conv

def conv2d_t_relu(inputs, num_outputs, kernel_size, stride):
    conv = tf.contrib.layers.convolution2d_transpose(inputs, num_outputs, kernel_size, stride,
                                                     weights_initializer=tf.contrib.layers.xavier_initializer(),
                                                     activation_fn=tf.identity)
    conv = tf.nn.relu(conv)
    return conv

def fc_lrelu(inputs, num_outputs):
    fc = tf.contrib.layers.fully_connected(inputs, num_outputs,
                                           weights_initializer=tf.contrib.layers.xavier_initializer(),
                                           activation_fn=tf.identity)
    fc = lrelu(fc)
    return fc

def fc_relu(inputs, num_outputs):
    fc = tf.contrib.layers.fully_connected(inputs, num_outputs,
                                           weights_initializer=tf.contrib.layers.xavier_initializer(),
                                           activation_fn=tf.identity)
    fc = tf.nn.relu(fc)
    return fc

In [None]:
# Encoder and decoder use the DC-GAN architecture
# 28 x 28 x 1
def encoder(x, z_dim):
    with tf.variable_scope('encoder'):
        conv1 = conv2d_lrelu(x, 64, 4, 2)   # None x 14 x 14 x 64
        conv2 = conv2d_lrelu(conv1, 128, 4, 2)   # None x 7 x 7 x 128
        conv2 = tf.reshape(conv2, [-1, np.prod(conv2.get_shape().as_list()[1:])]) # None x (7x7x128)
        fc1 = fc_lrelu(conv2, 1024)   
        mean = tf.contrib.layers.fully_connected(fc1, z_dim, activation_fn=tf.identity)
        stddev = tf.contrib.layers.fully_connected(fc1, z_dim, activation_fn=tf.sigmoid)
        stddev = tf.maximum(stddev, 0.005)
        return mean, stddev

In [None]:
def decoder(z, reuse=False):
    with tf.variable_scope('decoder') as vs:
        if reuse:
            vs.reuse_variables()
        fc1 = fc_relu(z, 1024)
        fc2 = fc_relu(fc1, 7*7*128)
        fc2 = tf.reshape(fc2, tf.stack([tf.shape(fc2)[0], 7, 7, 128]))
        conv1 = conv2d_t_relu(fc2, 64, 4, 2)
        mean = tf.contrib.layers.convolution2d_transpose(conv1, 1, 4, 2, activation_fn=tf.sigmoid)
        stddev = tf.contrib.layers.convolution2d_transpose(conv1, 1, 4, 2, activation_fn=tf.sigmoid)
        stddev = tf.maximum(stddev, 0.005)
        return mean, stddev

In [None]:
# Build the computation graph for training
z_dim = 5
x_dim = [28, 28, 1]
train_x = tf.placeholder(tf.float32, shape=[None] + x_dim)
train_zmean, train_zstddev = encoder(train_x, z_dim)
train_z =  train_zmean + tf.multiply(train_zstddev,
                                     tf.random_normal(tf.stack([tf.shape(train_x)[0], z_dim])))
zstddev_logdet = tf.reduce_mean(tf.reduce_sum(2.0 * tf.log(train_zstddev), axis=1))

train_xmean, train_xstddev = decoder(train_z)
train_xr = train_xmean + tf.multiply(train_xstddev,
                                     tf.random_normal(tf.stack([tf.shape(train_x)[0]] + x_dim)))
xstddev_logdet = tf.reduce_mean(tf.reduce_sum(2.0 * tf.log(train_xstddev), axis=(1, 2, 3)))

In [None]:
# Build the computation graph for generating samples
gen_z = tf.placeholder(tf.float32, shape=[None, z_dim])
gen_xmean, gen_xstddev = decoder(gen_z, reuse=True)

In [None]:
sample_nll = tf.div(tf.square(train_x - gen_xmean), tf.square(gen_xstddev)) / 2.0 + tf.log(gen_xstddev)
sample_nll += math.log(2 * np.pi) / 2.0
sample_nll = tf.reduce_sum(sample_nll, axis=(1, 2, 3))

In [None]:
def compute_kernel(x, y):
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1]))
    tiled_y = tf.tile(tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1]))
    return tf.exp(-tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32))

def compute_mmd(x, y):   # [batch_size, z_dim] [batch_size, z_dim]
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return tf.reduce_mean(x_kernel) + tf.reduce_mean(y_kernel) - 2 * tf.reduce_mean(xy_kernel)

# Compare the generated z with true samples from a standard Gaussian, and compute their MMD distance
true_samples = tf.random_normal(tf.stack([200, z_dim]))
loss_mmd = compute_mmd(true_samples, train_z)

In [None]:
loss_elbo = tf.reduce_sum(-tf.log(train_zstddev) + 0.5 * tf.square(train_zstddev) +
                          0.5 * tf.square(train_zmean) - 0.5, axis=1)
loss_elbo = tf.reduce_mean(loss_elbo) / np.prod(x_dim)

In [None]:
loss_nll = tf.div(tf.square(train_x - train_xmean), tf.square(train_xstddev)) / 2.0 + tf.log(train_xstddev)
loss_nll = tf.reduce_mean(loss_nll)
loss_nll += math.log(2 * np.pi) / 2.0

In [None]:
loss = loss_nll + loss_elbo
trainer = tf.train.AdamOptimizer(1e-4).minimize(loss)

In [None]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('mnist_data')

In [None]:
# Convert a numpy array of shape [batch_size, height, width, 1] into a displayable array 
# of shape [height*sqrt(batch_size, width*sqrt(batch_size))] by tiling the images
def convert_to_display(samples, max_samples=100):
    if max_samples > samples.shape[0]:
        max_samples = samples.shape[0]
    cnt, height, width = int(math.floor(math.sqrt(max_samples))), samples.shape[1], samples.shape[2]
    samples = samples[:cnt*cnt]
    samples = np.transpose(samples, axes=[1, 0, 2, 3])
    samples = np.reshape(samples, [height, cnt, cnt, width])
    samples = np.transpose(samples, axes=[1, 0, 2, 3])
    samples = np.reshape(samples, [height*cnt, width*cnt])
    return samples

In [None]:
class LimitedMnist:
    def __init__(self, size):
        self.data_ptr = 0
        self.size = size
        assert size <= mnist.train.images.shape[0]
        self.data = mnist.train.images[:size]


    def next_batch(self, batch_size):
        prev_ptr = self.data_ptr
        self.data_ptr += batch_size
        if self.data_ptr > self.size:
            prev_ptr = 0
            self.data_ptr = batch_size
        return self.data[prev_ptr:self.data_ptr]
            

In [None]:
limited_mnist = LimitedMnist(100)

In [None]:
for i in range(4):
    batch_x = limited_mnist.train_next_batch(100)
    plt.imshow(convert_to_display(np.reshape(batch_x, [-1]+x_dim)))
    plt.show()

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
batch_size = 100
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True))
sess.run(tf.global_variables_initializer())

In [None]:
# Start training
plt.rcParams["figure.figsize"] = (20, 5)
for i in range(500000):
    batch_x, batch_y = mnist.train.next_batch(batch_size)
    batch_x = batch_x.reshape(-1, 28, 28, 1)
    _, nll, mmd, elbo, xmean, xstddev, xlogdet, zlogdet = \
        sess.run([trainer, loss_nll, loss_mmd, loss_elbo, train_xmean, train_xstddev, xstddev_logdet, zstddev_logdet],
                 feed_dict={train_x: batch_x})
    if i % 100 == 0:
        print("Iteration %d, Negative log likelihood is %f, mmd loss %f, elbo loss %f" % (i, nll, mmd, elbo * 784 / 5))
        print("xlogdet %f  zlogdet %f" % (xlogdet, zlogdet))
    if i % 1000 == 0:
        samples, sample_stddev = sess.run([gen_xmean, gen_xstddev], feed_dict={gen_z: np.random.normal(size=(100, z_dim))})
        plt.subplot(1, 4, 1)
        plt.imshow(convert_to_display(samples), cmap='Greys_r')
        plt.subplot(1, 4, 2)
        plt.imshow(convert_to_display(sample_stddev), cmap='Greys_r')
        plt.subplot(1, 4, 3)
        plt.imshow(convert_to_display(xmean), cmap='Greys_r')
        plt.subplot(1, 4, 4)
        plt.imshow(convert_to_display(xstddev), cmap='Greys_r')
        plt.show()

        z_list, label_list = [], []
        test_batch_size = 500
        for i in range(20):
            batch_x, batch_y = mnist.test.next_batch(test_batch_size)
            batch_x = batch_x.reshape(-1, 28, 28, 1)
            z_list.append(sess.run(train_z, feed_dict={train_x: batch_x}))
            label_list.append(batch_y)
        z = np.concatenate(z_list, axis=0)
        label = np.concatenate(label_list)
        plt.scatter(z[:, 0], z[:, 1], c=label)
        plt.show()

In [None]:
def compute_log_sum(val):
    min_val = np.min(val, axis=0, keepdims=True)
    return np.mean(min_val - np.log(np.mean(np.exp(-val + min_val), axis=0)))

In [None]:
import time
print("---------------------> Computing true log likelihood")
start_time = time.time()
avg_nll = []
for _ in range(20):
    # Takes about 40min per batch, expected to take 20h in total
    batch_x, batch_y = mnist.test.next_batch(batch_size)
    batch_x = np.reshape(batch_x, [-1] + x_dim)
    nll_list = []
    num_iter = 50000

    for iter in range(num_iter):
        random_z = np.random.normal(size=[batch_size, z_dim])
        nll = sess.run(sample_nll, feed_dict={train_x: batch_x, gen_z: random_z})
        nll_list.append(nll)
        if iter % 1000 == 0:
            print("%d %f, timed used %f" % (iter, compute_log_sum(np.stack(nll_list)), time.time() - start_time))
    nll = compute_log_sum(np.stack(nll_list))
    print("Likelihood importance sampled = %f, time used %f" % (nll, time.time() - start_time))
    avg_nll.append(nll)
nll = np.mean(avg_nll)
print("Estimated log likelihood is %f, time elapsed %f" % (nll, time.time() - start_time))

In [None]:
import time
print("---------------------> Computing true log likelihood")
start_time = time.time()
avg_nll = []
for _ in range(20):
    # Takes about 40min per batch, expected to take 20h in total
    batch_x, batch_y = mnist.test.next_batch(batch_size)
    batch_x = np.reshape(batch_x, [-1] + x_dim)
    nll_list = []
    num_iter = 50000

    for iter in range(num_iter):
        random_z = np.random.normal(size=[batch_size, z_dim])
        nll = sess.run(sample_nll, feed_dict={train_x: batch_x, gen_z: random_z})
        nll_list.append(nll)
        if iter % 1000 == 0:
            print("%d %f, timed used %f" % (iter, compute_log_sum(np.stack(nll_list)), time.time() - start_time))
    nll = compute_log_sum(np.stack(nll_list))
    print("Likelihood importance sampled = %f, time used %f" % (nll, time.time() - start_time))
    avg_nll.append(nll)
nll = np.mean(avg_nll)
print("Estimated log likelihood is %f, time elapsed %f" % (nll, time.time() - start_time))

In [None]:
def compute_z_extent():
    z_list = []
    for k in range(100):
        batch_x = limited_mnist.next_batch(batch_size)
        batch_x = np.reshape(batch_x, [-1]+x_dim)
        z = sess.run(train_z, feed_dict={train_x: batch_x})
        z_list.append(z)
    z_list = np.concatenate(z_list, axis=0)
    cov = np.cov(z_list.T)
    sign, logdet = np.linalg.slogdet(cov)
    return logdet

In [None]:
compute_z_extent()