In [20]:
import numpy as np
import tensorflow as tf
import tensorbayes as tb
from tensorbayes.layers import Constant, Placeholder
from tensorbayes.layers import Dense, BatchNormalization
from tensorbayes.layers import GaussianSample
from tensorbayes.nbutils import show_graph
from tensorbayes.utils import progbar
from tensorbayes.distributions import log_bernoulli_with_logits, log_normal

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [12]:
def LadderMerge(zm1, zv1, zm2, zv2, scope):
    with tf.name_scope(scope):
        with tf.name_scope('variance'):
            zp1 = 1.0/zv1
            zp2 = 1.0/zv2
            zv = 1.0/(zp1 + zp2)
        with tf.name_scope('mean'):
            zm = (zm1 * zp1 + zm2 * zp2) * zv
    return zm, zv

In [13]:
def layer(x, size, phase, scope, bn=True, activation=None):
    with tf.variable_scope(scope):
        h = Dense(x, size, scope='dense')
        if bn: h = BatchNormalization(h, phase, scope='bn')
        if activation is not None: h = activation(h)
        return h

In [14]:
def vae_graph(x, phase):
    with tf.variable_scope('inference'):
        d1 = layer( x, 256, phase, 'layer1', activation=tf.nn.relu)
        d1 = layer(d1, 256, phase, 'layer2', activation=tf.nn.relu)
        d2 = layer(d1, 256, phase, 'layer3', activation=tf.nn.relu)        
        d2 = layer(d2, 256, phase, 'layer4', activation=tf.nn.relu)
    with tf.variable_scope('inference/z1_param'):
        z1m = layer(d1, 50, phase, 'm1')        
        z1v = layer(d1, 50, phase, 'v1', activation=tf.nn.softplus)
    with tf.variable_scope('inference/z2_param'):
        z2m = layer(d2, 50, phase, 'm2')        
        z2v = layer(d2, 50, phase, 'v2', activation=tf.nn.softplus)
    with tf.variable_scope('generate'):
        z2m_prior = Constant(0)
        z2v_prior = Constant(1)
        z2 = GaussianSample(z2m, z2v, 'z2')
        u2 = layer(z2, 256, phase, 'layer1', activation=tf.nn.relu)
        u2 = layer(u2, 256, phase, 'layer2', activation=tf.nn.relu)
        z1m_prior = layer(u2, 50, phase, 'm1')
        z1v_prior = layer(u2, 50, phase, 'v1', activation=tf.nn.softplus)        
        z1m, z1v = LadderMerge(z1m, z1v, z1m_prior, z1v_prior, 'pwm')
        z1 = GaussianSample(z1m, z1v, 'z1')
        u1 = layer(z1, 256, phase, 'layer3', activation=tf.nn.relu)
        u1 = layer(u1, 256, phase, 'layer4', activation=tf.nn.relu)        
        px_logits = layer(u1, 784, phase, 'logits', bn=False)
    return px_logits, z1, z1m, z1v, z1m_prior, z1v_prior, z2, z2m, z2v, z2m_prior, z2v_prior

In [15]:
tf.reset_default_graph()
phase = Placeholder(None, tf.bool, name='phase')
x = Placeholder((None, 784), name='x')
with tf.name_scope('x_binarized'):
    xb = tf.cast(tf.greater(x, tf.random_uniform(tf.shape(x), 0, 1)), tf.float32)

px_logits, z1, z1m, z1v, z1m_prior, z1v_prior, z2, z2m, z2v, z2m_prior, z2v_prior = vae_graph(xb, phase)

with tf.name_scope('loss'):
    recon = -log_bernoulli_with_logits(xb, px_logits)
    kl1   = -log_normal(z1, z1m_prior, z1v_prior) + log_normal(z1, z1m, z1v)
    kl2   = -log_normal(z2, z2m_prior, z2v_prior) + log_normal(z2, z2m, z2v)
    loss  = recon + kl1 + kl2

In [16]:
show_graph(tf.get_default_graph().as_graph_def())

In [17]:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    # Ensures that we execute the update_ops before performing the train_step
    train_step = tf.train.AdamOptimizer().minimize(loss)

In [18]:
sess = tf.Session()
sess.run(tf.initialize_all_variables())

Instructions for updating:
Use `tf.global_variables_initializer` instead.


In [19]:
history = []
iterep = 1
for i in range(iterep * 30):
    x_train, y_train = mnist.train.next_batch(100)
    sess.run(train_step,
             feed_dict={'x:0': x_train,
                        'phase:0': True})
    progbar(i, iterep)
    if (i + 1) %  iterep == 0:
        epoch = (i + 1)/iterep
        tr = sess.run(loss, 
                      feed_dict={'x:0': mnist.train.images,
                                 'phase:0': False})
        t = sess.run(loss, 
                     feed_dict={'x:0': mnist.test.images,
                                'phase:0': False})
        history += [[epoch, tr.mean(), t.mean()]]
        print history[-1]
return history

[1, 546.46252, 546.55145]
[2, 539.00146, 538.98212]
[3, 531.84485, 531.84521]
[4, 525.0238, 525.0224]


KeyboardInterrupt: 