[View in Colaboratory](https://colab.research.google.com/github/Xinliang-Zhao/Deep_Learning/blob/master/variational_autoencoder_tensorboard.ipynb)

In [0]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import os
from tensorflow.contrib.tensorboard.plugins import projector
from tensorflow.examples.tutorials.mnist import input_data

tfd = tf.contrib.distributions

LOG_DIR = 'embedded_test_sample'
NAME_TO_VISUALISE_VARIABLE = "mnist_embedding"

path_for_mnist_sprites =  os.path.join(LOG_DIR,'mnistdigits.png')
path_for_mnist_metadata =  os.path.join(LOG_DIR,'metadata.tsv')

def make_encoder(data, code_size):
  x = tf.layers.flatten(data)
  x = tf.layers.dense(x, 200, tf.nn.relu)
  x = tf.layers.dense(x, 200, tf.nn.relu)
  loc = tf.layers.dense(x, code_size)
  scale = tf.layers.dense(x, code_size, tf.nn.softplus)
  return tfd.MultivariateNormalDiag(loc, scale)


def make_prior(code_size):
  loc = tf.zeros(code_size)
  scale = tf.ones(code_size)
  return tfd.MultivariateNormalDiag(loc, scale)


def make_decoder(code, data_shape):
  x = code
  x = tf.layers.dense(x, 200, tf.nn.relu)
  x = tf.layers.dense(x, 200, tf.nn.relu)
  logit = tf.layers.dense(x, np.prod(data_shape))
  logit = tf.reshape(logit, [-1] + data_shape)
  return tfd.Independent(tfd.Bernoulli(logit), 2)


def plot_codes(fig, ax, codes, labels):
  im = ax.scatter(codes[:, 0], codes[:, 1], s=20, c=labels, alpha=1, cmap='viridis')
  fig.colorbar(im, ax=ax)
  ax.set_aspect('equal')
  ax.set_xlim(codes.min() - .1, codes.max() + .1)
  ax.set_ylim(codes.min() - .1, codes.max() + .1)
  ax.tick_params(
      axis='both', which='both', left='off', bottom='off',
      labelleft='off', labelbottom='off')


def plot_samples(ax, samples):
  for index, sample in enumerate(samples):
    ax[index].imshow(sample, cmap='gray')
    ax[index].axis('off')


data = tf.placeholder(tf.float32, [None, 28, 28])

make_encoder = tf.make_template('encoder', make_encoder)
make_decoder = tf.make_template('decoder', make_decoder)

embedding_size = 2
# Define the model.
prior = make_prior(code_size=embedding_size)
posterior = make_encoder(data, code_size=embedding_size)
code = posterior.sample()

# Define the loss.
likelihood = make_decoder(code, [28, 28]).log_prob(data)
divergence = tfd.kl_divergence(posterior, prior)
elbo = tf.reduce_mean(likelihood - divergence)
optimize = tf.train.AdamOptimizer(0.001).minimize(-elbo)

samples = make_decoder(prior.sample(10), [28, 28]).mean()

mnist = input_data.read_data_sets('MNIST_data/')
train_idx = [idx for (idx, label) in enumerate(mnist.train.labels) if label == 1]
train_images = mnist.train.images[train_idx]
train_labels = mnist.train.labels[train_idx]
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_dataset = train_dataset.shuffle(buffer_size=len(train_idx), reshuffle_each_iteration=True)
train_dataset = train_dataset.batch(100)
train_dataset = train_dataset.repeat()
iterator = train_dataset.make_one_shot_iterator()
train_features, train_labels = iterator.get_next()
epoch_num = 10
fig, ax = plt.subplots(nrows=epoch_num, ncols=1, figsize = (10, 10 * epoch_num))
test_codes = list()
init = tf.global_variables_initializer()
with tf.train.MonitoredSession() as sess:
  sess.run(init)
  for epoch in range(epoch_num):
    feed = {data: mnist.test.images.reshape([-1, 28, 28])}
    test_elbo = sess.run(elbo, feed_dict = feed)
    test_codes = sess.run(code, feed_dict = feed)
    print('Epoch', epoch, 'elbo', test_elbo)
    ax[epoch].set_ylabel('Epoch {}'.format(epoch))
    plot_codes(fig, ax[epoch], test_codes, mnist.test.labels)
    for _ in range(600):
      feed = {data: sess.run(train_features).reshape([-1, 28, 28])}
      sess.run(optimize, feed_dict = feed)
plt.show()

embedding_var = tf.Variable(test_codes, name=NAME_TO_VISUALISE_VARIABLE)
summary_writer = tf.summary.FileWriter(LOG_DIR)

config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = embedding_var.name

# Specify where you find the metadata
embedding.metadata_path = path_for_mnist_metadata #'metadata.tsv'

# Specify where you find the sprite (we will create this later)
embedding.sprite.image_path = path_for_mnist_sprites #'mnistdigits.png'
embedding.sprite.single_image_dim.extend([28,28])

# Say that you want to visualise the embeddings
projector.visualize_embeddings(summary_writer, config)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver()
  saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"), 1)

In [0]:
%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data

LOG_DIR = 'embedded_test_sample'
NAME_TO_VISUALISE_VARIABLE = "mnist_embedding"

path_for_mnist_sprites =  os.path.join(LOG_DIR,'mnistdigits.png')
path_for_mnist_metadata =  os.path.join(LOG_DIR,'metadata.tsv')

mnist = input_data.read_data_sets('MNIST_data/')

def create_sprite_image(images):
    """Returns a sprite image consisting of images passed as argument. Images should be count x width x height"""
    if isinstance(images, list):
        images = np.array(images)
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    
    
    spriteimage = np.ones((img_h * n_plots ,img_w * n_plots ))
    
    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                spriteimage[i * img_h:(i + 1) * img_h,
                  j * img_w:(j + 1) * img_w] = this_img
    
    return spriteimage

def vector_to_matrix_mnist(mnist_digits):
    """Reshapes normal mnist digit (batch,28*28) to matrix (batch,28,28)"""
    return np.reshape(mnist_digits,(-1,28,28))

def invert_grayscale(mnist_digits):
    """ Makes black white, and white black """
    return 1-mnist_digits
  
to_visualise = mnist.test.images
to_visualise = vector_to_matrix_mnist(to_visualise)
to_visualise = invert_grayscale(to_visualise)

sprite_image = create_sprite_image(to_visualise)

plt.imsave(path_for_mnist_sprites,sprite_image,cmap='gray')
plt.imshow(sprite_image,cmap='gray')

with open(path_for_mnist_metadata,'w') as f:
    f.write("Index\tLabel\n")
    for index,label in enumerate(mnist.test.labels):
        f.write("%d\t%d\n" % (index,label))