In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plot
from tensorflow.examples.tutorials.mnist import input_data

In [2]:
batch_size = 32
X_dim = 784
c_dim = 10
h_dim = 1024
d_steps = 5
lamda_cls = 0.1
lamda_rec = 0.1

In [3]:
mnist = input_data.read_data_sets('../../MNIST_data/', one_hot=True)

Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz


In [4]:
def plot_images(samples):
    fig = plot.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plot.subplot(gs[i])
        plot.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plot.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [5]:
def discriminator(X, reuse=False):
    with tf.variable_scope("discriminator", reuse=reuse):
        h1 = tf.layers.dense(X, 1024, tf.nn.relu, name="disc_common")

        out_disc = tf.layers.dense(h1, 1, tf.nn.sigmoid, name="disc_disc")
        out_class = tf.layers.dense(h1, c_dim, tf.nn.softmax, name="disc_class")

        return out_disc, out_class

In [6]:
def generator(x, c, reuse=False):
    with tf.variable_scope("generator", reuse=reuse):
        inp = tf.concat([x, c], axis=1)
        h1 = tf.layers.dense(inp, 1024, tf.nn.relu, name="gen_h")
        out = tf.layers.dense(h1, 784, tf.nn.sigmoid, name="gen_out")

        return out

In [7]:
real_image = tf.placeholder(dtype=tf.float32, shape=[None, 784])
real_labels = tf.placeholder(dtype=tf.float32, shape=[None, c_dim])
fake_labels = tf.placeholder(dtype=tf.float32, shape=[None, c_dim])
alpha = tf.placeholder(dtype=tf.float32, shape=[None, 1])

In [8]:
fake_image = generator(real_image, fake_labels, False)
real_disc, real_class = discriminator(real_image)
fake_disc, fake_class = discriminator(fake_image, True)
rec_image = generator(fake_image, real_labels, True)
interpolated = alpha * real_image + (1 - alpha) * fake_image
int_disc, int_cls = discriminator(interpolated, True)

In [9]:
gen_loss_fake = tf.reduce_mean(fake_disc)
gen_loss_rec = tf.reduce_mean(tf.abs(real_image - rec_image))
gen_loss_class = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_class, labels=fake_labels))

gen_loss = gen_loss_fake + lamda_rec * gen_loss_rec + lamda_cls * gen_loss_class

In [10]:
disc_loss_real = tf.reduce_mean(real_disc)
disc_loss_cls = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_class, labels=real_labels))
disc_loss_fake = tf.reduce_mean(fake_disc)

disc_loss = disc_loss_real + disc_loss_fake + lamda_cls * disc_loss_cls
grads = tf.gradients(int_disc, [interpolated])
grad_penalty = tf.reduce_mean(tf.square(tf.norm(grads[0], ord=2) - 1.0))

disc_grad_penalty = 10 * grad_penalty

In [11]:
def generate_fake_label(batch_size):
    idx = np.random.randint(0, 10)
    c = np.zeros([batch_size, c_dim])
    c[range(batch_size), idx] = 1
    return c

In [12]:
def get_alpha(batch_size):
    return np.random.rand(batch_size, 1)

In [13]:
all_vars = tf.trainable_variables()
generator_vars = [var for var in all_vars if var.name.startswith('gen')]
discriminator_vars = [var for var in all_vars if var.name.startswith('disc')]

d_optim = tf.train.AdamOptimizer(0.001).minimize(disc_loss, var_list=discriminator_vars)
d_grad_pen = tf.train.AdamOptimizer(0.001).minimize(disc_grad_penalty, var_list=discriminator_vars)

g_optim = tf.train.AdamOptimizer(0.001).minimize(gen_loss, var_list=generator_vars)

In [14]:
init = tf.global_variables_initializer()
sess = tf.InteractiveSession()

sess.run(init)

In [15]:
for epoch in range(10):
    for it in range(50000 // 32):
        for i in range(d_steps):
            x_batch, y_batch = mnist.train.next_batch(32)
            _, d_loss = sess.run([d_optim, disc_loss], feed_dict={
                real_image: x_batch,
                real_labels: y_batch,
                fake_labels: generate_fake_label(32),
            })
            
            sess.run([d_grad_pen], feed_dict={
                real_image: x_batch,
                real_labels: y_batch,
                fake_labels: generate_fake_label(32),
                alpha: get_alpha(32)
            })
            
        x_batch, y_batch = mnist.train.next_batch(32)
        _, g_loss = sess.run([g_optim, gen_loss], feed_dict={
            real_image: x_batch,
            real_labels: y_batch,
            fake_labels: generate_fake_label(32),
        })
        if it % 100 == 0:
            print('Epoch: {}, Iteration: {}: G: {} ; D: {}'.format(epoch, it, g_loss, d_loss))
        

#             if it % 1000 == 0:
#                 x_batch, y_batch = mnist.train.next_batch(16)
        
#                 c = generate_fake_label(16)
#                 print(c[0])
#                 samples = sess.run(fake_image, feed_dict={
#                     real_image: x_batch,
#                     fake_labels: c
#                 })

#                 fig = plot_images(samples)
#                 plot.show()
#                 plot.close(fig)


Epoch: 0, Iteration: 0: G: 0.1722393035888672 ; D: 0.5331546664237976
Epoch: 0, Iteration: 100: G: 0.10187432169914246 ; D: 0.06894289702177048
Epoch: 0, Iteration: 200: G: 0.09634210914373398 ; D: 0.06784749776124954
Epoch: 0, Iteration: 300: G: 0.08900151401758194 ; D: 0.0697457492351532
Epoch: 0, Iteration: 400: G: 0.08864057064056396 ; D: 0.06923242658376694
Epoch: 0, Iteration: 500: G: 0.08821352571249008 ; D: 0.06893336027860641
Epoch: 0, Iteration: 600: G: 0.0871540755033493 ; D: 0.06766389310359955
Epoch: 0, Iteration: 700: G: 0.07781060039997101 ; D: 0.06948808580636978
Epoch: 0, Iteration: 800: G: 0.0864066630601883 ; D: 0.06948162615299225
Epoch: 0, Iteration: 900: G: 0.08663106709718704 ; D: 0.06974824517965317
Epoch: 0, Iteration: 1000: G: 0.07768114656209946 ; D: 0.06820397824048996
Epoch: 0, Iteration: 1100: G: nan ; D: nan


KeyboardInterrupt: 