# 课时24 Pix2Pix_GAN代码实现

In [1]:
import numpy as np
import tensorflow as tf
import pickle
import matplotlib.pyplot as plt
import warnings
import glob
warnings.filterwarnings("ignore")
tf.__version__

'1.13.1'

>**需要注意的是，原始的GAN的生成器网络接收的输入的一个一维的向量，因此整体的生成器网络其实只是一个AutoEncode的Decode部分；而pix2pixGAN的生成器网络部分接收的输入是一个完整的图片，因此pix2pixGAN的生成器网络是一个完整的AutoEncode网络。当我们使用pix2pixGAN网络进行类似图像翻译的任务的时候，输入与输出之间会共享很多的信息，例如图像轮廓信息等。而这个AutoEncode生成器网络在使用普通的卷积网络进行传递的信息传递的时候，每一层网络都要存储这些信息，会很容易出错，因此为了避免这样的情况发生，我们使用U-Net网络来搭建生成器网络。**

## 1. 定义Generator模块

In [3]:
# 生成器采用的是U-Net网络结果，U-Net采用的也是Encode-Decode的结构
# 其中Encode是卷积结构，Decode是反卷积结构
def generator(inputs_real, is_train=True, alpha=0.01):
    # [256, 256, 3]
    with tf.variable_scope(name_or_scope='generator', reuse=(not is_train)):
        # Encode网络部分：
        # [128, 128, 64]
        conv_1 = tf.layers.conv2d(inputs=inputs_real, filters=64, kernel_size=(3, 3), padding='same')
        conv_1 = tf.nn.relu(conv_1)
        conv_1 = tf.layers.max_pooling2d(inputs=conv_1, pool_size=(2, 2), strides=(2, 2), padding='same')
        
        # [64, 64, 128]
        conv_2 = tf.layers.conv2d(inputs=conv_1, filters=128, kernel_size=(3, 3), padding='same')
        conv_2 = tf.nn.relu(conv_2)
        conv_2 = tf.layers.max_pooling2d(inputs=conv_2, pool_size=(2, 2), strides=(2, 2), padding='same')
        
        # [32, 32, 256]
        conv_3 = tf.layers.conv2d(inputs=conv_2, filters=256, kernel_size=(3, 3), padding='same')
        conv_3 = tf.nn.relu(conv_3)
        conv_3 = tf.layers.max_pooling2d(inputs=conv_3, pool_size=(2, 2), strides=(2, 2), padding='same')
        
        # [16, 16, 512]
        conv_4 = tf.layers.conv2d(inputs=conv_3, filters=512, kernel_size=(3, 3), padding='same')
        conv_4 = tf.nn.relu(conv_4)
        conv_4 = tf.layers.max_pooling2d(inputs=conv_4, pool_size=(2, 2), strides=(2, 2), padding='same')
        
        # [8, 8, 512]
        conv_5 = tf.layers.conv2d(inputs=conv_4, filters=512, kernel_size=(3, 3), padding='same')
        conv_5 = tf.nn.relu(conv_5)
        conv_5 = tf.layers.max_pooling2d(inputs=conv_5, pool_size=(2, 2), strides=(2, 2), padding='same')
        
        # [4, 4, 512]
        conv_6 = tf.layers.conv2d(inputs=conv_5, filters=512, kernel_size=(3, 3), padding='same')
        conv_6 = tf.nn.relu(conv_6)
        conv_6 = tf.layers.max_pooling2d(inputs=conv_6, pool_size=(2, 2), strides=(2, 2), padding='same')
        
        # [2, 2, 512]
        conv_7 = tf.layers.conv2d(inputs=conv_6, filters=512, kernel_size=(3, 3), padding='same')
        conv_7 = tf.nn.relu(conv_7)
        conv_7 = tf.layers.max_pooling2d(inputs=conv_7, pool_size=(2, 2), strides=(2, 2), padding='same')
        
        # [1, 1, 512]
        conv_8 = tf.layers.conv2d(inputs=conv_7, filters=512, kernel_size=(3, 3), padding='same')
        conv_8 = tf.nn.relu(conv_8)
        conv_8 = tf.layers.max_pooling2d(inputs=conv_8, pool_size=(2, 2), strides=(2, 2), padding='same')
        
        # Dncode网络部分：
        # [2, 2, 512]
        conv_9 = tf.layers.conv2d_transpose(inputs=conv_8, filters=512, kernel_size=3, strides=2, padding='same')
        conv_9 = tf.layers.batch_normalization(conv_9, training=is_train)
        conv_9 = tf.nn.relu(conv_9)
        conv_9 = tf.nn.dropout(conv_9, keep_prob=0.5)
        
        # [4, 4, 512]
        conv_10 = tf.concat([conv_9, conv_7], axis=3)
        conv_10 = tf.layers.conv2d_transpose(inputs=conv_10, filters=512, kernel_size=3, strides=2, padding='same')
        conv_10 = tf.layers.batch_normalization(conv_10, training=is_train)
        conv_10 = tf.nn.relu(conv_10)
        conv_10 = tf.nn.dropout(conv_10, keep_prob=0.5)
        
        # [8, 8, 512]
        conv_11 = tf.concat([conv_10, conv_6], axis=3)
        conv_11 = tf.layers.conv2d_transpose(inputs=conv_11, filters=512, kernel_size=3, strides=2, padding='same')
        conv_11 = tf.layers.batch_normalization(conv_11, training=is_train)
        conv_11 = tf.nn.relu(conv_11)
        conv_11 = tf.nn.dropout(conv_11, keep_prob=0.5)
        
        # [16, 16, 512]
        conv_12 = tf.concat([conv_11, conv_5], axis=3)
        conv_12 = tf.layers.conv2d_transpose(inputs=conv_12, filters=512, kernel_size=3, strides=2, padding='same')
        conv_12 = tf.layers.batch_normalization(conv_12, training=is_train)
        conv_12 = tf.nn.relu(conv_12)
        
        # [32, 32, 256]
        conv_13 = tf.concat([conv_12, conv_4], axis=3)
        conv_13 = tf.layers.conv2d_transpose(inputs=conv_13, filters=256, kernel_size=3, strides=2, padding='same')
        conv_13 = tf.layers.batch_normalization(conv_13, training=is_train)
        conv_13 = tf.nn.relu(conv_13)
        
        # [64, 64, 128]
        conv_14 = tf.concat([conv_13, conv_3], axis=3)
        conv_14 = tf.layers.conv2d_transpose(inputs=conv_14, filters=128, kernel_size=3, strides=2, padding='same')
        conv_14 = tf.layers.batch_normalization(conv_14, training=is_train)
        conv_14 = tf.nn.relu(conv_14)
        
        # [128, 128, 64]
        conv_15 = tf.concat([conv_14, conv_2], axis=3)
        conv_15 = tf.layers.conv2d_transpose(inputs=conv_15, filters=64, kernel_size=3, strides=2, padding='same')
        conv_15 = tf.layers.batch_normalization(conv_15, training=is_train)
        conv_15 = tf.nn.relu(conv_15)
        
        # [256, 256, 3]
        conv_16 = tf.concat([conv_15, conv_1], axis=3)
        conv_16 = tf.layers.conv2d_transpose(inputs=conv_16, filters=3, kernel_size=3, strides=2, padding='same')
        conv_16 = tf.layers.batch_normalization(conv_16, training=is_train)
        
        # 图像归一化
        outputs = tf.nn.tanh(conv_16)
        return outputs

