In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

In [2]:
def weight_variable(shape):
    return tf.Variable(tf.random_normal(shape, mean=0.0, stddev=0.01))

def bias_variable(shape):
    return tf.Variable(tf.random_normal(shape, mean=0.0, stddev=0.01))

def layer(x, shape, activation):
    W = weight_variable(shape)
    b = bias_variable([shape[1]])
    return activation(tf.matmul(x, W) + b)

In [3]:
def get_batch(images, batch_size, index):
    start_idx = index * batch_size
    end_idx = (index + 1) * batch_size        
    batch  = images[start_idx: end_idx]
    return batch

In [4]:
def montage_batch(images):
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    m = np.ones(
        (images.shape[1] * n_plots + n_plots + 1,
         images.shape[2] * n_plots + n_plots + 1, 3)) * 0.5

    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter, ...]
                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
                  1 + j + j * img_w:1 + j + (j + 1) * img_w, :] = this_img
    return m

In [5]:
% mkdir img/

def create_images(i, num_examples, sess, test_images, test_labels):
    xs = test_images
    ys =  test_labels

    
    fig_manifold, ax_manifold = plt.subplots(1, 1)
    fig_reconstruction, axs_reconstruction = plt.subplots(2, num_examples, figsize=(10, 2))
    fig_image_manifold, ax_image_manifold = plt.subplots(1, 1)
    
    # Plot example reconstructions from latent layer
    imgs = []
    for img_i in np.linspace(-3, 3, num_examples):
        for img_j in np.linspace(-3, 3, num_examples):
            recon = sess.run( y, 
                feed_dict={
                    z: np.array([[img_i, img_j]], dtype=np.float32)
                }
            )
            imgs.append(np.reshape(recon, (1, 28, 28, 1)))
    imgs_cat = np.concatenate(imgs)
    ax_manifold.imshow(montage_batch(imgs_cat))
    fig_manifold.savefig("img/manifold_%08d.png" % i)


    # Plot example reconstructions
    test_xs = get_batch(test_images, num_examples,0)
    recon = sess.run(y, feed_dict={x: test_xs})
    for example_i in range(num_examples):
        axs_reconstruction[0][example_i].imshow(
            np.reshape(test_xs[example_i, :], (28, 28)),
            cmap='gray')
        axs_reconstruction[1][example_i].imshow(
            np.reshape(
                np.reshape(recon[example_i, ...], (784,)),
                (28, 28)),
            cmap='gray')
        axs_reconstruction[0][example_i].axis('off')
        axs_reconstruction[1][example_i].axis('off')
    fig_reconstruction.savefig('img/reconstruction_%08d.png' % i)

    
    
    # Plot manifold of latent layer
    zs = sess.run(z, feed_dict={x: xs})
    ax_image_manifold.clear()
    ax_image_manifold.scatter(zs[:, 0], zs[:, 1],
        c=np.argmax(ys, 1), alpha=0.2)
    ax_image_manifold.set_xlim([-6, 6])
    ax_image_manifold.set_ylim([-6, 6])
    ax_image_manifold.axis('off')
    fig_image_manifold.savefig('img/image_manifold_%08d.png' % i)

mkdir: img/: File exists


In [6]:
image_size = 28 * 28

input_shape=[None, image_size]
encoder_internal_dim=2048
decoder_internal_dim=2048
latent_dim=2

x = tf.placeholder(tf.float32, input_shape)

softplus = tf.nn.softplus

h_enc1 = layer(x, [image_size, encoder_internal_dim], activation=softplus)
h_enc2 = layer(h_enc1, [encoder_internal_dim, encoder_internal_dim], activation=softplus)
h_enc3 = layer(h_enc2, [encoder_internal_dim, encoder_internal_dim], activation=softplus)

W_mu = weight_variable([encoder_internal_dim, latent_dim])
b_mu = bias_variable([latent_dim])

