In [1]:
import tensorflow as tf

# tfe = tf.contrib.eager
# tf.enable_eager_execution()

import os, time, itertools, imageio, pickle
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.gridspec as gridspec

In [2]:
def show_image(images):
    images = np.reshape(images, [images.shape[0], -1])
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
    
    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]),cmap='gray')
    return

def show_result(num_epoch, show = False, save = False, path = 'result.png'):
    test_images = sess.run(G_sample, {z: fixed_z_, y: fixed_y_})

    size_figure_grid = 10
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)

    for k in range(size_figure_grid*size_figure_grid):
        i = k // size_figure_grid
        j = k % size_figure_grid
        ax[i, j].cla()
        ax[i, j].imshow(np.reshape(test_images[k], (28, 28)), cmap='gray')

    label = 'Epoch {0}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()
        
def preprocess_img(x):
    return 2 * x - 1.0

def leaky_relu(x, alpha):
    condition = tf.less(x, 0)
    return tf.where(condition, alpha * x, x)
    
def sample_noise(batch_size, dim):
    z = np.random.normal(0, 1, (batch_size, dim))
    return z

def get_session():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    return session

In [13]:
# label preprocess
onehot = np.eye(10)
temp_z_ = np.random.normal(0, 1, (10, 100))
fixed_z_ = temp_z_
fixed_y_ = np.zeros((10, 1))

for i in range(9):
    fixed_z_ = np.concatenate([fixed_z_, temp_z_], 0)
    temp = np.ones((10,1)) + i
    fixed_y_ = np.concatenate([fixed_y_, temp], 0)

fixed_y_ = onehot[fixed_y_.astype(np.int32)].squeeze()

In [4]:
#dataset
class MNIST(object):
    def __init__(self, batch_size, shuffle=False):
        train, _ = tf.keras.datasets.mnist.load_data()
        X, y = train
        X = X.astype(np.float32)/255.
        X = X.reshape((X.shape[0], -1))
        self.X, self.y = X, y
        self.batch_size, self.shuffle = batch_size, shuffle
    
    def __iter__(self):
        N, B = self.X.shape[0], self.batch_size
        idxs = np.range(N)
        if self.shuffle:
            np.random.shuffle(idxs)
        return iter((self.X[i:i+B], self.y[i:i+B]) for i in range(0, N, B))

In [27]:
# generator and discriminator
def generator(z, y):
    with tf.variable_scope('generator'):
#         y1 = tf.layers.dense(y, units=1000, activation=tf.nn.relu)
#         z1 = tf.layers.dense(z, units=200, activation=tf.nn.relu)
        w_init = tf.contrib.layers.xavier_initializer()
        combine = tf.concat([z, y], 1)
        dense1 = tf.layers.dense(combine, 128, activation=tf.nn.relu, kernel_initializer=w_init)
        #h = tf.keras.layers.dropout(0.5)
        output = tf.layers.dense(dense1, 784, activation=tf.nn.tanh, kernel_initializer=w_init)
        return output

def discriminator(x, y):
    with tf.variable_scope('discriminator'):
        w_init = tf.contrib.layers.xavier_initializer()
        x1 = tf.contrib.layers.maxout(x, 784)
        y1 = tf.contrib.layers.maxout(y, 10)
        combine = tf.concat([x1, y1], 1)
        dense1 = tf.layers.dense(combine, 128, kernel_initializer=w_init)
        l_relu = leaky_relu(dense1, alpha=0.2)
        output = tf.layers.dense(l_relu, 1, kernel_initializer=w_init)
        return output
        

In [28]:
def gan_loss(logits_real, logits_fake):
    label_real = tf.ones_like(logits_real)
    label_fake = tf.zeros_like(logits_fake)
    
    D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label_real, logits=logits_real))
    D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label_fake, logits=logits_fake))
    D_loss = D_loss_real + D_loss_fake
    
    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_fake), logits=logits_fake))
    return D_loss, G_loss

def get_solver(learning_rate=0.0002, beta1=0.5):
    D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
    G_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
    return D_solver, G_solver

