## Generative Adversarial Networks for Natural Language Processing

In [18]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

### State of art weight Initialization strategy

In [33]:
def xavier_init(n_inputs, n_outputs, uniform=True):
  """Set the parameter initialization using the method described.
  This method is designed to keep the scale of the gradients roughly the same
  in all layers.
  Xavier Glorot and Yoshua Bengio (2010):
           Understanding the difficulty of training deep feedforward neural
           networks. International conference on artificial intelligence and
           statistics.
  Args:
    n_inputs: The number of input nodes into each output.
    n_outputs: The number of output nodes for each input.
    uniform: If true use a uniform distribution, otherwise use a normal.
  Returns:
    An initializer.
  """
  if uniform:
    # 6 was used in the paper.
    init_range = tf.sqrt(6.0 / (n_inputs + n_outputs))
    return tf.random_uniform_initializer(-init_range, init_range)
  else:
    # 3 gives us approximately the same limits as above since this repicks
    # values greater than 2 standard deviations from the mean.
    stddev = tf.sqrt(3.0 / (n_inputs + n_outputs))
    return tf.truncated_normal_initializer(stddev=stddev)

In [34]:
'''A recent paper by He, Rang, Zhen and Sun they build on Glorot & Bengio and suggest using 2/size_of_input_neuron
''' 
def xavier_init(size):
    in_dim = size[0]
#     xavier_stddev = 1. / in_dim
#     xavier_stddev = 2. / in_dim
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

### Discriminator

In [22]:
X = tf.placeholder(tf.float32, shape=[None, 784])

D_W1 = tf.Variable(xavier_init([784, 128]))
D_b1 = tf.Variable(tf.zeros(shape=[128]))

D_W2 = tf.Variable(xavier_init([128, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]

In [None]:
def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)

    return D_prob, D_logit

### Generator

In [23]:
Z = tf.placeholder(tf.float32, shape=[None, 100])

G_W1 = tf.Variable(xavier_init([100, 128]))
G_b1 = tf.Variable(tf.zeros(shape=[128]))

G_W2 = tf.Variable(xavier_init([128, 784]))
G_b2 = tf.Variable(tf.zeros(shape=[784]))

theta_G = [G_W1, G_W2, G_b1, G_b2]

In [24]:
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

In [25]:
def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob

In [27]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [28]:
G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)

In [29]:
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

In [30]:
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

In [31]:
minibatch_size = 128
Z_dim = 100

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('out/'):
    os.makedirs('out/')

i = 0

for it in range(1000000):
    if it % 1000 == 0:
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})

        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

    X_mb, _ = mnist.train.next_batch(minibatch_size)

    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(minibatch_size, Z_dim)})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(minibatch_size, Z_dim)})

    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting ../../MNIST_data\train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting ../../MNIST_data\train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting ../../MNIST_data\t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting ../../MNIST_data\t10k-labels-idx1-ubyte.gz
Iter: 0
D loss: 1.793
G_loss: 1.83

Iter: 1000
D loss: 0.005425
G_loss: 9.048

Iter: 2000
D loss: 0.02739
G_loss: 6.6

Iter: 3000
D loss: 0.03566
G_loss: 5.768

Iter: 4000
D loss: 0.1334
G_loss: 5.255

Iter: 5000
D loss: 0.187
G_loss: 4.526

Iter: 6000
D loss: 0.3584
G_loss: 3.865

Iter: 7000
D loss: 0.3289
G_loss: 4.048

Iter: 8000
D loss: 0.4025
G_loss: 2.979

Iter: 9000
D loss: 0.5652
G_loss: 3.069

Iter: 10000
D loss: 0.5653
G_loss: 3.092

Iter: 11000
D loss: 0.6922
G_loss: 2.194

Iter: 12000
D loss: 0.7456
G_loss: 

Iter: 184000
D loss: 0.4002
G_loss: 3.03

Iter: 185000
D loss: 0.5614
G_loss: 2.813

Iter: 186000
D loss: 0.4339
G_loss: 2.695

Iter: 187000
D loss: 0.4366
G_loss: 3.178

Iter: 188000
D loss: 0.4254
G_loss: 2.8

