**If training on colab, be sure to use a GPU (runtime > Change runtime type > GPU)**

In [1]:
# uncomment and run the lines below if running in google colab
# !pip install tensorflow==2.4.3
# !git clone https://github.com/jlaihong/image-super-resolution.git
# !mv image-super-resolution/* ./

# SRResNet and SRGAN Training for Image Super Resolution

- using AI to improve the quality of images

## Background and Intuition
- SRResNet and SRGAN
  - fall under an area of research called image super-resolution
  - both introduced in a paper showing and explaining some implementation: 
    - [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802)


- here is the code used to train both of these models

In [6]:
import os
import time
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay
from tensorflow.keras.losses import MeanSquaredError, BinaryCrossentropy, MeanAbsoluteError
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.metrics import Mean
from PIL import Image

from datasets.div2k.parameters import Div2kParameters 
from datasets.div2k.loader import create_training_and_validation_datasets
from utils.dataset_mappings import random_crop, random_flip, random_rotate, random_lr_jpeg_noise
from utils.metrics import psnr_metric
from utils.config import config
from utils.callbacks import SaveCustomCheckpoint
from models.srresnet import build_srresnet
from models.srgan import build_discriminator


## Prepare the dataset
- we have to show the models some example of low res and corresponding high res images.
- in this case its easy, we can use some high res images and we can resisze them to be low res images
- the paper uses a random sampel of 350k samples from the image net database
  - 150 GB of images and not all of the images are high quality 
- were using the DIV2k dataset since its high quality and only 4 GB
  - 2k resultion
  - specifically designed for image super resolution
  - has specific sets of low res images that corresponds to high resolution data
  - can read descriptions here: 
    - https://data.vision.ee.ethz.ch/cvl/DIV2K/
---


- According to the author, it should be fairly easy to train with other datasets
- we just need to change the dataset key
- DivkParameters handles downloading and parsing the div2k dataset
- depending on the dataset you choose it will determine the scaling factor
  - any of the 2k datasets uses a scaling factor of x2, and of the x4 datasets uses a scaling factor of 4

In [7]:
dataset_key = "bicubic_x4"

data_path = config.get("data_path", "") 

div2k_folder = os.path.abspath(os.path.join(data_path, "div2k"))

dataset_parameters = Div2kParameters(dataset_key, save_data_directory=div2k_folder)

- we can't just send whole images to the network
- the images are really big, the GPU/CPU could run our of memory
- we crop patches out of the images and send those to the model instead
- the paper uses high res crops of 96x96 pixels
- se we do the same

In [8]:
hr_crop_size = 96

### defining a list of mappings to use during training
- this first mapping takes crops of 96 by 96 from the high res images and the corresponding scale of the appropriate low res image from the low res image
- eg x4 means the images are downscaled by 4 times
  - the low res patches will be 24 by 24
- data augmentation
- horizontal filpping and rotation
- not mentioned in the paper
- its included since our dataset is a lot smaller than the papers
- this model is also trained with jpeg noise so the model learns to remove that as well
- you can read through the code in the div2k `loader.py` file

In [9]:
train_mappings = [
    lambda lr, hr: random_crop(lr, hr, hr_crop_size=hr_crop_size, scale=dataset_parameters.scale), 
    random_flip, 
    random_rotate, 
    random_lr_jpeg_noise]

-create interflow data set objects
- dataset objects are optimized to cache the images in the dataset to perform the calculations faster

In [10]:
train_dataset, valid_dataset = create_training_and_validation_datasets(dataset_parameters, train_mappings)

valid_dataset_subset = valid_dataset.take(10) # only taking 10 examples here to speed up evaluations during training

2023-07-12 14:38:25.660505: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


## Train the SRResNet generator model

- defining the srresnet method
- following the architecture given in the paper
- if you've worked with CNNs before, you'll notice the width and height dimensions decrease as we get deeper into the network
- this is because we use pooling layers or apply filters with strides greater than 1
- but for images super resolution we don't want the width and heigth resolution to decrease
  - we want the **opposite**
