# 课时15 cGAN(conditional GAN)代码实现

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
tf.__version__

'1.13.1'

## 1. 导入MNIST数据集

In [19]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('E:\SoftWare_Installing\Pycharm\Pycharm WorkPlace\GAN生成对抗网络入门与实战\data\MNIST', one_hot=True)

Extracting E:\SoftWare_Installing\Pycharm\Pycharm WorkPlace\GAN生成对抗网络入门与实战\data\MNIST\train-images-idx3-ubyte.gz
Extracting E:\SoftWare_Installing\Pycharm\Pycharm WorkPlace\GAN生成对抗网络入门与实战\data\MNIST\train-labels-idx1-ubyte.gz
Extracting E:\SoftWare_Installing\Pycharm\Pycharm WorkPlace\GAN生成对抗网络入门与实战\data\MNIST\t10k-images-idx3-ubyte.gz
Extracting E:\SoftWare_Installing\Pycharm\Pycharm WorkPlace\GAN生成对抗网络入门与实战\data\MNIST\t10k-labels-idx1-ubyte.gz


## 2. 定义模型搭建需要的各个组件

In [20]:
# 接收输入
def get_inputs(noise_dim, image_height, image_width, image_depth):
    inputs_real = tf.placeholder(dtype=tf.float32, 
                                 shape=[None, image_height, image_width, image_depth],
                                 name='inputs_real')
    inputs_noise = tf.placeholder(dtype=tf.float32, 
                                  shape=[None, noise_dim],
                                  name='inputs_noise')
    # 这里condition_label.shape = [None, 10]是因为one_hot=True
    condition_label = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='condition_label')
    return inputs_real, inputs_noise, condition_label

In [21]:
# 定义生成器
def generator(noise_img, output_dim, condition_label, is_train=True):
    with tf.variable_scope(name_or_scope='generator', reuse=(not is_train)):
        # 第一层为全连接层，将噪声数据的维度从100 x 1 ===> 4 x 4 x 512
        # 在传入网络之前需要将noise和condition合并起来
        noise_img_ = tf.concat(values=[noise_img, condition_label], axis=1)
        print(noise_img_.shape, noise_img.shape, condition_label.shape)
        layer_1 = tf.layers.dense(noise_img_, 4*4*512)
        layer_1 = tf.reshape(layer_1, [-1, 4, 4, 512])
        # batch_normalization
        layer_1 = tf.layers.batch_normalization(layer_1, training=is_train)
        layer_1 = tf.nn.relu(layer_1)
        # dropout
        layer_1 = tf.nn.dropout(layer_1, keep_prob=0.8)
        
        # 4 x 4 x 512 ===> 7 x 7 x 256
        layer_2 = tf.layers.conv2d_transpose(inputs=layer_1, filters=256, 
                                             kernel_size=4,
                                             strides=1, padding='valid')
        layer_2 = tf.layers.batch_normalization(layer_2, training=is_train)
        layer_2 = tf.nn.relu(layer_2)
        layer_2 = tf.nn.dropout(layer_2, keep_prob=0.8)
        
        # 7 x 7 x 256 ===> 14 x 14 x 128
        layer_3 = tf.layers.conv2d_transpose(inputs=layer_2, filters=128, 
                                             kernel_size=3,
                                             strides=2, padding='same')
        layer_3 = tf.layers.batch_normalization(layer_3, training=is_train)
        layer_3 = tf.nn.relu(layer_3)
        layer_3 = tf.nn.dropout(layer_3, keep_prob=0.8)
        
        # 14 x 14 x 128 ===> 28 x 28 x 1
        logits = tf.layers.conv2d_transpose(inputs=layer_3, filters=output_dim,
                                            kernel_size=3,
                                            strides=2, padding='same')
        outputs = tf.tanh(logits)
        return outputs

