In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

In [2]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

In [3]:
# Discriminator
X = tf.placeholder(tf.float32, shape=[None, 784])

D_W1 = tf.Variable(xavier_init([784, 128]))
D_b1 = tf.Variable(tf.zeros(shape=[128]))

D_W2 = tf.Variable(xavier_init([128, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]

In [4]:
# Generator
Z = tf.placeholder(tf.float32, shape=[None, 100])

G_W1 = tf.Variable(xavier_init([100, 128]))
G_b1 = tf.Variable(tf.zeros(shape=[128]))

G_W2 = tf.Variable(xavier_init([128, 784]))
G_b2 = tf.Variable(tf.zeros(shape=[784]))

theta_G = [G_W1, G_W2, G_b1, G_b2]

In [5]:
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])


def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob


def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)

    return D_prob, D_logit


def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [None]:
G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)

# D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# G_loss = -tf.reduce_mean(tf.log(D_fake))

# Alternative losses:
# -------------------
D_loss_real = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))

D_loss = D_loss_real + D_loss_fake

G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

mb_size = 128
Z_dim = 100

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

with tf.device("/gpu:0"):
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    if not os.path.exists('out/'):
        os.makedirs('out/')

    i = 0

    for it in range(1000000):
        if it % 1000 == 0:
            samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})

            fig = plot(samples)
            plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
            i += 1
            plt.close(fig)

        X_mb, _ = mnist.train.next_batch(mb_size)

        _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
        _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})

        if it % 1000 == 0:
            print('Iter: {}'.format(it))
            print('D loss: {:.4}'. format(D_loss_curr))
            print('G_loss: {:.4}'.format(G_loss_curr))
            print()

Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz
Iter: 0
D loss: 1.295
G_loss: 2.643

Iter: 1000
D loss: 0.004728
G_loss: 8.202

Iter: 2000
D loss: 0.0668
G_loss: 5.993

Iter: 3000
D loss: 0.04547
G_loss: 5.443

Iter: 4000
D loss: 0.1222
G_loss: 5.964

Iter: 5000
D loss: 0.1843
G_loss: 5.182

Iter: 6000
D loss: 0.2679
G_loss: 4.738

Iter: 7000
D loss: 0.3952
G_loss: 3.696

Iter: 8000
D loss: 0.3171
G_loss: 4.581

Iter: 9000
D loss: 0.5797
G_loss: 3.397

Iter: 10000
D loss: 0.3963
G_loss: 3.444

Iter: 11000
D loss: 0.4453
G_loss: 2.925

Iter: 12000
D loss: 0.4056
G_loss: 3.016

Iter: 13000
D loss: 0.4776
G_loss: 2.591

Iter: 14000
D loss: 0.5528
G_loss: 2.14

Iter: 15000
D loss: 0.3888
G_loss: 3.111

Iter: 16000
D loss: 0.5058
G_loss: 3.114

Iter: 17000
D loss: 0.6089
G_loss: 2.869

Iter: 18000
D loss: 0.5533
G_loss: 2.4

Iter: 189000
D loss: 0.4042
G_loss: 3.015

Iter: 190000
D loss: 0.4036
G_loss: 2.957

Iter: 191000
D loss: 0.6271
G_loss: 3.505

Iter: 192000
D loss: 0.4837
G_loss: 3.097

Iter: 193000
D loss: 0.3326
G_loss: 3.019

Iter: 194000
D loss: 0.3584
G_loss: 3.465

Iter: 195000
D loss: 0.4611
G_loss: 2.684

Iter: 196000
D loss: 0.4098
G_loss: 3.498

Iter: 197000
D loss: 0.4252
G_loss: 2.975

Iter: 198000
D loss: 0.4292
G_loss: 3.014

Iter: 199000
D loss: 0.5691
G_loss: 3.057

Iter: 200000
D loss: 0.3986
G_loss: 3.321

Iter: 201000
D loss: 0.4018
G_loss: 3.143

Iter: 202000
D loss: 0.4193
G_loss: 2.819

Iter: 203000
D loss: 0.3113
G_loss: 2.986

Iter: 204000
D loss: 0.4187
G_loss: 2.723

Iter: 205000
D loss: 0.3697
G_loss: 3.043

Iter: 206000
D loss: 0.336
G_loss: 3.164

Iter: 207000
D loss: 0.5095
G_loss: 3.257

Iter: 208000
D loss: 0.3814
G_loss: 2.786

Iter: 209000
D loss: 0.3432
G_loss: 2.883

Iter: 210000
D loss: 0.4857
G_loss: 3.061

Iter: 211000
D loss: 0.3856
G_loss: 3.025

Iter: 212000

Iter: 381000
D loss: 0.3552
G_loss: 3.284

Iter: 382000
D loss: 0.3879
G_loss: 3.209

Iter: 383000
D loss: 0.3126
G_loss: 3.087

Iter: 384000
D loss: 0.283
G_loss: 3.344

Iter: 385000
D loss: 0.3081
G_loss: 3.152

Iter: 386000
D loss: 0.2783
G_loss: 2.889

Iter: 387000
D loss: 0.3942
G_loss: 3.049

Iter: 388000
D loss: 0.4197
G_loss: 3.429

Iter: 389000
D loss: 0.289
G_loss: 3.089

Iter: 390000
D loss: 0.3312
G_loss: 3.18

Iter: 391000
D loss: 0.2963
G_loss: 2.752

Iter: 392000
D loss: 0.2797
G_loss: 3.185

Iter: 393000
D loss: 0.263
G_loss: 3.855

Iter: 394000
D loss: 0.3689
G_loss: 3.014

Iter: 395000
D loss: 0.2723
G_loss: 2.821

