# 课时17 infoGAN代码实现

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
tf.__version__

'1.13.1'

## 1. 导入MNIST数据集

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('E:\SoftWare_Installing\Pycharm\Pycharm WorkPlace\GAN生成对抗网络入门与实战\data\MNIST', one_hot=True)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting E:\SoftWare_Installing\Pycharm\Pycharm WorkPlace\GAN生成对抗网络入门与实战\data\MNIST\train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting E:\SoftWare_Installing\Pycharm\Pycharm WorkPlace\GAN生成对抗网络入门与实战\data\MNIST\train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting E:\SoftWare_Installing\Pycharm\Pycharm WorkPlace\GAN生成对抗网络入门与实战\data\MNIST\t10k-images-idx3-ubyte.gz
Extracting E:\SoftWare_Installing\Pycharm\Pycharm WorkPlace\GAN生成对抗网络入门与实战\data\MNIST\t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


## 2. 定义模型搭建需要的各个组件

In [2]:
# 接收输入
def get_inputs(noise_dim, image_height, image_width, image_depth):
    inputs_real = tf.placeholder(dtype=tf.float32, 
                                 shape=[None, image_height, image_width, image_depth],
                                 name='inputs_real')
    inputs_noise = tf.placeholder(dtype=tf.float32, 
                                  shape=[None, noise_dim],
                                  name='inputs_noise')
    # 这里condition_label.shape = [None, 10]是因为one_hot=True
    # 需要注意的是infoGAN中的condition_label不像cGAN中那样直接从MNIST数据集的标签中获取condition_label了
    # 而是我们自己后面需要生成一个condition_label，然后训练这个condition_label
    # 使得这个condition_label与我们图像之间的互信息增强
    # 二记：
    # 在这里的演示代码里设置的condition_label是考虑控制图片里数字的类别的，因此这里的维度为[None, 10]
    # 当然，infoGAN还可以添加多个condition_label，每个condition_label除了可以控制其产生的数字类别之外
    # 还可以控制例如数字的粗细或者倾斜程度等表征形式，因此如果设置或者添加别的condition_label的时候
    # 其维度不一定是[None, 10]，还可以是例如: [None, 50], [None, 100]等；
    condition_label = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='condition_label')
    return inputs_real, inputs_noise, condition_label

In [16]:
# 定义生成器
def generator(noise_img, output_dim, condition_label, is_train=True):
    with tf.variable_scope(name_or_scope='generator', reuse=(not is_train)):
        # 第一层为全连接层，将噪声数据的维度从100 x 1 ===> 4 x 4 x 512
        # 在传入网络之前需要将noise和condition合并起来
        noise_img_ = tf.concat(values=[noise_img, condition_label], axis=1)
        layer_1 = tf.layers.dense(noise_img_, 4*4*512)
        layer_1 = tf.reshape(layer_1, [-1, 4, 4, 512])
        # batch_normalization
        layer_1 = tf.layers.batch_normalization(layer_1, training=is_train)
        layer_1 = tf.nn.relu(layer_1)
        # dropout
        layer_1 = tf.nn.dropout(layer_1, keep_prob=0.8)
        
        # 4 x 4 x 512 ===> 7 x 7 x 256
        layer_2 = tf.layers.conv2d_transpose(inputs=layer_1, filters=256, 
                                             kernel_size=4,
                                             strides=1, padding='valid')
        layer_2 = tf.layers.batch_normalization(layer_2, training=is_train)
        layer_2 = tf.nn.relu(layer_2)
        layer_2 = tf.nn.dropout(layer_2, keep_prob=0.8)
        
        # 7 x 7 x 256 ===> 14 x 14 x 128
        layer_3 = tf.layers.conv2d_transpose(inputs=layer_2, filters=128, 
                                             kernel_size=3,
                                             strides=2, padding='same')
        layer_3 = tf.layers.batch_normalization(layer_3, training=is_train)
        layer_3 = tf.nn.relu(layer_3)
        layer_3 = tf.nn.dropout(layer_3, keep_prob=0.8)
        
        # 14 x 14 x 128 ===> 28 x 28 x 1
        logits = tf.layers.conv2d_transpose(inputs=layer_3, filters=output_dim,
                                            kernel_size=3,
                                            strides=2, padding='same')
        outputs = tf.tanh(logits)
        return outputs

