In [None]:
import tensorflow as tf
import time
from src.utils import check_GPU_is_available
from tensorflow import distribute
tf.random.set_seed(69)
check_GPU_is_available(is_required=True)
strategy = distribute.MirroredStrategy() # Allow multi-GPU training (if available)

In [None]:
import json
import os

## Read Training config

In [None]:
json_cfg_path = "./GPU_train_cfg.json"
train_cfg = json.load(open(json_cfg_path, 'r'))
print(json.dumps(train_cfg, indent=4))

## Prepare the Dataloaders

In [None]:
from src import create_data_loader

In [None]:
# params
dataset_tfrecords_path = train_cfg["dataset"]["TFR_TRAIN_PATH"]
lr_img_shape = train_cfg["dataset"]["LR_SHAPE"]
hr_img_shape = train_cfg["dataset"]["HR_SHAPE"]
batch_size = train_cfg["trainer"]["TRAIN_BATCH_SIZE"] * strategy.num_replicas_in_sync
# deploy
assert os.path.exists(dataset_tfrecords_path), print("Can't find dataset at the expected directory, please run 'download_dataset.py' first")
train_tfrecords = [os.path.join(dataset_tfrecords_path, x) for x in os.listdir(dataset_tfrecords_path) if "train" in x]
valid_tfrecords = [os.path.join(dataset_tfrecords_path, x) for x in os.listdir(dataset_tfrecords_path) if "validation" in x]
train_dataloader = create_data_loader(train_tfrecords, batch_size=batch_size, lr_img_shape=lr_img_shape, hr_img_shape=hr_img_shape, train_mode=True)
valid_dataloader = create_data_loader(valid_tfrecords, batch_size=batch_size, lr_img_shape=lr_img_shape, hr_img_shape=hr_img_shape, train_mode=False)
print("Train tf records: {} - Validation tf records: {}".format(len(train_tfrecords), len(valid_tfrecords)))

### If a pretrained model is available, load it

In [None]:
generator = None
pretrained_model_path = "./assets/pretrained/srgan_gen_div2k-bicubic_x4_1711220278.keras"

if os.path.exists(pretrained_model_path):
    from src.model.srgan import SRGAN
    from keras import config
    with strategy.scope():
        config.enable_unsafe_deserialization()
        scaling_factor = train_cfg["dataset"]["SCALING_FACTOR"]
        feature_maps = train_cfg["srgan"]["FEATURE_MAPS"]
        residual_blocks = train_cfg["srgan"]["RES_BLOCKS"]
        print("[INFO] loading the pretrained SRGAN generator...")
        generator = SRGAN.generator(
                scaling_factor=scaling_factor,
                feature_maps=feature_maps,
                residual_blocks=residual_blocks)
        generator.load_weights(pretrained_model_path)
        print("[INFO] pretrained SRGAN generator loaded!")
else:
    print("[INFO] pretrained SRGAN generator not found, will train from scratch!")

## Pretrain Generator

In [None]:
from tensorflow.keras.optimizers import Adam
from src.model.losses import Losses
from src.model.srgan import SRGAN

In [None]:
# params
scaling_factor = train_cfg["dataset"]["SCALING_FACTOR"]
feature_maps = train_cfg["srgan"]["FEATURE_MAPS"]
residual_blocks = train_cfg["srgan"]["RES_BLOCKS"]
pretrain_lr = train_cfg["trainer"]["PRETRAIN_LR"]
pretrain_epochs = train_cfg["trainer"]["PRETRAIN_EPOCHS"]
steps_per_epoch = train_cfg["trainer"]["STEPS_PER_EPOCH"]

if generator is None:
    with strategy.scope():
        losses = Losses(numReplicas=strategy.num_replicas_in_sync)
        generator = SRGAN.generator(
            scaling_factor=scaling_factor,
            feature_maps=feature_maps,
            residual_blocks=residual_blocks)
        generator.compile(
            optimizer=Adam(learning_rate=pretrain_lr),
            loss=losses.mse_loss)
        print("[INFO] pretraining SRGAN generator...")
        generator.fit(train_dataloader, 
                      epochs=pretrain_epochs,
                      steps_per_epoch=steps_per_epoch)
        
        pretrained_base_path = "./assets/pretrained"
        model_name = "srgan_gen_"+train_cfg["dataset"]["DATASET_NAME"]+"_"+str(int(time.time()))+".keras"
        pretrained_generator_out_path = os.path.join(pretrained_base_path, model_name)
        
        if not os.path.exists(pretrained_base_path):
            os.makedirs(pretrained_base_path)
        
        print(f"[INFO] saving the SRGAN pretrained generator to {pretrained_generator_out_path}...")
        generator.save(pretrained_generator_out_path)
        print("DONE!")

## Fine-tune the Generator

In [None]:
from src.model.vgg import VGG
from src.srgan_trainer import SRGANTraining

In [None]:
# params
leaky_alpha = train_cfg["srgan"]["LEAKY_ALPHA"]
disc_blocks = train_cfg["srgan"]["DISC_BLOCKS"]
finetune_lr = train_cfg["trainer"]["FINETUNE_LR"]
finetune_epochs = train_cfg["trainer"]["FINETUNE_EPOCHS"]

with strategy.scope():
    losses = Losses(numReplicas=strategy.num_replicas_in_sync)
    vgg = VGG.build()
    discriminator = SRGAN.discriminator(
        feature_maps=feature_maps, 
        leaky_alpha=leaky_alpha, 
        disc_blocks=disc_blocks)
    srgan = SRGANTraining(
        generator=generator,
        discriminator=discriminator,
        vgg=vgg,
        batch_size=batch_size)
    srgan.compile(
        d_optimizer=Adam(learning_rate=finetune_lr),
        g_optimizer=Adam(learning_rate=finetune_lr),
        bce_loss=losses.bce_loss,
        mse_loss=losses.mse_loss,
    )
    print("[INFO] fine-tuning SRGAN...")
    srgan.fit(train_dataloader, 
              epochs=finetune_epochs,
              steps_per_epoch=steps_per_epoch)

In [None]:
finetuned_base_path = "./assets/finetuned"
model_name = "srgan_gen_"+train_cfg["dataset"]["DATASET_NAME"]+"_"+str(int(time.time()))+"_finetuned.keras"
finetuned_generator_out_path = os.path.join(finetuned_base_path, model_name)

if not os.path.exists(finetuned_base_path):
    os.makedirs(finetuned_base_path)

print(f"[INFO] saving the SRGAN finetuned generator to {finetuned_generator_out_path}...")
generator.save(finetuned_generator_out_path)
print("DONE!")