In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

In [2]:
def weight_variable(shape):
    return tf.Variable(tf.random_normal(shape, mean=0.0, stddev=0.01))

def bias_variable(shape):
    return tf.Variable(tf.random_normal(shape, mean=0.0, stddev=0.01))

In [3]:
def layer(x, shape, activation):
    W = weight_variable(shape)
    b = bias_variable([shape[1]])
    return activation(tf.matmul(x, W) + b)

In [4]:
def montage_batch(images):
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    m = np.ones(
        (images.shape[1] * n_plots + n_plots + 1,
         images.shape[2] * n_plots + n_plots + 1, 3)) * 0.5

    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter, ...]
                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
                  1 + j + j * img_w:1 + j + (j + 1) * img_w, :] = this_img
    return m

In [5]:
def create_images(i, num_examples, sess, mnist):
    test_xs, _ = mnist.test.next_batch(num_examples)
    xs, ys = mnist.test.images, mnist.test.labels
    fig_manifold, ax_manifold = plt.subplots(1, 1)
    fig_reconstruction, axs_reconstruction = plt.subplots(2, num_examples, figsize=(10, 2))
    fig_image_manifold, ax_image_manifold = plt.subplots(1, 1)
    # Plot example reconstructions from latent layer
    imgs = []
    for img_i in np.linspace(-3, 3, num_examples):
        for img_j in np.linspace(-3, 3, num_examples):
            recon = sess.run(
                y, feed_dict={z: np.array([[img_i, img_j]], dtype=np.float32)})
            imgs.append(np.reshape(recon, (1, 28, 28, 1)))
    imgs_cat = np.concatenate(imgs)
    ax_manifold.imshow(montage_batch(imgs_cat))
    fig_manifold.savefig('manifold_%08d.png' % i)


    # Plot example reconstructions
    recon = sess.run(y, feed_dict={x: test_xs})
    for example_i in range(num_examples):
        axs_reconstruction[0][example_i].imshow(
            np.reshape(test_xs[example_i, :], (28, 28)),
            cmap='gray')
        axs_reconstruction[1][example_i].imshow(
            np.reshape(
                np.reshape(recon[example_i, ...], (784,)),
                (28, 28)),
            cmap='gray')
        axs_reconstruction[0][example_i].axis('off')
        axs_reconstruction[1][example_i].axis('off')
    fig_reconstruction.savefig('reconstruction_%08d.png' % i)

    # Plot manifold of latent layer
    zs = sess.run(z, feed_dict={x: xs})
    ax_image_manifold.clear()
    ax_image_manifold.scatter(zs[:, 0], zs[:, 1],
        c=np.argmax(ys, 1), alpha=0.2)
    ax_image_manifold.set_xlim([-6, 6])
    ax_image_manifold.set_ylim([-6, 6])
    ax_image_manifold.axis('off')
    fig_image_manifold.savefig('image_manifold_%08d.png' % i)

In [6]:
image_size = 28 * 28

input_shape=[None, image_size]
encoder_internal_dim=2048
decoder_internal_dim=2048
latent_dim=2

x = tf.placeholder(tf.float32, input_shape)

softplus = tf.nn.softplus

h_enc1 = layer(x, [image_size, encoder_internal_dim], activation=softplus)
h_enc2 = layer(h_enc1, [encoder_internal_dim, encoder_internal_dim], activation=softplus)
h_enc3 = layer(h_enc2, [encoder_internal_dim, encoder_internal_dim], activation=softplus)

W_mu = weight_variable([encoder_internal_dim, latent_dim])
b_mu = bias_variable([latent_dim])

W_log_sigma = weight_variable([encoder_internal_dim, latent_dim])
b_log_sigma = bias_variable([latent_dim])

    
z_mu = tf.matmul(h_enc3, W_mu) + b_mu
z_log_sigma = tf.matmul(h_enc3, W_log_sigma) + b_log_sigma

