In [1]:
# package import
import tensorflow as tf
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec


mb_size = 128
z_dim = 100

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)



  from ._conv import register_converters as _register_converters


Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz


In [2]:
# Network constructing

# discriminater net
X = tf.placeholder(tf.float32, shape=[None, 784], name='X')
D_W1 = tf.get_variable(name='D_W1', shape=[784, 128], initializer=tf.contrib.layers.xavier_initializer())
D_b1 = tf.get_variable(name='D_b1', shape=[128], initializer=tf.constant_initializer(0))

D_W2 = tf.get_variable(name='D_W2', shape=[128, 1], initializer=tf.contrib.layers.xavier_initializer())
D_b2 = tf.get_variable(name='D_b2', shape=[1], initializer=tf.constant_initializer(0))

theta_D = [D_W1, D_W2, D_b1, D_b2]


In [3]:
# generator net
z = tf.placeholder(tf.float32, shape=[None, 100], name='z')
G_W1 = tf.get_variable(name='G_W1', shape=[100, 128], initializer=tf.contrib.layers.xavier_initializer())
G_b1 = tf.get_variable(name='G_b1', shape=[128], initializer=tf.constant_initializer(0))

G_W2 = tf.get_variable(name='G_W2', shape=[128, 784], initializer=tf.contrib.layers.xavier_initializer())
G_b2 = tf.get_variable(name='G_b2', shape=[784], initializer=tf.constant_initializer(0))
theta_G = [G_W1, G_W2, G_b1, G_b2]

In [4]:
# Network Definition
def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)
    return G_prob

def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)
    return D_prob, D_logit

G_sample = generator(z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)

In [5]:
# Loss Function
D_loss = -tf.reduce_mean(tf.log(tf.clip_by_value(D_real, 1e-5, 1.0)) + tf.log(tf.clip_by_value(1.0 - D_fake, 1e-5, 1.0)))
G_loss = -tf.reduce_mean(tf.log(tf.clip_by_value(D_fake, 1e-5, 1.0)))

In [6]:
D_optimizer = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_optimizer = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

In [7]:
def get_sample_z(r, c):
    return np.random.uniform(-1.0, 1.0, size=[r, c])

def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4,4)
    gs.update(wspace=0.05, hspace=0.05)
    
    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    return fig

In [8]:
if not os.path.exists('out/'):
    os.makedirs('out/')

In [9]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
i = 0
for it in range(1000000):
    if it % 1000 == 0:
        samples = sess.run(G_sample, feed_dict={z: get_sample_z(16, z_dim)})
        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)
    
    X_mb, _ = mnist.train.next_batch(mb_size)
    _, D_loss_curr = sess.run([D_optimizer, D_loss], feed_dict = {X: X_mb, z: get_sample_z(mb_size, z_dim)})
    _, G_loss_curr = sess.run([G_optimizer, G_loss], feed_dict = {z: get_sample_z(mb_size, z_dim)})
    
    if it % 1000 == 0:
        print("Iter: {}".format(it))
        print("D_loss: {:.4}".format(D_loss_curr))
        print("G_loss: {:.4}".format(G_loss_curr))
        print()

Iter: 0
D_loss: 1.165
G_loss: 2.939

Iter: 1000
D_loss: 0.008623
G_loss: 8.74

Iter: 2000
D_loss: 0.03489
G_loss: 5.302

Iter: 3000
D_loss: 0.05976
G_loss: 4.894

Iter: 4000
D_loss: 0.1774
G_loss: 5.069

Iter: 5000
D_loss: 0.2549
G_loss: 4.406

Iter: 6000
D_loss: 0.1626
G_loss: 4.887

Iter: 7000
D_loss: 0.3028
G_loss: 4.877

Iter: 8000
D_loss: 0.3253
G_loss: 4.273

Iter: 9000
D_loss: 0.4509
G_loss: 3.637

Iter: 10000
D_loss: 0.5162
G_loss: 3.019

Iter: 11000
D_loss: 0.5617
G_loss: 3.166

Iter: 12000
D_loss: 0.6087
G_loss: 3.186

Iter: 13000
D_loss: 0.5523
G_loss: 2.823

Iter: 14000
D_loss: 0.6055
G_loss: 3.007

Iter: 15000
D_loss: 0.556
G_loss: 2.516

Iter: 16000
D_loss: 0.6092
G_loss: 2.595

Iter: 17000
D_loss: 0.6379
G_loss: 2.696

Iter: 18000
D_loss: 0.7844
G_loss: 2.386

Iter: 19000
D_loss: 0.5443
G_loss: 2.266

Iter: 20000
D_loss: 0.5758
G_loss: 2.391

Iter: 21000
D_loss: 0.7801
G_loss: 2.643

Iter: 22000
D_loss: 0.6461
G_loss: 2.2

Iter: 23000
D_loss: 0.6172
G_loss: 2.282

Iter: 

Iter: 195000
D_loss: 0.4098
G_loss: 2.783

