<a href="https://colab.research.google.com/github/Annmodels/mnist_recon/blob/master/InfoGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%matplotlib inline
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [0]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/content/data', one_hot = True)

In [0]:
lrate = 0.0002
batch_size = 32
epochs = 50000

In [0]:
image_dim = 784
z_dim = 16
c_dim = 10

In [0]:
d_h1 = 128
g_h1 = 256
q_h1 = 128

In [0]:
def xavier_init(shape):
  return tf.random_normal(shape = shape , stddev = 1./tf.sqrt(shape[0]/2.0))

In [0]:
disc_wt = {
    "disc_hidn": tf.Variable(xavier_init([image_dim , d_h1])),
    "disc_final": tf.Variable(xavier_init([d_h1 , 1]))
          }
gen_wt = {
    "gen_hidn": tf.Variable(xavier_init([z_dim+c_dim , g_h1])),
    "gen_final": tf.Variable(xavier_init([g_h1 , image_dim]))
          }
q_wt = {
    "q_hidn": tf.Variable(xavier_init([image_dim , q_h1])),
    "q_final": tf.Variable(xavier_init([q_h1 , c_dim]))
       }

disc_bias = {
    "disc_hidn": tf.Variable(xavier_init([d_h1])),
    "disc_final": tf.Variable(xavier_init([1])),
            }
gen_bias = {
    "gen_hidn": tf.Variable(xavier_init([g_h1])),
    "gen_final": tf.Variable(xavier_init([image_dim]))
       }
q_bias = {
    "q_hidn": tf.Variable(xavier_init([q_h1])),
    "q_final": tf.Variable(xavier_init([c_dim]))
         }

In [0]:
z_input = tf.placeholder(tf.float32 , shape = [None , z_dim])
c_input = tf.placeholder(tf.float32 , shape = [None , c_dim])  
x_input = tf.placeholder(tf.float32 , shape = [None , image_dim])

In [0]:
def z_distribution(m,n):
  return np.random.uniform(-1.,1.,size = [m,n])

def c_distribution(m,c_dim):
  return np.random.multinomial(1,c_dim*[1/c_dim],size = m)

In [0]:
def Discriminator(x):
  d_hidn_layer = tf.nn.relu(tf.add(tf.matmul(x , disc_wt["disc_hidn"]) , disc_bias["disc_hidn"]))
  final_layer = tf.add(tf.matmul(d_hidn_layer , disc_wt["disc_final"]) , disc_bias["disc_final"])
  disc_prob_output = tf.nn.sigmoid(final_layer)
  return disc_prob_output

def Generator(z,c):
  input = tf.concat(axis = 1 , values = [z,c])
  g_hidn_layer = tf.nn.relu(tf.add(tf.matmul(input , gen_wt["gen_hidn"]) , gen_bias["gen_hidn"]))
  final_layer = tf.add(tf.matmul(g_hidn_layer , gen_wt["gen_final"]) , gen_bias["gen_final"])
  gen_prob_output = tf.nn.sigmoid(final_layer)
  return gen_prob_output

def Auxillary_nn(x):
  q_hidn_layer = tf.nn.relu(tf.add(tf.matmul(x , q_wt["q_hidn"]) , q_bias["q_hidn"]))
  final_layer = tf.add(tf.matmul(q_hidn_layer , q_wt["q_final"]) , q_bias["q_final"])
  q_prob_output = tf.nn.softmax(final_layer)
  return q_prob_output

In [0]:
samples_generated = Generator(z_input,c_input)
real_output_disc = Discriminator(x_input)
fake_output_disc = Discriminator(samples_generated)
Estimated_c = Auxillary_nn(samples_generated)

Disc_Loss = - tf.reduce_mean(tf.log(real_output_disc + 1e-7)+tf.log(1.0 - fake_output_disc + 1e-7))
Gen_Loss = -tf.reduce_mean(tf.log(fake_output_disc + 1e-7)) 
Q_Loss = -tf.reduce_mean(-tf.reduce_sum(tf.log( Estimated_c + 1e-7)*c_input,1))

In [0]:
Gen_var = [gen_wt["gen_hidn"],gen_wt["gen_final"],gen_bias["gen_hidn"],gen_bias["gen_final"]]
Disc_var = [disc_wt["disc_hidn"],disc_wt["disc_final"],disc_bias["disc_hidn"],disc_bias["disc_final"]]
Q_var = [q_wt["q_hidn"],q_wt["q_final"],q_bias["q_hidn"],q_bias["q_final"]]

In [0]:
Disc_optimize = tf.train.AdamOptimizer(learning_rate = lrate).minimize(Disc_Loss , var_list = Disc_var)
Gen_optimize = tf.train.AdamOptimizer(learning_rate = lrate).minimize(Gen_Loss , var_list = Gen_var)
Q_optimize = tf.train.AdamOptimizer(learning_rate = lrate).minimize(Q_Loss , var_list = Gen_var+Q_var)

In [0]:
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

for epoch in range(epochs):
  x_batch,_ = mnist.train.next_batch(batch_size)
  z_noise = z_distribution(batch_size , z_dim) 
  c_noise = c_distribution(batch_size , c_dim) 
 
  _,Disc_loss_epoch = sess.run([Disc_optimize,Disc_Loss],feed_dict = {x_input:x_batch,z_input:z_noise,c_input:c_noise}) 
  _,Gen_loss_epoch = sess.run([Gen_optimize,Gen_Loss],feed_dict = {z_input:z_noise,c_input:c_noise})
  _,Aux_loss_epoch = sess.run([Q_optimize,Gen_Loss],feed_dict = {z_input:z_noise,c_input:c_noise})
  if epoch%2000 == 0:
    print("steps: {0}  Disc_loss: {1} , Gen_loss: {2} , Aux_loss: {3}".format(epoch,Disc_loss_epoch,Gen_loss_epoch,Aux_loss_epoch))

In [0]:
test_output = Generator(z_input,c_input)
n = 6
canvas = np.empty((28*n,28*n))
for i in range(n):
  z_noise = z_distribution(batch_size , z_dim)
  
  #c_noise = c_distribution(batch_size , c_dim)
  #or
  id = 0
  c_noise = np.zeros((batch_size,c_dim))
  c_noise[range(batch_size),id] = 1
  
  g = sess.run(test_output , feed_dict = {z_input:z_noise , c_input:c_noise})
  g = -1*(g-1)
  for j in range(n):
    canvas[i*28:(i+1)*28 , j*28:(j+1)*28] = g[j].reshape([28,28])

plt.figure(figsize = (n,n))
plt.imshow(canvas , origin = "upper" , cmap = "gray")
plt.show()