- so the srresnet doesnt use pooling layers, and only uses strides of 1 and always uses 1
- this model is restricted to upsampling of 2, 4, or 8.

In [11]:
generator = build_srresnet(scale=dataset_parameters.scale)

## SRResNet architecture
- built using structure from (paper)[https://arxiv.org/abs/1609.04802]

```python
def build_srresnet(scale=4, num_filters=64, num_res_blocks=16):
    if scale not in upsamples_per_scale:
        raise ValueError(f"available scales are: {upsamples_per_scale.keys()}")

    num_upsamples = upsamples_per_scale[scale]
```
### Start by defining an input layer

- `(None, None 3)` to take color images of any height and width
  - the network is fully convolution's network, so we're able to do that (not worry about size of the images)
- input images contain pixel values between 0 and 255
  - in the paper the scale is between 0 and 1
- so we do the same, 
  - data normalization is usually done outside the network
  - this is done inside the model for convenience, 
  - now anyone using the model doesn't have to worry about normalizing it before sending it into the model and if they forget to normalizing you'd get strange results
```python

    lr = Input(shape=(None, None, 3))
    x = Lambda(normalize_01)(lr)
```
    # uses 64 filters of size 9x9 since we dont want the hight and width to decrease we use same padding
```python
    x = Conv2D(num_filters, kernel_size=9, padding='same')(x)
```
### Parametric ReLu 
- Neural networks use no linear activation functions to do complex mappings
  - in our case we are trying to learn the mapping from low res images to high res images
  - in practice the normal ReLu function still works well
  -  the problem is it causes a lot of dead neurons inside your network because every number less than 0 will simply be mapped to 0
  - the parameters attached to those neurons are not included!
- Leaky Relu
  - instead of setting the values to 0 we set them ot some predefined constant fraction of themselves
- Parameteric Relu 
  - we allow the network to edit that fraction instead, 
  - usually this would create a separate parameter value /alpha for every value that was passed into the layer
- here we're saying the shared axes is equal to `[1, 2]` 
  - meaning that we're sharing the parameters across the width and hight parameters so we're only creating one parameter per channel. 
  - since there are 64 pulses there are only 64 channels so there are only 64 parameters
  - we're sharing the parameters across the width and height dimensions
  - the ouput of this PReLu Layer is also sent to multiple places so we've defined another variable `x_1` to make use of them later
```python
    x = x_1 = PReLU(shared_axes=[1, 2])(x)
```
### Residual block
- we're just grouping common operations and calling it a residual block
- all of the operations in each of the blocks are the same, 
- each block uses a convolutional layer 64 pulses of size 3x3 
- followed by a batch normalization layer followed by a parametric relu followed by another convolutional layer with the same settings, 
- then another batch norm.
- we take this batch norm and preform an element-wise sum with the input of the resudual block
  - this is the the second time that the output of this relu is used
- we continue staking residual blocks on top of the other, each time sending the output from the previous block as the input to the next block
  - performing an element wise sum between the block input and its final batch nomr layer

- this functino is responsitble for all of them! 
- see in `srresnet.py`
```python
def residual_block(block_input, num_filters, momentum=0.8):
    x = Conv2D(num_filters, kernel_size=3, padding='same')(block_input)
    x = BatchNormalization(momentum=momentum)(x)
    x = PReLU(shared_axes=[1, 2])(x)
    x = Conv2D(num_filters, kernel_size=3, padding='same')(x)
    x = BatchNormalization(momentum=momentum)(x)
    x = Add()([block_input, x])
    return x
for _ in range(num_res_blocks):
    x = residual_block(x, num_filters)

```
- the purpose of all of these resudual block is to extract features from the input image
  - people have figured out that earlier layers in a CNN learns to detect low level features (like lines in a certain direction) and later layers detect shapes
    - as you go deeper the layers learn more object features
    - this really isn't a 1 to 1 comparison for super resolution but
    - adding more residual layers produces better images
    - comes with the trade off that your network will take longer to process images

- after all of the residual blocks we have another convolutional layer and a batch norm layer
- and we see the 3rd time that the output of the parametric relu is used
- we preform an elementwise sum on that with the output layer
```python
    x = Conv2D(num_filters, kernel_size=3, padding='same')(x)
    x = BatchNormalization()(x)
    x = Add()([x_1, x])
```
- up until this point the height and width have stayed consistent through the network
- now wer reach a new type of block
- we call it the upsample block
- uses an convolutional layer with 256 filter of size 3x3 so thats 4 times the previous layer
- also a pixel shuffle layer
  - this takes values from the channel dimension and sticks them into the height and width dimension
  - so it doubles the height and with and divides the channel dimension by 4
```python
    for _ in range(num_upsamples):
        x = upsample(x, num_filters * 4)

    x = Conv2D(3, kernel_size=9, padding='same', activation='tanh')(x)
    sr = Lambda(denormalize_m11)(x)

    return Model(lr, sr)
```

In [None]:
checkpoint_dir=f'./ckpt/sr_resnet_{dataset_key}'

learning_rate=1e-4

checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                 epoch=tf.Variable(0),
                                 psnr=tf.Variable(0.0),
                                 optimizer=Adam(learning_rate),
                                 model=generator)

