# Enhanced Deep Residual Networks (EDSR) & 
# Super Resolution Generative Adversarial Networks (SRGAN)



Notebook author: Dipanjan (DJ) Sarkar  & Ozgun Haznedar 

## In this notebook, EDSR and SRGAN EDSR models are built and trained.

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import tensorflow as tf
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from glob import glob

In [None]:
tf.__version__

In [None]:
# Uncomment to download the DIV2K dataset
"""
lr_dataset = f'DIV2K_train_LR_bicubic_X4.zip'
lr_datset_url = f'http://data.vision.ee.ethz.ch/cvl/DIV2K/{lr_dataset}'

hr_dataset = f'DIV2K_train_HR.zip'
hr_datset_url = f'http://data.vision.ee.ethz.ch/cvl/DIV2K/{hr_dataset}'

download_dir = './div2k_data/images'
download_dir = os.path.abspath(download_dir)

tf.keras.utils.get_file(lr_dataset, lr_datset_url, 
                        cache_subdir=download_dir, 
                        extract=True)
tf.keras.utils.get_file(hr_dataset, hr_datset_url, 
                        cache_subdir=download_dir, 
                        extract=True)
"""

In [None]:
"""
download_dir = './div2k_data/images'
download_dir = os.path.abspath(download_dir)
"""

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

lr_dir = f"LR TRAIN DIRECTORY"
hr_dir = f"HR TRAIN DIRECTORY"

lr_images = sorted(glob(lr_dir+'/*.*'))
hr_images = sorted(glob(hr_dir+'/*.*'))

lr_ds = tf.data.Dataset.from_tensor_slices(lr_images)
lr_ds = lr_ds.map(tf.io.read_file)
lr_ds = (lr_ds.map(lambda x: tf.image.decode_png(x, channels=3),
                  num_parallel_calls=AUTOTUNE)
              .cache())

hr_ds = tf.data.Dataset.from_tensor_slices(hr_images)
hr_ds = hr_ds.map(tf.io.read_file)
hr_ds = (hr_ds.map(lambda x: tf.image.decode_png(x, channels=3),
                  num_parallel_calls=AUTOTUNE)
              .cache())

In [None]:
def random_crop(lr_img, hr_img):
  hr_crop_size=96; scale=3
  lr_crop_size = 32
  lr_img_shape = tf.shape(lr_img)[:2]

  lr_w = tf.random.uniform(shape=(), 
                           maxval=lr_img_shape[1] - lr_crop_size + 1, 
                           dtype=tf.int32)
  lr_h = tf.random.uniform(shape=(), 
                           maxval=lr_img_shape[0] - lr_crop_size + 1, 
                           dtype=tf.int32)

  hr_w = lr_w * scale
  hr_h = lr_h * scale
  lr_img_crop = lr_img[lr_h:lr_h + lr_crop_size, 
                          lr_w:lr_w + lr_crop_size]
  hr_img_crop = hr_img[hr_h:hr_h + hr_crop_size, 
                          hr_w:hr_w + hr_crop_size]

  return lr_img_crop, hr_img_crop


def random_flip(lr_img, hr_img):
  rn = tf.random.uniform(shape=(), maxval=1)
  return tf.cond(rn < 0.5,
                  lambda: (lr_img, hr_img),
                  lambda: (tf.image.flip_left_right(lr_img),
                          tf.image.flip_left_right(hr_img)))


def random_rotate(lr_img, hr_img):
  rn = tf.random.uniform(shape=(), 
                         maxval=4, dtype=tf.int32)
  return (tf.image.rot90(lr_img, rn), 
          tf.image.rot90(hr_img, rn))

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 16

train_ds = tf.data.Dataset.zip((lr_ds, hr_ds))
train_ds = train_ds.map(lambda lr, hr: random_crop(lr, hr), 
                        num_parallel_calls=AUTOTUNE)
train_ds = train_ds.map(random_rotate, 
                        num_parallel_calls=AUTOTUNE)
train_ds = train_ds.map(random_flip, 
                        num_parallel_calls=AUTOTUNE)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = train_ds.repeat()
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)

In [None]:
from tensorflow.keras.layers import Add, Conv2D, Input, Lambda
from tensorflow.keras.models import Model
import numpy as np
import tensorflow as tf

DIV2K_RGB_MEAN = np.array([0.4488, 0.4371, 0.4040]) * 255

