In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plot
from tensorflow.examples.tutorials.mnist import input_data
from helper import *

In [2]:
batch_size = 32
X_dim = 784
c_dim = 10
h_dim = 1024
d_steps = 5
lamda_cls = 0.1
lamda_rec = 0.1

In [3]:
mnist = input_data.read_data_sets('../../MNIST_data/', one_hot=True)

Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz


In [4]:
def plot_images(samples):
    fig = plot.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plot.subplot(gs[i])
        plot.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plot.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [5]:
def discriminator(x, reuse=False):
    with tf.variable_scope("discriminator"):
        if reuse:
            tf.get_variable_scope().reuse_variables()
        else:
            assert tf.get_variable_scope().reuse is False
        
        x = tf.reshape(x, [-1, 28, 28, 1])
        print(x.get_shape())
        
        x = lrelu(conv2d(x, 64, kernel_size=4, strides=[1, 2, 2, 1], name='disc_conv_1'))
        print(x.get_shape())
        x = lrelu(conv2d(x, 128, kernel_size=4, strides=[1, 2, 2, 1], name='disc_conv_2'))
        print(x.get_shape())
        x = lrelu(conv2d(x, 256, kernel_size=4, strides=[1, 2, 2, 1], name='disc_conv_3'))
        print(x.get_shape())
        
        x = conv2d(x, 1+10, kernel_size=1, strides=[1, 1, 1, 1], name='disc_conv_4')
        print(x.get_shape())
        x = tf.reshape(tf.reduce_mean(x,axis=[1, 2]), [-1, 1 + 10])
        print(x.get_shape())
        src = x[:,0]
        cls = x[:,1:]
        return src, cls

In [6]:
def generator(x, c, reuse=False):
    with tf.variable_scope("generator"):
        if reuse:
            tf.get_variable_scope().reuse_variables()
        else:
            assert tf.get_variable_scope().reuse is False
        
        c = tf.reshape(c, [-1, 1, 1, 10])
        
        x = tf.reshape(x, [-1, 28, 28, 1])
        x = tf.concat([x, c], axis=3)
        print(x.get_shape())
        
        x = relu(instance_norm(conv2d(x, 64, kernel_size=7, strides=[1, 1, 1, 1], 
                                      name='gen_ds_conv1'), 'in1_1'))
        print(x.get_shape())
        x = relu(instance_norm(conv2d(x, 128, kernel_size=4,strides=[1, 2, 2, 1], 
                                      name='gen_ds_conv2'), 'in1_2'))
        print(x.get_shape())
        x = relu(instance_norm(conv2d(x, 256, kernel_size=4, strides=[1, 2, 2, 1], 
                                      name='gen_ds_conv3'), 'in1_3'))
        print(x.get_shape())
        
        x = relu(instance_norm(conv2d(x, 256, kernel_size=3, strides=[1, 1, 1, 1], 
                                      name='gen_bn_conv1'), 'in2_1'))
        print(x.get_shape())
        x = relu(instance_norm(conv2d(x, 256, kernel_size=3, strides=[1, 1, 1, 1], 
                                      name='gen_bn_conv2'), 'in2_2'))
        print(x.get_shape())
        x = relu(instance_norm(conv2d(x, 256, kernel_size=3, strides=[1, 1, 1, 1], 
                                      name='gen_bn_conv3'), 'in2_3'))
        print(x.get_shape())
        
        x = relu(instance_norm(deconv_2d(x, [-1, 14, 14, 128], kernel_size=4, strides=[1, 2, 2, 1], 
                                        name='gen_us_deconv1'), 'in3_1'))
        print(x.get_shape())
        x = relu(instance_norm(deconv_2d(x, [-1, 28, 28, 64], kernel_size=4, strides=[1, 2, 2, 1], 
                                        name='gen_us_deconv2'), 'in3_2'))
        print(x.get_shape())
        x = tanh(deconv_2d(x, [-1, 28, 28, 1], kernel_size=7, strides=[1, 1, 1, 1], 
                          name='gen_us_dwconv3'))
        print(x.get_shape())
        
        return x