In [22]:
# 定义判别器
def discriminator(inputs_img, condition_label, reuse=False, alpha=0.01):
    with tf.variable_scope(name_or_scope='discriminator', reuse=reuse):
        # 图像在输入到网络之前需要与condition_label进行合并，因此必须先进行flatten，
        # 然后才好进行合并，最后合并完成了才好传入到网络
        flatten_0 = tf.reshape(inputs_img, (-1, 28*28*1))
        flatten_1 = tf.concat(values=[flatten_0, condition_label], axis=1)
        layer_0 = tf.layers.dense(flatten_1, 28*28*1)
        layer_0_ = tf.reshape(layer_0, [-1, 28, 28, 1])
        
        
        # [28, 28, 1] ===> [14, 14, 128]
        # 第一层不加BN
        layer_1 = tf.layers.conv2d(inputs=layer_0_, filters=128, kernel_size=3,
                                   strides=2, padding='same')
        layer_1 = tf.maximum(alpha*layer_1, layer_1)
        layer_1 = tf.nn.dropout(layer_1, keep_prob=0.8)
        
        # [14, 14, 128] ===> [7, 7, 256]
        layer_2 = tf.layers.conv2d(inputs=layer_1, filters=256, kernel_size=3,
                                   strides=2, padding='same')
        layer_2 = tf.layers.batch_normalization(layer_2, training=True)
        layer_2 = tf.maximum(alpha*layer_2, layer_2)
        layer_2 = tf.nn.dropout(layer_2, keep_prob=0.8)
        
        # [7, 7, 256] ===> [4, 4, 512]
        layer_3 = tf.layers.conv2d(inputs=layer_2, filters=512, kernel_size=3,
                                   strides=2, padding='same')
        layer_3 = tf.layers.batch_normalization(layer_3, training=True)
        layer_3 = tf.maximum(alpha*layer_3, layer_3)
        layer_3 = tf.nn.dropout(layer_3, keep_prob=0.8)
        
        # [4, 4, 512] ===> [4*4*512, 1]
        flatten = tf.reshape(layer_3, (-1, 4*4*512))
        logits = tf.layers.dense(flatten, 1)
        outputs = tf.sigmoid(logits)
        
        return logits, outputs

In [23]:
# 获取loss值
def get_loss(inputs_real, inputs_noise, condition_label, image_depth, smooth=0.1):
    g_outputs = generator(inputs_noise, image_depth, condition_label, is_train=True)
    d_logits_real, d_output_real = discriminator(inputs_real, condition_label)
    d_logits_fake, d_output_fake = discriminator(g_outputs, condition_label, reuse=True)
    
    # 计算loss值
    g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                    labels=tf.ones_like(d_output_fake)*(1-smooth)))
    
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                                         labels=tf.ones_like(d_output_real)*(1-smooth)))
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                         labels=tf.zeros_like(d_output_fake)))
    d_loss = tf.add(d_loss_real, d_loss_fake)
    
    return g_loss, d_loss

In [24]:
# 设置优化器
def get_optimizer(g_loss, d_loss, beta1=0.4, learning_rate=0.001):
    train_vars = tf.trainable_variables()
    g_vars = [var for var in train_vars if var.name.startswith('generator')]
    d_vars = [var for var in train_vars if var.name.startswith('discriminator')]
    
    # optimizers
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        g_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)
        d_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
    
    return g_opt, d_opt

In [25]:
# 单独创建一个绘图函数用于后面绘制每一步生成器产生的图像
def plot_image(samples):
    # 由于生成器采用的激活函数是[-1, 1]之间，为了能够更好的绘图
    # 需要将其转换到[0, 1]之间，因此有下面一行的操作
    samples = (samples + 1) / 2
    fig, axes = plt.subplots(nrows=1, ncols=25, sharex=True, 
                             sharey=True, figsize=(50, 2))
    for img, ax in zip(samples, axes):
        ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad=0)

