# This must be run with a kernel having tensorflow. Virtual-env tf works

# VAE using Edward, a probabilistic programming language in TensorFlow.

## First we do setup -- imports, data handling and model parameters

### Imports

In [1]:
import warnings
warnings.simplefilter("ignore")

import edward as ed
import numpy as np
import tensorflow as tf
from observations import mnist
from matplotlib.image import imsave



### Helper for generating data

In [2]:
def generator(array, batch_size):
    """Generate batch with respect to array's first axis."""
    start = 0  # pointer to where we are in iteration
    while True:
        stop = start + batch_size
        diff = stop - array.shape[0]
        if diff <= 0:
            batch = array[start:stop]
            start += batch_size
        else:
            batch = np.concatenate((array[start:], array[:diff]))
            start = diff
        batch = batch.astype(np.float32) / 255.0  # normalize pixel intensities
        batch = np.random.binomial(1, batch)  # Make images binary -- with small noise
        yield batch

### Set up model parameters and load data

In [3]:
# Set seed
ed.set_seed(123)

# Define model sizes etc.
z_dim = 2
batch_size = 1000

# DATA. MNIST batches are fed at training time.
(x_train, _), (x_test, _) = mnist('./data')
x_train_generator = generator(x_train, batch_size)


## Next, define the models -- both generative and variational

### The generative part  -- Taking us from the latent Z (Gaussian) to the observed X (Bernoulli)

In [4]:
# Define a the generative  model
# Sizes corresponding to a minibatch
# Note the simple definition of variables
z = ed.models.Normal(loc=tf.zeros([batch_size, z_dim]),
                     scale=tf.ones([batch_size, z_dim]))
hidden_gen = tf.layers.dense(z, 64, activation=tf.nn.relu)
hidden_gen = tf.layers.dense(hidden_gen, 256, activation=tf.nn.relu)
x = ed.models.Bernoulli(logits=tf.layers.dense(hidden_gen, 28 * 28, activation=None))

### The variational model -- Taking us from a placeholder for the data to the latent Z.

In [5]:
# Define the variational approximation. 
# Again, sizes corresponding to a minibatch
# x_ph is the Tensorflow placeholder that will be fed by data
x_ph = tf.placeholder(tf.int32, [batch_size, 28 * 28])
hidden_vb = tf.layers.dense(tf.cast(x_ph, tf.float32), 256,
                            activation=tf.nn.relu)
hidden_vb = tf.layers.dense(hidden_vb, 256,
                            activation=tf.nn.relu)
qz = ed.models.Normal(loc=tf.layers.dense(hidden_vb, z_dim, activation=None),
                      scale=tf.layers.dense(
                          hidden_vb, z_dim, activation=tf.nn.softplus))


## Set up inference machinery

In [6]:
# Bind p(x, z) and q(z | x) to the same TensorFlow placeholder for x.
# Edward has several inference engines, here we choose KLqp, which is the Variational Inference
inference = ed.KLqp({z: qz}, data={x: x_ph})
# Next, the solver is one of the standard solvers from Tensorflow
optimizer = tf.train.RMSPropOptimizer(0.01, epsilon=1.0)
# Initialize the inference engine and the variables. 
inference.initialize(optimizer=optimizer)
tf.global_variables_initializer().run()

## Do the training -- simply iterate over epochs, then over mini-batches. 

In [7]:
n_epoch = 10
n_iter_per_epoch = x_train.shape[0] // batch_size
for epoch in range(1, n_epoch + 1):
    print("Epoch: {:3d}: ".format(epoch), end='')
    avg_loss = 0.0

    for t in range(1, n_iter_per_epoch + 1):
        x_batch = next(x_train_generator)
        info_dict = inference.update(feed_dict={x_ph: x_batch})
        avg_loss += info_dict['loss']

    # Print a lower bound to the average marginal likelihood for an
    # image. The loss is -ELBO, so the log likelihood is lower-bunded 
    # by -loss (after dividing by the number of images used to generate the loss)
    avg_loss /= (n_iter_per_epoch * batch_size)
    print("log p(x) >= {:0.3f}".format(-1. * avg_loss))

Epoch:   1: log p(x) >= -296.420
Epoch:   2: log p(x) >= -196.557
Epoch:   3: log p(x) >= -184.731
Epoch:   4: log p(x) >= -175.976
Epoch:   5: log p(x) >= -171.846
Epoch:   6: log p(x) >= -168.593
Epoch:   7: log p(x) >= -166.887
Epoch:   8: log p(x) >= -164.657
Epoch:   9: log p(x) >= -162.874
Epoch:  10: log p(x) >= -161.167
