In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import os
import imageio

In [2]:
from tensorflow.examples.tutorials.mnist import input_data 

mnist = input_data.read_data_sets('./MNIST_data/')

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ./MNIST_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ./MNIST_data/train-labels-idx1-ubyte.gz
Extracting ./MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


In [3]:
batch_size = 100
z_dim = 100

OUTPUT_DIR = './output/'

X = tf.placeholder(dtype = tf.float32,shape=[None,28,28,1],name = 'X')
# size of picture in MINIST is 28*28*1
noise = tf.placeholder(dtype = tf.float32,shape=[None,z_dim],name = 'noise')
is_training = tf.placeholder(dtype = tf.bool,name = 'is_training')

def leakyrelu(x,leak = 0.2):
    return tf.maximum(x,leak*x)

# Discriminator

In [4]:
# Discriminator
def discriminator(image,reuse = None, is_training = is_training):#reuse: Discriminator and Generator share parameters
    momentum = 0.5
    with tf.variable_scope('discriminator',reuse = reuse):
        h0 = leakyrelu(tf.layers.conv2d(image,kernel_size=5, filters=64,strides = 2,padding = 'SAME'))
        
        h1 = tf.layers.conv2d(h0,kernel_size=5,filters=128, strides=2, padding='SAME')
        h1 = leakyrelu(tf.contrib.layers.batch_norm(h1,is_training=is_training,decay=momentum))
        
        h2 = tf.layers.conv2d(h1,kernel_size=5,filters=256, strides=2, padding='SAME')
        h2 = leakyrelu(tf.contrib.layers.batch_norm(h2,is_training=is_training,decay=momentum))
        
        h3 = tf.layers.conv2d(h2,kernel_size=5,filters=512, strides=2, padding='SAME')
        h3 = leakyrelu(tf.contrib.layers.batch_norm(h3,is_training=is_training,decay=momentum))
        
        # fully connected layer
        h4 = tf.contrib.layers.flatten(h3)
        h4 = tf.layers.dense(h4,units=1)
        
        return h4

# Generator

In [5]:
# Generator
def generator(z,is_training=is_training):
    momentum=0.9
    with tf.variable_scope('generator',reuse=None):
        d = 3
        h0 = tf.layers.dense(z,units=d*d*512)
        h0 = tf.reshape(h0,shape=[-1,d,d,512])
        h0 = tf.nn.relu(tf.contrib.layers.batch_norm(h0,is_training=is_training,decay=momentum))
        
        h1 = tf.layers.conv2d_transpose(h0,kernel_size=5,filters=256,strides=2,padding='SAME')
        h1 = tf.nn.relu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum))
        
        h2 = tf.layers.conv2d_transpose(h1,kernel_size=5,filters=128, strides=2, padding='SAME')
        h2 = tf.nn.relu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum))
        
        h3 = tf.layers.conv2d_transpose(h2,kernel_size=5,filters=64, strides=2, padding='SAME')
        h3 = tf.nn.relu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum))
        
        h4 = tf.layers.conv2d_transpose(h3,kernel_size=5,filters=1, strides=1, padding='VALID',activation=tf.nn.tanh, name='g')
        return h4

# Loss Function

In [None]:
g = generator(noise)
d_real_logits = discriminator(X)
d_fake_logits = discriminator(g,reuse=True)

In [9]:
loss_d_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real_logits),logits=d_real_logits))
loss_d_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake_logits),logits=d_fake_logits))
loss_d = loss_d_real+loss_d_fake
loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_fake_logits),logits=d_fake_logits))

# Optimize Function

In [None]:
vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]
# make sure mean and variance in Batch Normalization is correctly recorded
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optimizer_d = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5).minimize(loss_d,var_list=vars_d)
    optimizer_g = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5).minimize(loss_g,var_list=vars_g)

In [None]:
# join multiple images
def join(images):
    if isinstance(images,list):
        images = np.array(images)
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    m = np.ones((images.shape[1]*n_plots+n_plots+1,images.shape[2]*n_plots+n_plots+1))*0.5
    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i*n_plots +j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
                  1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
    return m

# Training

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
z_samples = np.random.uniform(-1,1,[batch_size,z_dim]).astype(np.float32)
samples = []
loss = {'d':[],'g':[]}

for i in range(50000):
    n = np.random.uniform(-1,1,[batch_size,z_dim]).astype(np.float32)
    batch = mnist.train.next_batch(batch_size=batch_size)[0]
    batch = np.reshape(batch,[-1,28,28,1])
    batch = (batch - 0.5) * 2
    
    d_loss,g_loss = sess.run([loss_d,loss_g],feed_dict={X:batch,noise:n,is_training:True})
    loss['d'].append(d_loss)
    loss['g'].append(g_loss)
    
    sess.run(optimizer_d,feed_dict={X:batch,noise:n,is_training:True})
    sess.run(optimizer_g,feed_dict={X:batch,noise:n,is_training:True})
    sess.run(optimizer_g,feed_dict={X:batch,noise:n,is_training:True})
    
    if 1%1000 ==0:
        print(i,d_loss,g_loss)
        generated_imgs=sess.run(g, feed_dict={noise: z_samples, is_training: False})
        generated_imgs=(generated_imgs + 1) / 2
        print(generated_imgs.shape)
        imgs = [img[:, :, 0] for img in gen_imgs]
        gen_imgs = montage(imgs)
        plt.axis('off')
        plt.imshow(gen_imgs, cmap='gray')
        plt.savefig(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % i))
        plt.show()
        samples.append(gen_imgs)
        
plt.plot(loss['d'], label='Discriminator')
plt.plot(loss['g'], label='Generator')
plt.legend(loc='upper right')
plt.savefig('Loss.png')
plt.show()
imageio.mimsave(os.path.join(OUTPUT_DIR, 'samples.gif'), samples, fps=5)