In [None]:
import os
import time
import random
import numpy as np
import scipy, multiprocessing
import tensorflow as tf
import tensorlayer as tl
from model import get_G, get_D
from config import config


In [2]:
###====================== HYPER-PARAMETERS ===========================###

batch_size = config.TRAIN.batch_size  # use 8 if your GPU memory is small, and change [4, 4] in tl.vis.save_images to [2, 4]
lr_init = config.TRAIN.lr_init
beta1 = config.TRAIN.beta1
## initialize G
n_epoch_init = config.TRAIN.n_epoch_init

n_epoch = config.TRAIN.n_epoch
lr_decay = config.TRAIN.lr_decay
decay_every = config.TRAIN.decay_every
shuffle_buffer_size = 128


In [3]:
# create folders to save result images and trained models
save_dir = "samples"
tl.files.exists_or_mkdir(save_dir)
checkpoint_dir = "checkpoint"
tl.files.exists_or_mkdir(checkpoint_dir)

[TL] [!] samples exists ...
[TL] [!] checkpoint exists ...


True

In [None]:
def get_train_data():
    # load dataset
#     train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))[:3000]
    train_lr_img_list = sorted(tl.files.load_file_list(path='DIV2K/DIV2K_train_LR_wild/', regx='.*.png', printable=False))

    ## If your machine have enough memory, please pre-load the entire train set.
#     train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
    train_lr_imgs = tl.vis.read_images(train_lr_img_list, path='DIV2K/DIV2K_train_LR_wild/', n_threads=32)
        
    # dataset API and augmentation
    def generator_train():
        for img in train_lr_imgs:
            yield img
            
    def _map_fn_train(img):
#         hr_patch = tf.image.random_crop(img, [384, 384, 3])
#         hr_patch = hr_patch / (255. / 2.)
#         hr_patch = hr_patch - 1.
#         hr_patch = tf.image.random_flip_left_right(hr_patch)
        
        lr_patch = tf.image.random_crop(img, [96, 96, 3])
        lr_patch = lr_patch / (255. / 2.)
        lr_patch = lr_patch - 1.
        lr_patch = tf.image.random_flip_left_right(lr_patch)
    
        return lr_patch
    
    train_ds = tf.data.Dataset.from_generator(generator_train, output_types=(tf.float32))
    train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())
        # train_ds = train_ds.repeat(n_epoch_init + n_epoch)
    train_ds = train_ds.shuffle(shuffle_buffer_size)
    train_ds = train_ds.prefetch(buffer_size=2)
    train_ds = train_ds.batch(batch_size)
        # value = train_ds.make_one_shot_iterator().get_next()
    return train_ds

def downscale_hr_patches(hr_patch):
    return tf.image.resize(hr_patch, size=[96, 96])

In [5]:
G = get_G((batch_size, 96, 96, 3))
D = get_D((batch_size, 384, 384, 3))
VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')

lr_v = tf.Variable(lr_init)
g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1)
g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)

G.train()
D.train()
VGG.train()



[TL] Input  _inputlayer_1: (8, 96, 96, 3)
[TL] Conv2d conv2d_1: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: relu
[TL] Conv2d conv2d_2: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_1: decay: 0.900000 epsilon: 0.000010 act: relu is_train: False
[TL] Conv2d conv2d_3: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_2: decay: 0.900000 epsilon: 0.000010 act: No Activation is_train: False
[TL] Elementwise elementwise_1: fn: add act: No Activation
[TL] Conv2d conv2d_4: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_3: decay: 0.900000 epsilon: 0.000010 act: relu is_train: False
[TL] Conv2d conv2d_5: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_4: decay: 0.900000 epsilon: 0.000010 act: No Activation is_train: False
[TL] Elementwise elementwise_2

[TL] BatchNorm batchnorm2d_34: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: False
[TL] Conv2d conv2d_40: n_filter: 256 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_35: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: False
[TL] Conv2d conv2d_41: n_filter: 512 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_36: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: False
[TL] Conv2d conv2d_42: n_filter: 1024 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_37: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: False
[TL] Conv2d conv2d_43: n_filter: 2048 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_38: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: False
[TL] Conv2d conv2d_44: n_filter: 1024 filter_size: (1, 1) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNo

FileNotFoundError: [Errno 2] No such file or directory: 'DIV2/DIV2K_train_LR_wild/'

In [7]:
train_ds = get_train_data()

