<a href="https://colab.research.google.com/github/TA-aiacademy/course_3.0/blob/v2-5_gan/08_v2-5_GAN/Part2/01_Pix2Pix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pix2Pix (Image-to-Image Translation with Conditional Adversarial Networks)

### 本章節內容大綱
* [Build a Pix2Pix Model](#Build-a-Pix2Pix-Model)
* [Build a UNET model as Generator](#Build-a-UNET-model-as-Generator)
* [Build a PatchGAN as Discriminator](#Build-a-PatchGAN-as-Discriminator)
* [Loss function](#Loss-function)

<br>
<img src="https://hackmd.io/_uploads/BJbwlIWl6.jpg" width=700  />

Pix2Pix是CGAN(conditional GAN)的一種，與vanilla gan不同的是，其input為一張圖片，而output出我們想要變成的樣子(不像是vanilla GAN無法控制想要output的樣子)，像是上面的例子：把標記的segmentation變成圖片、把黑白照片轉成彩色照片等等，資料上需要一個個的pair data(labels與圖片)，是supervised learning的一種。

# Import

In [None]:
''' basic package '''
import os
import time
import glob
from IPython.display import display, Image, clear_output

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import (
    Dense, Conv2DTranspose, Conv2D, BatchNormalization,
    LeakyReLU, Dropout, Reshape, Flatten
)

from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [None]:

# 下載 facades dataset 到 hub 裡面的 home/jovyan/.keras 資料夾
_URL = 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz'

path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
                                      origin=_URL,
                                      extract=True)

PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')

# Config

In [None]:
BUFFER_SIZE = 400  # total 400 張圖片
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [None]:
# plot 原始圖片
image = tf.io.read_file(PATH+'train/100.jpg')
# 將 image decode 為 unit8 的 tensor
image = tf.image.decode_jpeg(image)
plt.imshow(image)

In [None]:
def load(image_file):
    image = tf.io.read_file(image_file)
    # 將 image decode 為 unit8 的 tensor
    image = tf.image.decode_jpeg(image)

    w = tf.shape(image)[1]

    # 因為原始圖片是將 image 與 label 黏在一起，所以這邊要把他們切開
    w = w // 2
    real_image = image[:, :w, :]
    input_image = image[:, w:, :]

    # 將 image datatype 改為 float32
    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)

    return input_image, real_image

In [None]:
inp, re = load(PATH+'train/100.jpg')
fig, axs = plt.subplots(1, 2)
axs[0].imshow(inp/255.0)
axs[1].imshow(re/255.0)

In [None]:
def resize(input_image, real_image, height, width):
    # 將圖片放大成我們想要的大小，方法是用 NEAREST_NEIGHBOR 最鄰近插值法
    input_image = tf.image.resize(input_image, [height, width],
                                  method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width],
                                 method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    return input_image, real_image


def random_crop(input_image, real_image):
    # 將圖片一張一張疊起來(類似append)
    stacked_image = tf.stack([input_image, real_image], axis=0)
    # 將圖片隨機切割成我們想要的大小： 256 x 256
    # 我們不想要隨機切 stack_size 與 channel 的部分，故就輸入其原始的 shape: stack_size=2 和 channel=3
    cropped_image = tf.image.random_crop(
        stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image[0], cropped_image[1]


# 將圖片正規化到 -1 到 1 之間
def normalize(input_image, real_image):
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1

    return input_image, real_image

In [None]:
@tf.function()
# data augmentation
def random_jitter(input_image, real_image):
    # 調整圖片大小至 286 x 286 x 3
    input_image, real_image = resize(input_image, real_image, 286, 286)

    # 隨機切割圖片至 256 x 256 x 3
    input_image, real_image = random_crop(input_image, real_image)

    # 隨機水平翻轉圖片
    if tf.random.uniform(()) > 0.5:
        # random mirroring
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)

    return input_image, real_image

In [None]:
# 檢查 jitter 後的結果
plt.figure(figsize=(6, 6))
for i in range(4):
    rj_inp, rj_re = random_jitter(inp, re)
    plt.subplot(2, 2, i+1)
    plt.imshow(rj_inp/255.0)
    plt.axis('off')
plt.show()

In [None]:
# 定義 train/test generator
def load_image_train(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image


# no jitter at testing!
def load_image_test(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = resize(input_image, real_image,
                                     IMG_HEIGHT, IMG_WIDTH)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image

In [None]:
train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')
# num_parallel_calls 是一次準備多少圖片一起處理，他可以最佳化到底要讀多少圖的這個參數
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)


test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)

In [None]:
OUTPUT_CHANNELS = 3

# Build a Pix2Pix Model

<br>
<img src="https://hackmd.io/_uploads/HkrYxLWx6.png" width=700  />

Pix2Pix的Generator因為是要將標籤圖片(labels)轉為照片，所以整個模型要使用類似Autoencoder的結構，稱作Unet，其實Unet與Autoencoder的結構基本上一樣，只是增加了skip connection來增加圖像的品質。
而Discriminator則與之前差不多，目的是要透過真實的data作為依據，判斷餵進來的圖是真的還是假的。

這邊先將 Generator 與 Discriminator 需要用到的結構：Downsample與Upsample寫成 function 備用。

In [None]:
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)  # mean=0, stddev=0.02

    result = tf.keras.Sequential()

    # 因為預設會使用 batchnorm，所以不需要加 bias
    result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()

    # 還記得 Vanilla GAN 裡提到的 Conv2DTranspose 的介紹嗎？忘記了可以再回去看喔！
    result.add(
        tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                        padding='same',
                                        kernel_initializer=initializer,
                                        use_bias=False))

    result.add(tf.keras.layers.BatchNormalization())

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result

