<a href="https://colab.research.google.com/github/Annmodels/mnist_recon/blob/master/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

In [0]:
from tensorflow.examples.tutorials.mnist import input_data
database = input_data.read_data_sets('/content/data', one_hot = True)

In [0]:
learning_param = 0.001
epochs = 30000
batch_size = 32

In [0]:
image_dim = 784
nn_dim = 512
latent_var_dim = 2

In [0]:
def xavier(in_shape):
  val = tf.random_normal(shape = in_shape,stddev = 1./tf.sqrt(in_shape[0]/2.))
  return val

In [0]:
weights = {
    "wt_en_hidden": tf.Variable(xavier([image_dim,nn_dim])),
    "wt_mean_hidden": tf.Variable(xavier([nn_dim,latent_var_dim])),
    "wt_stddev_hidden": tf.Variable(xavier([nn_dim,latent_var_dim])),
    "wt_d_hidden": tf.Variable(xavier([latent_var_dim,nn_dim])),
    "wt_d": tf.Variable(xavier([nn_dim,image_dim]))
}
biases = {
    "bias_en_hidden": tf.Variable(xavier([nn_dim])),
    "bias_mean_hidden": tf.Variable(xavier([latent_var_dim])),
    "bias_stddev_hidden": tf.Variable(xavier([latent_var_dim])),
    "bias_d_hidden": tf.Variable(xavier([nn_dim])),
    "bias_d": tf.Variable(xavier([image_dim]))
}

In [0]:
image_x = tf.placeholder(tf.float32,shape = [None,image_dim])

In [0]:
en_layer = tf.add(tf.matmul(image_x,weights["wt_en_hidden"]),biases["bias_en_hidden"])
en_layer = tf.nn.tanh(en_layer)
mean_layer = tf.add(tf.matmul(en_layer,weights["wt_mean_hidden"]),biases["bias_mean_hidden"])
stddev_layer = tf.add(tf.matmul(en_layer,weights["wt_stddev_hidden"]),biases["bias_stddev_hidden"])

epsilon = tf.random_normal(tf.shape(stddev_layer) , dtype = tf.float32 , mean = 0.0 , stddev = 1.0)
latent_layer = mean_layer + tf.exp(0.5*stddev_layer)*epsilon

d_hidden = tf.add(tf.matmul(latent_layer,weights["wt_d_hidden"]),biases["bias_d_hidden"])
d_hidden = tf.nn.tanh(d_hidden)
d_output_layer = tf.add(tf.matmul(d_hidden,weights["wt_d"]),biases["bias_d"])
d_output_layer = tf.nn.sigmoid(d_output_layer)                      

In [0]:
def lossfunc(org_img,recon_img):
  data_fidelity_loss = org_img*tf.log(1e-10 + recon_img)+(1-org_img)*tf.log(1e-10 + 1-recon_img)
  data_fidelity_loss = - tf.reduce_sum(data_fidelity_loss,1)
  
  KL_div_loss = 1 + stddev_layer - tf.square(mean_layer) - tf.exp(stddev_layer)
  KL_div_loss = -0.5*tf.reduce_sum(KL_div_loss,1)
  alpha = 1 
  beta = 1 
  network_loss = tf.reduce_mean(alpha*data_fidelity_loss + beta*KL_div_loss) 
  return(network_loss)

In [0]:
lossval = lossfunc(image_x,d_output_layer)
optimizer = tf.train.RMSPropOptimizer(learning_param).minimize(lossval)
init = tf.global_variables_initializer()

In [0]:
sess = tf.Session()
sess.run(init)
for i in range(epochs):
  xbatch,_ = database.train.next_batch(batch_size)
  _,loss = sess.run([optimizer,lossval],feed_dict = {image_x : xbatch})
  if(i%5000 == 0):
    print("loss is {0} at itr {1}".format(loss,i))

In [0]:
noisy_x = tf.placeholder(tf.float32,shape = [None,latent_var_dim])
d_hidden = tf.add(tf.matmul(noisy_x,weights["wt_d_hidden"]),biases["bias_d_hidden"])
d_hidden = tf.nn.tanh(d_hidden)
d_output_layer = tf.add(tf.matmul(d_hidden,weights["wt_d"]),biases["bias_d"])
d_output_layer = tf.nn.sigmoid(d_output_layer)

In [0]:
n = 20
xlimit = np.linspace(-2,2,n)
ylimit = np.linspace(-2,2,n)
empty_img = np.empty((28*n,28*n))
for i,zi in enumerate(xlimit):
  for j,pi in enumerate(ylimit):
    generated_latent_layer  = np.array([[zi,pi]]*batch_size)
    #generated_latent_layer = np.random.normal(0,1,size = [batch_size , latent_var_dim])
    generated_img = sess.run(d_output_layer,feed_dict = {noisy_x: generated_latent_layer})
    empty_img[(n-i-1)*28:(n-i)*28 , j*28:(j+1)*28] = generated_img[0].reshape(28,28)

plt.figure(figsize = (8,10))
x,y = np.meshgrid(xlimit,ylimit)
plt.imshow(empty_img,origin = "upper", cmap = "gray")
plt.grid(False)
plt.show()
                             

In [0]:
print(batch_size)
xsample , ysample = database.test.next_batch((batch_size + 15000))
print(xsample.shape)
interim = sess.run(latent_layer,feed_dict = {image_x : xsample})
print(interim.shape)
colors = np.argmax(ysample,1)
plt.figure(figsize=(8,6))
plt.scatter(interim[:,0],interim[:,1],c = colors, cmap = 'viridis')
plt.colorbar()
plt.grid()


In [0]:
sess.close()