def normalize(x):
  return (x - DIV2K_RGB_MEAN) / 127.5


def denormalize(x):
  return x * 127.5 + DIV2K_RGB_MEAN


def residual_block(inp):
  """Creates an EDSR residual block."""
  x = Conv2D(64, 3, padding='same', activation='relu')(inp)
  x = Conv2D(64, 3, padding='same')(x)
  x = Add()([inp, x])
  return x


def edsr_model_arch(num_residual_blocks):
  """Creates an EDSR model."""
  inp = Input(shape=(None, None, 3))
  x = Lambda(normalize)(inp)

  x = rb = Conv2D(64, 3, padding='same')(x)
  for i in range(num_residual_blocks):
      rb = residual_block(rb)
  rb = Conv2D(64, 3, padding='same')(rb)
  x = Add()([x, rb])

  x = Conv2D(64 * (3 ** 2), 3, padding='same')(x)
  x = Lambda(lambda x: tf.nn.depth_to_space(x, 3))(x)
  #x = Conv2D(64 * (3 ** 2), 3, padding='same')(x)
  #x = Lambda(lambda x: tf.nn.depth_to_space(x, 3))(x)
  x = Conv2D(3, 3, padding='same')(x)

  out = Lambda(denormalize)(x)

  return Model(inp, out, name="edsr_model")

In [None]:
edsr_model = edsr_model_arch(num_residual_blocks=16)

In [None]:
edsr_model.summary()

In [None]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay

optim = Adam(learning_rate=PiecewiseConstantDecay(boundaries=[200000, 400000], 
                                                       values=[1e-4, 5e-5, 2.5e-5]))
edsr_model.compile(optimizer=optim, loss='mean_absolute_error')

In [None]:
history = edsr_model.fit(train_ds, epochs=500, steps_per_epoch=1000)

In [None]:
f, ax = plt.subplots(1, 2, figsize=(10, 4))

epochs = history.epoch
learning_rates = [optim.lr(e*1000).numpy() for e in epochs]
losses = history.history['loss']

plt.subplot(121)
plt.plot(epochs, learning_rates, 'k--')
plt.title('Learning Rate History')
plt.subplot(122)
plt.plot(epochs, losses, 'k')
plt.title('Loss History');

In [None]:
from PIL import Image

def run_sr_inference(img_path, model):
  lr_file = tf.io.read_file(img_path)
  lr_img = tf.image.decode_png(lr_file, channels=3)
  lr_img_npy = lr_img.numpy()

  upsamp_img = (np.asarray(
                Image.fromarray(lr_img_npy)
                     .resize(size=(lr_img_npy.shape[1]*3, 
                                   lr_img_npy.shape[0]*3), 
                             resample=Image.BICUBIC)))
  
  lr_img = tf.expand_dims(lr_img, axis=0)
  lr_img = tf.cast(lr_img, tf.float32)
  sr_img = model(lr_img)
  sr_img = tf.clip_by_value(sr_img, 0, 255)
  sr_img = tf.cast(sr_img, tf.uint8)

  trg_img_path = img_path.replace("val_LR", "val_HR")
  trg_file = tf.io.read_file(trg_img_path)
  trg_img = tf.image.decode_png(trg_file, channels=3)
  trg_img_npy = trg_img.numpy()
  trg_img = tf.expand_dims(trg_img, axis=0)
  trg_img = tf.cast(trg_img, tf.float32)



  return lr_img[0], upsamp_img, sr_img[0] , trg_img[0]


def plot_edsr_results(orig, bicubic, super_res):
  fig, axes = plt.subplots(2, 2, figsize=(12, 10))
  plt.subplot(221)
  plt.imshow(orig/255.)
  plt.title('Original Image')
  plt.axis("off")
  plt.subplot(222)
  plt.axis("off")
  plt.subplot(223)
  plt.imshow(bicubic)
  plt.title('Bicubic Upsampled Image')
  plt.axis("off")
  plt.subplot(224)
  plt.imshow(super_res)
  plt.title('Super Resolution Image')
  plt.axis("off")
  fig.tight_layout();

In [None]:
lr_img_loc = 'sample image'
lr, bicubic, sr = run_sr_inference(lr_img_loc,
                                   edsr_model)
plot_edsr_results(lr, bicubic, sr)

In [None]:
edsr_model.save_weights('/content/drive/MyDrive/edsr_16-res-block-x4.h5')