## Build a UNET model as Generator

Unet的結構如前述所說，就是一個有skip connection的Autoencoder，

In [None]:
def Generator():
    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  # (bs, 128, 128, 64)
        downsample(128, 4),  # (bs, 64, 64, 128)
        downsample(256, 4),  # (bs, 32, 32, 256)
        downsample(512, 4),  # (bs, 16, 16, 512)
        downsample(512, 4),  # (bs, 8, 8, 512)
        downsample(512, 4),  # (bs, 4, 4, 512)
        downsample(512, 4),  # (bs, 2, 2, 512)
        downsample(512, 4),  # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True),  # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True),  # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True),  # (bs, 8, 8, 1024)
        upsample(512, 4),  # (bs, 16, 16, 1024)
        upsample(256, 4),  # (bs, 32, 32, 512)
        upsample(128, 4),  # (bs, 64, 64, 256)
        upsample(64, 4),  # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    # 最後 output 的 range 要在 -1 ~ 1 之間，所以選用的 activation function 是 "tanh"
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                           strides=2,
                                           padding='same',
                                           kernel_initializer=initializer,
                                           activation='tanh')  # (bs, 256, 256, 3)

    concat = tf.keras.layers.Concatenate()
    inputs = tf.keras.layers.Input(shape=[None, None, 3])
    x = inputs

    # Downsampling
    # 用一個 list 將每層的輸出存起來，之後再 Upsampling 時可以使用
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
    skips = reversed(skips[:-1])  # 把 skip connections 的值存起來並顛倒，後面在 upsampling 時會用到

    # Upsampling 和 skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = concat([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
generator = Generator()

# 測試 output 是不是跟我們想要的一樣
# 記得training要設成 false 否則 batchnorm 裡面的參數會被更新
gen_output = generator(inp[tf.newaxis, ...], training=False)
plt.imshow(gen_output[0, ...])

## Build a PatchGAN as Discriminator

Discriminator相對單純，比較不一樣的地方是這邊所使用的是Markovian discriminator(PatchGAN)，簡單來說就是在最後的輸出不要將feature maps轉成一個值作分類，而是把圖切成NxN的Patch一格格作判斷，之後取平均的概念。

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')
    tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')

    x = tf.keras.layers.concatenate([inp, tar], axis=-1)  # (bs, 256, 256, channels*2)

    down1 = downsample(64, 4, False)(x)  # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1)  # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2)  # (bs, 32, 32, 256)

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (bs, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                  kernel_initializer=initializer,
                                  use_bias=False)(zero_pad1)  # (bs, 31, 31, 512)

    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (bs, 33, 33, 512)

    last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                  kernel_initializer=initializer)(zero_pad2)  # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [None]:
# 畫出 discriminator 的 output，確認是我們要的樣子
discriminator = Discriminator()
disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False)
plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()

# Loss function

這邊的discriminator與一般的一樣，而Generator有些微的差異，即加入了identity loss來修正output與實際圖片的差異。

- Discriminator loss
    - real_loss: 更新將真實資料判斷成生成資料的狀況
    - generated_loss: 更新將生成資料判斷成真實資料的狀況
    
- Generator loss
    - gan_loss: 更新生成資料被Discriminator判斷成生成資料的狀況
    - l1_loss: 修正生成資料與對應該label image的真實資料實際的差異（identity loss）

In [None]:
LAMBDA = 100

loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss


def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

    # 加入 L1 loss 來增加模型的 robustness 與細緻度， 影像也較 L2 loss 不模糊
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

    total_gen_loss = gan_loss + (LAMBDA * l1_loss)

    return total_gen_loss


generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [None]:
EPOCHS = 50


def generate_images(model, test_input, tar):

    prediction = model(test_input, training=True)  # 這邊設 training=True 是希望能得到 test_input 的一些統計量
    plt.figure(figsize=(15, 15))

    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        # 將圖片像素值調整至 0 - 1 之間才能 plot
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()


@tf.function
def train_step(input_image, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)

        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)

        gen_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_loss,
                                            generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,
                                                 discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients,
                                            generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                discriminator.trainable_variables))


def fit(train_ds, epochs, test_ds):
    for epoch in range(epochs):
        start = time.time()

        # Train
        for input_image, target in train_ds:
            train_step(input_image, target)

        # 每 10 個 epochs 顯示一次圖片
        if (epoch + 1) % 10 == 0:
            for example_input, example_target in test_ds.take(1):
                generate_images(generator, example_input, example_target)

        # 每 50 個 epochs 存一次 weight
        if (epoch + 1) % 50 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)

        print('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                           time.time()-start))

In [None]:
fit(train_dataset, EPOCHS, test_dataset)

最後將存好的weight載入，試試看生成一些圖片吧！

In [None]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
for inp, tar in test_dataset.take(5):
    generate_images(generator, inp, tar)