<a href="https://colab.research.google.com/github/Shobhit2000/Super_Resolution/blob/master/Super_Resolution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Importing Libraries**

In [None]:
import cv2
import numpy as np
import os
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.layers import Dense, Input, Conv2D, Flatten, Reshape, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


**Data Prerpocessing Testing**

In [None]:
# img = cv2.imread('test.jpg')
# img_480 = cv2.resize(img, (640, 480))
# img_144 = cv2.resize(img, (256, 144))
# img_480_resize = cv2.resize(img_144, (640, 480))

# cv2_imshow(img_480)
# cv2_imshow(img_480_resize)

# https://keras.io/examples/vision/super_resolution_sub_pixel/
#  check this for preprocessing

**Download Dataset**

In [None]:
!wget http://images.cocodataset.org/zips/test2017.zip

--2022-09-23 06:31:45--  http://images.cocodataset.org/zips/test2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 52.217.173.81
Connecting to images.cocodataset.org (images.cocodataset.org)|52.217.173.81|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6646970404 (6.2G) [application/zip]
Saving to: ‘test2017.zip’


2022-09-23 06:40:11 (12.6 MB/s) - ‘test2017.zip’ saved [6646970404/6646970404]



In [None]:
!unzip test2017.zip

**Parameters**

**Generator MODEL**

In [None]:
def residual(temp_x):

  conv_residual = layers.Conv2D(64, kernel_size=(3, 3), padding="same", strides=1)(temp_x)
  relu_residual = layers.ReLU()(conv_residual)
  conv_residual = layers.Conv2D(64, kernel_size=(3, 3), padding="same", strides=1)(relu_residual)

  residual_block = tf.math.add(temp_x, conv_residual)

  return residual_block

In [None]:
def generator_residual(temp_x):

  x = layers.Conv2D(64, kernel_size=(3, 3), padding="same", strides=1)(temp_x)
  x = layers.BatchNormalization()(x)
  x = layers.ReLU()(x)
  x = layers.Conv2D(64, kernel_size=(3, 3), padding="same", strides=1)(x)
  x = layers.BatchNormalization()(x)
  gen_residual_block = tf.math.add(temp_x, x)

  return gen_residual_block

In [None]:
def generator_model(upscale_factor=2):

  input = Input(shape=(144, 256, 3))
  gen_base = layers.Conv2D(64, kernel_size=(9, 9), strides=1, padding="same")(input)
  gen_base = layers.ReLU()(gen_base)

  gen_residual_block1 = generator_residual(gen_base)
  gen_residual_block2 = generator_residual(gen_residual_block1)
  gen_residual_block3 = generator_residual(gen_residual_block2)
  gen_residual_block4 = generator_residual(gen_residual_block3)
  gen_residual_block5 = generator_residual(gen_residual_block4)

  gen_model = layers.Conv2D(64, kernel_size=(3, 3), strides=1, padding="same")(gen_residual_block5)
  gen_model = layers.BatchNormalization()(gen_model)
  gen_model = tf.math.add(gen_base, gen_model)

  gen_model = layers.Conv2D(256, kernel_size=(3, 3), strides=1, padding="same")(gen_model)
  # PIXEL_SHUFFLER
  gen_model = layers.Conv2D(1 * (upscale_factor ** 2), kernel_size=(3, 3), strides=1, padding="same")(gen_model)
  gen_model = tf.nn.depth_to_space(gen_model, upscale_factor)
  gen_model = layers.ReLU()(gen_model)

  gen_model = layers.Conv2D(256, kernel_size=(3, 3), strides=1, padding="same")(gen_model)
  # PIXEL_SHUFFLER
  gen_model = layers.Conv2D(1 * (upscale_factor ** 2), kernel_size=(3, 3), strides=1, padding="same")(gen_model)
  gen_model = tf.nn.depth_to_space(gen_model, upscale_factor)
  gen_model = layers.ReLU()(gen_model)
  gen_model = layers.Conv2D(3, kernel_size=(9, 9), strides=1, padding="same")(gen_model)
  gen_model_output = layers.Resizing(480, 640)(gen_model)
  
  gen_model = tf.keras.Model(inputs=input, outputs=gen_model_output, name="Super_Resolution_Generator")
  gen_model.summary()

  return gen_model