In [17]:
# 定义判别器(info中的判别器与cGAN中的判别器是不一样的，它不再需要condition_label)
def discriminator(inputs_img, reuse=False, alpha=0.01):
    with tf.variable_scope(name_or_scope='discriminator', reuse=reuse):
        # [28, 28, 1] ===> [14, 14, 128]
        # 第一层不加BN
        layer_1 = tf.layers.conv2d(inputs=inputs_img, filters=128, kernel_size=3,
                                   strides=2, padding='same')
        layer_1 = tf.maximum(alpha*layer_1, layer_1)
        layer_1 = tf.nn.dropout(layer_1, keep_prob=0.8)
        
        # [14, 14, 128] ===> [7, 7, 256]
        layer_2 = tf.layers.conv2d(inputs=layer_1, filters=256, kernel_size=3,
                                   strides=2, padding='same')
        layer_2 = tf.layers.batch_normalization(layer_2, training=True)
        layer_2 = tf.maximum(alpha*layer_2, layer_2)
        layer_2 = tf.nn.dropout(layer_2, keep_prob=0.8)
        
        # [7, 7, 256] ===> [4, 4, 512]
        layer_3 = tf.layers.conv2d(inputs=layer_2, filters=512, kernel_size=3,
                                   strides=2, padding='same')
        layer_3 = tf.layers.batch_normalization(layer_3, training=True)
        layer_3 = tf.maximum(alpha*layer_3, layer_3)
        layer_3 = tf.nn.dropout(layer_3, keep_prob=0.8)
        
        # [4, 4, 512] ===> [4*4*512, 1]
        flatten = tf.reshape(layer_3, (-1, 4*4*512))
        logits = tf.layers.dense(flatten, 1)
        outputs = tf.sigmoid(logits)
        
        return logits, outputs

In [3]:
# 这里是infoGAN的独特之处，还需要定义一个Q网络
# infoGAN中的Q网络目的上是为了通过训练，以增强图像与condition_label之间的互信息
# 在操作上Q网络也是一个判别器，它通过输入生成器生成的图像来判断这个图像的condition_label是否一致
# Q网络由于还是一个判别器，所以可以考虑和判别器进行合并，最后判别器多输出就好了
def get_Q(g_out, reuse=False, alpha=0.01):
    with tf.variable_scope("Q", reuse=reuse):
        # 28 x 28 x 1 to 14 x 14 x 128
        # 第一层不加入BN
        layer1 = tf.layers.conv2d(g_out, 128, 3, strides=2, padding='same')
        layer1 = tf.maximum(alpha * layer1, layer1)
        layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
        
        # 14 x 14 x 128 to 7 x 7 x 256
        layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
        layer2 = tf.layers.batch_normalization(layer2, training=True)
        layer2 = tf.maximum(alpha * layer2, layer2)
        layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
        
        # 7 x 7 x 256 to 4 x 4 x 512
        layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
        layer3 = tf.layers.batch_normalization(layer3, training=True)
        layer3 = tf.maximum(alpha * layer3, layer3)
        layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
        
        # 4 x 4 x 512 to 4*4*512 x 1
        flatten = tf.reshape(layer3, (-1, 4*4*512))
        
        logits = tf.layers.dense(flatten, 10)
        # 在这里的演示代码里设置的condition_label是考虑控制图片里数字的类别的，因此这里的维度为[None, 10]
        # 因此这里的激活函数是softmax多分类激活函数
        outputs = tf.nn.softmax(logits)
        
        return outputs

In [4]:
# 获取loss值
def get_loss(inputs_real, inputs_noise, condition_label, image_depth, smooth=0.1):
    g_outputs = generator(inputs_noise, image_depth, condition_label, is_train=True)
    d_logits_real, d_outputs_real = discriminator(inputs_real)
    d_logits_fake, d_outputs_fake = discriminator(g_outputs, reuse=True)
    q_c = get_Q(g_outputs)
    
    # 计算Loss
    g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, 
                                                                    labels=tf.ones_like(d_outputs_fake)*(1-smooth)))
    
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                                         labels=tf.ones_like(d_outputs_real)*(1-smooth)))
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                         labels=tf.zeros_like(d_outputs_fake)))
    
    d_loss = tf.add(d_loss_real, d_loss_fake)
    # Q网络期望的是它判别出的图片的标签值与我们输入的condition_label是一样的
    # 由此可以看出这个损失本质也就是一个分类损失，也就是交叉熵损失(cross_entropy)
    # 1e-8是为了避免q_c全为0时造成tf.log无穷大
    q_loss = tf.reduce_mean(-tf.reduce_sum(tf.log(q_c + 1e-8) * condition_label, 1))
        
    return g_loss, d_loss, q_loss

In [29]:
# 设置优化器
def get_optimizer(g_loss, d_loss, q_loss, beta1=0.4, learning_rate=0.001):
    train_vars = tf.trainable_variables()
    
    g_vars = [var for var in train_vars if var.name.startswith("generator")]
    d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
    q_vars = [var for var in train_vars if var.name.startswith("Q")]
    
    # Optimizer
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        g_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)
        d_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
        # 这个需要注意的是Q网络优化的变量为g_vars + q_vars
        q_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(q_loss, var_list=g_vars + q_vars)
    
    return g_opt, d_opt, q_opt 

In [21]:
# 单独创建一个绘图函数用于后面绘制每一步生成器产生的图像
def plot_image(samples):
    # 由于生成器采用的激活函数是[-1, 1]之间，为了能够更好的绘图
    # 需要将其转换到[0, 1]之间，因此有下面一行的操作
    samples = (samples + 1) / 2
    fig, axes = plt.subplots(nrows=1, ncols=25, sharex=True, 
                             sharey=True, figsize=(50, 2))
    for img, ax in zip(samples, axes):
        ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad=0)

