In [1]:
import time
import os
import argparse
import numpy as np
import tensorflow as tf
import tensorbayes as tb
from tensorbayes.layers import Constant, Placeholder
from tensorbayes.layers import Dense, BatchNormalization, GaussianUpdate
from tensorbayes.layers import GaussianSample, Duplicate
from tensorbayes.nbutils import show_graph
from tensorbayes.utils import progbar
from tensorbayes.distributions import log_bernoulli_with_logits
from tensorbayes.distributions import log_normal
from tensorflow.examples.tutorials.mnist import input_data
from tensorbayes.nputils import log_sum_exp
from tensorbayes.tbutils import log_sum_exp as tf_log_sum_exp

In [2]:
def show_default_graph():
    show_graph(tf.get_default_graph().as_graph_def())

In [11]:
class Dummy(object):
    pass
args = Dummy()
args.run = 32
args.bs = 256
args.lr = 5e-4
args.nonlin = 'elu'
args.eps = 1e-8
args.save_dir = '/scratch/users/rshu15'
args.n_checks = 100

In [12]:
if args.nonlin == 'relu':
    activate = tf.nn.relu
elif args.nonlin == 'elu':
    activate = tf.nn.elu
else:
    raise Exception("Unexpected nonlinearity arg")
args.save_dir = args.save_dir.rstrip('/')
model_dir = '{:s}/results/lvae{:d}'.format(args.save_dir, args.run)
log_bern = lambda x, logits: log_bernoulli_with_logits(x, logits, args.eps)
log_norm = lambda x, mu, var: log_normal(x, mu, var, 0.0)
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [13]:
# Convenience layers and graph blocks
def split_array(arr, size):
    for i in range(0, len(arr), size):
        yield arr[i:i + size]
        
def name(index, suffix):
    return 'z{:d}'.format(index) + '_' + suffix

def layer(x, size, scope, bn=True, activation=None):
    with tf.variable_scope(scope):
        h = Dense(x, size, scope='dense')
        if bn: h = BatchNormalization(h, phase, scope='bn')
        if activation is not None: h = activation(h)
        return h

def encode_block(x, h_size, z_size, idx):
    with tf.variable_scope(name(idx, 'encode')):
        h = layer(x, h_size, 'layer1', activation=activate)
        h = layer(h, h_size, 'layer2', activation=activate)
    with tf.variable_scope(name(idx, 'encode/likelihood')):
        z_m = layer(h, z_size, 'mean')
        z_v = layer(h, z_size, 'variance', activation=tf.nn.softplus) + args.eps
    return (z_m, z_v)

def infer_block(likelihood, prior, idx):
    with tf.variable_scope(name(idx, 'sample')):
        if prior is None:
            posterior = likelihood
        else:
            args = likelihood + prior
            posterior = GaussianUpdate(*args, scope='pwm')
        z = GaussianSample(*posterior, scope='sample')
    return z, posterior

def decode_block(z_like, z_prior, h_size, x_size, idx):
    z, z_post = infer_block(z_like, z_prior, idx)
    with tf.variable_scope(name(idx - 1, 'decode')):
        h = layer(z, h_size, 'layer1', activation=activate)
        h = layer(h, h_size, 'layer2', activation=activate)
    with tf.variable_scope(name(idx - 1, 'decode/prior')):
        if (idx - 1) == 0:
            logits = layer(h, 784, 'logits', bn=False)
            return z, z_post, logits
        else:
            x_m = layer(h, x_size, 'mean')
            x_v = layer(h, x_size, 'variance', activation=tf.nn.softplus) + args.eps
            x_prior = (x_m, x_v)
            return z, z_post, x_prior

In [14]:
# Ladder VAE set-up
tf.reset_default_graph()
phase = Placeholder(None, tf.bool, name='phase')
iw = Placeholder(None, 'int32', name='iw')
mc = Placeholder(None, 'int32', name='mc')
x = Placeholder((None, 784), name='x')
with tf.name_scope('z0'):
    z0 = tf.cast(tf.greater(x, tf.random_uniform(tf.shape(x), 0, 1)), tf.float32)

# encode
z1_like = encode_block(z0, 512, 64, idx=1)
z2_like = encode_block(z1_like[0], 256, 32, idx=2)
z3_like = encode_block(z2_like[0], 128, 16, idx=3)

# tiling for monte carlo and importance weighted samples
z0 = Duplicate(z0, iw, mc) # not naming. Makes visualization prettier for some reason
z1_like = tuple([Duplicate(val, iw, mc, 'dup') for val in z1_like])
z2_like = tuple([Duplicate(val, iw, mc, 'dup') for val in z2_like])
z3_like = tuple([Duplicate(val, iw, mc, 'dup') for val in z3_like])

# decode
z3_prior = (Constant(0), Constant(1))
z3, z3_post, z2_prior = decode_block(z3_like, None, 128, 32, idx=3)
z2, z2_post, z1_prior = decode_block(z2_like, z2_prior, 256, 64, idx=2)
z1, z1_post, z0_logits = decode_block(z1_like, z1_prior, 512, 784, idx=1)

with tf.name_scope('loss'):
    with tf.name_scope('recon'):
        recon = -log_bern(z0, z0_logits)
    with tf.name_scope('kl1'):
        kl1   = -log_norm(z1, *z1_prior) + log_norm(z1, *z1_post)
    with tf.name_scope('kl2'):
        kl2   = -log_norm(z2, *z2_prior) + log_norm(z2, *z2_post)
    with tf.name_scope('kl3'):
        kl3   = -log_norm(z3, *z3_prior) + log_norm(z3, *z3_post)
    per_sample_loss  = recon + kl1 + kl2 + kl3
    per_sample_loss = tf.reshape(per_sample_loss, [iw, mc, -1])
    per_sample_gain = tb.tbutils.log_sum_exp(-per_sample_loss, axis=0) - tf.log(tf.cast(iw, 'float32'))
    loss = -tf.reduce_mean(per_sample_gain)

In [15]:
show_default_graph()

In [16]:
sess = tf.Session()
saver = tf.train.Saver()
saver.restore(sess, model_dir + '/model.ckpt')

In [17]:
t = time.time()
img_list = list(split_array(mnist.test.images, 50))
n = len(img_list)
l = []
for i in xrange(n):
    l += sess.run([loss], feed_dict={'x:0': img_list[i],
                                     'phase:0': False,
                                     'iw:0': 5000,
                                     'mc:0': 1})
    progbar(i, n)
print 'Computation time:', time.time() - t
print 'Computed loss:', np.mean(l)

Computation time: 83.0723218918
Computed loss: 81.7514
