In [4]:
import tensorflow as tf

# tfe = tf.contrib.eager
# tf.enable_eager_execution()

import os, time, itertools, imageio, pickle
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.gridspec as gridspec
import platform

In [5]:
#load dataset
def load_pickle(f):
    version = platform.python_version_tuple()
    if version[0] == '2':
        return  pickle.load(f)
    elif version[0] == '3':
        return  pickle.load(f, encoding='latin1')
    raise ValueError("invalid python version: {}".format(version))
    
def load_CIFAR_batch(filename):
    """ load single batch of cifar """
    with open(filename, 'rb') as f:
        datadict = load_pickle(f)
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float32")
        Y = np.array(Y)
        return X, Y

def load_CIFAR10(ROOT):
    """ load all of cifar """
    xs = []
    ys = []
    for b in range(1,6):
        f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)    
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte
#dataset
cifar10_dir = 'cifar-10-batches-py'
train_image, train_label, test_image, _ = load_CIFAR10(cifar10_dir)
#Normalizing
train_image /= 255
test_image /= 255

In [6]:
def show_result(num_epoch, show = False, save = False, path = 'result.png'):
    test_images = sess.run(G_sample, {z: fixed_z_, y: fixed_y_})

    size_figure_grid = 10
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)

    for k in range(size_figure_grid*size_figure_grid):
        i = k // size_figure_grid
        j = k % size_figure_grid
        ax[i, j].cla()
        ax[i, j].imshow(np.reshape(test_images[k], (32, 32, 3)))

    label = 'Epoch {0}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
    x = range(len(hist['D_losses']))

    y1 = hist['D_losses']
    y2 = hist['G_losses']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

def preprocess_img(x):
    return 2 * x - 1.0

def leaky_relu(x, alpha=0.2, name="lrelu"):
    condition = tf.less(x, 0)
    return tf.where(condition, alpha * x, x)
    
def sample_noise(batch_size, dim):
    z = np.random.randn(batch_size, dim)
    return z

def get_session():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    return session

In [7]:
# label preprocess
onehot = np.eye(10)
temp_z_ = np.random.normal(0, 1, (10, 100))
fixed_z_ = temp_z_
fixed_y_ = np.zeros((10, 1))

for i in range(9):
    fixed_z_ = np.concatenate([fixed_z_, temp_z_], 0)
    temp = np.ones((10,1)) + i
    fixed_y_ = np.concatenate([fixed_y_, temp], 0)

fixed_y_ = onehot[fixed_y_.astype(np.int32)].squeeze()

In [8]:
#dataset
# class MNIST(object):
#     def __init__(self, batch_size, shuffle=False):
#         train, _ = tf.keras.datasets.mnist.load_data()
#         X, y = train
#         X = X.astype(np.float32)/255.
#         X = X.reshape((X.shape[0], -1))
#         self.X, self.y = X, y
#         self.batch_size, self.shuffle = batch_size, shuffle
    
#     def __iter__(self):
#         N, B = self.X.shape[0], self.batch_size
#         idxs = np.range(N)
#         if self.shuffle:
#             np.random.shuffle(idxs)
#         return iter((self.X[i:i+B], self.y[i:i+B]) for i in range(0, N, B))

In [11]:
# generator and discriminator
def generator(z, y):
    with tf.variable_scope('generator'):
        w_init = tf.contrib.layers.xavier_initializer()
        combine = tf.concat([z, y], 1)
        dense1 = tf.layers.dense(combine, 4*4*1024, activation=tf.nn.relu, kernel_initializer=w_init)
        dense2 = tf.reshape(dense1, (-1, 4, 4, 1024))
        conv1 = tf.layers.conv2d_transpose(dense2, filters=512, kernel_size=3, strides=2, padding='SAME', activation=tf.nn.relu, kernel_initializer=w_init)
        conv2 = tf.layers.conv2d_transpose(conv1, filters=256, kernel_size=3, strides=2, padding='SAME', activation=tf.nn.relu, kernel_initializer=w_init)
        output = tf.layers.conv2d_transpose(conv2, filters=3, kernel_size=3, strides=2, padding='SAME', activation=tf.nn.tanh, kernel_initializer=w_init)
        return output