In [None]:
# GENERATOR_MODEL = generator_model()

**Generator Architecture Diagram**

In [None]:
# from tensorflow.keras.utils import plot_model

# plot_model(GENERATOR_MODEL, to_file='model_plot.png', show_shapes=True, show_layer_names=True)

**Discriminator Model**

In [None]:
def discriminator_model():

  input = Input(shape=(480, 640, 3))
  disc_model = layers.Conv2D(64, kernel_size=(3, 3), strides=1, padding="same")(input)
  disc_model = layers.LeakyReLU(alpha=0.2)(disc_model)

  disc_block1 = layers.Conv2D(64, kernel_size=(3, 3), padding="same", strides=2)(disc_model)
  disc_block1 = layers.BatchNormalization()(disc_block1)
  disc_block1 = layers.LeakyReLU(alpha=0.2)(disc_block1)

  disc_block2 = layers.Conv2D(128, kernel_size=(3, 3), padding="same", strides=1)(disc_block1)
  disc_block2 = layers.BatchNormalization()(disc_block2)
  disc_block2 = layers.LeakyReLU(alpha=0.2)(disc_block2)

  disc_block3 = layers.Conv2D(128, kernel_size=(3, 3), padding="same", strides=2)(disc_block2)
  disc_block3 = layers.BatchNormalization()(disc_block3)
  disc_block3 = layers.LeakyReLU(alpha=0.2)(disc_block3)

  disc_block4 = layers.Conv2D(256, kernel_size=(3, 3), padding="same", strides=1)(disc_block3)
  disc_block4 = layers.BatchNormalization()(disc_block4)
  disc_block4 = layers.LeakyReLU(alpha=0.2)(disc_block4)
  
  disc_block5 = layers.Conv2D(256, kernel_size=(3, 3), padding="same", strides=2)(disc_block4)
  disc_block5 = layers.BatchNormalization()(disc_block5)
  disc_block5 = layers.LeakyReLU(alpha=0.2)(disc_block5)

  disc_block6 = layers.Conv2D(512, kernel_size=(3, 3), padding="same", strides=1)(disc_block5)
  disc_block6 = layers.BatchNormalization()(disc_block6)
  disc_block6 = layers.LeakyReLU(alpha=0.2)(disc_block6)
  
  disc_block7 = layers.Conv2D(512, kernel_size=(3, 3), padding="same", strides=2)(disc_block6)
  disc_block7 = layers.BatchNormalization()(disc_block7)
  disc_block7 = layers.LeakyReLU(alpha=0.2)(disc_block7)

  disc_model = GlobalAveragePooling2D()(disc_block7)
  disc_model = layers.Dense(1024)(disc_model)
  disc_model = layers.LeakyReLU(alpha=0.2)(disc_model)
  disc_model_output = layers.Dense(1, activation='sigmoid')(disc_model)

  disc_model = tf.keras.Model(inputs=input, outputs=disc_model_output, name="Super_Resolution_Discriminator")
  disc_model.summary()

  return disc_model

In [None]:
# DISCRIMINATOR_MODEL = discriminator_model()

In [None]:
# from tensorflow.keras.utils import plot_model

# plot_model(DISCRIMINATOR_MODEL, to_file='model_plot_discriminator.png', show_shapes=True, show_layer_names=True)

In [None]:
def soft_dice_loss(y_true, y_pred, epsilon=1e-6): 

    # skip the batch and class axis for calculating Dice score
    axes = tuple(range(1, len(y_pred.shape)-1)) 
    numerator = 2. * tf.math.reduce_sum(y_pred * y_true, axis=axes)
    denominator = tf.math.reduce_sum(tf.math.square(y_pred) + tf.math.square(y_true), axes)
    
    return 1 - tf.math.reduce_mean(numerator / (denominator + epsilon)) # average over classes and batch

