In [None]:
!pip install --upgrade pip

In [None]:
!pip install -r requirements.txt

In [None]:
import os
import time
import random
import numpy as np
import scipy, multiprocessing
import tensorflow as tf
import tensorlayer as tl
from model import get_G, get_D
from config import config


In [None]:
###====================== HYPER-PARAMETERS ===========================###

batch_size = config.TRAIN.batch_size  # use 8 if your GPU memory is small, and change [4, 4] in tl.vis.save_images to [2, 4]
lr_init = config.TRAIN.lr_init
beta1 = config.TRAIN.beta1
## initialize G
n_epoch_init = config.TRAIN.n_epoch_init

n_epoch = config.TRAIN.n_epoch
lr_decay = config.TRAIN.lr_decay
decay_every = config.TRAIN.decay_every
shuffle_buffer_size = 128


In [3]:
# create folders to save result images and trained models
save_dir = "samples"
tl.files.exists_or_mkdir(save_dir)
checkpoint_dir = "checkpoint"
tl.files.exists_or_mkdir(checkpoint_dir)

[TL] [!] samples exists ...
[TL] [!] checkpoint exists ...


True

In [4]:
def get_train_data():
    # load dataset
    train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))[0:200]

    ## If your machine have enough memory, please pre-load the entire train set.
    train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
        
    # dataset API and augmentation
    def generator_train():
        for img in train_hr_imgs:
            yield img
    def _map_fn_train(img):
        hr_patch = tf.image.random_crop(img, [384, 384, 3])
        hr_patch = hr_patch / (255. / 2.)
        hr_patch = hr_patch - 1.
        hr_patch = tf.image.random_flip_left_right(hr_patch)
        lr_patch = downscale_hr_patches(hr_patch)
        return lr_patch, hr_patch
    train_ds = tf.data.Dataset.from_generator(generator_train, output_types=(tf.float32))
    train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())
        # train_ds = train_ds.repeat(n_epoch_init + n_epoch)
    train_ds = train_ds.shuffle(shuffle_buffer_size)
    train_ds = train_ds.prefetch(buffer_size=2)
    train_ds = train_ds.batch(batch_size)
        # value = train_ds.make_one_shot_iterator().get_next()
    return train_ds

def downscale_hr_patches(hr_patch):
    return tf.image.resize(hr_patch, size=[96, 96])

In [5]:
G = get_G((batch_size, 96, 96, 3))
D = get_D((batch_size, 384, 384, 3))
VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')

lr_v = tf.Variable(lr_init)
g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1)
g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)

G.train()
D.train()
VGG.train()

train_ds = get_train_data()

[TL] Input  _inputlayer_1: (8, 96, 96, 3)
[TL] Conv2d conv2d_1: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: relu
[TL] Conv2d conv2d_2: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_1: decay: 0.900000 epsilon: 0.000010 act: relu is_train: False
[TL] Conv2d conv2d_3: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_2: decay: 0.900000 epsilon: 0.000010 act: No Activation is_train: False
[TL] Elementwise elementwise_1: fn: add act: No Activation
[TL] Conv2d conv2d_4: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_3: decay: 0.900000 epsilon: 0.000010 act: relu is_train: False
[TL] Conv2d conv2d_5: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_4: decay: 0.900000 epsilon: 0.000010 act: No Activation is_train: False
[TL] Elementwise elementwise_2

[TL] BatchNorm batchnorm2d_34: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: False
[TL] Conv2d conv2d_40: n_filter: 256 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_35: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: False
[TL] Conv2d conv2d_41: n_filter: 512 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_36: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: False
[TL] Conv2d conv2d_42: n_filter: 1024 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_37: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: False
[TL] Conv2d conv2d_43: n_filter: 2048 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_38: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: False
[TL] Conv2d conv2d_44: n_filter: 1024 filter_size: (1, 1) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNo

In [None]:
#G

In [None]:
#D

In [None]:
## initialize learning (G)
n_step_epoch = round(n_epoch_init // batch_size)
for epoch in range(n_epoch_init):
    for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
        if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
            break
        step_time = time.time()
        with tf.GradientTape() as tape:
            fake_hr_patchs = G(lr_patchs)

            mse_f_lr_p = 0.0
            if config.DOWNSCALE_COMPARE : 
                fake_lr_patches = downscale_hr_patches(fake_hr_patchs)

                mse_f_lr_p = tl.cost.mean_squared_error(fake_lr_patches, lr_patchs, is_mean=True)

            mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True) + mse_f_lr_p

        grad = tape.gradient(mse_loss, G.trainable_weights)
        g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
        print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format(
            epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss))
    if (epoch != 0) and (epoch % 10 == 0):
        tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_g_init_{}.png'.format(epoch)))
        