Iter: 189000
D loss: 0.4011
G_loss: 2.805

Iter: 190000
D loss: 0.3461
G_loss: 3.089

Iter: 191000
D loss: 0.4336
G_loss: 2.698

Iter: 192000
D loss: 0.381
G_loss: 2.889

Iter: 193000
D loss: 0.381
G_loss: 2.839

Iter: 194000
D loss: 0.4994
G_loss: 2.565

Iter: 195000
D loss: 0.3866
G_loss: 3.228

Iter: 196000
D loss: 0.447
G_loss: 2.54

Iter: 197000
D loss: 0.4347
G_loss: 2.974

Iter: 198000
D loss: 0.4382
G_loss: 2.952

Iter: 199000
D loss: 0.4991
G_loss: 2.791

Iter: 200000
D loss: 0.4671
G_loss: 2.857

Iter: 201000
D loss: 0.4219
G_loss: 3.127

Iter: 202000
D loss: 0.4649
G_loss: 2.819

Iter: 203000
D loss: 0.4194
G_loss: 2.831

Iter: 204000
D loss: 0.3203
G_loss: 3.034

Iter: 205000
D loss: 0.4903
G_loss: 2.609

Iter: 206000
D loss: 0.4918
G_loss: 3.091

Iter: 207000
D los

Iter: 376000
D loss: 0.2926
G_loss: 3.402

Iter: 377000
D loss: 0.3579
G_loss: 3.027

Iter: 378000
D loss: 0.3307
G_loss: 3.286

Iter: 379000
D loss: 0.3569
G_loss: 3.323

Iter: 380000
D loss: 0.3769
G_loss: 3.176

Iter: 381000
D loss: 0.2395
G_loss: 3.081

Iter: 382000
D loss: 0.3862
G_loss: 3.261

Iter: 383000
D loss: 0.3951
G_loss: 3.144

Iter: 384000
D loss: 0.176
G_loss: 3.684

Iter: 385000
D loss: 0.3269
G_loss: 3.029

Iter: 386000
D loss: 0.291
G_loss: 3.464

Iter: 387000
D loss: 0.1834
G_loss: 3.629

Iter: 388000
D loss: 0.1598
G_loss: 3.496

Iter: 389000
D loss: 0.2546
G_loss: 3.611

Iter: 390000
D loss: 0.3678
G_loss: 2.992

Iter: 391000
D loss: 0.1874
G_loss: 3.452

Iter: 392000
D loss: 0.2704
G_loss: 2.993

Iter: 393000
D loss: 0.2211
G_loss: 3.341

Iter: 394000
D loss: 0.3005
G_loss: 3.172

Iter: 395000
D loss: 0.3762
G_loss: 3.89

Iter: 396000
D loss: 0.3607
G_loss: 2.76

Iter: 397000
D loss: 0.3465
G_loss: 3.364

Iter: 398000
D loss: 0.2993
G_loss: 3.222

Iter: 399000
D 

Iter: 568000
D loss: 0.2683
G_loss: 2.95

Iter: 569000
D loss: 0.2294
G_loss: 3.171

Iter: 570000
D loss: 0.2483
G_loss: 3.139

Iter: 571000
D loss: 0.3145
G_loss: 2.793

Iter: 572000
D loss: 0.2064
G_loss: 3.281

Iter: 573000
D loss: 0.2039
G_loss: 3.469

Iter: 574000
D loss: 0.1974
G_loss: 3.346

Iter: 575000
D loss: 0.2608
G_loss: 3.043

Iter: 576000
D loss: 0.2977
G_loss: 3.104

Iter: 577000
D loss: 0.1711
G_loss: 3.302

Iter: 578000
D loss: 0.225
G_loss: 3.048

Iter: 579000
D loss: 0.2431
G_loss: 2.886

Iter: 580000
D loss: 0.297
G_loss: 3.187

Iter: 581000
D loss: 0.2901
G_loss: 3.245

Iter: 582000
D loss: 0.2462
G_loss: 2.979

Iter: 583000
D loss: 0.2423
G_loss: 3.284

Iter: 584000
D loss: 0.1745
G_loss: 3.627

Iter: 585000
D loss: 0.387
G_loss: 2.883

