In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

In [2]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)
    
    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    return fig

def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1.0 / tf.sqrt(in_dim/2.0)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

In [3]:
X = tf.placeholder(tf.float32, shape=[None, 784])

D_W1 = tf.Variable(xavier_init([784, 128]), name='D_W1')
D_b1 = tf.Variable(tf.zeros(shape=[128]), name='D_b1')

D_W2 = tf.Variable(xavier_init([128, 1]), name='D_W2')
D_b2 = tf.Variable(tf.zeros(shape=[1]), name='D_b2')

theta_D = [D_W1, D_W2, D_b1, D_b2]

In [4]:
Z = tf.placeholder(tf.float32, shape=[None, 100])

G_W1 = tf.Variable(xavier_init([100, 128]), name='G_W1')
G_b1 = tf.Variable(tf.zeros(shape=[128]), name='G_b1')

G_W2 = tf.Variable(xavier_init([128, 784]), name='G_W2')
G_b2 = tf.Variable(tf.zeros(shape=[784]), name='G_b2')

theta_G = [G_W1, G_W2, G_b1, G_b2]

In [6]:
def sample_Z(m, n):                                                             
    return np.random.uniform(-1., 1., size=[m, n])

def generator(Z):                                                               
    G_h1 = tf.nn.relu( tf.matmul(Z, G_W1)+G_b1 )                                
    G_log_prob = tf.matmul(G_h1, G_W2)+G_b2                                   
    G_prob = tf.nn.sigmoid(G_log_prob)                                                     
    return G_prob                                                               
                                                                        
def discriminator(X):                                                           
    D_h1 = tf.nn.relu( tf.matmul(X, D_W1)+D_b1 )         
    D_logit = tf.matmul(D_h1, D_W2)+D_b2                                      
    D_prob = tf.nn.sigmoid(D_logit)                                             
                                                                                
    return D_prob, D_logit 

In [7]:
G_sample = generator(Z)                                                         
D_real, D_logit_real = discriminator(X)                                         
D_fake, D_logit_fake = discriminator(G_sample)

In [8]:
D_loss_real = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake 

G_loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

In [14]:
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)          
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

In [15]:
np.set_printoptions(precision=3, suppress=True)
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

mb_size = 128                                                                   
Z_dim = 100  
k = 1

if not os.path.exists('out/'):                                                  
    os.makedirs('out/')  

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
    sess.run(init_op)
    
    # Restore variables from disk.
    if os.path.isfile('model.meta'):
        # Restore variables from disk.
        saver.restore(sess, './model')
        print('Model restored.')
        
    # Do some work with the model.                                               
    i = 0                                                                                                                                                        
    for it in range(1000000):                                                       
        if it % 2000 == 0:                                                          
            samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})        

            fig = plot(samples)                                                     
            plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')  
            i += 1                                                                  
            plt.close(fig)                                                          
        
        for step in range(k):
            X_mb, _ = mnist.train.next_batch(mb_size)                                   
            _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
        _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})

        if it % 10000 == 0:                                                          
            print('Iter:{}'.format(it), end=' ')       
            #print('D loss:\n', D_loss_curr[:4])                             
            print('G_loss: {}'.format(G_loss_curr))                              
            
    # Save the variables to disk.
    save_path = saver.save(sess, 'model')
    print('Model saved in file: {}'.format(save_path))

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Iter:0 G_loss: 2.2089664936065674
Iter:10000 G_loss: 4.761900424957275
Iter:20000 G_loss: 2.649146556854248
Iter:30000 G_loss: 2.6667377948760986
Iter:40000 G_loss: 2.868349313735962
Iter:50000 G_loss: 3.03558349609375
Iter:60000 G_loss: 3.073139190673828
Iter:70000 G_loss: 3.570329427719116
Iter:80000 G_loss: 2.701941967010498
Iter:90000 G_loss: 2.5218796730041504
Iter:100000 G_loss: 3.2024054527282715
Iter:110000 G_loss: 2.662503480911255
Iter:120000 G_loss: 3.133331298828125
Iter:130000 G_loss: 3.053745746612549
Iter:140000 G_loss: 3.3171310424804688
Iter:150000 G_loss: 3.196390151977539
Iter:160000 G_loss: 3.2536442279815674
Iter:170000 G_loss: 3.504542827606201
Iter:180000 G_loss: 3.3438353538513184
Iter:190000 G_loss: 3.1465351581573486
Iter:200000 G_loss: 3.4816670417785645
Iter:210000 G_