<a href="https://colab.research.google.com/github/Annmodels/mnist_recon/blob/master/AAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%matplotlib inline
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [0]:
!wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
!wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
!wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
!wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz

In [0]:
!mkdir MNIST_Fashion
!mv *.gz MNIST_Fashion/

In [0]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_Fashion/')

In [0]:
plt.figure(figsize = (1,1))
sample_image = mnist.train.next_batch(1)[0]
sample_image = sample_image.reshape([28,28])
plt.imshow(sample_image , cmap = "gray")
plt.show()

In [0]:
lrate = 0.0002
batch_size = 32
epochs = 50000

In [0]:
image_dim = 784
nn_dim = 128
latent_var_dim = 10
z_noise_dim = 10

In [0]:
z_input = tf.placeholder(tf.float32 , shape = [None , z_noise_dim] , name = "input_noise")
x_input = tf.placeholder(tf.float32 , shape = [None , image_dim] , name = "real_input")

In [0]:
def xavier_init(shape):
  return tf.random_normal(shape = shape , stddev = 1./tf.sqrt(shape[0]/2.0))

In [0]:
disc_wt = {
    "disc_hidn": tf.Variable(xavier_init([latent_var_dim , nn_dim])),
    "disc_final": tf.Variable(xavier_init([nn_dim , 1]))
         }
en_wt = {
    "en_1": tf.Variable(xavier_init([image_dim , nn_dim])),
    "en_hidn": tf.Variable(xavier_init([nn_dim , latent_var_dim]))
        }
d_wt = {
    "d_hidn": tf.Variable(xavier_init([latent_var_dim , nn_dim])),
    "d_final": tf.Variable(xavier_init([nn_dim , image_dim]))
       }

disc_bias = {
    "disc_hidn": tf.Variable(xavier_init([nn_dim])),
    "disc_final": tf.Variable(xavier_init([1]))
            }
en_bias = {
    "en_1": tf.Variable(xavier_init([nn_dim])),
    "en_hidn": tf.Variable(xavier_init([latent_var_dim]))
          }
d_bias = {
        "d_hidn": tf.Variable(xavier_init([nn_dim])),
        "d_final": tf.Variable(xavier_init([image_dim]))
         }

In [0]:
def Discriminator(x):
  hidn_layer = tf.nn.relu(tf.add(tf.matmul(x , disc_wt["disc_hidn"]) , disc_bias["disc_hidn"]))
  final_layer = tf.add(tf.matmul(hidn_layer , disc_wt["disc_final"]) , disc_bias["disc_final"])
  disc_output = tf.nn.sigmoid(final_layer)
  return disc_output

def Encoder(x):
  hidn_layer = tf.nn.relu(tf.add(tf.matmul(x , en_wt["en_1"]) , en_bias["en_1"]))
  en_output = tf.add(tf.matmul(hidn_layer , en_wt["en_hidn"]) , en_bias["en_hidn"])
  return en_output

def Decoder(x):
  hidn_layer = tf.nn.relu(tf.add(tf.matmul(x , d_wt["d_hidn"]) , d_bias["d_hidn"]))
  d_output = tf.add(tf.matmul(hidn_layer , d_wt["d_final"]) , d_bias["d_final"])
  prob = tf.nn.sigmoid(d_output)
  return prob,d_output

In [0]:
latent_output = Encoder(x_input)
_,final_output = Decoder(latent_output)

disc_real_output = Discriminator(z_input)
disc_fake_output = Discriminator(latent_output)

In [0]:
Discriminator_Loss = -tf.reduce_mean(tf.log(disc_real_output)+tf.log(1.0-disc_fake_output))
Encoder_Loss = -tf.reduce_mean(tf.log(disc_fake_output))
Decoder_Loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = final_output , labels = x_input))

In [0]:
Discriminator_var = [disc_wt["disc_hidn"],disc_wt["disc_final"],disc_bias["disc_hidn"],disc_bias["disc_final"]]
Encoder_var = [en_wt["en_1"],en_wt["en_hidn"],en_bias["en_1"],en_bias["en_hidn"]]
Decoder_var = [d_wt["d_hidn"],d_wt["d_final"],d_bias["d_hidn"],d_bias["d_final"]]

In [0]:
Discriminator_optimize = tf.train.AdamOptimizer(learning_rate = lrate).minimize(Discriminator_Loss , var_list = Discriminator_var)
Encoder_optimize = tf.train.AdamOptimizer(learning_rate = lrate).minimize(Encoder_Loss,var_list = Encoder_var)
Decoder_optimize = tf.train.AdamOptimizer(learning_rate = lrate).minimize(Decoder_Loss , var_list = Encoder_var+Decoder_var)

In [0]:
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

for epoch in range(epochs):
  x_batch,_ = mnist.train.next_batch(batch_size)
  z_noise = np.random.uniform(-1.,1.,size = [batch_size , z_noise_dim]) 
  _,D_loss_epoch = sess.run([Decoder_optimize,Decoder_Loss],feed_dict = {x_input:x_batch})
  _,Disc_loss_epoch = sess.run([Discriminator_optimize,Discriminator_Loss],feed_dict = {x_input:x_batch,z_input:z_noise}) 
  _,En_loss_epoch = sess.run([Encoder_optimize,Encoder_Loss],feed_dict = {x_input:x_batch})
  if epoch%2000 == 0:
    print("steps: {0}  D_loss: {1} , En_loss: {2} , Disc_loss: {3}".format(epoch,D_loss_epoch,En_loss_epoch,Disc_loss_epoch))

In [0]:
test_output,_ = Decoder(z_input)
n = 6
canvas = np.empty((28*n,28*n))
for i in range(n):
  z_noise = np.random.uniform(-1.,1.,size = [batch_size , z_noise_dim])
  g = sess.run(test_output , feed_dict = {z_input:z_noise})
  g = -1*(g-1)
  for j in range(n):
    canvas[i*28:(i+1)*28 , j*28:(j+1)*28] = g[j].reshape([28,28])

plt.figure(figsize = (n,n))
plt.imshow(canvas , origin = "upper" , cmap = "gray")
plt.show()