Iter: 196000
D_loss: 0.4065
G_loss: 3.094

Iter: 197000
D_loss: 0.404
G_loss: 2.566

Iter: 198000
D_loss: 0.433
G_loss: 3.089

Iter: 199000
D_loss: 0.5087
G_loss: 3.133

Iter: 200000
D_loss: 0.3514
G_loss: 3.177

Iter: 201000
D_loss: 0.3667
G_loss: 3.428

Iter: 202000
D_loss: 0.3819
G_loss: 3.052

Iter: 203000
D_loss: 0.403
G_loss: 3.103

Iter: 204000
D_loss: 0.4435
G_loss: 3.254

Iter: 205000
D_loss: 0.5775
G_loss: 3.113

Iter: 206000
D_loss: 0.4769
G_loss: 3.012

Iter: 207000
D_loss: 0.3725
G_loss: 2.922

Iter: 208000
D_loss: 0.3502
G_loss: 2.941

Iter: 209000
D_loss: 0.3576
G_loss: 2.948

Iter: 210000
D_loss: 0.3971
G_loss: 3.119

Iter: 211000
D_loss: 0.378
G_loss: 3.422

Iter: 212000
D_loss: 0.3348
G_loss: 2.9

Iter: 213000
D_loss: 0.4788
G_loss: 3.057

Iter: 214000
D_loss: 0.4574
G_loss: 2.817

Iter: 215000
D_loss: 0.39
G_loss: 2.772

Iter: 216000
D_loss: 0.4507
G_loss: 3.536

Iter: 217000
D_loss: 0.4406
G_loss: 2.791

Iter: 218000
D_loss

Iter: 387000
D_loss: 0.368
G_loss: 3.869

Iter: 388000
D_loss: 0.372
G_loss: 3.77

Iter: 389000
D_loss: 0.3362
G_loss: 3.561

Iter: 390000
D_loss: 0.2426
G_loss: 3.38

Iter: 391000
D_loss: 0.3103
G_loss: 3.207

Iter: 392000
D_loss: 0.4216
G_loss: 3.421

Iter: 393000
D_loss: 0.2447
G_loss: 3.757

Iter: 394000
D_loss: 0.4192
G_loss: 3.722

Iter: 395000
D_loss: 0.371
G_loss: 3.691

Iter: 396000
D_loss: 0.3383
G_loss: 3.601

Iter: 397000
D_loss: 0.343
G_loss: 3.359

Iter: 398000
D_loss: 0.2761
G_loss: 4.1

Iter: 399000
D_loss: 0.3676
G_loss: 3.395

Iter: 400000
D_loss: 0.4102
G_loss: 3.515

Iter: 401000
D_loss: 0.4342
G_loss: 3.402

Iter: 402000
D_loss: 0.2464
G_loss: 3.825

Iter: 403000
D_loss: 0.29
G_loss: 4.015

Iter: 404000
D_loss: 0.3816
G_loss: 3.465

Iter: 405000
D_loss: 0.3772
G_loss: 3.052

Iter: 406000
D_loss: 0.2597
G_loss: 3.906

Iter: 407000
D_loss: 0.4328
G_loss: 3.761

Iter: 408000
D_loss: 0.2603
G_loss: 3.65

Iter: 409000
D_loss: 0.2824
G_loss: 3.889

Iter: 410000
D_loss: 0

Iter: 579000
D_loss: 0.5216
G_loss: 3.646

Iter: 580000
D_loss: 0.3136
G_loss: 3.673

Iter: 581000
D_loss: 0.3099
G_loss: 3.693

Iter: 582000
D_loss: 0.4237
G_loss: 3.629

Iter: 583000
D_loss: 0.2876
G_loss: 3.211

Iter: 584000
D_loss: 0.2204
G_loss: 3.331

Iter: 585000
D_loss: 0.3386
G_loss: 3.514

Iter: 586000
D_loss: 0.2496
G_loss: 3.48

Iter: 587000
D_loss: 0.2773
G_loss: 3.71

Iter: 588000
D_loss: 0.2872
G_loss: 3.561

Iter: 589000
D_loss: 0.342
G_loss: 3.717

Iter: 590000
D_loss: 0.2713
G_loss: 3.057

Iter: 591000
D_loss: 0.4296
G_loss: 3.6

Iter: 592000
D_loss: 0.3255
G_loss: 3.698

Iter: 593000
D_loss: 0.2335
G_loss: 3.767

Iter: 594000
D_loss: 0.3462
G_loss: 3.561

Iter: 595000
D_loss: 0.263
G_loss: 3.908

Iter: 596000
D_loss: 0.2849
G_loss: 3.679

Iter: 597000
D_loss: 0.3017
G_loss: 3.474

Iter: 598000
D_loss: 0.1878
G_loss: 3.837

Iter: 599000
D_loss: 0.282
G_loss: 3.224

Iter: 600000
D_loss: 0.2744
G_loss: 3.96

Iter: 601000
D_loss: 0.2757
G_loss: 3.252

Iter: 602000
D_loss