In [None]:
def construct_models(verbose=False):

    ### discriminator
    DISCRIMINATOR_MODEL = discriminator_model()
    DISCRIMINATOR_MODEL.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.0001), metrics=['accuracy'])

    ### generator
    # do not compile generator
    GENERATOR_MODEL = generator_model()
    GENERATOR_MODEL.compile(loss=soft_dice_loss, optimizer=tf.keras.optimizers.Adam(lr=0.0001), metrics=['accuracy'])

    ### SRGAN 
    SRGAN = tf.keras.Sequential()
    SRGAN.add(GENERATOR_MODEL)
    SRGAN.add(DISCRIMINATOR_MODEL)
    DISCRIMINATOR_MODEL.trainable = False 
    SRGAN.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.0001), metrics=['accuracy'])

    # if verbose: 
    #     GENERATOR_MODEL.summary()
    #     DISCRIMINATOR_MODEL.summary()
    #     SRGAN.summary()
        
    return GENERATOR_MODEL, DISCRIMINATOR_MODEL, SRGAN
  
GENERATOR_MODEL, DISCRIMINATOR_MODEL, SRGAN = construct_models(verbose=True)

**Data Loader: Load images in batches to optimise memory**

In [None]:
FILE_LIST = os.listdir('test2017')
BATCH_SIZE = 16

**Load models incase saved**

In [None]:
GENERATOR_MODEL = tf.keras.models.load_model('drive/MyDrive/saved_model/generator.h5')
DISCRIMINATOR_MODEL = tf.keras.models.load_model('drive/MyDrive/saved_model/discriminator.h5')
SRGAN = tf.keras.models.load_model('drive/MyDrive/saved_model/srgan.h5')

# Show the model architecture
GENERATOR_MODEL.summary()
DISCRIMINATOR_MODEL.summary()
SRGAN.summary()

ValueError: ignored

In [None]:
def image_generator(files, batch_size):
    
    while True:
          # Select files (paths/indices) for the batch
          batch_paths  = np.random.choice(a    = files, 
                                          size = batch_size)
          batch_input  = []
          batch_output = [] 
          
          # Read in each input, perform preprocessing and get labels
          for input_path in batch_paths:
              path = 'test2017/' + input_path
              img = cv2.imread(path)                        # 480p image
              img_480 = cv2.resize(img, (640, 480))
              img_144 = cv2.resize(img, (256, 144))

              img_480 = img_480/255
              img_144 = img_144/255

              batch_input.append(img_144)
              batch_output.append(img_480)

          batch_input = np.asarray(batch_input)
          batch_output = np.asarray(batch_output)
        
          return batch_input, batch_output

In [None]:
print(len(FILE_LIST))

40670


**GAN Training**

In [None]:
# number of discriminator updates per alternating training iteration
DISC_UPDATES = 1  
# number of generator updates per alternating training iteration
GAN_UPDATES = 1 
PROGRESS_INTERVAL = 20 

