In [None]:
import tensorflow as tf
from tensorflow.contrib import slim
import numpy as np
from tqdm import tqdm_notebook
from matplotlib import pyplot as plt
import os

In [None]:
BATCH_SIZE = 512
LR = 2e-5

In [None]:
def get_data_samples(N):
    data = tf.random_uniform([N], minval=0, maxval=4, dtype=tf.int32)
    return data

def encoder_func(x):
    net = x
    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)
    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)
    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)

    zmean = slim.fully_connected(net, 2, activation_fn=None)
    zlogstd = slim.fully_connected(net, 2, activation_fn=None)

    return zmean, zlogstd


def decoder_func(z):
    net = z
    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)
    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)
    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)

    xlogits = slim.fully_connected(net, 4, activation_fn=None)
    return xlogits

def create_scatter(x_test_labels, eps_test, savepath=None):
    plt.figure(figsize=(5,5), facecolor='w')

    for i in range(4):
        z_out = sess.run(z_inferred, feed_dict={x_real_labels: x_test_labels[i], eps: eps_test})
        plt.scatter(z_out[:, 0], z_out[:, 1],  edgecolor='none', alpha=0.5)

    plt.xlim(-3, 3); plt.ylim(-3.5, 3.5)

    plt.axis('off')
    if savepath:
        plt.savefig(savepath)

encoder = tf.make_template('encoder', encoder_func)
decoder = tf.make_template('decoder', decoder_func)

In [None]:
eps = tf.random_normal([BATCH_SIZE, 2])
x_real_labels = get_data_samples(BATCH_SIZE)
x_real = tf.one_hot(x_real_labels, 4)
zmean, zlogstd = encoder(x_real)
z_inferred = zmean + eps*tf.exp(zlogstd)
x_reconstr_logits = decoder(z_inferred)

reconstr_err = tf.reduce_sum(
    tf.nn.sigmoid_cross_entropy_with_logits(labels=x_real, logits=x_reconstr_logits),
    axis=1
)

KL = tf.reduce_sum(0.5*tf.square(z_inferred) - zlogstd - 0.5, 1)

loss = tf.reduce_mean(reconstr_err + KL)
optimizer = tf.train.AdamOptimizer(LR)
train_op = optimizer.minimize(loss)


In [None]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [None]:
x_test_labels = [[i] * BATCH_SIZE for i in range(4)]
eps_test = np.random.randn(BATCH_SIZE, 2) 

outdir = './out_vae'
if not os.path.exists(outdir):
    os.makedirs(outdir)
    
progress = tqdm_notebook(range(100000))
for i in progress:
    ELBO_out, _ = sess.run([loss, train_op])

    progress.set_description('ELBO = %.2f' % ELBO_out)
    if i % 100 == 0:
        create_scatter(x_test_labels, eps_test, savepath=os.path.join(outdir, '%08d.png' % i))