In [0]:
%tensorflow_version 1.x

In [27]:
!pip install git+https://github.com/thu-ml/zhusuan.git

Collecting git+https://github.com/thu-ml/zhusuan.git
  Cloning https://github.com/thu-ml/zhusuan.git to /tmp/pip-req-build-seo9dwbx
  Running command git clone -q https://github.com/thu-ml/zhusuan.git /tmp/pip-req-build-seo9dwbx
Building wheels for collected packages: zhusuan
  Building wheel for zhusuan (setup.py) ... [?25l[?25hdone
  Created wheel for zhusuan: filename=zhusuan-0.4.0-py2.py3-none-any.whl size=73591 sha256=b1d4cf70ba09898308ccbc46fb9431c41b641979cc31724de93cd3932bf65b86
  Stored in directory: /tmp/pip-ephem-wheel-cache-3_v7hqh4/wheels/45/cb/f5/bfd913ae94924c3151ac7d20dab61be39e90a2b07bdc6cb75e
Successfully built zhusuan


In [0]:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
import gzip
import tarfile
import zipfile
import math
import time
import numpy as np
import six
from six.moves import urllib, range
from six.moves import cPickle as pickle
import tensorflow as tf
import matplotlib.pyplot as plt
import zhusuan as zs
import conf
from utils import save_image_collections

In [29]:
def download_dataset(url, path):
    print('Downloading data from %s' % url)
    urllib.request.urlretrieve(url, path)


def to_one_hot(x, depth):
    """
    Get one-hot representation of a 1-D numpy array of integers.
    :param x: 1-D Numpy array of type int.
    :param depth: A int.
    :return: 2-D Numpy array of type int.
    """
    ret = np.zeros((x.shape[0], depth))
    ret[np.arange(x.shape[0]), x] = 1
    return ret


def load_mnist_realval(path, one_hot=True, dequantify=False):
    """
    Loads the real valued MNIST dataset.
    :param path: Path to the dataset file.
    :param one_hot: Whether to use one-hot representation for the labels.
    :param dequantify:  Whether to add uniform noise to dequantify the data
        following (Uria, 2013).
    :return: The MNIST dataset.
    """
    if not os.path.isfile(path):
        data_dir = os.path.dirname(path)
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(data_dir)
        download_dataset('http://ml.cs.tsinghua.edu.cn/~ziyu/static'
                         '/mnist.pkl.gz', path)

    f = gzip.open(path, 'rb')
    if six.PY2:
        train_set, valid_set, test_set = pickle.load(f)
    else:
        train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
    f.close()
    x_train, t_train = train_set[0], train_set[1]
    x_valid, t_valid = valid_set[0], valid_set[1]
    x_test, t_test = test_set[0], test_set[1]
    
    if dequantify:
        x_train += np.random.uniform(0, 1. / 256,
                                     size=x_train.shape).astype('float32')
        x_valid += np.random.uniform(0, 1. / 256,
                                     size=x_valid.shape).astype('float32')
        x_test += np.random.uniform(0, 1. / 256,
                                    size=x_test.shape).astype('float32')
    n_y = t_train.max() + 1
    t_transform = (lambda x: to_one_hot(x, n_y)) if one_hot else (lambda x: x)
    print("Downloaded")
    return x_train, t_transform(t_train), x_valid, t_transform(t_valid), \
        x_test, t_transform(t_test)


data_path = './mnist.pkl.gz'
x_train, t_train, x_valid, t_valid, x_test, t_test = \
    load_mnist_realval(data_path)

x_valid = np.random.binomial(1, x_valid, size=x_valid.shape)
x_test = np.random.binomial(1, x_test, size=x_test.shape)
print("Dataset Ready")

Downloaded
Dataset Ready


In [0]:
@zs.meta_bayesian_net(scope="gen", reuse_variables=True)
def build_gen(x_dim, z_dim, n, n_particles=1):
    bn = zs.BayesianNet()
    z_mean = tf.zeros([n, z_dim])
    z = bn.normal("z", z_mean, std=1., group_ndims=1, n_samples=n_particles)
    h = tf.layers.dense(z, 500, activation=tf.nn.relu)
    h = tf.layers.dense(h, 500, activation=tf.nn.relu)
    x_logits = tf.layers.dense(h, x_dim)
    bn.deterministic("x_mean", tf.sigmoid(x_logits))
    bn.bernoulli("x", x_logits, group_ndims=1)
    return bn


@zs.reuse_variables(scope="q_net")
def build_q_net(x,y, z_dim, n_z_per_x):
    bn = zs.BayesianNet()
    inputs = tf.concat([x,y],axis=-1)
    h = tf.layers.dense(tf.cast(inputs, tf.float32), 500, activation=tf.nn.relu)
    h = tf.layers.dense(h, 500, activation=tf.nn.relu)
    z_mean = tf.layers.dense(h, z_dim)
    z_logstd = tf.layers.dense(h, z_dim)
    bn.normal("z", z_mean, logstd=z_logstd, group_ndims=1, n_samples=n_z_per_x)
    return bn

In [0]:
x_dim = x_train.shape[1]
y_dim = t_train.shape[1]
z_dim = 40
  
n_particles = tf.placeholder(tf.int32, shape=[], name="n_particles")
x_input = tf.placeholder(tf.float32, shape=[None, x_dim], name="x")
x = tf.cast(tf.less(tf.random_uniform(tf.shape(x_input)), x_input),
                tf.int32)
y_input = tf.placeholder(tf.float32, shape=[None, y_dim], name="y")
y = tf.cast(tf.less(tf.random_uniform(tf.shape(y_input)), y_input),
                tf.int32)
n = tf.placeholder(tf.int32, shape=[], name="n")

model = build_gen(x_dim, z_dim, n, n_particles)
variational = build_q_net(x,y, z_dim, n_particles)
lower_bound = zs.variational.elbo(
        model, {"x": x}, variational=variational, axis=0)
cost = tf.reduce_mean(lower_bound.sgvb())
lower_bound = tf.reduce_mean(lower_bound)

is_log_likelihood = tf.reduce_mean(
        zs.is_loglikelihood(model, {"x": x}, proposal=variational, axis=0))

optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
infer_op = optimizer.minimize(cost)

x_gen = tf.reshape(model.observe()["x_mean"], [-1, 28, 28, 1])

In [0]:
epochs = 300
batch_size = 128
iters = x_train.shape[0] // batch_size
save_freq = 10
test_freq = 100
test_batch_size = 400
test_iters = x_test.shape[0] // test_batch_size
result_path = "results/vae"

In [33]:
with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(1, epochs + 1):
            time_epoch = -time.time()
            np.random.shuffle(x_train)
            lbs = []
            for t in range(iters):
                x_batch = x_train[t * batch_size:(t + 1) * batch_size]
                y_batch = t_train[t * batch_size:(t+1) * batch_size]
                _, lb = sess.run([infer_op, lower_bound],
                                 feed_dict={x_input: x_batch,
                                            y_input:y_batch,
                                            n_particles: 1,
                                            n: batch_size})
                lbs.append(lb)
            time_epoch += time.time()
            print("Epoch {} ({:.1f}s): Lower bound = {}".format(
                epoch, time_epoch, np.mean(lbs)))
            
            if epoch % test_freq == 0:
                time_test = -time.time()
                test_lbs, test_lls = [], []
                for t in range(test_iters):
                    test_x_batch = x_test[t * test_batch_size:
                                          (t + 1) * test_batch_size]
                    test_y_batch = t_test[t * test_batch_size:
                                          (t + 1) * test_batch_size]                      
                    test_lb = sess.run(lower_bound,
                                       feed_dict={x: test_x_batch,
                                                  y: test_y_batch,
                                                  n_particles: 1,
                                                  n: test_batch_size})
                    test_ll = sess.run(is_log_likelihood,
                                       feed_dict={x: test_x_batch,
                                                  y: test_y_batch,
                                                  n_particles: 1000,
                                                  n: test_batch_size})
                    test_lbs.append(test_lb)
                    test_lls.append(test_ll)
                time_test += time.time()
                print(">>> TEST ({:.1f}s)".format(time_test))
                print(">> Test lower bound = {}".format(np.mean(test_lbs)))
                print('>> Test log likelihood (IS) = {}'.format(
                    np.mean(test_lls)))
                if epoch % save_freq == 0:
                  images = sess.run(x_gen, feed_dict={n: 100, n_particles: 1})
                  name = os.path.join(result_path,
                                    "cvae.epoch.{}.png".format(epoch))
                  save_image_collections(images, name)
                

Epoch 1 (2.2s): Lower bound = -180.97695922851562
Epoch 2 (1.9s): Lower bound = -130.2580108642578
Epoch 3 (1.9s): Lower bound = -118.62026977539062
Epoch 4 (1.9s): Lower bound = -113.35120391845703
Epoch 5 (1.9s): Lower bound = -110.1614990234375
Epoch 6 (2.0s): Lower bound = -108.26241302490234
Epoch 7 (1.9s): Lower bound = -106.75016021728516
Epoch 8 (1.9s): Lower bound = -105.6695327758789
Epoch 9 (1.9s): Lower bound = -104.6786880493164
Epoch 10 (1.9s): Lower bound = -103.86714935302734
>>> TEST (10.1s)
>> Test lower bound = -103.10693359375
>> Test log likelihood (IS) = -97.78926086425781
Epoch 11 (1.8s): Lower bound = -103.25358581542969
Epoch 12 (1.8s): Lower bound = -102.75040435791016
Epoch 13 (1.9s): Lower bound = -102.16028594970703
Epoch 14 (1.9s): Lower bound = -101.72972869873047
Epoch 15 (1.9s): Lower bound = -101.29808044433594
Epoch 16 (2.0s): Lower bound = -100.91851043701172
Epoch 17 (1.9s): Lower bound = -100.6188735961914
Epoch 18 (2.0s): Lower bound = -100.276924

In [34]:
!zip -r /content/file.zip /content/results/vae
from google.colab import files

updating: content/results/vae/ (stored 0%)
updating: content/results/vae/cvae.epoch.220.png (stored 0%)
updating: content/results/vae/cvae.epoch.10.png (stored 0%)
updating: content/results/vae/cvae.epoch.120.png (stored 0%)
updating: content/results/vae/cvae.epoch.240.png (stored 0%)
updating: content/results/vae/cvae.epoch.20.png (stored 0%)
updating: content/results/vae/cvae.epoch.210.png (stored 0%)
updating: content/results/vae/cvae.epoch.40.png (stored 0%)
updating: content/results/vae/cvae.epoch.110.png (stored 0%)
updating: content/results/vae/cvae.epoch.160.png (stored 0%)
updating: content/results/vae/cvae.epoch.70.png (stored 0%)
updating: content/results/vae/cvae.epoch.100.png (stored 0%)
updating: content/results/vae/cvae.epoch.280.png (stored 0%)
updating: content/results/vae/cvae.epoch.270.png (stored 0%)
updating: content/results/vae/cvae.epoch.30.png (stored 0%)
updating: content/results/vae/cvae.epoch.50.png (stored 0%)
updating: content/results/vae/cvae.epoch.290.png

In [35]:
def construct_numvec(digit, z = None):
    out = np.zeros((1, z_dim + y_dim))
    out[:, digit + z_dim] = 1.
    if z is None:
        return(out)
    else:
        for i in range(len(z)):
            out[:,i] = z[i]
        return(out)
    
sample_3 = construct_numvec(3)
print(sample_3)

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.
  0. 0.]]


In [0]:
dig = 9
sides = 8
max_z = 1.5

img_it = 0
for i in range(0, sides):
    z1 = (((i / (sides-1)) * max_z)*2) - max_z
    for j in range(0, sides):
        z2 = (((j / (sides-1)) * max_z)*2) - max_z
        z_ = [z1, z2]
        vec = construct_numvec(dig, z_)
        decoded = tf.reshape(model.observe()["x_mean"], [-1, 28, 28, 1])
        #decoded = model.observe()
        plt.subplot(sides, sides, 1 + img_it)
        img_it +=1
        plt.imshow(x_gen, cmap = plt.cm.gray), axis('off')
plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=.2)
plt.show()