In [1]:
import numpy as np
import tensorflow as tf
import tensorbayes as tb
from tensorbayes.layers import constant, placeholder
from tensorbayes.layers import dense, batch_normalization
from tensorbayes.layers import sample_gaussian
from tensorbayes.utils import show_graph
from tensorbayes.utils import progbar
from tensorbayes.distributions import normal_log_pdf, bernoulli_log_pdf

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
# General settings.
activate = tf.nn.elu

In [4]:
def LadderMerge(zm1, zv1, zm2, zv2, scope):
    with tf.name_scope(scope):
        with tf.name_scope('variance'):
            zp1 = 1.0/zv1
            zp2 = 1.0/zv2
            zv = 1.0/(zp1 + zp2)
        with tf.name_scope('mean'):
            zm = (zm1 * zp1 + zm2 * zp2) * zv
    return zm, zv

In [5]:
def layer(x, size, scope, bn=True, activation=None):
    with tf.variable_scope(scope):
        h = dense(x, size, scope='dense')
        if bn:
            h = batch_normalization(h, phase, scope='bn')
        if activation is not None:
            h = activation(h)
        return h

def name(index, suffix):
    return 'z{:d}'.format(index) + '_' + suffix
    
def encode_block(x, h_size, z_size, idx):
    with tf.variable_scope(name(idx, 'encode')):
        h = layer(x, h_size, 'layer1', activation=activate)
        h = layer(h, h_size, 'layer2', activation=activate)
    with tf.variable_scope(name(idx, 'encode/likelihood')):
        z_m = layer(h, z_size, 'mean')
        z_v = layer(h, z_size, 'variance', activation=tf.nn.softplus)
    return (z_m, z_v)
    
def infer_block(likelihood, prior, idx):
    with tf.variable_scope(name(idx, 'sample')):
        if prior is None:
            posterior = likelihood
        else:
            args = likelihood + prior
            posterior = LadderMerge(*args, scope='pwm')
        z = sample_gaussian(*posterior, scope='sample')
    return z, posterior

def decode_block(z_like, z_prior, h_size, x_size, idx):
    z, z_post = infer_block(z_like, z_prior, idx)
    with tf.variable_scope(name(idx - 1, 'decode')):
        h = layer(z, h_size, 'layer1', activation=activate)
        h = layer(h, h_size, 'layer2', activation=activate)
    with tf.variable_scope(name(idx - 1, 'decode/prior')):
        if (idx - 1) == 0:
            logits = layer(h, 784, 'logits', bn=False)
            return z, z_post, logits
        else:
            x_m = layer(h, x_size, 'mean')
            x_v = layer(h, x_size, 'variance', activation=tf.nn.softplus)
            x_prior = (x_m, x_v)
            return z, z_post, x_prior

In [6]:
tf.reset_default_graph()
phase = placeholder(None, tf.bool, name='phase')
x = placeholder((None, 784), name='x')
with tf.name_scope('z0'):
    z0 = tf.cast(tf.greater(x, tf.random_uniform(tf.shape(x), 0, 1)), tf.float32)

# Encode.
z1_like = encode_block(z0, 512, 64, idx=1)
z2_like = encode_block(z1_like[0], 256, 32, idx=2)
z3_like = encode_block(z2_like[0], 128, 16, idx=3)
# Decode.
z3_prior = (constant(0), constant(1))
z3, z3_post, z2_prior = decode_block(z3_like, None, 128, 32, idx=3)
z2, z2_post, z1_prior = decode_block(z2_like, z2_prior, 256, 64, idx=2)
z1, z1_post, z0_logits = decode_block(z1_like, z1_prior, 512, 784, idx=1)

with tf.name_scope('loss'):
    with tf.name_scope('recon'):
        recon = -bernoulli_log_pdf(z0, z0_logits)
    with tf.name_scope('kl1'):
        kl1   = -normal_log_pdf(z1, *z1_prior) + normal_log_pdf(z1, *z1_post)
    with tf.name_scope('kl2'):
        kl2   = -normal_log_pdf(z2, *z2_prior) + normal_log_pdf(z2, *z2_post)
    with tf.name_scope('kl3'):
        kl3   = -normal_log_pdf(z3, *z3_prior) + normal_log_pdf(z3, *z3_post)
    loss  = recon + kl1 + kl2 + kl3

In [7]:
if False:
    show_graph(tf.get_default_graph().as_graph_def())

In [8]:
lr = placeholder(None, name='lr')
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    # Ensures that we execute the update_ops before performing the train_step
    train_step = tf.train.AdamOptimizer(lr).minimize(loss)

In [9]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [10]:
history = []
iterep = 500
for i in range(iterep * 2):
    x_train, y_train = mnist.train.next_batch(100)
    sess.run(train_step,
             feed_dict={'x:0': x_train,
                        'phase:0': True,
                        'lr:0': 2e-4})
    progbar(i, iterep)
    if (i + 1) %  iterep == 0:
        epoch = (i + 1) // iterep
        tr = sess.run(loss, 
                      feed_dict={'x:0': mnist.train.images,
                                 'phase:0': False})
        t = sess.run(loss, 
                     feed_dict={'x:0': mnist.test.images,
                                'phase:0': False})
        history += [[epoch, tr.mean(), t.mean()]]
        print(history[-1])

[1, 227.50777, 226.16895]
[2, 161.21005, 160.13354]