In [29]:
#put it together
tf.reset_default_graph()

# number of images for each batch
batch_size = 100


x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
z = tf.placeholder(tf.float32, [None, 100])

G_sample = generator(z, y)
with tf.variable_scope("") as scope:
    logits_real = discriminator(x, y)
    scope.reuse_variables()
    logits_fake = discriminator(G_sample, y)

D_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'discriminator')
G_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'generator')

D_solver, G_solver = get_solver()
D_loss, G_loss = gan_loss(logits_real, logits_fake)

D_train_step = D_solver.minimize(D_loss, var_list=D_var)
G_train_step = G_solver.minimize(G_loss, var_list=G_var)
D_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS,'discriminator')
G_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS,'generator')

Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.


In [32]:
#run a gan
def run_a_gan(sess, D_train_step, D_loss, G_train_step, G_loss, D_extra_step,G_extra_step,\
              batch_size=100, num_epoch=100):
    #(x_train, y_train), (x_test, _) = tf.keras.datasets.mnist.load_data()
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    x_train = preprocess_img(mnist.train.images)
    y_train = mnist.train.labels
    for epoch in range(num_epoch):
        for iter in range(len(x_train) // batch_size):
            #update discriminator
            x_ = x_train[iter * batch_size:(iter + 1) * batch_size]
            x_ = x_.reshape((x_.shape[0],-1))
            y_ = y_train[iter * batch_size:(iter + 1) * batch_size]
            z_ = sample_noise(batch_size, 100)
            _, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x:x_, y:y_, z:z_})
            
            #update generator
            y_ = np.random.randint(0, 9, (batch_size, 1))
            y_ = onehot[y_.astype(np.int32)].squeeze()
            z_ = sample_noise(batch_size, 100)
            _, G_loss_curr = sess.run([G_train_step, G_loss], feed_dict={z:z_, x:x_, y:y_})
            
        print("Epoch:{}, D:{:.4}, G:{:.4}".format(epoch+1, D_loss_curr, G_loss_curr))
        
        root = 'My_CGAN_result/'
        model = 'maxout1_/'
  
        PATH = root + model + str(epoch+1) + '.png'
        if not os.path.isdir(root):
            os.mkdir(root)
        if not os.path.isdir(root + model):
            os.mkdir(root + model)

        show_result(epoch+1, save=True, path=PATH)


In [None]:
with get_session() as sess:
    sess.run(tf.global_variables_initializer())
    run_a_gan(sess,D_train_step,D_loss,G_train_step,G_loss,D_extra_step,G_extra_step)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Epoch:1, D:0.9081, G:0.9884
Epoch:2, D:1.146, G:0.9925
Epoch:3, D:1.18, G:1.505
Epoch:4, D:1.23, G:1.337
Epoch:5, D:1.232, G:0.8375
Epoch:6, D:1.253, G:0.9206
Epoch:7, D:1.404, G:0.8874
Epoch:8, D:1.059, G:1.283
Epoch:9, D:1.013, G:1.031
Epoch:10, D:1.483, G:0.7436
Epoch:11, D:1.17, G:0.8256
Epoch:12, D:1.32, G:0.6497
Epoch:13, D:1.075, G:0.8659
Epoch:14, D:0.9432, G:1.351
Epoch:15, D:0.7803, G:1.805
Epoch:16, D:0.7497, G:1.92
Epoch:17, D:0.9555, G:1.367
Epoch:18, D:1.084, G:1.01
Epoch:19, D:0.9863, G:1.108
Epoch:20, D:1.042, G:1.054
Epoch:21, D:0.9336, G:1.457
Epoch:22, D:1.147, G:1.112
Epoch:23, D:1.358, G:0.6721
Epoch:24, D:1.043, G:1.039
Epoch:25, D:1.096, G:0.8088
Epoch:26, D:0.9489, G:1.233
Epoch:27, D:1.098, G:1.389
Epoch:28, D:0.8794, G:1.209
Epoch:29, D:0.8883, G:1.427
Epoch:30, D:1.038