In [26]:
# 展示生成器的生成结果
def show_generator_output(sess, n_images, inputs_noise, output_dim, condition_label):
    noise_shape = inputs_noise.get_shape().as_list()[-1]
    # 生成噪声图片
    example_noise = np.random.uniform(-1, 1, size=[n_images, noise_shape])
    condition_label_test = mnist.train.labels[50: 75]
    
    samples = sess.run(generator(inputs_noise, output_dim, condition_label, False),
                       feed_dict={inputs_noise: example_noise, 
                                  condition_label: condition_label_test})
    result = np.squeeze(samples, -1)
    return result

In [27]:
# 定义参数
batch_size = 64
noise_size = 100
epochs = 5
n_samples = 25
learning_rate = 0.01
beta1 = 0.4

In [35]:
def train(noise_size, data_shape, batch_size, n_samples):
    # 存储loss
    losses = []
    step = 0
    
    inputs_real, inputs_noise, condition_label = get_inputs(noise_size, data_shape[1], data_shape[2], data_shape[3])
    g_loss, d_loss = get_loss(inputs_real, inputs_noise, condition_label, data_shape[-1])
    g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, beta1, learning_rate)
    
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for e in range(epochs):
            for batch_i in range(mnist.train.num_examples // batch_size):
                step += 1
                batch_imgase_, batch_labels = mnist.train.next_batch(batch_size)
                batch_images = batch_imgase_.reshape((batch_size, 
                                                      data_shape[1],
                                                      data_shape[2], 
                                                      data_shape[3]))
                # [-1, 1]
                batch_images = batch_images*2 - 1
                
                # generator的输入噪声
                batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))

                # Run optimizers
                _ = sess.run(g_train_opt, feed_dict={inputs_real:batch_images,
                                                     inputs_noise:batch_noise,
                                                     condition_label:batch_labels})
                _ = sess.run(d_train_opt, feed_dict={inputs_real:batch_images,
                                                     inputs_noise:batch_noise,
                                                     condition_label:batch_labels})
                
                if step % 101 == 0:
                    saver.save(sess, '../tf_saver_files/class_6_cGANS/generator.ckpt')
                    train_loss_d = d_loss.eval({inputs_real:batch_images,
                                                inputs_noise:batch_noise,
                                                condition_label:batch_labels})
                    train_loss_g = g_loss.eval({inputs_real:batch_images,
                                                inputs_noise:batch_noise,
                                                condition_label:batch_labels})
                    losses.append((train_loss_d, train_loss_g))
                    
                    # 显示图片
                    samples = show_generator_output(sess, n_samples, 
                                                    inputs_noise, data_shape[-1],
                                                    condition_label)
                    plot_image(samples)
                    
                    print('Epoch is %i/%i'%(e+1, epochs),
                          'Discriminator Loss is %.3f'%(train_loss_d),
                          ', Generator Loss is %.3f'%(train_loss_g))
            saver.save(sess, '../tf_saver_files/class_6_cGANS/generator.ckpt')    

In [36]:
with tf.Graph().as_default():
    train(noise_size, [-1, 28, 28, 1], batch_size, n_samples)

KeyboardInterrupt: 

In [None]:
data_shape = [-1, 28, 28, 1]

In [None]:
inputs_real, inputs_noise, condition_label = get_inputs(noise_size, data_shape[1], data_shape[2], data_shape[3])
g_loss, d_loss = get_loss(inputs_real, inputs_noise, condition_label, data_shape[-1])
g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, beta1, learning_rate)

In [None]:
saver = tf.train.Saver()
sess=tf.Session()
saver.restore(sess, tf.train.latest_checkpoint('../tf_saver_files/class_6_cGANS/generator.ckpt'))

In [None]:
samples = show_generator_output(sess, 25, inputs_noise, data_shape[-1], condition_label)
plot_images(samples)

In [None]:
np.argmax(mnist.train.labels[50:75], 1)

In [None]:
# 每次运行下面两句话都会生成标签相同的数字图片，但是由于噪声不同，所以每次产生的图片的样式会有变化
samples = show_generator_output(sess, 25, inputs_noise, data_shape[-1], condition_label)
plot_images(samples)