def discriminator(x, y):
    with tf.variable_scope('discriminator'):
        w_init = tf.contrib.layers.xavier_initializer()
        x = tf.reshape(x, (100, 32, 32, 3))
        y = tf.reshape(y,(100, 1, 1, 10))
        combine = tf.concat([x, y * tf.ones([x.shape[0], x.shape[1], x.shape[2], y.shape[3]])], 3) # combine: (100, 32, 32, 13)
        conv1 = tf.layers.conv2d(combine, filters=64, kernel_size=3, strides=1, padding='SAME', kernel_initializer=w_init) #(100,32,32,64)
        lrelu1 = leaky_relu(conv1)
        conv2 = tf.layers.conv2d(lrelu1, filters=128, kernel_size=4, strides=2, padding='SAME', kernel_initializer=w_init)#(16,16,128)
        lrelu2 = leaky_relu(conv2)
        conv3 = tf.layers.conv2d(lrelu2, filters=256, kernel_size=4, strides=2, padding='SAME', kernel_initializer=w_init)#(8,8,256)
        lrelu3 = leaky_relu(conv3)
        conv4 = tf.layers.conv2d(lrelu3, filters=512, kernel_size=4, strides=2, padding='SAME', kernel_initializer=w_init)#(4,4,512)
        flatten = tf.layers.flatten(conv4)
        output = tf.layers.dense(flatten, 1, kernel_initializer=w_init)
        return output
        

In [12]:
def gan_loss(logits_real, logits_fake):
    label_real = tf.ones_like(logits_real)
    label_fake = tf.zeros_like(logits_fake)
    
    D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label_real, logits=logits_real))
    D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label_fake, logits=logits_fake))
    D_loss = D_loss_real + D_loss_fake
    
    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_fake), logits=logits_fake))
    return D_loss, G_loss

def get_solver(learning_rate=0.0002, beta1=0.5):
    D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
    G_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
    return D_solver, G_solver

In [13]:
#put it together
tf.reset_default_graph()

# number of images for each batch
batch_size = 100


x = tf.placeholder(tf.float32, [None, 32, 32, 3])
y = tf.placeholder(tf.float32, [None, 10])
z = tf.placeholder(tf.float32, [None, 100])

G_sample = generator(z, y)

with tf.variable_scope("") as scope:
    logits_real = discriminator(x, y)
    scope.reuse_variables()
    logits_fake = discriminator(G_sample, y)

D_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'discriminator')
G_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'generator')

D_solver, G_solver = get_solver()
D_loss, G_loss = gan_loss(logits_real, logits_fake)

D_train_step = D_solver.minimize(D_loss, var_list=D_var)
G_train_step = G_solver.minimize(G_loss, var_list=G_var)
D_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS,'discriminator')
G_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS,'generator')

Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use keras.layers.conv2d_transpose instead.
Instructions for updating:
Use keras.layers.conv2d instead.
Instructions for updating:
Use keras.layers.flatten instead.


In [None]:
#run a gan
def run_a_gan(sess, D_train_step, D_loss, G_train_step, G_loss, D_extra_step,G_extra_step,\
              batch_size=100, num_epoch=100):
    #(x_train, y_train), (x_test, _) = tf.keras.datasets.mnist.load_data()
    x_train = train_image
    y_train = train_label
    
    train_hist = {}
    train_hist['D_losses'] = []
    train_hist['G_losses'] = []
    
    root = 'cifar_CGAN_result/'
    model = 'conv/'
    if not os.path.isdir(root):
        os.mkdir(root)
    if not os.path.isdir(root + model):
        os.mkdir(root + model)
    
    for epoch in range(num_epoch):
        for iter in range(len(x_train) // batch_size):
            D_losses = []
            G_losses = []
            #update discriminator
            x_ = x_train[iter * batch_size:(iter + 1) * batch_size]
            y_ = y_train[iter * batch_size:(iter + 1) * batch_size]
            y_ = y_[:,np.newaxis]
            y_ = onehot[y_.astype(np.int32)].squeeze()
            
            z_ = sample_noise(batch_size, 100)
            _, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x:x_, y:y_, z:z_})
            D_losses.append(D_loss_curr)
            #update generator
            y_ = np.random.randint(0, 10, (batch_size, 1))
            y_ = onehot[y_.astype(np.int32)].squeeze()
            z_ = sample_noise(batch_size, 100)
            _, G_loss_curr = sess.run([G_train_step, G_loss], feed_dict={z:z_, x:x_, y:y_})
            G_losses.append(G_loss_curr)
        print("Epoch:{}, D:{:.4}, G:{:.4}".format(epoch+1, D_loss_curr, G_loss_curr))
  
        PATH = root + model + str(epoch+1) + '.png' 

        show_result(epoch+1, save=True, path=PATH)
        train_hist['D_losses'].append(np.mean(D_losses))
        train_hist['G_losses'].append(np.mean(G_losses))
        
    with open(root + model + 'train_hist.pkl', 'wb') as f:
        pickle.dump(train_hist, f)

    show_train_hist(train_hist, save=True, path=root + model + 'train_hist.png')

    images = []
    for e in range(num_epoch):
        img_name = root + model + str(e + 1) + '.png'
        images.append(imageio.imread(img_name))
    imageio.mimsave(root + model + 'generation_animation.gif', images, fps=5)


In [None]:
with get_session() as sess:
    sess.run(tf.global_variables_initializer())
    run_a_gan(sess,D_train_step,D_loss,G_train_step,G_loss,D_extra_step,G_extra_step)
    sess.close()