# 课时25 pix2pixGAN实例(TF2.0城市街景数据集)

In [3]:
import numpy as np
import tensorflow as tf
import pickle
import matplotlib.pyplot as plt
import warnings
import glob
warnings.filterwarnings("ignore")
tf.__version__

'2.1.0'

>1. **需要注意的是，原始的GAN的生成器网络接收的输入的一个一维的向量，因此整体的生成器网络其实只是一个AutoEncode的Decode部分；而pix2pixGAN的生成器网络部分接收的输入是一个完整的图片，因此pix2pixGAN的生成器网络是一个完整的AutoEncode网络。当我们使用pix2pixGAN网络进行类似图像翻译的任务的时候，输入与输出之间会共享很多的信息，例如图像轮廓信息等。而这个AutoEncode生成器网络在使用普通的卷积网络进行传递的信息传递的时候，每一层网络都要存储这些信息，会很容易出错，因此为了避免这样的情况发生，我们使用U-Net网络来搭建生成器网络。**
>2. **pix2pixGAN本质上是一种图像翻译模型。也就是把一张图像翻译成另一张图像。**

## 1. 定义数据集读取和预处理函数

In [None]:
train_images_path = '../'
train_images_path_list = glob.glob(pathname=train_images_path)
len(train_images_path_list)

In [None]:
train_images_path_list[:3]

In [None]:
plt.show(tf.keras.preprocessing.image.load_img(train_images_path_list[0]))

In [5]:
# 定义图片解码函数
def read_jpg(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    return img

In [4]:
# 定义归一化函数
def normalize(mask, image):
    mask = tf.cast(mask, tf.float32) / 127.5 - 1
    image = tf.cast(image, tf.float32) / 127.5 - 1
    return mask, image

In [9]:
# 定义图像加载函数
def load_image(image_path):
    image = read_jpg(image_path)
    width = tf.shape(image)[1]
    # 由于image和它对应的mask是合并到一张图片的，因此需要从中间切分开来
    w = w // 2
    input_mask = image[:, w:, :]
    input_image = image[:, :w, :]
    
    # 下面两个resize的操作实际上没有什么实质性的影响，因为mask和image本身就是[256, 256, 3]的
    # 但是下面这两个操作能够使得mask和image在做成datasets之后可以正常显示两个的shape大小
    input_mask = tf.image.resize(input_mask, (256, 256))
    input_image = tf.image.resize(input_image, (256, 256))
    
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)
    input_mask, input_image = normalize(input_mask, input_image)
    return input_mask, input_image

In [11]:
BATCH_SIZE = 64
BUFFER_SIZE = len(train_images_path_list)

In [8]:
# 创建datasets
dataset = tf.data.Dataset.from_tensor_slices(train_images_path_list)
dataset = dataset.map(load_image)
dataset = dataset.shuffer(BUFFER_SIZE).batch(BATCH_SIZE)
# GPU在训练当前批次的时候，使用prefetch函数能够让CPU去预加载另一批数据
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

In [None]:
# 先展示一下图像（需要注意的是dataset.take(1)这里代表的是取一个batch的数据）
for mask, img in dataset.take(1):
    plt.subplot(1, 2, 1)
    # 这里使用tf.keras.preprocessing.image.array_to_img是因为到这里mask和image已经归一化了，存在你显示不正常的情况
    plt.imshow(tf.keras.preprocessing.image.array_to_img(mask[0]))
    plt.subplot(1, 2, 2)
    plt.imshow(tf.keras.preprocessing.image.array_to_img(image[0]))

In [None]:
# ================================================================================

In [None]:
# 上面训练数据加载完毕之后开始加载测试数据
test_images_path = '../'
test_images_path_list = glob.glob(pathname=test_images_path)
dataset_test = tf.data.Dataset.from_tensor_slices(test_images_path_list)

In [13]:
# 定义图像加载函数
def load_image_test(image_path):
    image = read_jpg(image_path)
    width = tf.shape(image)[1]
    # 由于image和它对应的mask是合并到一张图片的，因此需要从中间切分开来
    w = w // 2
    input_mask = image[:, w:, :]
    input_image = image[:, :w, :]
    
    # 下面两个resize的操作实际上没有什么实质性的影响，因为mask和image本身就是[256, 256, 3]的
    # 但是下面这两个操作能够使得mask和image在做成datasets之后可以正常显示两个的shape大小
    input_mask = tf.image.resize(input_mask, (256, 256))
    input_image = tf.image.resize(input_image, (256, 256))
    
    input_mask, input_image = normalize(input_mask, input_image)
    return input_mask, input_image

In [None]:
dataset_test = dataset_test.map(load_image_test)
dataset_test = dataset_test.batch(16)

In [None]:
# 先展示一下图像（需要注意的是dataset.take(1)这里代表的是取一个batch的数据）
for mask, img in dataset_test.take(1):
    plt.subplot(1, 2, 1)
    # 这里使用tf.keras.preprocessing.image.array_to_img是因为到这里mask和image已经归一化了，存在你显示不正常的情况
    plt.imshow(tf.keras.preprocessing.image.array_to_img(mask[0]))
    plt.subplot(1, 2, 2)
    plt.imshow(tf.keras.preprocessing.image.array_to_img(image[0]))

## 2. 定义上采用和下采样模型

In [18]:
def down_sampling(filters, kernel_size, apply_bn_flag=True):
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=2, 
                               padding='same', use_bias=False))
    if apply_bn_flag:
        model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU())
    return model

In [19]:
def up_sampling(filters, kernel_size, apply_dropout_flag=False):
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2DTranspose(filters=filters, kernel_size=kernel_size, 
                                              strides=2, padding='same', use_bias=False))

    model.add(tf.keras.layers.BatchNormalization())
    if apply_dropout_flag:
        model.add(tf.keras.layers.Dropout(rate=0.5))
    model.add(tf.keras.layers.ReLU())
    return model