In [None]:
from tensorflow.keras.layers import BatchNormalization, \
          Conv2D, Dense, Flatten, Input, LeakyReLU, Lambda
from tensorflow.keras.models import Model

def minmax_normalize(x):
  """Normalizes RGB images to [-1, 1]."""
  return x / 127.5 - 1


def srgan_discriminator_arch():
  hr_size=96
  inp = Input(shape=(hr_size, hr_size, 3))
  x = Lambda(minmax_normalize)(inp)

  x = Conv2D(filters=64, kernel_size=3, 
             strides=1, padding='same')(x)
  x = LeakyReLU(alpha=0.2)(x)
  x = Conv2D(filters=64, kernel_size=3, 
             strides=2, padding='same')(x)
  x = BatchNormalization(momentum=0.8)(x)           
  x = LeakyReLU(alpha=0.2)(x)

  x = Conv2D(filters=128, kernel_size=3, 
             strides=1, padding='same')(x)
  x = BatchNormalization(momentum=0.8)(x)           
  x = LeakyReLU(alpha=0.2)(x)
  x = Conv2D(filters=128, kernel_size=3, 
             strides=2, padding='same')(x)
  x = BatchNormalization(momentum=0.8)(x)           
  x = LeakyReLU(alpha=0.2)(x)

  x = Conv2D(filters=256, kernel_size=3, 
             strides=1, padding='same')(x)
  x = BatchNormalization(momentum=0.8)(x)           
  x = LeakyReLU(alpha=0.2)(x)
  x = Conv2D(filters=256, kernel_size=3, 
             strides=2, padding='same')(x)
  x = BatchNormalization(momentum=0.8)(x)           
  x = LeakyReLU(alpha=0.2)(x)

  x = Conv2D(filters=512, kernel_size=3, 
             strides=1, padding='same')(x)
  x = BatchNormalization(momentum=0.8)(x)           
  x = LeakyReLU(alpha=0.2)(x)
  x = Conv2D(filters=512, kernel_size=3, 
             strides=2, padding='same')(x)
  x = BatchNormalization(momentum=0.8)(x)           
  x = LeakyReLU(alpha=0.2)(x)

  x = Flatten()(x)
  x = Dense(1024)(x)
  x = LeakyReLU(alpha=0.2)(x)
  out = Dense(1, activation='sigmoid')(x)

  return Model(inp, out)

In [None]:
generator = edsr_model_arch(num_residual_blocks=16)
generator.load_weights('/content/drive/MyDrive/edsr_16-res-block-x4.h5')

discriminator = srgan_discriminator_arch()

In [None]:
generator.summary()

In [None]:
tf.keras.utils.plot_model(generator, show_shapes=True, 
                          rankdir='TB')

In [None]:
discriminator.summary()

In [None]:
tf.keras.utils.plot_model(discriminator, show_shapes=True, 
                          rankdir='TB')

In [None]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay

lr_schedule = PiecewiseConstantDecay(boundaries=[100000, 200000], 
                                     values=[1e-4, 1e-5, 5e-6])
generator_optim = Adam(learning_rate=lr_schedule)
discriminator_optim = Adam(learning_rate=lr_schedule)

In [None]:
from tensorflow.keras.applications import vgg19

vgg = vgg19.VGG19(input_shape=(None, None, 3), 
                  weights='imagenet',
                  include_top=False)
cutvgg_model = Model(vgg.input, vgg.layers[20].output)

cutvgg_model.summary()

In [None]:
mse_loss = tf.keras.losses.MeanSquaredError()
bce_loss = tf.keras.losses.BinaryCrossentropy()

def compute_generator_loss(sr_pred):
  return bce_loss(tf.ones_like(sr_pred), sr_pred)

def compute_discriminator_loss(hr_pred, sr_pred):
  hr_loss = bce_loss(tf.ones_like(hr_pred), hr_pred)
  sr_loss = bce_loss(tf.zeros_like(sr_pred), sr_pred)
  return hr_loss + sr_loss

@tf.function
def compute_content_loss(hr, sr):
  hr = vgg19.preprocess_input(hr)
  sr = vgg19.preprocess_input(sr)
  hr_deep_features = cutvgg_model(hr) / 12.75
  sr_deep_features = cutvgg_model(sr) / 12.75
  return mse_loss(hr_deep_features, sr_deep_features)