W_log_sigma = weight_variable([encoder_internal_dim, latent_dim])
b_log_sigma = bias_variable([latent_dim])

    
z_mu = tf.matmul(h_enc3, W_mu) + b_mu
z_log_sigma = tf.matmul(h_enc3, W_log_sigma) + b_log_sigma

# reparametarization trick

# noise gaussian ε ~ N(0, 1)
epsilon = tf.random_normal(tf.stack([tf.shape(x)[0], latent_dim]))

# z = μ+σ^(1/2)*ε
z = z_mu + tf.exp(z_log_sigma/2) * epsilon


h_dec1 = layer(z, [latent_dim, decoder_internal_dim], activation=softplus)
h_dec2 = layer(h_dec1, [decoder_internal_dim, decoder_internal_dim], activation=softplus)
h_dec3 = layer(h_dec2, [decoder_internal_dim, decoder_internal_dim], activation=softplus)

# log(p(x|z)) (p is Bernoulli) reconstruction loss
y = layer(h_dec3, [decoder_internal_dim, image_size], activation=tf.nn.sigmoid)
log_px_given_z = -tf.reduce_sum(x * tf.log(y + 1e-10) + (1 - x) * tf.log(1 - y + 1e-10), 1)


# KLD(q(z|x)||p(z)) ~ -(1/2) * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl_div = - (1/2)* tf.reduce_sum(1.0 + 2.0 * z_log_sigma - tf.square(z_mu) - tf.exp(2.0 * z_log_sigma),1)

    
cost = tf.reduce_mean(log_px_given_z + kl_div)



In [11]:
def train(train_images, validation_images, test_images, test_labels, learning_rate = 0.001):

    
    
    sess = tf.Session()
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    sess.run(tf.global_variables_initializer())

    
    batch_size = 100
    num_epochs = 1

    num_train_batches = len(train_images) // batch_size
    num_validation_batches = len(validation_images) // batch_size

    print("num of train batches: ", num_train_batches)
    print("num of validation batches: ", num_validation_batches)

    i = 0
    for epoch in range(num_epochs):
        print("epoch No.", epoch)
        
        for batch_idx in range(num_train_batches):            
            batch  = get_batch(train_images, batch_size, batch_idx)
            
            # train
            sess.run(optimizer,feed_dict={x: batch})
            
            if batch_idx % 10 == 0:
                
                train_cost = sess.run(cost, feed_dict={x: batch})
                i += 1
                create_images(i, 20, sess, test_images, test_labels)
                print('train cost per a batch: ', train_cost)
      


        valid_cost = 0
        for i in range(num_validation_batches):
            batch  = get_batch(validation_images, batch_size, i)
            valid_cost += sess.run(cost, feed_dict={x: batch})
        print('validation cost per a batch:', valid_cost / num_validation_batches)


In [9]:
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets('MNIST_DATA', one_hot=True)

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting MNIST_DATA/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting MNIST_DATA/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_DATA/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_DATA/t10k-labels-idx1-ubyte.gz


In [None]:
train_images = mnist.train.images
validation_images = mnist.validation.images
test_images = mnist.test.images
test_labels = mnist.test.labels

train(train_images, validation_images, test_images, test_labels)

num of train batches:  550
num of validation batches:  50
epoch No. 0
train cost per a batch:  4.75853e+13
train cost per a batch:  245.436
train cost per a batch:  221.135
train cost per a batch:  231.673
train cost per a batch:  215.627
train cost per a batch:  221.218




train cost per a batch:  225.954
train cost per a batch:  215.37
train cost per a batch:  235.871
train cost per a batch:  237.109
train cost per a batch:  215.661
train cost per a batch:  215.611
train cost per a batch:  209.666
train cost per a batch:  208.207
train cost per a batch:  211.605
train cost per a batch:  231.556
train cost per a batch:  213.134
train cost per a batch:  231.656
train cost per a batch:  232.033
