In [1]:
from __future__ import print_function, division
import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import models
import utils
import glob
import visualize

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray' 
%load_ext autoreload
%autoreload 2

def get_session():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    return session

  from ._conv import register_converters as _register_converters


KeyboardInterrupt: 

In [None]:
def pre_process(img):
    return img * 2 - 1

def post_process(img):
    return (img + 1) / 2

def sample_noise(batch_size, dim):
    return np.random.uniform(-1, 1, size=(batch_size, dim))
   
def random_t(batch_size):
    noise_t = np.zeros((batch_size, 10))
    noise_t[np.arange(batch_size), np.random.randint(0, 10, size=(batch_size))] += 1
    return noise_t

# Build sorted format t to show out
formal_t = np.zeros((100, 10))
for i in range(10):
    formal_t[i*10:i*10+10, i] += 1

In [None]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./dataset/MNIST_data', one_hot=True)

batch_size = 64

In [None]:
def run_a_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,\
              show_every=250, print_every=50, batch_size=128, num_epoch=10, n_critic=5):
    
    max_iter = mnist.train.num_examples*num_epoch // (batch_size * n_critic)
    for it in range(max_iter):
        if it % show_every == 0:
            noise_z = sample_noise(batch_size, noise_dim)
            samples = sess.run(G_sample, feed_dict={z:noise_z, t:formal_t})
            fig = visualize.show_images(samples)
            plt.show()
            print()
        
        for i in range(n_critic):
            noise_z = sample_noise(batch_size, noise_dim)
            minibatch = np.reshape(mnist.train.next_batch(batch_size)[0], [-1, 28, 28, 1])
            minibatch, mini_t = mnist.train.next_batch(batch_size)
            minibatch = tf.reshape(minibatch, [-1, 28, 28, 1])
            mini_t_bar = random_t(batch_size)
            _, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x:minibatch, z:noise_z, t:mini_t, t_bar:mini_t_bar})
        
        
        noise_z = sample_noise(batch_size, noise_dim)            
        _, G_loss_curr = sess.run([G_train_step, G_loss,], feed_dict={z:noise_z})
        
        if it % print_every == 0:
            print('Iter: {}/{}, D: {:.4}, G:{:.4}'.format(it, max_iter,D_loss_curr, G_loss_curr))
    print('Final images')
    noise_z = sample_noise(100, noise_dim)
    samples = sess.run(G_sample, feed_dict={z:noise_z, t_bar:formal_t})
    
    fig = visualize.show_images(samples, final="WGAN-GP")
    plt.show()

In [None]:
tf.reset_default_graph()

noise_dim = 100

x = tf.placeholder(tf.float32, [None, 28, 28, 1])
z = tf.placeholder(tf.float32, [None, noise_dim])
t = tf.placeholder(tf.float32, [None, 10])
t_false = tf.placeholder(tf.float32, [None, 10])

generator = models.generator_C_WGAN
discriminator = models.discriminator_C_WGAN
#generator = models.generator_C_231
#discriminator = models.discriminator_C_WGAN_231
#discriminator = models.discriminator_C_DCGAN_231

G_sample = generator(z, t, reuse=False)
logits_real = discriminator(x, t, reuse=False)
logits_fake_1 = discriminator(G_sample, t, reuse=False)
logits_fak2_2 = discriminator(x, t_false, reuse=True)

D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')

In [None]:
def wgangp_loss(logits_real, logits_fake_1, logits_fake_2, batch_size, x, G_sample):
    
    D_loss = (tf.reduce_mean(logits_fake_1) + logits_fake_2) / 2 - tf.reduce_mean(logits_real)
    G_loss = -tf.reduce_mean(logits_fake_1)
    
    lam = 10
    shape = tf.concat((tf.shape(x)[0:1], tf.tile([1], [x.shape.ndims - 1])), axis=0)
    eps = tf.random_uniform(shape=shape, minval=0., maxval=1.)
    x_hat = x + eps * (G_sample - x)
    
    with tf.variable_scope('', reuse=True) as scope:
        pred = discriminator(x_hat)
        grad_D_x_hat = tf.gradients(pred, x)[0]
    
    grad_norm = tf.sqrt(tf.reduce_sum(tf.square(grad_D_x_hat), reduction_indices=tf.range(1, x.shape.ndims)))
    #grad_norm = tf.sqrt(tf.reduce_sum(tf.square(grad_D_x_hat), reduction_indices=range(1, x.shape.ndims)))
    grad_pen = lam * tf.reduce_mean((grad_norm - 1.) ** 2)
    
    D_loss += grad_pen
    
    return D_loss, G_loss

def dcgan_loss(logits_real, logits_fake_1, logits_fake_2, batch_size, x, G_smaple):
    
    D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_real), logits=logits_real))
    D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits_fake_1), logits=logits_fake_1)) / 2
    D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits_fake_2), logits=logits_fake_2)) / 2
    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_fake_1), logits=logits_fake_1))
    return D_loss, G_loss
    
#D_loss, G_loss  =dcgan_loss(logits_real, logits_fake_1, logits_fake_2, batch_size, x, G_sample)
D_loss, G_loss  =wgangp_loss(logits_real, logits_fake_1, logits_fake_2, batch_size, x, G_sample)
D_train_step = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5).minimize(D_loss, var_list=D_vars)
G_train_step = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5).minimize(G_loss, var_list=G_vars)
D_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'discriminator')
G_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'generator')

In [None]:
with get_session() as sess:
    sess.run(tf.global_variables_initializer())
    run_a_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step, batch_size=batch_size, num_epoch=50, n_critic=5)