Iter: 586000
D loss: 0.2659
G_loss: 3.171

Iter: 587000
D loss: 0.1617
G_loss: 3.538

Iter: 588000
D loss: 0.1601
G_loss: 3.284

Iter: 589000
D loss: 0.1648
G_loss: 3.14

Iter: 590000
D loss: 0.2406
G_loss: 3.588

Iter: 591000
D l

Iter: 760000
D loss: 0.2729
G_loss: 3.55

Iter: 761000
D loss: 0.1812
G_loss: 3.162

Iter: 762000
D loss: 0.1905
G_loss: 3.677

Iter: 763000
D loss: 0.1788
G_loss: 3.565

Iter: 764000
D loss: 0.2277
G_loss: 3.319

Iter: 765000
D loss: 0.2115
G_loss: 3.591

Iter: 766000
D loss: 0.1696
G_loss: 3.892

Iter: 767000
D loss: 0.2761
G_loss: 3.44

Iter: 768000
D loss: 0.1272
G_loss: 3.624

Iter: 769000
D loss: 0.1941
G_loss: 2.987

Iter: 770000
D loss: 0.151
G_loss: 3.88

Iter: 771000
D loss: 0.2883
G_loss: 3.846

Iter: 772000
D loss: 0.1235
G_loss: 3.671

Iter: 773000
D loss: 0.1935
G_loss: 3.698

Iter: 774000
D loss: 0.3969
G_loss: 3.806

Iter: 775000
D loss: 0.1576
G_loss: 4.208

Iter: 776000
D loss: 0.2489
G_loss: 3.734

Iter: 777000
D loss: 0.3095
G_loss: 3.073

Iter: 778000
D loss: 0.3787
G_loss: 3.777

Iter: 779000
D loss: 0.274
G_loss: 3.777

Iter: 780000
D loss: 0.2544
G_loss: 3.341

Iter: 781000
D loss: 0.1771
G_loss: 3.335

Iter: 782000
D loss: 0.213
G_loss: 3.209

Iter: 783000
D lo

Iter: 952000
D loss: 0.1853
G_loss: 4.052

Iter: 953000
D loss: 0.1568
G_loss: 4.267

Iter: 954000
D loss: 0.2473
G_loss: 3.848

Iter: 955000
D loss: 0.1935
G_loss: 3.398

Iter: 956000
D loss: 0.1153
G_loss: 4.242

Iter: 957000
D loss: 0.2894
G_loss: 3.541

Iter: 958000
D loss: 0.2526
G_loss: 3.722

Iter: 959000
D loss: 0.1921
G_loss: 3.725

Iter: 960000
D loss: 0.1468
G_loss: 4.32

Iter: 961000
D loss: 0.2058
G_loss: 3.469

Iter: 962000
D loss: 0.2457
G_loss: 3.207

Iter: 963000
D loss: 0.1368
G_loss: 3.243

Iter: 964000
D loss: 0.1626
G_loss: 3.524

Iter: 965000
D loss: 0.2024
G_loss: 3.664

Iter: 966000
D loss: 0.2339
G_loss: 3.703

Iter: 967000
D loss: 0.09176
G_loss: 3.814

Iter: 968000
D loss: 0.1933
G_loss: 3.449

Iter: 969000
D loss: 0.1257
G_loss: 3.649

Iter: 970000
D loss: 0.1197
G_loss: 3.889

Iter: 971000
D loss: 0.09248
G_loss: 4.303

Iter: 972000
D loss: 0.1157
G_loss: 4.018

Iter: 973000
D loss: 0.2459
G_loss: 3.745

Iter: 974000
D loss: 0.1363
G_loss: 3.794

Iter: 9750

## After checking images in `out` dir we see that the GAN [mode collapsed](http://aiden.nibali.org/blog/2017-01-18-mode-collapse-gans/)
Fix for that is to let discriminator see ground truth in mini batches

** Here's a YouTube of all 1000 images at 100ms delay **

In [1]:
# from IPython.display import HTML

# # Youtube
# HTML('<iframe width="560" height="315" src="https://www.youtube.com/embed/S_f2qV2_U00?rel=0&amp;controls=0&amp;showinfo=0" frameborder="0" allowfullscreen></iframe>')