## 2. 定义Discriminator模块

In [9]:
# 需要注意的是Discriminator接收成对的输入
def discriminator(inputs_real, inputs_cartoon, reuse=False, alpha=0.01):
    with tf.variable_scope("discriminator", reuse=reuse):
        # 由于要同时对输入的图像和卡通图像进行判别，得出一个结果(1/0)，所以将这两个输入的图像进行合并
        layer0 = tf.concat([inputs_real, inputs_cartoon], axis=3)

        layer1 = tf.layers.conv2d(layer0, 64, 3, strides=2, padding='same')
        layer1 = tf.maximum(alpha*layer1, layer1)
        
        layer2 = tf.layers.conv2d(layer1, 128, 3, strides=2, padding='same')
        layer2 = tf.layers.batch_normalization(layer2, training=True)
        layer2 = tf.maximum(alpha*layer2, layer2)
        
        layer3 = tf.layers.conv2d(layer2, 256, 3, strides=2, padding='same')
        layer3 = tf.layers.batch_normalization(layer3, training=True)
        layer3 = tf.maximum(alpha*layer3, layer3)
        
        layer4 = tf.layers.conv2d(layer3, 512, 3, strides=2, padding='same')
        layer4 = tf.layers.batch_normalization(layer4, training=True)
        layer4 = tf.maximum(alpha*layer4, layer4)
        
        flatten = tf.reshape(layer4, (-1, 16*16*512))
        logits = tf.layers.dense(flatten, 1)
        # 通过sigmoid判断输入的类别(1/0)
        outputs = tf.sigmoid(logits)
        
        return logits, outputs

## 3. 定义Loss模块
- D网络损失函数：
>- 1. 输入真实的成对图像希望判定为1；
>- 2. 输入生成图像和原图像希望判定位0；
- G网络损失函数：
>- 1. 输入生成图像和原图像希望判别为1；