Epoch: [0/100] step: [0/12] time: 0.977s, mse: 0.819 
Epoch: [0/100] step: [1/12] time: 0.469s, mse: 0.435 
Epoch: [0/100] step: [2/12] time: 0.540s, mse: 0.510 
Epoch: [0/100] step: [3/12] time: 0.414s, mse: 0.593 
Epoch: [0/100] step: [4/12] time: 0.445s, mse: 0.647 
Epoch: [0/100] step: [5/12] time: 0.465s, mse: 0.534 
Epoch: [0/100] step: [6/12] time: 0.448s, mse: 0.602 
Epoch: [0/100] step: [7/12] time: 0.432s, mse: 0.577 
Epoch: [0/100] step: [8/12] time: 0.350s, mse: 0.536 
Epoch: [0/100] step: [9/12] time: 0.351s, mse: 0.864 
Epoch: [0/100] step: [10/12] time: 0.347s, mse: 0.712 
Epoch: [0/100] step: [11/12] time: 0.348s, mse: 0.615 
Epoch: [0/100] step: [12/12] time: 0.449s, mse: 0.396 
Epoch: [0/100] step: [13/12] time: 0.351s, mse: 0.483 
Epoch: [0/100] step: [14/12] time: 0.357s, mse: 0.520 
Epoch: [0/100] step: [15/12] time: 0.360s, mse: 0.553 
Epoch: [0/100] step: [16/12] time: 0.356s, mse: 0.466 
Epoch: [0/100] step: [17/12] time: 0.356s, mse: 0.289 
Epoch: [0/100] step:

Epoch: [6/100] step: [1/12] time: 0.439s, mse: 0.087 
Epoch: [6/100] step: [2/12] time: 0.440s, mse: 0.102 
Epoch: [6/100] step: [3/12] time: 0.404s, mse: 0.081 
Epoch: [6/100] step: [4/12] time: 0.603s, mse: 0.089 
Epoch: [6/100] step: [5/12] time: 0.421s, mse: 0.084 
Epoch: [6/100] step: [6/12] time: 0.443s, mse: 0.090 
Epoch: [6/100] step: [7/12] time: 0.422s, mse: 0.057 
Epoch: [6/100] step: [8/12] time: 0.357s, mse: 0.083 
Epoch: [6/100] step: [9/12] time: 0.354s, mse: 0.097 
Epoch: [6/100] step: [10/12] time: 0.353s, mse: 0.057 
Epoch: [6/100] step: [11/12] time: 0.355s, mse: 0.058 
Epoch: [6/100] step: [12/12] time: 0.351s, mse: 0.072 
Epoch: [6/100] step: [13/12] time: 0.354s, mse: 0.075 
Epoch: [6/100] step: [14/12] time: 0.538s, mse: 0.076 
Epoch: [6/100] step: [15/12] time: 0.356s, mse: 0.092 
Epoch: [6/100] step: [16/12] time: 0.350s, mse: 0.071 
Epoch: [6/100] step: [17/12] time: 0.352s, mse: 0.058 
Epoch: [6/100] step: [18/12] time: 0.350s, mse: 0.102 
Epoch: [6/100] step

Epoch: [12/100] step: [1/12] time: 0.389s, mse: 0.076 
Epoch: [12/100] step: [2/12] time: 0.428s, mse: 0.051 
Epoch: [12/100] step: [3/12] time: 0.406s, mse: 0.054 
Epoch: [12/100] step: [4/12] time: 0.563s, mse: 0.059 
Epoch: [12/100] step: [5/12] time: 0.447s, mse: 0.060 
Epoch: [12/100] step: [6/12] time: 0.439s, mse: 0.044 
Epoch: [12/100] step: [7/12] time: 0.452s, mse: 0.057 
Epoch: [12/100] step: [8/12] time: 0.354s, mse: 0.045 
Epoch: [12/100] step: [9/12] time: 0.355s, mse: 0.130 
Epoch: [12/100] step: [10/12] time: 0.351s, mse: 0.032 
Epoch: [12/100] step: [11/12] time: 0.347s, mse: 0.069 
Epoch: [12/100] step: [12/12] time: 0.367s, mse: 0.065 
Epoch: [12/100] step: [13/12] time: 0.355s, mse: 0.054 
Epoch: [12/100] step: [14/12] time: 0.510s, mse: 0.067 
Epoch: [12/100] step: [15/12] time: 0.348s, mse: 0.070 
Epoch: [12/100] step: [16/12] time: 0.350s, mse: 0.084 
Epoch: [12/100] step: [17/12] time: 0.354s, mse: 0.058 
Epoch: [12/100] step: [18/12] time: 0.358s, mse: 0.057 
E