checkpoint_manager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                                directory=checkpoint_dir,
                                                max_to_keep=3)

if checkpoint_manager.latest_checkpoint:
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    print(f'Model restored from checkpoint at step {checkpoint.step.numpy()} with validation PSNR {checkpoint.psnr.numpy()}.')

In [None]:
training_steps = 1_000_000

steps_per_epoch = 1000

training_epochs = training_steps / steps_per_epoch

if checkpoint.epoch.numpy() < training_epochs:
    remaining_epochs = int(training_epochs - checkpoint.epoch.numpy())
    print(f"Continuing Training from epoch {checkpoint.epoch.numpy()}. Remaining epochs: {remaining_epochs}.")
    save_checkpoint_callback = SaveCustomCheckpoint(checkpoint_manager, steps_per_epoch)
    checkpoint.model.compile(optimizer=checkpoint.optimizer, loss=MeanSquaredError(), metrics=[psnr_metric])
    checkpoint.model.fit(train_dataset,validation_data=valid_dataset_subset, steps_per_epoch=steps_per_epoch, epochs=remaining_epochs, callbacks=[save_checkpoint_callback])
else:
    print("Training already completed. To continue training, increase the number of training steps")

In [None]:
weights_directory = f"weights/srresnet_{dataset_key}"
os.makedirs(weights_directory, exist_ok=True)
weights_file = f'{weights_directory}/generator.h5'
checkpoint.model.save_weights(weights_file)

## Train SRGAN using SRResNet as the generator

In [None]:
generator = build_srresnet(scale=dataset_parameters.scale)
generator.load_weights(weights_file)

In [None]:
discriminator = build_discriminator(hr_crop_size=hr_crop_size)

In [None]:
layer_5_4 = 20
vgg = VGG19(input_shape=(None, None, 3), include_top=False)
perceptual_model = Model(vgg.input, vgg.layers[layer_5_4].output)

In [None]:
binary_cross_entropy = BinaryCrossentropy()
mean_squared_error = MeanSquaredError()

In [None]:
learning_rate=PiecewiseConstantDecay(boundaries=[100000], values=[1e-4, 1e-5])

In [None]:
generator_optimizer = Adam(learning_rate=learning_rate)
discriminator_optimizer = Adam(learning_rate=learning_rate)

In [None]:
srgan_checkpoint_dir=f'./ckpt/srgan_{dataset_key}'

srgan_checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                       psnr=tf.Variable(0.0),
                                       generator_optimizer=Adam(learning_rate),
                                       discriminator_optimizer=Adam(learning_rate),
                                       generator=generator,
                                       discriminator=discriminator)

srgan_checkpoint_manager = tf.train.CheckpointManager(checkpoint=srgan_checkpoint,
                                                directory=srgan_checkpoint_dir,
                                                max_to_keep=3)

