In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plot

In [2]:
train_data = pd.read_csv('train.csv')
train_data = train_data.drop('label', axis=1)
train_data = train_data.as_matrix() / 255

In [3]:
IMAGE_SIZE = 28
BATCH_SIZE = 32
NUM_ITERATIONS = 10000
HIDDEN_LAYERS_GEN = 128
LEARNING_RATE = 1e-4

In [12]:
# Loading Batches
epochs_completed = 0
index_in_epoch = 0
num_examples = train_data.shape[0]


def next_batch(batch_size):
    global train_data
    global index_in_epoch
    global epochs_completed
    
    start = index_in_epoch
    index_in_epoch += batch_size
    
    if index_in_epoch > num_examples:
        # finished epoch
        epochs_completed += 1
        # shuffle the data
        perm = np.arange(num_examples)
        np.random.shuffle(perm)
        train_data = train_data[perm]
        # start next epoch
        start = 0
        index_in_epoch = batch_size
        assert batch_size <= num_examples
    end = index_in_epoch
    return train_data[start:end]

In [5]:
def get_sample_z(size=(1, 100)):
    return np.random.normal(size=size)


def display_image(image_data):
    img = image_data.reshape([IMAGE_SIZE, IMAGE_SIZE])
    plot.axis('off')
    plot.imshow(img, cmap=matplotlib.cm.binary)
    plot.show()

In [6]:
Z_in = tf.placeholder(tf.float32, shape=[None, 100])
image_in = tf.placeholder(tf.float32, shape=[None,
                                             IMAGE_SIZE * IMAGE_SIZE])


def generator(z):
    with tf.variable_scope('generator'):
        h = tf.layers.dense(z, HIDDEN_LAYERS_GEN)
        h = tf.minimum(h, 0.01)
        logits = tf.layers.dense(h, IMAGE_SIZE * IMAGE_SIZE)
        output = tf.nn.sigmoid(logits)
        return output


def discriminator(image, reuse=False):
    with tf.variable_scope('discriminator', reuse=reuse):
        h = tf.layers.dense(image, HIDDEN_LAYERS_GEN)
        h = tf.minimum(h, 0.01)
        logits = tf.layers.dense(h, 1)
        return logits

In [7]:
gen_sample = generator(Z_in)

discriminator_data = discriminator(image_in)
discriminator_model = discriminator(gen_sample, reuse=True)

In [8]:
discriminator_loss = tf.reduce_mean(discriminator_data) - \
                     tf.reduce_sum(discriminator_model)

generator_loss = -tf.reduce_mean(discriminator_model)

In [9]:
all_vars = tf.trainable_variables()
generator_vars = [var for var in all_vars if var.name.startswith('generator')]
discriminator_vars = [var for var in all_vars if var.name.startswith('discriminator')]

discriminator_optimize = tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE)\
    .minimize(discriminator_loss, var_list=discriminator_vars)
generator_optimize = tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE)\
    .minimize(generator_loss, var_list=generator_vars)

clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in discriminator_vars]

In [10]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [13]:
for i in range(NUM_ITERATIONS):
    for a in range(5):
        image_batch = next_batch(BATCH_SIZE)
        _, disc_loss, _ = sess.run([discriminator_optimize, discriminator_loss, clip_D], 
                                   feed_dict={Z_in: get_sample_z([BATCH_SIZE, 100]),
                                              image_in: image_batch})
        
    _, gen_loss = sess.run([generator_optimize, generator_loss], 
                           feed_dict={Z_in: get_sample_z([BATCH_SIZE, 100])})
    
    if i % 100:
        print('Step {} => Discriminator: {} | Generator: {}'.format(i, disc_loss, gen_loss))
        
        if i % 1000 == 0:
            sample = sess.run(gen_sample, feed_dict={Z_in: get_sample_z()})
            display_image(sample)

Step 1 => Discriminator: -1.1637146472930908 | Generator: -0.043234728276729584
Step 2 => Discriminator: -1.7513519525527954 | Generator: -0.06284071505069733
Step 3 => Discriminator: -2.403160333633423 | Generator: -0.08430031687021255
Step 4 => Discriminator: -3.1358673572540283 | Generator: -0.10936858505010605
Step 5 => Discriminator: -3.994913339614868 | Generator: -0.13778886198997498
Step 6 => Discriminator: -4.952796936035156 | Generator: -0.16920188069343567
Step 7 => Discriminator: -6.009079456329346 | Generator: -0.20426663756370544
Step 8 => Discriminator: -7.214214324951172 | Generator: -0.24364250898361206
Step 9 => Discriminator: -8.484606742858887 | Generator: -0.28547078371047974
Step 10 => Discriminator: -9.832845687866211 | Generator: -0.3291471004486084
Step 11 => Discriminator: -11.187735557556152 | Generator: -0.37284645438194275
Step 12 => Discriminator: -12.608689308166504 | Generator: -0.417092502117157
Step 13 => Discriminator: -14.00932788848877 | Generator: 