In [10]:
def get_loss(inputs_image, inputs_cartoons, smooth=0.01):
    g_outputs = generator(inputs_image, is_train=True)
    d_logits_real, d_outputs_real = discriminator(inputs_image, inputs_cartoons)
    d_logits_fake, d_outputs_fake = discriminator(inputs_image, g_outputs, reuse=True)
    
    # 计算d_loss
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                                         labels=tf.ones_like(d_outputs_real)*(1-smooth)))
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                         labels=tf.zeros_like(d_outputs_fake)))
    
    # 计算g_loss
    g_loss_gan = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                        labels=tf.ones_like(d_outputs_fake)*(1-smooth)))
    # 为了约束生成网络，将生成网络生成好了的图像与原始的输入生成网络的图像的标签进行L1约束
    # 将生成了的图像扁平化成一个一维向量，方便L1约束计算
    g_outputs_logits = tf.reshape(g_outputs, [-1, 256*256*3])
    # 输入生成网络的图像原始的对应标签也扁平化成一维向量，方便L1约束计算
    inputs_cartoons_logits = tf.reshape(inputs_cartoons, [-1, 256*256*3])
    g_loss_L1 = tf.reduce_mean(tf.reduce_sum(tf.abs(g_outputs_logits-inputs_cartoons_logits)))
    
    # 计算Loss和
    g_loss = tf.add(g_loss_gan, g_loss_L1)
    d_loss = tf.add(d_loss_real, d_loss_fake)
    
    return g_loss, d_loss

## 4. 定义optimizer模块

In [11]:
def get_optimizer(g_loss, d_loss, beta1=0.4, learning_rate=0.001):
    train_vars = tf.trainable_variables()
    g_vars = [var for var in train_vars if var.name.startswith("generator")]
    d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
    
    # Optimizer
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        g_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)
        d_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
    
    return g_opt, d_opt

## 5. 定义辅助函数

In [16]:
def plot_images(samples):
    # 由于生成模型生成的数据是使用tanh输出的，所以其值处于[-1, 1]之间，不适合画图
    # 因此这里先需要将图像的值转换到[0, 1]之间才行
    samples = (samples + 1) / 2
    fig, axes = plt.subplots(nrows=1, ncols=5, sharex=True, sharey=True, figsize=(10,2))
    for img, ax in zip(samples, axes):
        # 然后这里需要注意的是这里需将其reshape成原始图像对应的大小尺寸
        ax.imshow(img.reshape((250, 200, 3)))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad=0)

In [18]:
def show_generator_output(sess, samp_images):
    samples = sess.run(generator(samp_images, False))
    samples = sess.run(tf.image.resize_image_with_crop_or_pad(samples, 250, 200))
    return samples

## 6. 定义训练模块

In [19]:
# 定义参数
learning_rate = 0.001
beta1 = 0.4

In [34]:
def train():
    # 存储loss
    losses = []
    steps = 300
        
    image_filenames = glob.glob('../日月光华-GAN大型数据集/1.0部分代码和数据集/数据集/pix2pixGAN数据集/training_photos/*.jpg')
    cartoon_filenames = glob.glob('日月光华-GAN大型数据集/1.0部分代码和数据集/数据集/pix2pixGAN数据集/training__sketches/*.jpg')
    
    image_que = tf.train.slice_input_producer([image_filenames, cartoon_filenames], 
                                              shuffle=True)
    image_ = tf.read_file(image_que[0])
    image = tf.image.decode_jpeg(image_, channels=3)
    image = tf.image.resize_image_with_crop_or_pad(image, 256, 256)
    new_img = tf.image.per_image_standardization(image)
        
    cartoon_ = tf.read_file(image_que[1])
    cartoon = tf.image.decode_jpeg(cartoon_, channels=3)
    cartoon = tf.image.resize_image_with_crop_or_pad(cartoon, 256, 256)
    new_cartoon = tf.image.per_image_standardization(cartoon)
    
    batch_size = 5
    capacity = 3 + 2*batch_size
          
    image_batch, cartoon_batch = tf.train.batch([new_img, new_cartoon], batch_size=batch_size, capacity=capacity)
    
    g_loss, d_loss = get_loss(image_batch, cartoon_batch)
    g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, beta1, learning_rate)
    
    
    saver = tf.train.Saver()
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        sess.run(tf.global_variables_initializer())
        
        # 迭代epoch
        for e in range(steps):
            # run optimizer
            _ = sess.run(g_train_opt)
            _ = sess.run(d_train_opt)
                
            if e % 50 == 0:
                saver.save(sess,'../tf_saver_files/class_10_of_pix2pix/generator.ckpt',global_step=e)
                train_loss_d = d_loss.eval()
                
                train_loss_g = g_loss.eval()
                losses.append((train_loss_d, train_loss_g))
                # 显示图片
                samples = show_generator_output(sess, image_batch)
                plot_images(samples)
                print("Epoch {}/{}....".format(e+1, steps), 
                      "Discriminator Loss: {:.4f}....".format(train_loss_d),
                      "Generator Loss: {:.4f}....". format(train_loss_g))
        saver.save(sess,'../tf_saver_files/class_10_of_pix2pix/generator.ckpt',global_step=steps)
        coord.request_stop()
        coord.join(threads)                  

In [35]:
with tf.Graph().as_default():
    train()

TypeError: Input 'filename' of 'ReadFile' Op has type float32 that does not match expected type of string.