# reparametarization trick

# noise gaussian ε ~ N(0, 1)
epsilon = tf.random_normal(tf.stack([tf.shape(x)[0], latent_dim]))

# z = μ+σ^(1/2)*ε
z = z_mu + tf.exp(z_log_sigma/2) * epsilon


h_dec1 = layer(z, [latent_dim, decoder_internal_dim], activation=softplus)
h_dec2 = layer(h_dec1, [decoder_internal_dim, decoder_internal_dim], activation=softplus)
h_dec3 = layer(h_dec2, [decoder_internal_dim, decoder_internal_dim], activation=softplus)

# log(p(x|z)) (p is Bernoulli) reconstruction loss
y = layer(h_dec3, [decoder_internal_dim, image_size], activation=tf.nn.sigmoid)
log_px_given_z = -tf.reduce_sum(x * tf.log(y + 1e-10) + (1 - x) * tf.log(1 - y + 1e-10), 1)


# KLD(q(z|x)||p(z)) ~ -(1/2) * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl_div = - (1/2)* tf.reduce_sum(1.0 + 2.0 * z_log_sigma - tf.square(z_mu) - tf.exp(2.0 * z_log_sigma),1)

    
cost = tf.reduce_mean(log_px_given_z + kl_div)



In [7]:
def train(mnist, learning_rate = 0.001):

    
    
    sess = tf.Session()
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    sess.run(tf.global_variables_initializer())

    
    batch_size = 100
    num_epochs = 50
    num_examples = 20
    

    num_train_batches = mnist.train.num_examples // batch_size
    num_validation_batches = mnist.validation.num_examples // batch_size

    print("num of train batches: ", num_train_batches)
    print("num of validation batches: ", num_validation_batches)

    i = 0
    for epoch in range(num_epochs):
        print("epoch No.", epoch)
        
        for batch_idx in range(num_train_batches):
            batch_xs, _ = mnist.train.next_batch(batch_size)
            
            # train
            sess.run(optimizer,feed_dict={x: batch_xs})
            
            if batch_idx % 10 == 0:
                
                train_cost = sess.run(cost,feed_dict={x: batch_xs})
                i += 1
                create_images(i, num_examples, sess, mnist)
                print('train cost per a batch: ', train_cost)
      


        valid_cost = 0
        for _ in range(num_validation_batches):
            batch_xs, _ = mnist.validation.next_batch(batch_size)
            valid_cost += sess.run(cost,feed_dict={x: batch_xs})
        print('validation cost per a batch:', valid_cost / num_validation_batches)


In [8]:
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

train(mnist)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
num of train batches:  550
num of validation batches:  50
epoch No. 0
train cost per a batch:  1.25401e+13
train cost per a batch:  228.756
train cost per a batch:  226.027
train cost per a batch:  225.633
train cost per a batch:  220.778
train cost per a batch:  225.618




train cost per a batch:  222.447
train cost per a batch:  220.672
train cost per a batch:  224.332
train cost per a batch:  225.418
train cost per a batch:  230.276
train cost per a batch:  224.265
train cost per a batch:  220.873
train cost per a batch:  224.523
train cost per a batch:  213.963
train cost per a batch:  223.381
train cost per a batch:  228.039
train cost per a batch:  228.262
train cost per a batch:  222.985
train cost per a batch:  226.975
train cost per a batch:  223.748
train cost per a batch:  224.983
train cost per a batch:  227.536
train cost per a batch:  222.676
train cost per a batch:  221.639
train cost per a batch:  229.009
train cost per a batch:  228.755
train cost per a batch:  226.053
train cost per a batch:  217.852
train cost per a batch:  225.329
train cost per a batch:  222.843
train cost per a batch:  234.125
train cost per a batch:  231.86
train cost per a batch:  221.132
train cost per a batch:  223.133
train cost per a batch:  226.057
train cost 

KeyboardInterrupt: 