[TL] read 32 from DIV2K/DIV2K_train_LR_wild/
[TL] read 64 from DIV2K/DIV2K_train_LR_wild/
[TL] read 96 from DIV2K/DIV2K_train_LR_wild/
[TL] read 128 from DIV2K/DIV2K_train_LR_wild/
[TL] read 160 from DIV2K/DIV2K_train_LR_wild/
[TL] read 192 from DIV2K/DIV2K_train_LR_wild/
[TL] read 224 from DIV2K/DIV2K_train_LR_wild/
[TL] read 256 from DIV2K/DIV2K_train_LR_wild/
[TL] read 288 from DIV2K/DIV2K_train_LR_wild/
[TL] read 320 from DIV2K/DIV2K_train_LR_wild/
[TL] read 352 from DIV2K/DIV2K_train_LR_wild/
[TL] read 384 from DIV2K/DIV2K_train_LR_wild/
[TL] read 416 from DIV2K/DIV2K_train_LR_wild/
[TL] read 448 from DIV2K/DIV2K_train_LR_wild/
[TL] read 480 from DIV2K/DIV2K_train_LR_wild/
[TL] read 512 from DIV2K/DIV2K_train_LR_wild/
[TL] read 544 from DIV2K/DIV2K_train_LR_wild/
[TL] read 576 from DIV2K/DIV2K_train_LR_wild/
[TL] read 608 from DIV2K/DIV2K_train_LR_wild/
[TL] read 640 from DIV2K/DIV2K_train_LR_wild/
[TL] read 672 from DIV2K/DIV2K_train_LR_wild/
[TL] read 704 from DIV2K/DIV2K_train_

In [None]:
## initialize learning (G)
n_step_epoch = round(n_epoch_init // batch_size)
for epoch in range(n_epoch_init):
    for step, lr_patchs in enumerate(train_ds):
        if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
            break
        step_time = time.time()
        with tf.GradientTape() as tape:
            fake_hr_patchs = G(lr_patchs)
            fake_lr_patches = downscale_hr_patches(fake_hr_patchs)

            mse_loss = tl.cost.mean_squared_error(fake_lr_patches, lr_patchs, is_mean=True)

            grad = tape.gradient(mse_loss, G.trainable_weights)
            g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))

        with open("logs/init_generator.log", "a") as log_file:
            log_file.write("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format(
            epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss))
            log_file.close()
    
        print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format(
            epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss))
    if (epoch != 0) and (epoch % 10 == 0):
        tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_g_init_{}.png'.format(epoch)))
        

In [None]:
G.save_weights(os.path.join(checkpoint_dir, 'g-initial.h5'))

In [None]:
## adversarial learning (G, D)
n_step_epoch = round(n_epoch // batch_size)
for epoch in range(n_epoch):
    for step, lr_patchs in enumerate(train_ds):
        if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
            break
        step_time = time.time()
        with tf.GradientTape(persistent=True) as tape:
            fake_patchs = G(lr_patchs)
            fake_lr_patches = downscale_hr_patches(fake_patchs)
            
            logits_fake = D(fake_patchs)
            logits_real = D(fake_lr_patches)
            
            feature_fake = VGG((fake_patchs+1)/2.) # the pre-trained VGG uses the input range of [0, 1]
            feature_real = VGG((fake_lr_patches+1)/2.)
            
            d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real))
            d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake))
            
            d_loss = d_loss1 + d_loss2
            g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake))

            # mse_loss = tl.cost.mean_squared_error(fake_patchs, hr_patchs, is_mean=True)
            
            

            mse_loss = tl.cost.mean_squared_error(fake_lr_patches, lr_patchs, is_mean=True)

            vgg_loss = 2e-6 * tl.cost.mean_squared_error(feature_fake, feature_real, is_mean=True)
            
            g_loss = mse_loss + vgg_loss + g_gan_loss
            
            grad = tape.gradient(g_loss, G.trainable_weights)
            g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
            grad = tape.gradient(d_loss, D.trainable_weights)
            d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
        
        with open("logs/gan.log", "a") as log_file:
            log_file.write("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.3f}, vgg:{:.3f}, adv:{:.3f}) d_loss: {:.3f}".format(
            epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss, vgg_loss, g_gan_loss, d_loss))
            log_file.close()
            
        print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.3f}, vgg:{:.3f}, adv:{:.3f}) d_loss: {:.3f}".format(
            epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss, vgg_loss, g_gan_loss, d_loss))

    # update the learning rate
    if epoch != 0 and (epoch % decay_every == 0):
        new_lr_decay = lr_decay**(epoch // decay_every)
        lr_v.assign(lr_init * new_lr_decay)
        log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
        
        with open("logs/learning_rate.log", "a") as log_file:
            log_file.write(log)
            log_file.close()
            
        print(log)

    if (epoch != 0) and (epoch % 10 == 0):
        tl.vis.save_images(fake_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_g_{}.png'.format(epoch)))
        G.save_weights(os.path.join(checkpoint_dir, 'g-{epoch}.h5'.format(epoch=epoch)))
        D.save_weights(os.path.join(checkpoint_dir, 'd-{epoch}.h5'.format(epoch=epoch)))