# function for training a GAN
def run_training(generator, discriminator, gan, start_it=0, num_epochs=100):

  # list for storing loss
  avg_loss_discriminator = []
  avg_loss_generator = []
  avg_loss_srgan = []
  total_it = start_it

  # main training loop
  for epoch in range(num_epochs):

      ckpt.restore(manager.latest_checkpoint)
      if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
      else:
        print("Initializing from scratch.")

      # alternating training loop
      loss_discriminator = []
      loss_generator = []
      loss_srgan = []
      print('Discriminator training')
      for i in range(20): 

          # select a random set of real images
          x_train, y_train = image_generator(FILE_LIST, BATCH_SIZE)

          # generate a set of fake images using the generator
          super_resolved_pred = generator.predict_on_batch(x_train)

          #### DISCRIMINATOR training loop ####
          for i in range(DISC_UPDATES): 

              #  Uncomment these 4 lines when value of DISC_UPDATES is greater than 1
              # # select a random set of real images
              # x_train, y_train = image_generator(FILE_LIST, BATCH_SIZE)

              # # generate a set of fake images using the generator
              # super_resolved_pred = generator.predict_on_batch(x_train)
           
              # train the discriminator on real images with label 1
              true_labels = np.expand_dims(np.ones([BATCH_SIZE], dtype=np.float32), axis=1)
              d_loss_real = discriminator.train_on_batch(y_train, true_labels)[1]
              
              # train the discriminator on fake images with label 0
              fake_labels = np.expand_dims(np.zeros([BATCH_SIZE], dtype=np.float32), axis=1)
              d_loss_fake = discriminator.train_on_batch(super_resolved_pred, fake_labels)[1]
          
          #### GENERATOR Training ####
          gen_loss = generator.train_on_batch(x_train, y_train)[1]

          # display some fake images for visual control of convergence
          if total_it % PROGRESS_INTERVAL == 0:
              plt.figure(figsize=(5,2))
              batch_vis = min(BATCH_SIZE, 5)
              x_train_visualize, y_train_visualize = image_generator(FILE_LIST, batch_vis)
              super_resolved_pred_visualize = generator.predict_on_batch(x_train_visualize)
              
              for obj_plot in [super_resolved_pred_visualize, y_train_visualize]:
                  plt.figure(figsize=(batch_vis * 3, 3))
                  
                  for b in range(batch_vis):
                      disc_score = float(discriminator.predict_on_batch(np.expand_dims(obj_plot[b], axis=0))[0])
                      plt.subplot(1, batch_vis, b + 1)
                      plt.title(str(round(disc_score, 3)))
                      plt.imshow(obj_plot[b] * 0.5 + 0.5) 
                 
                  plt.show()  

          #### SRGAN training loop ####
          gan_loss = 0
          y = np.ones([BATCH_SIZE, 1])
          
          for j in range(GAN_UPDATES):
              # generate a set of random noise vectors
              x_train, y_train = image_generator(FILE_LIST, BATCH_SIZE)
              # train the generator on fake images with label 1
              gan_loss += gan.train_on_batch(x_train, y)[1]

          # store loss
          loss_discriminator.append((d_loss_real + d_loss_fake) / 2.)
          loss_generator.append(gen_loss)        
          loss_srgan.append(gan_loss / GAN_UPDATES)
          print(total_it)
          total_it += 1

      # visualize loss
      print('Epoch', epoch)
      print('Discriminator Loss:- ', str(np.mean(loss_discriminator)))
      print('Generator Loss:- ', str(np.mean(loss_generator)))
      print('SRGAN Loss:- ', str(np.mean(loss_srgan)))
      avg_loss_discriminator.append(np.mean(loss_discriminator))
      avg_loss_generator.append(np.mean(loss_generator))
      avg_loss_srgan.append(np.mean(loss_srgan))
      plt.plot(range(len(avg_loss_discriminator)), avg_loss_discriminator)
      plt.plot(range(len(avg_loss_generator)), avg_loss_generator)
      plt.plot(range(len(avg_loss_srgan)), avg_loss_srgan)
      plt.legend(['Discriminator Loss', 'Generator Loss', 'SRGAN Loss'])
      plt.show()

      ckpt.step.assign_add(1)
      if int(ckpt.step) % 1 == 0:
        save_path = manager.save()
        print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
        print("loss {:1.2f}".format(gan_loss))

  return generator, discriminator, gan

In [None]:
ckpt = tf.train.Checkpoint(step=tf.Variable(1))
manager = tf.train.CheckpointManager(ckpt, '/content/drive/MyDrive/Video_Enhancer/DL_Models/SRGAN/tf_ckpts', max_to_keep=1)

generator_celeb, discriminator_celeb, gan_celeb = run_training(GENERATOR_MODEL, DISCRIMINATOR_MODEL, SRGAN, 
                                                               num_epochs=100)

**Test**

In [None]:
x_train_visualize, y_train_visualize = image_generator(FILE_LIST, 1)
super_resolved_pred_visualize = GENERATOR_MODEL.predict_on_batch(x_train_visualize)

cv2_imshow(super_resolved_pred_visualize[0]*255)
cv2_imshow(y_train_visualize[0]*255)

**Save Model**

In [None]:
GENERATOR_MODEL.save('drive/MyDrive/saved_model/generator.h5')
DISCRIMINATOR_MODEL.save('drive/MyDrive/saved_model/discriminator.h5')
SRGAN.save('drive/MyDrive/saved_model/srgan.h5')

**Load Model**

In [None]:
GENERATOR_MODEL = tf.keras.models.load_model('drive/MyDrive/saved_model/generator.h5')
DISCRIMINATOR_MODEL = tf.keras.models.load_model('drive/MyDrive/saved_model/discriminator.h5')
SRGAN = tf.keras.models.load_model('drive/MyDrive/saved_model/srgan.h5')

# Show the model architecture
GENERATOR_MODEL.summary()
DISCRIMINATOR_MODEL.summary()
SRGAN.summary()