<a href="https://colab.research.google.com/github/4tif4nwer/image-super-resolution/blob/main/SRGAN_Training_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**If training on colab, be sure to use a GPU (runtime > Change runtime type > GPU)**

In [1]:
# uncomment and run the lines below if running in google colab
# !pip install tensorflow==2.4.3
!git clone https://github.com/4tif4nwer/image-super-resolution.git
!mv image-super-resolution/* ./

Cloning into 'image-super-resolution'...
remote: Enumerating objects: 266, done.[K
remote: Counting objects: 100% (266/266), done.[K
remote: Compressing objects: 100% (194/194), done.[K
remote: Total 266 (delta 125), reused 184 (delta 58), pack-reused 0[K
Receiving objects: 100% (266/266), 149.33 MiB | 13.62 MiB/s, done.
Resolving deltas: 100% (125/125), done.


In [2]:
import tensorflow as tf
from tensorflow import keras
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
tf.test.is_built_with_cuda()
print(tf.version.VERSION)
import sys
sys.version

Num GPUs Available:  1
2.7.0


'3.7.12 (default, Sep 10 2021, 00:21:48) \n[GCC 7.5.0]'

# SRResNet and SRGAN Training for Image Super Resolution

An Implementation of SRGAN: https://arxiv.org/pdf/1609.04802.pdf

In [3]:
import os
import time
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay
from tensorflow.keras.losses import MeanSquaredError, BinaryCrossentropy, MeanAbsoluteError
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.metrics import Mean
from PIL import Image
from google.colab import files

from datasets.div2k.parameters import Div2kParameters 
from datasets.div2k.loader import create_training_and_validation_datasets
from utils.dataset_mappings import random_crop, random_flip, random_rotate, random_lr_jpeg_noise
from utils.metrics import psnr_metric
from utils.config import config
from utils.callbacks import SaveCustomCheckpoint
from models.srresnet import build_srresnet
from models.srgan import build_discriminator


## Prepare the dataset

In [4]:
dataset_key = "realistic_difficult_x4"

data_path = config.get("data_path", "") 

div2k_folder = os.path.abspath(os.path.join(data_path, "div2k"))

dataset_parameters = Div2kParameters(dataset_key, save_data_directory=div2k_folder)

In [5]:
hr_crop_size = 96

In [6]:
train_mappings = [
    lambda lr, hr: random_crop(lr, hr, hr_crop_size=hr_crop_size, scale=dataset_parameters.scale), 
    random_flip, 
    random_rotate, 
    random_lr_jpeg_noise]

In [7]:
train_dataset, valid_dataset = create_training_and_validation_datasets(dataset_parameters, train_mappings)

valid_dataset_subset = valid_dataset.take(10) # only taking 10 examples here to speed up evaluations during training

Couldn't find directory:  /content/div2k/DIV2K_train_LR_difficult
/content/div2k
Downloading data from http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_difficult.zip
Begin caching in /content/div2k/cache/DIV2K_train_LR_difficult/cache.
Completed caching in /content/div2k/cache/DIV2K_train_LR_difficult/cache.
Couldn't find directory:  /content/div2k/DIV2K_train_HR
/content/div2k
Downloading data from http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
Begin caching in /content/div2k/cache/DIV2K_train_HR/cache.
Completed caching in /content/div2k/cache/DIV2K_train_HR/cache.
Couldn't find directory:  /content/div2k/DIV2K_valid_LR_difficult
/content/div2k
Downloading data from http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_difficult.zip
Begin caching in /content/div2k/cache/DIV2K_valid_LR_difficult/cache.
Completed caching in /content/div2k/cache/DIV2K_valid_LR_difficult/cache.
Couldn't find directory:  /content/div2k/DIV2K_valid_HR
/content/div2k
Downloading data from

## Train the SRResNet generator model

In [8]:
generator = build_srresnet(scale=dataset_parameters.scale)

In [9]:
generator.load_weights("weights/srresnet_realistic_difficult_x4/generator.h5")

In [10]:
checkpoint_dir=f'./ckpt/sr_resnet_{dataset_key}'

learning_rate=1e-4

checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                 epoch=tf.Variable(0),
                                 psnr=tf.Variable(0.0),
                                 optimizer=Adam(learning_rate),
                                 model=generator)

checkpoint_manager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                                directory=checkpoint_dir,
                                                max_to_keep=3)

if checkpoint_manager.latest_checkpoint:
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    print(f'Model restored from checkpoint at step {checkpoint.step.numpy()} with validation PSNR {checkpoint.psnr.numpy()}.')

In [11]:
weights_directory = f"weights/srresnet_{dataset_key}"
os.makedirs(weights_directory, exist_ok=True)
weights_file = f'{weights_directory}/generator.tf'
checkpoint.model.save_weights(weights_file)

## Train SRGAN using SRResNet as the generator

In [13]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [14]:
generator = build_srresnet(scale=dataset_parameters.scale)
generator.load_weights(weights_file)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f63f04b2850>

In [15]:
discriminator = build_discriminator(hr_crop_size=hr_crop_size)

In [16]:
layer_5_4 = 20
vgg = VGG19(input_shape=(None, None, 3), include_top=False)
perceptual_model = Model(vgg.input, vgg.layers[layer_5_4].output)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5


In [17]:
binary_cross_entropy = BinaryCrossentropy()
mean_squared_error = MeanSquaredError()

In [18]:
learning_rate=PiecewiseConstantDecay(boundaries=[100000], values=[1e-4, 1e-5])

In [19]:
generator_optimizer = Adam(learning_rate=learning_rate)
discriminator_optimizer = Adam(learning_rate=learning_rate)

In [20]:
srgan_checkpoint_dir=f'./ckpt/srgan_{dataset_key}'

srgan_checkpoint = tf.train.Checkpoint(step=tf.Variable(32000+45000+45000),
                                       psnr=tf.Variable(16.581737518310547),
                                       generator_optimizer=Adam(learning_rate),
                                       discriminator_optimizer=Adam(learning_rate),
                                       generator=generator,
                                       discriminator=discriminator)

srgan_checkpoint_manager = tf.train.CheckpointManager(checkpoint=srgan_checkpoint,
                                                directory=srgan_checkpoint_dir,
                                                max_to_keep=3)

In [21]:
if srgan_checkpoint_manager.latest_checkpoint:
    srgan_checkpoint.restore(srgan_checkpoint_manager.latest_checkpoint)
    print(f'Model restored from checkpoint at step {srgan_checkpoint.step.numpy()} with validation PSNR {srgan_checkpoint.psnr.numpy()}.')

In [22]:
@tf.function
def train_step(lr, hr):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        lr = tf.cast(lr, tf.float32)
        hr = tf.cast(hr, tf.float32)

        sr = srgan_checkpoint.generator(lr, training=True)

        hr_output = srgan_checkpoint.discriminator(hr, training=True)
        sr_output = srgan_checkpoint.discriminator(sr, training=True)

        con_loss = calculate_content_loss(hr, sr)
        gen_loss = calculate_generator_loss(sr_output)
        perc_loss = con_loss + 0.001 * gen_loss
        disc_loss = calculate_discriminator_loss(hr_output, sr_output)

    gradients_of_generator = gen_tape.gradient(perc_loss, srgan_checkpoint.generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, srgan_checkpoint.discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, srgan_checkpoint.generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, srgan_checkpoint.discriminator.trainable_variables))

    return perc_loss, disc_loss

@tf.function
def calculate_content_loss(hr, sr):
    sr = preprocess_input(sr)
    hr = preprocess_input(hr)
    sr_features = perceptual_model(sr) / 12.75
    hr_features = perceptual_model(hr) / 12.75
    return mean_squared_error(hr_features, sr_features)

def calculate_generator_loss(sr_out):
    return binary_cross_entropy(tf.ones_like(sr_out), sr_out)

def calculate_discriminator_loss(hr_out, sr_out):
    hr_loss = binary_cross_entropy(tf.ones_like(hr_out), hr_out)
    sr_loss = binary_cross_entropy(tf.zeros_like(sr_out), sr_out)
    return hr_loss + sr_loss


In [23]:
def saveitpls():
  weights_directory = f"weights/srgan_{dataset_key}"
  os.makedirs(weights_directory, exist_ok=True)
  weights_file = f'{weights_directory}/generator.tf'
  srgan_checkpoint.generator.save_weights(weights_file)
  srgan_checkpoint.generator.save_weights(f'{weights_directory}/generator.h5')
  !zip -r srgan_weights.zip weights/srgan_realistic_difficult_x4/
  
  !cp -r "srgan_weights.zip" /content/drive/MyDrive/
  
  # files.download("srgan_weights.zip")


In [24]:
srgan_checkpoint.generator.load_weights("weights/srgan_realistic_difficult_x4/generator.tf")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f63806153d0>

In [None]:
perceptual_loss_metric = Mean()
discriminator_loss_metric = Mean()

step = srgan_checkpoint.step.numpy()
steps = 200000

monitor_folder = f"monitor_training/srgan_{dataset_key}"
os.makedirs(monitor_folder, exist_ok=True)

now = time.perf_counter()

for lr, hr in train_dataset.take(steps - step):
    srgan_checkpoint.step.assign_add(1)
    step = srgan_checkpoint.step.numpy()
    print(step)

    perceptual_loss, discriminator_loss = train_step(lr, hr)
    perceptual_loss_metric(perceptual_loss)
    discriminator_loss_metric(discriminator_loss)

    if step % 1000 == 0:
        psnr_values = []
        
        for lr, hr in valid_dataset_subset:
            sr = srgan_checkpoint.generator.predict(lr)[0]
            sr = tf.clip_by_value(sr, 0, 255)
            sr = tf.round(sr)
            sr = tf.cast(sr, tf.uint8)
            
            psnr_value = psnr_metric(hr, sr)[0]
            psnr_values.append(psnr_value)
            psnr = tf.reduce_mean(psnr_values)
        saveitpls()  
        image = Image.fromarray(sr.numpy())
        image.save(f"{monitor_folder}/{step}.png" )
        
        duration = time.perf_counter() - now
        
        now = time.perf_counter()
        
        print(f'{step}/{steps}, psnr = {psnr}, perceptual loss = {perceptual_loss_metric.result():.4f}, discriminator loss = {discriminator_loss_metric.result():.4f} ({duration:.2f}s)')
        
        perceptual_loss_metric.reset_states()
        discriminator_loss_metric.reset_states()
        
        srgan_checkpoint.psnr.assign(psnr)
        srgan_checkpoint_manager.save()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
159216
159217
159218
159219
159220
159221
159222
159223
159224
159225
159226
159227
159228
159229
159230
159231
159232
159233
159234
159235
159236
159237
159238
159239
159240
159241
159242
159243
159244
159245
159246
159247
159248
159249
159250
159251
159252
159253
159254
159255
159256
159257
159258
159259
159260
159261
159262
159263
159264
159265
159266
159267
159268
159269
159270
159271
159272
159273
159274
159275
159276
159277
159278
159279
159280
159281
159282
159283
159284
159285
159286
159287
159288
159289
159290
159291
159292
159293
159294
159295
159296
159297
159298
159299
159300
159301
159302
159303
159304
159305
159306
159307
159308
159309
159310
159311
159312
159313
159314
159315
159316
159317
159318
159319
159320
159321
159322
159323
159324
159325
159326
159327
159328
159329
159330
159331
159332
159333
159334
159335
159336
159337
159338
159339
159340
159341
159342
159343
159344
159345
159346
159347
159348
1593