In [None]:
if srgan_checkpoint_manager.latest_checkpoint:
    srgan_checkpoint.restore(srgan_checkpoint_manager.latest_checkpoint)
    print(f'Model restored from checkpoint at step {srgan_checkpoint.step.numpy()} with validation PSNR {srgan_checkpoint.psnr.numpy()}.')

In [None]:
@tf.function
def train_step(lr, hr):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        lr = tf.cast(lr, tf.float32)
        hr = tf.cast(hr, tf.float32)

        sr = srgan_checkpoint.generator(lr, training=True)

        hr_output = srgan_checkpoint.discriminator(hr, training=True)
        sr_output = srgan_checkpoint.discriminator(sr, training=True)

        con_loss = calculate_content_loss(hr, sr)
        gen_loss = calculate_generator_loss(sr_output)
        perc_loss = con_loss + 0.001 * gen_loss
        disc_loss = calculate_discriminator_loss(hr_output, sr_output)

    gradients_of_generator = gen_tape.gradient(perc_loss, srgan_checkpoint.generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, srgan_checkpoint.discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, srgan_checkpoint.generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, srgan_checkpoint.discriminator.trainable_variables))

    return perc_loss, disc_loss

@tf.function
def calculate_content_loss(hr, sr):
    sr = preprocess_input(sr)
    hr = preprocess_input(hr)
    sr_features = perceptual_model(sr) / 12.75
    hr_features = perceptual_model(hr) / 12.75
    return mean_squared_error(hr_features, sr_features)

def calculate_generator_loss(sr_out):
    return binary_cross_entropy(tf.ones_like(sr_out), sr_out)

def calculate_discriminator_loss(hr_out, sr_out):
    hr_loss = binary_cross_entropy(tf.ones_like(hr_out), hr_out)
    sr_loss = binary_cross_entropy(tf.zeros_like(sr_out), sr_out)
    return hr_loss + sr_loss


In [None]:
perceptual_loss_metric = Mean()
discriminator_loss_metric = Mean()

step = srgan_checkpoint.step.numpy()
steps = 200000

monitor_folder = f"monitor_training/srgan_{dataset_key}"
os.makedirs(monitor_folder, exist_ok=True)

now = time.perf_counter()

for lr, hr in train_dataset.take(steps - step):
    srgan_checkpoint.step.assign_add(1)
    step = srgan_checkpoint.step.numpy()

    perceptual_loss, discriminator_loss = train_step(lr, hr)
    perceptual_loss_metric(perceptual_loss)
    discriminator_loss_metric(discriminator_loss)

    if step % 1000 == 0:
        psnr_values = []
        
        for lr, hr in valid_dataset_subset:
            sr = srgan_checkpoint.generator.predict(lr)[0]
            sr = tf.clip_by_value(sr, 0, 255)
            sr = tf.round(sr)
            sr = tf.cast(sr, tf.uint8)
            
            psnr_value = psnr_metric(hr, sr)[0]
            psnr_values.append(psnr_value)
            psnr = tf.reduce_mean(psnr_values)
            
        image = Image.fromarray(sr.numpy())
        image.save(f"{monitor_folder}/{step}.png" )
        
        duration = time.perf_counter() - now
        
        now = time.perf_counter()
        
        print(f'{step}/{steps}, psnr = {psnr}, perceptual loss = {perceptual_loss_metric.result():.4f}, discriminator loss = {discriminator_loss_metric.result():.4f} ({duration:.2f}s)')
        
        perceptual_loss_metric.reset_states()
        discriminator_loss_metric.reset_states()
        
        srgan_checkpoint.psnr.assign(psnr)
        srgan_checkpoint_manager.save()

In [None]:
weights_directory = f"weights/srgan_{dataset_key}"
os.makedirs(weights_directory, exist_ok=True)
weights_file = f'{weights_directory}/generator.h5'
srgan_checkpoint.generator.save_weights(weights_file)