In [None]:
@tf.function
def train_step(lr_img_batch, hr_img_batch):
  """SRGAN training step.
  
  Takes an LR and an HR image batch as input and returns
  the computed perceptual loss and discriminator loss.
  """
  with tf.GradientTape() as gen_tape,\
          tf.GradientTape() as disc_tape:
    lr_img_batch = tf.cast(lr_img_batch, tf.float32)
    hr_img_batch = tf.cast(hr_img_batch, tf.float32)

    # Forward pass
    sr_gen_batch = generator(lr_img_batch, training=True)
    hr_pred = discriminator(hr_img_batch, training=True)
    sr_pred = discriminator(sr_gen_batch, training=True)

    # Compute losses
    content_loss = compute_content_loss(hr_img_batch, 
                                        sr_gen_batch)
    gen_loss = compute_generator_loss(sr_pred)
    perceptual_loss = content_loss + 1e-3 * gen_loss
    disc_loss = compute_discriminator_loss(hr_pred, sr_pred)

  # Compute gradient of perceptual loss w.r.t. generator weights 
  gen_grads = gen_tape.gradient(perceptual_loss, 
                                generator.trainable_variables)
  # Compute gradient of discriminator loss w.r.t. discriminator weights 
  disc_grads = disc_tape.gradient(disc_loss, 
                                  discriminator.trainable_variables)

  # Update weights of generator and discriminator
  generator_optim.apply_gradients(zip(gen_grads, 
                                      generator.trainable_variables))
  discriminator_optim.apply_gradients(zip(disc_grads, 
                                          discriminator.trainable_variables))

  return perceptual_loss, disc_loss

In [None]:
total_steps = 300000
step = 0
pl_batch = []
dl_batch = []
for lr_batch, sr_batch in tqdm(train_ds.take(total_steps)):
  pl, dl = train_step(lr_batch, sr_batch)
  pl_batch.append(pl)
  dl_batch.append(dl)
  step += 1

  if step % 1000 == 0:
    print('Step: {step}/{steps}: Perceptual Loss: {ploss:.5f}, Discriminator Loss: {dloss:.5f}'.format(
        step=step, steps=total_steps,
        ploss=np.mean(pl_batch),
        dloss=np.mean(dl_batch)
    ))
    pl_batch = []
    dl_batch = []

In [None]:
generator.save_weights('/content/drive/MyDrive/srgan_finetuned_edsr_16-res-block-x4.h5')

In [None]:
#LOAD MODEL AND WEIGHTS
weight_edsr = f"/content/drive/MyDrive/South Pole/training_weights_logs/weights/second_module/l8_s2_training/edsr_16-res-block-x4.h5"
weight_srgan = f"/content/drive/MyDrive/South Pole/training_weights_logs/weights/second_module/l8_s2_training/srgan_finetuned_edsr_16-res-block-x4.h5"


edsr_orig_model = edsr_model_arch(num_residual_blocks=16)
edsr_orig_model.load_weights(weight_edsr)

edsr_finetuned_model = edsr_model_arch(num_residual_blocks=16)
edsr_finetuned_model.load_weights(weight_srgan)

In [None]:
# CREATE THE PREDICTION FILE FOR EACH IMAGE IN VALIDATION SET

model = edsr_finetuned_model
model_name = "M2_SRGAN_l8_s2"
output_directory = "PREDICTIONS DIRECTORY"
input_directory = "LR IMAGES DIRECTORY"

filenames = list()
output_directory = output_directory + "/" + model_name
os.makedirs(output_directory)

for filename in os.listdir(input_directory):
      f = os.path.join(input_directory, filename)
      # checking if it is a file
      if os.path.isfile(f) and filename.split(".")[-1] == "png":
        filenames.append(f)
      
for input_path in filenames:
  
  output_file = input_path.split("/")[-1]
  output_path = os.path.join(output_directory,output_file)

  lr_file = tf.io.read_file(input_path)
  lr_img = tf.image.decode_png(lr_file, channels=3)
  lr_img_npy = lr_img.numpy()

  lr_img = tf.expand_dims(lr_img, axis=0)
  lr_img = tf.cast(lr_img, tf.float32)

  sr_img = model(lr_img)
  sr_img = tf.clip_by_value(sr_img, 0, 255)
  sr_img = tf.cast(sr_img, tf.uint8)
  tf.keras.preprocessing.image.save_img(output_path,sr_img[0])  



