# Varational Autoencoders in ZhuSuan

### This code implements a VAE with Gaussian prior and Bernoulli likelihood based on ZhuSuan.
#### The framework has been setup, please only fill the space with "TODO" comments.
#### You may see some detailed instructions in the comments. For detailed usage of ZhuSuan, please see the documentation on http://zhusuan.readthedocs.io/en/latest/concepts.html

#### If you have any questions, please contact with me: lucheng.lc15@gmail.com

In [None]:
import os
import time
import tensorflow as tf
import numpy as np
import zhusuan as zs

import matplotlib.pyplot as plt
from skimage import io, img_as_ubyte
from skimage.exposure import rescale_intensity

## I. Load MNIST Dataset
We use tensorflow tools to load dataset for convenience.

In [None]:
(data_train, label_train), _ = tf.keras.datasets.mnist.load_data()

In [None]:
print(data_train.shape, label_train.shape)

## II. Some Utils for Training and Visualization

Here are some functions for next steps. You don't need to know the details.

In [None]:
def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True):
    assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both'
    n = arrays[0].shape[0]
    assert all(a.shape[0] == n for a in arrays[1:])
    inds = np.arange(n)
    if shuffle: np.random.shuffle(inds)
    sections = np.arange(0, n, batch_size)[1:] if num_batches is None else num_batches
    for batch_inds in np.array_split(inds, sections):
        if include_final_partial_batch or len(batch_inds) == batch_size:
            yield tuple(a[batch_inds] for a in arrays)

In [None]:
def save_image_collections(x, filename, shape=(10, 10), scale_each=False,
                           transpose=False):
    """
    :param shape: tuple
        The shape of final big images.
    :param x: numpy array
        Input image collections. (number_of_images, rows, columns, channels) or
        (number_of_images, channels, rows, columns)
    :param scale_each: bool
        If true, rescale intensity for each image.
    :param transpose: bool
        If true, transpose x to (number_of_images, rows, columns, channels),
        i.e., put channels behind.
    :return: `uint8` numpy array
        The output image.
    """
    n = x.shape[0]
    if transpose:
        x = x.transpose(0, 2, 3, 1)
    if scale_each is True:
        for i in range(n):
            x[i] = rescale_intensity(x[i], out_range=(0, 1))
    n_channels = x.shape[3]
    x = img_as_ubyte(x)
    r, c = shape
    if r * c < n:
        print('Shape too small to contain all images')
    h, w = x.shape[1:3]
    ret = np.zeros((h * r, w * c, n_channels), dtype='uint8')
    for i in range(r):
        for j in range(c):
            if i * c + j < n:
                ret[i * h:(i + 1) * h, j * w:(j + 1) * w, :] = x[i * c + j]
    ret = ret.squeeze()
    io.imsave(filename, ret)

## III. Build the Generative Model by ZhuSuan

Define the generative model according to the generative process.

* TODO: complete the network.
    

In [None]:
@zs.meta_bayesian_net(scope="gen", reuse_variables=True)
def build_gen(x_dim, z_dim, n, n_particles=1):
    bn = zs.BayesianNet()
    z_mean = tf.zeros([n, z_dim])
    z = bn.normal("z", z_mean, std=1., group_ndims=1, n_samples=n_particles)
    h = tf.layers.dense(z, 500, activation=tf.nn.relu)
    ''' 
        TODO: add one more mlp layer of 500 hidden units with h as the input and return h
    '''
    x_logits = tf.layers.dense(h, x_dim)
    bn.deterministic("x_mean", tf.sigmoid(x_logits))
    bn.bernoulli("x", x_logits, group_ndims=1)
    return bn

Define the variational posterior model.
- TODO: complete the network

In [None]:
@zs.reuse_variables(scope="q_net")
def build_q_net(x, z_dim, n_z_per_x):
    '''
        TODO: define a Bayesian network
        HINT: see the generative model
    '''
    x = tf.cast(x, tf.float32)
    '''
        TODO: add two more mlp layers of 500 hidden units
        HINT: from x to h
    '''
    z_mean = tf.layers.dense(h, z_dim)
    z_logstd = tf.layers.dense(h, z_dim)
    '''
        TODO: define q(z|x) using the Gaussian distribution of ZhuSuan
            > given input
                - "z", z_mean, z_logstd (note that it is not std), n_z_per_x
                - set group_ndims as 1
            > e.g.
                - z = bn.normal("z", ..., logstd=..., group_ndims=..., n_samples=...)
    '''
    return bn

## IV. Set the Hyperparameters

The following hyperparameters work well. However, you can modify them anyway.

In [None]:
x_dim = 28 * 28
z_dim = 40
n_particles = 1

learning_rate = 0.001
epochs = 3000
batch_size = 128
save_freq = 10

## V. Build the Model and the Loss

In [None]:
x_input = tf.placeholder(tf.float32, shape=[None, x_dim], name="x")
# Change x to binary form.
x = tf.cast(tf.less(tf.random_uniform(tf.shape(x_input)), x_input), tf.int32)
n = tf.placeholder(tf.int32, shape=[], name="n")

model = build_gen(x_dim, z_dim, n, n_particles)
variational = build_q_net(x, z_dim, n_particles)

# Define the ELBO and optimize it by SGVB.
lower_bound = zs.variational.elbo(model, {"x": x}, variational=variational, axis=0)
cost = tf.reduce_mean(lower_bound.sgvb())
lower_bound = tf.reduce_mean(lower_bound)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
infer_op = optimizer.minimize(cost)

x_gen = tf.reshape(model.observe()["x_mean"], [-1, 28, 28, 1])

## VI. Train the Model and Save the Generated Samples

In [None]:
# All generated samples are saved to this path.
result_path = 'images'
if not os.path.exists(result_path):
    os.mkdir(result_path)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for epoch in range(1, epochs + 1):
        time_epoch = -time.time()
        lbs = []
        for i_step, (x_batch, _) in enumerate(iterbatches(
            [data_train, label_train], batch_size=batch_size, include_final_partial_batch=False,
        )):
            x_batch = x_batch.reshape([-1, 28*28])
            _, lb = sess.run([infer_op, lower_bound],
                             feed_dict={x_input: x_batch,
                                        n: batch_size})
            lbs.append(lb)
        time_epoch += time.time()
        print("Epoch {} ({:.1f}s): Lower bound = {}".format(
            epoch, time_epoch, np.mean(lbs)))

        if epoch % save_freq == 0 or epoch == 1:
            images = sess.run(x_gen, feed_dict={n: 100})
            name = os.path.join(result_path,
                                "vae_epoch_{}.png".format(epoch))
            save_image_collections(images, name)