In [7]:
real_image = tf.placeholder(dtype=tf.float32, shape=[None, 784])
real_labels = tf.placeholder(dtype=tf.float32, shape=[None, c_dim])
fake_labels = tf.placeholder(dtype=tf.float32, shape=[None, c_dim])
alpha = tf.placeholder(dtype=tf.float32, shape=[None, 1])

In [8]:
print('fake_img')
fake_image = generator(real_image, fake_labels, False)
print('real_disc')
real_disc, real_class = discriminator(real_image)
print('fake_disc')
fake_disc, fake_class = discriminator(fake_image, True)
print('rec_image')
rec_image = generator(fake_image, real_labels, True)
print('interpolated')
interpolated = alpha * real_image + (1 - alpha) * fake_image
print('disc_int')
int_disc, int_cls = discriminator(interpolated, True)

fake_img


ValueError: Dimension 1 in both shapes must be equal, but are 28 and 1 for 'generator/concat' (op: 'ConcatV2') with input shapes: [?,28,28,1], [?,1,1,10], [] and with computed input tensors: input[2] = <3>.

In [None]:
gen_loss_fake = tf.reduce_mean(fake_disc)
gen_loss_rec = tf.reduce_mean(tf.abs(real_image - rec_image))
gen_loss_class = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_class, labels=fake_labels))

gen_loss = gen_loss_fake + lamda_rec * gen_loss_rec + lamda_cls * gen_loss_class

In [None]:
disc_loss_real = tf.reduce_mean(real_disc)
disc_loss_cls = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_class, labels=real_labels))
disc_loss_fake = tf.reduce_mean(fake_disc)

grads = tf.gradients(int_disc, [interpolated])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1,]))
grad_penalty = tf.reduce_mean(tf.square(slopes - 1.))
    
disc_loss = disc_loss_real + disc_loss_fake + 10 * grad_penalty + lamda_cls * disc_loss_cls

In [None]:
def generate_fake_label(batch_size):
    idx = np.random.randint(0, 10)
    c = np.zeros([batch_size, c_dim])
    c[range(batch_size), idx] = 1
    print(c.shape)
    return c

In [None]:
def get_alpha(batch_size):
    return np.random.rand(batch_size, 1)

In [None]:
all_vars = tf.trainable_variables()
generator_vars = [var for var in all_vars if var.name.startswith('gen')]
discriminator_vars = [var for var in all_vars if var.name.startswith('disc')]

g_optim = tf.train.AdamOptimizer(0.001).minimize(gen_loss, var_list=generator_vars)

d_optim = tf.train.AdamOptimizer(0.001).minimize(disc_loss, var_list=discriminator_vars)

In [None]:
init = tf.global_variables_initializer()
sess = tf.InteractiveSession()

sess.run(init)

In [None]:
for epoch in range(10):
    for it in range(50000 // 32):
        for i in range(d_steps):
            x_batch, y_batch = mnist.train.next_batch(32)
            print(x_batch.shape)
            _, d_loss = sess.run([d_optim, disc_loss], feed_dict={
                real_image: x_batch,
                real_labels: y_batch,
                fake_labels: generate_fake_label(32),
                alpha: get_alpha(32)
            })
            
        x_batch, y_batch = mnist.train.next_batch(32)
        _, g_loss = sess.run([g_optim, gen_loss], feed_dict={
            real_image: x_batch,
            real_labels: y_batch,
            fake_labels: generate_fake_label(32),
        })
        if it % 100 == 0:
            print('Epoch: {}, Iteration: {}: G: {} ; D: {}'.format(epoch, it, g_loss, d_loss))
        

            if it % 1000 == 0:
                x_batch, y_batch = mnist.train.next_batch(16)
        
                c = generate_fake_label(16)
                print(c[0])
                samples = sess.run(fake_image, feed_dict={
                    real_image: x_batch,
                    fake_labels: c
                })

                fig = plot_images(samples)
                plot.show()
                plot.close(fig)