In [24]:
# 定义Generator模型(U-Net架构)
def Generator():
    inputs = tf.keras.layers.Input(shape=(256, 256, 3))
    down_stach = [
        # [256, 256, 3] ===> [128, 128, 64]
        down_sampling(filters=64, kernel_size=3, apply_bn_flag=False),
        # [128, 128, 64] ===> [64, 64, 128]
        down_sampling(filters=128, kernel_size=3),
        # [64, 64, 128] ===> [32, 32, 256]
        down_sampling(filters=256, kernel_size=3),
        # [32, 32, 256] ===> [16, 16, 512]
        down_sampling(filters=512, kernel_size=3),
        
        # [16, 16, 512] ===> [8, 8, 512]
        down_sampling(filters=512, kernel_size=3),
        # [8, 8, 512] ===> [4, 4, 512]
        down_sampling(filters=512, kernel_size=3),
        # [4, 4, 512] ===> [2, 2, 512]
        down_sampling(filters=512, kernel_size=3),
        # [2, 2, 512] ===> [1, 1, 512]
        down_sampling(filters=512, kernel_size=3)
    ]
    
    up_stach = [
        # [1, 1, 512] ===> [2, 2, 512]
        up_sampling(filters=512, kernel_size=3, apply_dropout_flag=True),
        # [2, 2, 512] ===> [4, 4, 512]
        up_sampling(filters=512, kernel_size=3, apply_dropout_flag=True),
        # [4, 4, 512] ===> [8, 8, 512]
        up_sampling(filters=512, kernel_size=3, apply_dropout_flag=True),
        # [8, 8, 512] ===> [16, 16, 512]
        up_sampling(filters=512, kernel_size=3),
        
        # [16, 16, 512] ===> [32, 32, 256]
        up_sampling(filters=256, kernel_size=3),
        # [32, 32, 256] ===> [64, 64, 128]
        up_sampling(filters=128, kernel_size=3),
        # [64, 64, 128] ===> [128, 128, 64]
        up_sampling(filters=64, kernel_size=3)
    ]
    
    x = inputs
    skips = []
    for down in down_stach:
        x = down(x)
        skips.append(x)
    skips = reversed(skips[:-1])
    
    for up, skip in zip(up_stach, skips):
        x = up(x)
        x = tf.keras.layers.concatenate([x, skip])
        
    # 由于图片的预处理中，对图片进行了[-1, 1]区间的归一化处理，因此使用activation='tanh'
    x = tf.keras.layers.Conv2DTranspose(filters=3, kernel_size=3, strides=2, 
                                        padding='same', activation='tanh')(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

In [25]:
generator = Generator()

In [26]:
tf.keras.utils.plot_model(generator, show_shapes=True)

Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.


In [27]:
generator.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
sequential_31 (Sequential)      (None, 128, 128, 64) 1728        input_4[0][0]                    
__________________________________________________________________________________________________
sequential_32 (Sequential)      (None, 64, 64, 128)  74240       sequential_31[0][0]              
__________________________________________________________________________________________________
sequential_33 (Sequential)      (None, 32, 32, 256)  295936      sequential_32[0][0]              
______________________________________________________________________________________________

In [28]:
def Discriminator():
    inputs = tf.keras.layers.Input(shape=(256, 256, 3))
    targets = tf.keras.layers.Input(shape=(256, 256, 3))
    
    x = tf.keras.layers.concatenate([inputs, targets]) # [256, 256, 6]
    x = down_sampling(filters=64, kernel_size=3, apply_bn_flag=False)(x) # [128, 128, 64]
    x = down_sampling(filters=128, kernel_size=3)(x) # [64, 64, 128]
    x = down_sampling(filters=256, kernel_size=3)(x) # [32, 32, 256]
    x = tf.keras.layers.Conv2D(filters=512, kernel_size=3, strides=1, 
                               padding='same', use_bias=False)(x) # [32, 32, 512]
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)
    
    x = tf.keras.layers.Conv2D(filters=1, kernel_size=3, strides=1)(x) # [30, 30, 512]
    return tf.keras.Model(inputs=[inputs, targets], outputs=x)

In [30]:
discriminator = Discriminator()
discriminator.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_7 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
input_8 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
concatenate_10 (Concatenate)    (None, 256, 256, 6)  0           input_7[0][0]                    
                                                                 input_8[0][0]                    
__________________________________________________________________________________________________
sequential_49 (Sequential)      (None, 128, 128, 64) 3456        concatenate_10[0][0]       

## 3. 定义损失函数

In [31]:
loss_func = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [32]:
def generator_loss(d_gen_output, gen_output, target):
    gen_loss = loss_func(y_true=tf.ones_like(d_gen_output), y_pred=d_gen_output)
    l1_loss = tf.reduce_mean(tf.abs(gen_output-target))
    return gen_loss + l1_loss

In [33]:
def discriminator(d_real_output, d_fake_output):
    real_loss = loss_func(y_true=tf.ones_like(d_real_output), y_pred=d_real_output)
    fake_loss = loss_func(y_true=tf.zeros_like(d_fake_output), y_pred=d_fake_output)
    
    return real_loss + fake_loss

In [34]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
def generate_images(model, test_inputs, tar):
    predition = model(test_inputs, training=True)
    plt.figure(figsize=(15, 15))
    
    display_list = [test_inputs[0], tar[0], predition[0]]
    titles = ['Input Image', 'Ground Truth', 'Predicted Image']
    
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(titles[i])
        plt.imshow(display_list[i]*0.5+0.5)
        plt.axis('off')
    plt.show()