In [None]:
import numpy as np
import tensorflow as tf

from utils.datahandler import DataHandler
import matplotlib.pyplot as plt

from models import  RRDN, Discriminator, VGG_Features

from train import PreTrainer
from train import GANTrainer


In [None]:
BASE_PATH = '/media/nvme0/home/aslan/workspace/SR/DIV2K/'

TRAIN_HR_PATH = BASE_PATH + 'DIV2K_train_HR/'
TRAIN_LR_PATH = BASE_PATH + 'DIV2K_train_LR_bicubic/X4/'
VALID_HR_PATH = BASE_PATH + 'DIV2K_valid_HR/'
VALID_LR_PATH = BASE_PATH + 'DIV2K_valid_LR_bicubic/X4/'

patch_size = 128

#block2_conv2, layer 5, mean 1.0
#block5_conv4, layer 20, mean 0.01

#feature_layers = [5, 20]
feature_layers = [5]

In [None]:
#sys.path.append('..')

generator = RRDN(arch_params={'C':4, 'D':2, 'G':64, 'G0':64, 'T':8, 'x':4}, patch_size = patch_size)

discriminator = Discriminator(patch_size = patch_size * 4)

vgg_perception = VGG_Features(patch_size = patch_size * 4, layers_to_extract = feature_layers)

print(generator.arch_name)

In [None]:
pre_trainer = PreTrainer(
    generator = generator,
    discriminator = discriminator,
    feature_extractor = vgg_perception,
    lr_train_dir = TRAIN_LR_PATH,
    hr_train_dir = TRAIN_HR_PATH,
    lr_valid_dir = VALID_LR_PATH,
    hr_valid_dir = VALID_HR_PATH,
    n_validation = 100
)

pre_trainer.train(2, 100, 2)

#pre_trainer.generator.model.save_weights('pre_generator.h5')
pre_trainer.save_best_weights('pre_generator_')

In [None]:
gan_trainer = GANTrainer(
    generator = generator,
    discriminator = discriminator,
    feature_extractor = vgg_perception,
    lr_train_dir = TRAIN_LR_PATH,
    hr_train_dir = TRAIN_HR_PATH,
    lr_valid_dir = VALID_LR_PATH,
    hr_valid_dir = VALID_HR_PATH,
    n_validation = 100,
    loss_weights = {'generator': 0.0, 'discriminator': 0.1, 'feature_extractor': 1.0}
)

gan_trainer.generator.model.load_weights('pre_generator_' + generator.arch_name + '.h5')

gan_trainer.train(160, 500, 2)

#gan_trainer.generator.model.save_weights('esrgan_generator.h5')
gan_trainer.save_best_weights('gan_generator_')

In [None]:
#gan_trainer.generator.model.save_weights('esrgan_generator.h5')
for item in vgg_perception.model.layers:
    print(item.name)