Epoch: [18/100] step: [0/12] time: 0.403s, mse: 0.055 
Epoch: [18/100] step: [1/12] time: 0.431s, mse: 0.042 
Epoch: [18/100] step: [2/12] time: 0.451s, mse: 0.075 
Epoch: [18/100] step: [3/12] time: 0.432s, mse: 0.072 
Epoch: [18/100] step: [4/12] time: 0.588s, mse: 0.048 
Epoch: [18/100] step: [5/12] time: 0.451s, mse: 0.045 
Epoch: [18/100] step: [6/12] time: 0.468s, mse: 0.041 
Epoch: [18/100] step: [7/12] time: 0.443s, mse: 0.034 
Epoch: [18/100] step: [8/12] time: 0.359s, mse: 0.038 
Epoch: [18/100] step: [9/12] time: 0.358s, mse: 0.055 
Epoch: [18/100] step: [10/12] time: 0.359s, mse: 0.043 
Epoch: [18/100] step: [11/12] time: 0.376s, mse: 0.037 
Epoch: [18/100] step: [12/12] time: 0.347s, mse: 0.050 
Epoch: [18/100] step: [13/12] time: 0.352s, mse: 0.033 
Epoch: [18/100] step: [14/12] time: 0.503s, mse: 0.042 
Epoch: [18/100] step: [15/12] time: 0.368s, mse: 0.055 
Epoch: [18/100] step: [16/12] time: 0.356s, mse: 0.039 
Epoch: [18/100] step: [17/12] time: 0.353s, mse: 0.033 
Ep

Epoch: [23/100] step: [23/12] time: 0.355s, mse: 0.042 
Epoch: [23/100] step: [24/12] time: 0.352s, mse: 0.026 
Epoch: [24/100] step: [0/12] time: 0.461s, mse: 0.049 
Epoch: [24/100] step: [1/12] time: 0.462s, mse: 0.037 
Epoch: [24/100] step: [2/12] time: 0.451s, mse: 0.044 
Epoch: [24/100] step: [3/12] time: 0.418s, mse: 0.045 
Epoch: [24/100] step: [4/12] time: 0.576s, mse: 0.051 
Epoch: [24/100] step: [5/12] time: 0.454s, mse: 0.045 
Epoch: [24/100] step: [6/12] time: 0.452s, mse: 0.033 
Epoch: [24/100] step: [7/12] time: 0.425s, mse: 0.028 
Epoch: [24/100] step: [8/12] time: 0.346s, mse: 0.027 
Epoch: [24/100] step: [9/12] time: 0.345s, mse: 0.031 
Epoch: [24/100] step: [10/12] time: 0.368s, mse: 0.038 
Epoch: [24/100] step: [11/12] time: 0.354s, mse: 0.041 
Epoch: [24/100] step: [12/12] time: 0.351s, mse: 0.038 
Epoch: [24/100] step: [13/12] time: 0.358s, mse: 0.037 
Epoch: [24/100] step: [14/12] time: 0.503s, mse: 0.034 
Epoch: [24/100] step: [15/12] time: 0.350s, mse: 0.038 
Ep

In [None]:
G.trainable_weights

In [None]:
G.save_weights(os.path.join(checkpoint_dir, 'g-initial.h5'))

In [None]:
## adversarial learning (G, D)
n_step_epoch = round(n_epoch // batch_size)
for epoch in range(n_epoch):
    for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
        if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
            break
        step_time = time.time()
        with tf.GradientTape(persistent=True) as tape:
            fake_patchs = G(lr_patchs)
            logits_fake = D(fake_patchs)
            logits_real = D(hr_patchs)
            feature_fake = VGG((fake_patchs+1)/2.) # the pre-trained VGG uses the input range of [0, 1]
            feature_real = VGG((hr_patchs+1)/2.)
            d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real))
            d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake))
            d_loss = d_loss1 + d_loss2
            g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake))

            # mse_loss = tl.cost.mean_squared_error(fake_patchs, hr_patchs, is_mean=True)
            mse_f_lr_p = 0.0
            if config.DOWNSCALE_COMPARE: 
                fake_lr_patches = downscale_hr_patches(fake_hr_patchs)

                mse_f_lr_p = tl.cost.mean_squared_error(fake_lr_patches, lr_patchs, is_mean=True)

            mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True) + mse_f_lr_p

            vgg_loss = 2e-6 * tl.cost.mean_squared_error(feature_fake, feature_real, is_mean=True)
            g_loss = mse_loss + vgg_loss + g_gan_loss
        grad = tape.gradient(g_loss, G.trainable_weights)
        g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
        grad = tape.gradient(d_loss, D.trainable_weights)
        d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
        
        print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.3f}, vgg:{:.3f}, adv:{:.3f}) d_loss: {:.3f}".format(
            epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss, vgg_loss, g_gan_loss, d_loss))

    # update the learning rate
    if epoch != 0 and (epoch % decay_every == 0):
        new_lr_decay = lr_decay**(epoch // decay_every)
        lr_v.assign(lr_init * new_lr_decay)
        log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
        print(log)

    if (epoch != 0) and (epoch % 10 == 0):
        tl.vis.save_images(fake_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_g_{}.png'.format(epoch)))
        G.save_weights(os.path.join(checkpoint_dir, 'g-{epoch}.h5'.format(epoch=epoch)))
        D.save_weights(os.path.join(checkpoint_dir, 'd-{epoch}.h5'.format(epoch=epoch)))