Adapted by Elena Gronskaya from ISR_Training_Tutorial.ipynb: https://github.com/idealo/image-super-resolution/tree/master/notebooks

The purpose of this notebook is to set up a model from the ISR repo and train it.

# Install ISR

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/image-super-resolution
!python setup.py install

# if using local repo of ISR
# else use !pip install ISR and see ISR_module_adjustments notebook for changes
# to run locally
#!pip install gast>=0.3.2
#!pip install ISR

In [None]:
!pip install 'h5py==2.10.0' --force-reinstall
!pip install -U PyYAML
!pip install imagecodecs

# Train

## Create the models
Import the models from the ISR package and create

- a RRDN super scaling network
- a discriminator network for GANs training
- a VGG19 feature extractor to train with a perceptual loss function

Carefully select
- 'x': this is the upscaling factor (2 by default)
- 'layers_to_extract': these are the layers from the VGG19 that will be used in the perceptual loss (leave the default if you're not familiar with it)
- 'lr_patch_size': this is the size of the patches that will be extracted from the LR images and fed to the ISR network during training time

Play around with the other architecture parameters

In [None]:
from ISR.models import RRDN
from ISR.models import Discriminator
from ISR.models import Cut_VGG19

In [None]:
lr_train_patch_size = 60
layers_to_extract = [5, 9]
scale = 3
hr_train_patch_size = lr_train_patch_size * scale

rrdn  = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size)
f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)

## Give the models to the Trainer
The Trainer object will combine the networks, manage your training data and keep you up-to-date with the training progress through Tensorboard and the command line.

Here we do not use  the pixel-wise MSE but only the perceptual loss by specifying the respective weights in `loss_weights`

In [None]:
from ISR.train import Trainer
loss_weights = {
  'generator': 1,
  'feature_extractor' : 0,
  'discriminator': 0
}
losses = {
  'generator': 'mae',
  'feature_extractor': 'mse',
  'discriminator': 'binary_crossentropy'
} 

log_dirs = {'logs': './logs', 'weights': './weights'}

learning_rate = {'initial_value': 0.0004, 'decay_factor': .5, 'decay_frequency': 30}

flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}

trainer = Trainer(
    generator=rrdn,
    discriminator=discr,
    feature_extractor=f_ext,
    lr_train_dir='PATH TO LR TRAINING DATA',
    hr_train_dir='PATH TO HR TRAINING DATA',
    lr_valid_dir='PATH TO LR VALIDATION DATA',
    hr_valid_dir='PATH TO HR VALIDATION DATA',
    loss_weights=loss_weights,
    learning_rate=learning_rate,
    flatness=flatness,
    dataname='div2k',
    log_dirs=log_dirs,
    weights_generator=None,
    weights_discriminator=None,
    n_validation = 60
    )

Choose epoch number, steps and batch size and start training

In [None]:
# cd to directory where you want the weights and TensorBoard logs to be saved
cd /content/drive/MyDrive

In [None]:
trainer.train(
    epochs=400,
    steps_per_epoch=40,
    batch_size=4,
    monitored_metrics={'val_generator_PSNR_Y':'max'})

## Train in a loop to test different hyperparameters


In [None]:
patch_sizes = [20, 40, 60, 80]
batch_sizes = [2, 4, 8]
learning_rates = [{'initial_value': 0.0004, 'decay_factor': .5, 'decay_frequency': 30},
                  {'initial_value': 0.00004, 'decay_factor': .5, 'decay_frequency': 60},
                  {'initial_value': 0.000004, 'decay_factor': .5, 'decay_frequency': 90}]

In [None]:
import time

for patch in patch_sizes:
  for batch in batch_sizes:
    for lr in learning_rates:
      try:
        lr_train_patch_size = patch
        layers_to_extract = [5, 9]
        scale = 3
        hr_train_patch_size = lr_train_patch_size * scale

        rrdn  = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size)
        f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
        discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)

        from ISR.train import Trainer
        loss_weights = {
          'generator': 1,
          'feature_extractor' : 0,
          'discriminator': 0
        }
        losses = {
          'generator': 'mae',
          'feature_extractor': 'mse',
          'discriminator': 'binary_crossentropy'
        } 

        log_dirs = {'logs': './logs', 'weights': './weights'}

        fallback_save_every_n_epochs=10

        learning_rate = lr

        flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}

        trainer = Trainer(
            generator=rrdn,
            discriminator=discr,
            feature_extractor=f_ext,
            lr_train_dir='PATH TO LR TRAINING DATA',
            hr_train_dir='PATH TO HR TRAINING DATA',
            lr_valid_dir='PATH TO LR VALIDATION DATA',
            hr_valid_dir='PATH TO HR VALIDATION DATA',
            loss_weights=loss_weights,
            learning_rate=learning_rate,
            flatness=flatness,
            dataname='div2k',
            log_dirs=log_dirs,
            weights_generator=None,
            weights_discriminator=None,
            n_validation = 40
            )
        
        trainer.train(
        epochs=200,
        steps_per_epoch=int(20*4/batch),
        batch_size=batch,
        monitored_metrics={'val_generator_PSNR_Y':'max'})

      except Exception as e:
        print('!!!!!!!!!!!!!!!TRAINING FAILED!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        print('batch_size: '+str(batch))
        print('patch_size: '+str(patch))
        print(lr)
        print(e)
        time.sleep(90)


### Train in loop to test different weight combinations (for RRDN)

In [None]:
loss_weights_list = [{'generator': 1,'feature_extractor' : 0,'discriminator': 0},
                     {'generator': 1,'feature_extractor' : 0.0001,'discriminator': 0.0001},
                     {'generator': 1,'feature_extractor' : 0.001,'discriminator': 0.001},
                     {'generator': 1,'feature_extractor' : 0.01,'discriminator': 0.01},
                     {'generator': 1,'feature_extractor' : 0.04,'discriminator': 0.02}]

In [None]:
import time
for loss in loss_weights_list:
  try:
    lr_train_patch_size = 40
    layers_to_extract = [5, 9]
    scale = 3
    hr_train_patch_size = lr_train_patch_size * scale

    rrdn  = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size)
    f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
    discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)

    from ISR.train import Trainer
    loss_weights = loss
    losses = {
      'generator': 'mae',
      'feature_extractor': 'mse',
      'discriminator': 'binary_crossentropy'
    } 

    log_dirs = {'logs': './logs', 'weights': './weights'}

    fallback_save_every_n_epochs=10

    learning_rate = {'initial_value': 0.0004, 'decay_factor': .5, 'decay_frequency': 30}

    flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}

    trainer = Trainer(
        generator=rrdn,
        discriminator=discr,
        feature_extractor=f_ext,
        lr_train_dir='PATH TO LR TRAINING DATA',
        hr_train_dir='PATH TO HR TRAINING DATA',
        lr_valid_dir='PATH TO LR VALIDATION DATA',
        hr_valid_dir='PATH TO HR VALIDATION DATA',
        loss_weights=loss_weights,
        learning_rate=learning_rate,
        flatness=flatness,
        dataname='l2_s8',
        log_dirs=log_dirs,
        weights_generator=None,
        weights_discriminator=None,
        n_validation = 40
        )

    trainer.train(
    epochs=400,
    steps_per_epoch=40,
    batch_size=4,
    monitored_metrics={'val_generator_PSNR_Y':'max'})

  except Exception as e:
    print('!!!!!!!!!!!!!!!TRAINING FAILED!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
    print(loss)
    print(e)
    time.sleep(90)