In [32]:
# 定义参数
batch_size = 64
noise_size = 100
epochs = 7
n_samples = 25
learning_rate = 0.001
beta1 = 0.4

In [40]:
def train(noise_size, data_shape, batch_size, n_samples):
   
    # 存储loss
    losses = []
    steps = 0
    
    inputs_real, inputs_noise, condition_label = get_inputs(noise_size, data_shape[1], data_shape[2], data_shape[3])
    g_loss, d_loss, q_loss = get_loss(inputs_real, inputs_noise, condition_label, data_shape[-1])
    g_train_opt, d_train_opt, q_train_opt= get_optimizer(g_loss, d_loss, q_loss, beta1, learning_rate)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # 迭代epoch
        for e in range(epochs):
            for batch_i in range(mnist.train.num_examples//batch_size):
                steps += 1
                batch_images_, batch_labels = mnist.train.next_batch(batch_size)

                batch_images = batch_images_.reshape((batch_size, data_shape[1], data_shape[2], data_shape[3]))
                batch_images = batch_images*2 -1

                # noise
                batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))
                c_labels = np.random.multinomial(1, 10*[0.1], size=batch_size)

                # run optimizer
                _ = sess.run(g_train_opt, feed_dict={inputs_real: batch_images,
                                                     inputs_noise: batch_noise,
                                                     condition_label: c_labels})
                _ = sess.run(d_train_opt, feed_dict={inputs_real: batch_images,
                                                     inputs_noise: batch_noise,
                                                     condition_label: c_labels})
                _ = sess.run(q_train_opt, feed_dict={inputs_real: batch_images,
                                                     inputs_noise: batch_noise,
                                                     condition_label: c_labels})
                
                if steps % 101 == 0:
                    saver.save(sess, "../tf_saver_files/class_7_of_infoGAN/generator.ckpt")
                    train_loss_d = d_loss.eval({inputs_real: batch_images,
                                                inputs_noise: batch_noise,
                                                condition_label:c_labels})
                    train_loss_g = g_loss.eval({inputs_real: batch_images,
                                                inputs_noise: batch_noise,
                                                condition_label:c_labels})
                    train_loss_q = q_loss.eval({inputs_real: batch_images,
                                                inputs_noise: batch_noise,
                                                condition_label:c_labels})
                    
                    losses.append((train_loss_d, train_loss_g, train_loss_q))
                    # 显示图片
                    c_labels = tf.to_float(c_labels)
                    samples = sess.run(generator(batch_noise, data_shape[-1], c_labels, False))
                    plot_images(samples)
                    print("Epoch {}/{}....".format(e+1, epochs), 
                          "Discriminator Loss: {:.4f}....".format(train_loss_d),
                          "Generator Loss: {:.4f}....". format(train_loss_g))
            saver.save(sess, "../tf_saver_files/class_7_infoGAN/generator.ckpt")

In [41]:
with tf.Graph().as_default():
    train(noise_size, [-1, 28, 28, 1], batch_size, n_samples)

KeyboardInterrupt: 

In [None]:
data_shape =[-1, 28, 28, 1]

In [None]:
inputs_real, inputs_noise, condition_label = get_inputs(noise_size, data_shape[1], data_shape[2], data_shape[3])
g_loss, d_loss, q_loss = get_loss(inputs_real, inputs_noise, condition_label, data_shape[-1])
g_train_opt, d_train_opt, q_train_opt= get_optimizer(g_loss, d_loss, q_loss, beta1, learning_rate)

In [None]:
saver = tf.train.Saver()
sess = tf.Session()
model_file=tf.train.latest_checkpoint('../tf_saver_files/class_7_of_infoGAN/generator.ckpt')
saver.restore(sess, model_file)

In [None]:
batch_size = 25
noise_size =100

In [None]:
batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))
c_labels = np.zeros((batch_size,10))

In [None]:
# 这里制定我们的c_labels的第4个维度的值为1，下面可以看到在第4个维度置为1了之后，代表的时候控制数字2的生成
# 这里需要注意的是condition_label每个维度都是准确的控制某个属性的生成的
# 但是哪个位置对应哪个属性的生成并不是固定不变的，或者说不是一一对应的(比方说第2个维度控制数字2的生成就不是固定的，否则这里也不会要用到第四个维度才能控制数字2的生成)
# 所以condition_label在训练完毕之后，每个人/每次训练完毕之后每个维度控制的属性都是不一样的，只是每个维度准备控制一个属性的生成这个是不变的，是指位置是不固定的或者说以一一对应的而已
c_labels[:, 3] = 1

In [None]:
c_labels = tf.cast(c_labels, tf.float32)
batch_noise = tf.cast(batch_noise, tf.float32)

In [None]:
samples = sess.run(generator(batch_noise, 1, c_labels, False))

In [None]:
plot_images(samples)