Step 110 => Discriminator: -54.369571685791016 | Generator: -1.7201052904129028
Step 111 => Discriminator: -54.349098205566406 | Generator: -1.7220323085784912
Step 112 => Discriminator: -54.321800231933594 | Generator: -1.7229167222976685
Step 113 => Discriminator: -54.30638122558594 | Generator: -1.721365213394165
Step 114 => Discriminator: -54.441341400146484 | Generator: -1.7238311767578125
Step 115 => Discriminator: -54.39077377319336 | Generator: -1.7219679355621338
Step 116 => Discriminator: -54.41047286987305 | Generator: -1.7272963523864746
Step 117 => Discriminator: -54.4925651550293 | Generator: -1.7236649990081787
Step 118 => Discriminator: -54.375118255615234 | Generator: -1.7260887622833252
Step 119 => Discriminator: -54.514190673828125 | Generator: -1.7259008884429932
Step 120 => Discriminator: -54.5282096862793 | Generator: -1.7286385297775269
Step 121 => Discriminator: -54.573062896728516 | Generator: -1.727268099784851
Step 122 => Discriminator: -54.6378059387207 | Ge

Step 217 => Discriminator: -64.29344177246094 | Generator: -2.0235538482666016
Step 218 => Discriminator: -63.75193405151367 | Generator: -2.0450305938720703
Step 219 => Discriminator: -64.95553588867188 | Generator: -2.0232555866241455
Step 220 => Discriminator: -64.27729797363281 | Generator: -2.037165880203247
Step 221 => Discriminator: -64.60663604736328 | Generator: -2.040452480316162
Step 222 => Discriminator: -64.66670989990234 | Generator: -2.044635772705078
Step 223 => Discriminator: -64.3510971069336 | Generator: -2.053497552871704
Step 224 => Discriminator: -64.67524719238281 | Generator: -2.0382368564605713
Step 225 => Discriminator: -64.88170623779297 | Generator: -2.058718681335449
Step 226 => Discriminator: -65.3595962524414 | Generator: -2.0681447982788086
Step 227 => Discriminator: -65.50130462646484 | Generator: -2.0619513988494873
Step 228 => Discriminator: -64.76737976074219 | Generator: -2.070699453353882
Step 229 => Discriminator: -65.2901611328125 | Generator: -2

Step 323 => Discriminator: -76.36387634277344 | Generator: -2.434129238128662
Step 324 => Discriminator: -76.60416412353516 | Generator: -2.439876079559326
Step 325 => Discriminator: -75.96553802490234 | Generator: -2.4032092094421387
Step 326 => Discriminator: -76.27760314941406 | Generator: -2.4049556255340576
Step 327 => Discriminator: -76.7311782836914 | Generator: -2.455289363861084
Step 328 => Discriminator: -77.18022155761719 | Generator: -2.4233834743499756
Step 329 => Discriminator: -77.58645629882812 | Generator: -2.448122024536133
Step 330 => Discriminator: -76.67543029785156 | Generator: -2.443269729614258
Step 331 => Discriminator: -76.94512939453125 | Generator: -2.4391274452209473
Step 332 => Discriminator: -77.25393676757812 | Generator: -2.4472451210021973
Step 333 => Discriminator: -77.29739379882812 | Generator: -2.4614813327789307
Step 334 => Discriminator: -78.67655181884766 | Generator: -2.4415762424468994
Step 335 => Discriminator: -78.03227996826172 | Generator:

Step 430 => Discriminator: -86.43709564208984 | Generator: -2.743553876876831
Step 431 => Discriminator: -87.10638427734375 | Generator: -2.7528085708618164
Step 432 => Discriminator: -87.05824279785156 | Generator: -2.741741895675659
Step 433 => Discriminator: -86.7693099975586 | Generator: -2.7530012130737305
Step 434 => Discriminator: -86.73371887207031 | Generator: -2.7460246086120605
Step 435 => Discriminator: -87.42711639404297 | Generator: -2.740833044052124
Step 436 => Discriminator: -87.57511138916016 | Generator: -2.724519968032837
Step 437 => Discriminator: -87.27001953125 | Generator: -2.8008694648742676
Step 438 => Discriminator: -88.63658142089844 | Generator: -2.8002562522888184
Step 439 => Discriminator: -88.30267333984375 | Generator: -2.7762627601623535
Step 440 => Discriminator: -88.10730743408203 | Generator: -2.7682039737701416
Step 441 => Discriminator: -87.68501281738281 | Generator: -2.7751619815826416
Step 442 => Discriminator: -87.35478210449219 | Generator: -

KeyboardInterrupt: 