In [3]:
from google.colab import drive
drive.mount('/content/drive')


Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [None]:
import os
os.listdir('/content/drive/My Drive/DIV2K_train_HR/')

In [None]:
# get_ipython().system_raw("/content/drive/My Drive/DIV2K_train_HR.zip")
!unzip "/content/drive/My Drive/DIV2K_valid_LR_wild.zip" -d '/content/drive/My Drive/'

In [4]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, LeakyReLU, PReLU, Add, Lambda, Dense
from tensorflow.keras.applications import VGG19
from tensorflow.keras.losses import BinaryCrossentropy, MeanAbsoluteError, MeanSquaredError
from tf.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay
import numpy as np
from tensorflow.keras.optimizers import Adam
from tensorflow.python.data.experimental import AUTOTUNE
from easydict import EasyDict as edict
import os

In [3]:
config = edict()

config.HR_TRAIN_PATH       = '/content/drive/My Drive/DIV2K_train_HR/'
config.HR_VALID_PATH       = '/content/drive/My Drive/DIV2K_valid_HR/'
config.LR_TRAIN_PATH       = '/content/drive/My Drive/DIV2K_train_LR/'
config.LR_VALID_PATH       = '/content/drive/My Drive/DIV2K_valid_LR_wild/'

config.GEN_PRE_CHECKPOINT_DIR = '/content/drive/My Drive/SRGAN/checkpoints/pretrained_gen/'
config.SRGAN_CHECKPOINT_DIR  = '/content/drive/My Drive/SRGAN/checkpoints/'
config.FINAL_WEIGHTS_DIR   = '/content/drive/My Drive/SRGAN/weights/'

config.LR_DOWNSCALE        = 4
config.HR_CROP_SIZE        = 96
config.NUM_EPOCHS          = 20000
config.NUM_RES_BLOCKS      = 16
config.NUM_UPSAMPLE_BLOCKS = 2

In [None]:
DIV2k_images_mean = np.array([0.4488, 0.4371, 0.4040]) *  255


def normalize_to_range_01(img):
  return img / 255.0

def normalize_to_range_n11(img):
  return img / 127.5 - 1

def denormalize_n11(img):
  return (img + 1) * 127.5




In [None]:
class Generator:
  def __init__(self):

    self.num_res_blocks = config.NUM_RES_BLOCKS
    self.num_upsample_blocks = config.NUM_UPSAMPLE_BLOCKS
  

  def pixel_shuffle(self, data, scale):
    return lambda data : tf.nn.depth_to_space(data, scale)


  def conv_and_upsample(self, data, scale):
    data = Conv2D(filters = 256, kernel_size = (3, 3), strides = (1, 1), padding = 'SAME')(data)
    data = Lambda(self.pixel_shuffle())(data, 2) # Lambda is used to convert that specific operation to a keras layer
    data = PReLU(shared_axes = [1, 2])(data)
    return data

  def resnet_block(self, data):
    res_data = Conv2d(filters = 64, kernel_size = (3, 3), strides = (1, 1), padding = 'SAME')(data)
    res_data = BatchNormalization()(res_data)
    res_data = PReLU(shared_axes = [1, 2])(res_data)
    res_data = Conv2d(filters = 64, kernel_size = (3, 3), strides = (1, 1), padding = 'SAME')(res_data)
    res_data = BatchNormalization()(res_data)
    res_data = Add()([data, res_data])
    return res_data

  def gen_network(self):
    in_data = Input(shape=(None, None, 3))
    data = Lambda(normalize_to_range_01)(in_data)
    data = Conv2D(filters = 64, kernel_size = 9, strides = (1, 1), padding = 'SAME')(data)
    data = BatchNormalization()(data)
    data = LeakyReLU()(data)
    data_copy = data

    for _ in range(self.num_res_blocks):
      data = self.resnet_block(data)
    
    data = Conv2D(filters = 64, kernel_size = (3, 3), strides = (1, 1), padding = 'SAME')(data)
    data = BatchNormalization()(data)
    data = Add()([data, data_copy])


    for _ in range(self.num_upsample_blocks):
      data = self.conv_and_upsample(data, 2)
    
    data = Conv2D(filters = 3, kernel_size = 9, strides = (1, 1), padding = 'SAME')(data)
    data = Lambda(denormalize_n11)(data)
    self.gen_model = tf.keras.Model(in_data, data)



##############################################################################################
##############################################################################################

class Discriminator:
  def __init__(self):
    pass

  
  def disc_block(self, data, n_filters, stride):
    data = Conv2D(filters = n_filters, kernel_size = 3, strides = (stride, stride), padding = 'SAMEG')(data)
    data = BatchNormalization()(data)
    data = LeakyReLU()(data)
    return data

  def disc_network(self):
    in_data = Input(shape=(96, 96, 3))
    data = Lambda(normalize_to_range_n11)(in_data)
    data = Conv2D(filters = 64, kernel_size = 3, strides = (1, 1), padding = 'SAME')(in_data)
    data = LeakyReLU()(data)

    data = self.disc_block(data, n_filters = 64, stride = 2)
    data = self.disc_block(data, n_filters = 128, stride = 1)
    data = self.disc_block(data, n_filters = 128, stride = 2)
    data = self.disc_block(data, n_filters = 256, stride = 1)
    data = self.disc_block(data, n_filters = 256, stride = 2)
    data = self.disc_block(data, n_filters = 512, stride = 1)
    data = self.disc_block(data, n_filters = 512, stride = 2)

    data = Dense(1024)(data)
    data = LeakyReLU()(data)
    data = Dense(1, activation = 'sigmoid')(data)
    self.disc_model = tf.keras.Model(in_data, data)





##############################################################################################
##############################################################################################



class SRGAN_Training:

  def __init__(self):
    # self.num_res_blocks = num_res_blocks # 16
    # self.num_upsample_blocks = num_upsample_blocks #2
    
    '''
    learning rate is 0.0001 till 10^5 iterations, after that, the learning rate should be 0.00001
    '''
    rate_values = [0.0001, 0.00001]
    rate_boundary = [100000]
    self.learning_rate = PiecewiseConstantDecay(boundaries = rate_boundary, values = rate_values)

    generator = Generator()
    discriminator = Discriminator()

    self.gen_model = generator.gen_model
    self.disc_model = discriminator.disc_model

    vgg = VGG19(include_top = False, input_shape = (None, None, 3))
    self.vgg_model = Model(vgg.input, vgg.layers[20].output)
    
    self.generator_optimizer = Adam(learning_rate = self.learning_rate)
    self.discriminator_optimizer = Adam(learning_rate = self.learning_rate )

    self.binary_cross_entropy = BinaryCrossEntropy(from_logits = False)
    self.mean_squared_error   = MeanSquaredError()
    
    self.srgan_checkpoint = tf.train.Checkpoint(curr_epoch = tf.Variable(0),
                                                g_optim = self.generator_optimizer, 
                                                d_optim = self.discriminator_optimizer,
                                                g_model = self.gen_model, 
                                                d_model = self.disc_modle)
    self.srgan_checkpoint_manager = tf.train.CheckpointManager(self.srgan_checkpoint,
                                                               directory = config.SRGAN_CHECKPOINT_DIR,
                                                               max_to_keep = 3)


    
  
  def disc_loss(self, logits_real, logits_fake)
    disc_real_loss = self.binary_cross_entropy(tf.ones_like(logits_real), logits_real)
    disc_fake_loss = self.binary_cross_entropy(tf.zeros_like(logits_fake), logits_fake)
    return disc_real_loss + disc_fake_loss


  def vgg_loss(self, vgg_fake, vgg_real):
    return self.mean_squared_error(vgg_fake, vgg_real)

  def gen_loss(self, fake_hr_images):
    return self.binary_cross_entropy(tf.ones_like(fake_hr_images), fake_hr_images)

  def train_step(self, lr, hr):
    with tf.GradientTape(persistent = True) as gen_tape, with tf.GradientTape(persistent = True) as disc_tape:
      fake_hr_images = self.gen_model(lr, training = True)
      logits_fake    = self.disc_model(fake_hr_images, training = True)
      logits_real    = self.disc_model(hr, training = True)
      vgg_fake       = self.vgg_model(preprocess_input(fake_hr_images)) / 12.75
      vgg_real       = self.vgg_model(preprocess_input(hr + 1)) / 12.75

      perceptual_loss    = self.vgg_loss(vgg_fake, vgg_real) + (0.0001 * self.gen_loss(logits_fake))
      discriminator_loss = self.disc_loss(logits_real, logits_fake)
    
    generator_gradients = gen_tape.gradient(perceptual_loss, self.gen_model.trainable_variables)
    discriminator_gradients = disc_tape.gradient(discriminator_loss, self.disc_model.trainable_variables)

    self.gen_checkpoint.optimizer.apply_gradients(zip(generator_gradients, self.gen_checkpoint.model.trainable_variables))
    self.disc_checkpoint.optimizer.apply_gradients(zip(discriminator_gradients, self.disc_checkpoint.model.trainable_variables))

    return perceptual_loss, discriminator_loss





  def restore_checkpoint(self, resume_training = False):

    if resuming and self.srgan_checkpoint_manager.latest_checkpoint:
      self.srgan_checkpoint.restore(self.srgan_checkpoint_manager.latest_checkpoint)

    else:
      latest_ckpt = tf.train.latest_checkpoint(config.GEN_PRE_CHECKPOINT_DIR)
      self.srgan_checkpoint.g_model.load_weights(latest_ckpt)
      



  def train(self, train_data, resume_training = True):
    '''
    Write function to read data batch and then perform the training'''
    self.restore_checkpoint(resume_training)
    if not resume_training:
      curr_epoch = 0
    else:
      curr_epoch = self.srgan_checkpoint.curr_epoch
    
    perc_loss_log = tf.keras.metrics.Mean('perc_loss', dtype = tf.float32)
    disc_loss_log = tf.keras.metrics.Mean('disc_loss', dtype = tf.float32)


    for epoch in range(config.NUM_EPOCHS - curr_epoch):
      self.srgan_checkpoint.curr_epoch.assign_add(1)
      for hr, lr in train_data.take(1):
        
        perceptual_loss, discriminator_loss = self.train_step(hr, lr)
        perc_loss_log.update_state(perceptual_loss)
        disc_loss_log.update_state(discriminator_loss)
      
      if epoch % 10 == 0:
        print('In epoch ' + str(epoch) + ' perceptual loss is ' + str(perc_loss_log.result()) + ' and discriminator loss is ' + str(disc_loss_log.result()))
        perc_loss_log.reset_states()
        disc_loss_log.reset_states()
      
      self.srgan_checkpoint_manager.save()
    
    if epoch == config.NUM_EPOCHS - 1:
      self.srgan_checkpoint.g_model.save_weights(config.FINAL_WEIGHTS_DIR + 'final_generator_weights.h5')
      self.srgan_checkpoint.d_model.save_weights(config.FINAL_WEIGHTS_DIR + 'final_discriminator_weights.h5')




##############################################################################################
##############################################################################################


class Generator_Training:
  def __init__(self):

    generator = Generator()
    self.learning_rate = 0.0001
    self.mse_loss = MeanSquaredError()
    self.generator_optimizer = Adam(learning_rate = self.learning_rate)
    self.generator_model = generator.gen_model
    self.checkpoint_dir = config.GEN_PRE_CHECKPOINT_DIR
    self.checkpoint = tf.train.Checkpoint(curr_epoch = tf.Variable(0),
                                          psnr_value = tf.Variable(-1.0),
                                          optimizer = self.generator_optimizer,
                                          model = self.generator_model)
    self.checkpoint_manager = tf.train.CheckpointManager(self.checkpoint,
                                                         directory = self.checkpoint_dir,
                                                         max_to_keep = 5)

    



  def restore_recent_checkpoint(self):
    if self.checkpoint_manager.latest_checkpoint:
      self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
      print('restored checkpoint successfully at epoch ' + str(self.checkpoint.curr_epoch.numpy()))



  def train_step(self, lr, hr):
    with tf.GradientTape(persistent = True) as grad_tape:
      lr = tf.cast(lr, tf.float32)
      hr = tf.cast(hr, tf.float32)

      fake_hr_image = self.chekcpoint.model(lr, training = True)
      loss = self.mse_loss(hr, fake_hr_image)

    gradients = grad_tape.gradient(loss, self.checkpoint.model.trainable_variables)
    self.checkpoint.optimizer.apply_gradients(zip(gradients, self.checkpoint.model.trainable_variables))
    return loss


  def get_fake_hr_images(self, lr):
    fake_hr_image = self.checkpoint.model(lr, training = False)
    fake_hr_image = tf.clip_by_value(fake_hr_image, (0, 255))
    fake_hr_image = tf.round(fake_hr_image)
    fake_hr_image = tf.cast(fake_hr_image, tf.uint8)
    

    return fake_hr_image


  def evaluate(self, valid_data):
    psnr_values = []
    for lr, hr in valid_data:
      lr = tf.cast(lr, tf.float32)
      fake_hr_image = self.get_fake_hr_images(lr)
      psnr = tf.image.psnr(fake_hr_image, hr, max_val=255)
      psnr_values.append(psnr)
    
    return tf.reduce_mean(psnr_values)


  def train_generator(self, train_data, valid_data, evaluate_step = 1000, total_epochs = 100000):
    self.restore_recent_checkpoint()
    for hr, lr in train_data.take(total_epochs - self.checkpoint.curr_epoch.numpy()):
      self.checkpoint.curr_epoch.assign_add(1)
      epoch = self.checkpoint.curr_epoch
      loss = self.train_step(lr, hr)
      
      if epoch % 10 == 0:
        print('generator training at epoch ' + str(epoch) + ' and loss is ' + str(loss))
      if epoch % evaluate_step == 0:
        psnr_value = self.evaluate(valid_data)

        if self.checkpoint.psnr_value < psnr_value:
          self.checkpoint.psnr_value = psnr_value
          self.checkpoint_manager.save()
      
      if epoch == total_epochs - 1:
        self.checkpoint.model.save_weights(config.FINAL_WEIGHTS_DIR + 'generator_weights.h5')


In [None]:
##############################################################################################
########################## DATA LOADER AND PREPROCESSOR ######################################
##############################################################################################

In [27]:
import glob

class DatasetLoader_Preprocessing:

  def __init__(self, hr_train, hr_valid, lr_train, lr_valid, scale = 4):
    self.hr_train_path = hr_train
    self.lr_train_path = lr_train
    self.hr_valid_path = hr_valid
    self.lr_valid_path = lr_valid
    self.scale = scale
    # self.get_lr_hr_names(self.hr_train_path, self.lr_train_path, 'train')
    self.get_lr_hr_names(self.hr_valid_path, self.lr_valid_path, 'valid')



  def get_only_image_nums(self, images_path):
    image_ids = os.listdir(images_path)
    image_ids = [image_id for image_id in image_ids if '.png' in image_id]
    only_image_nums = []

    for image_id in image_ids:
      
      ext_start_index = image_id.rindex('.')
      img_str = image_id[0 : ext_start_index]
      only_image_nums.append(img_str)

    return only_image_nums



  def get_full_images_paths(self, images_path, image_ids):
    total_ids = len(image_ids)
    for index in range(total_ids):
      image_ids[index] = images_path + image_ids[index]
    
    return image_ids



  def get_lr_image_names(self, hr_images_path, scale = 4):
    lr_images_names = []

    hr_images_ids = self.get_only_image_nums(hr_images_path)
    for image_id in hr_images_ids:
      lr_image_num = image_id + 'x' + str(scale) + 'w.png'
      lr_images_names.append(lr_image_num)
    
    return lr_images_names


  def get_lr_hr_names(self, hr_images_path, lr_images_path, dataset_type):
    lr_images_names = get_lr_image_names(hr_images_path, self.scale)
    
    hr_images_full_ids = glob.glob(hr_images_path + '*')
    lr_images_full_ids = self.get_full_images_paths(lr_images_path, lr_images_names)

    if dataset_type == 'train':
      self.train_hr_names = hr_images_full_ids
      self.train_lr_names = lr_images_full_ids
    
    if dataset_type == 'valid':
      self.valid_hr_names = hr_images_full_ids
      self.valid_lr_names = lr_images_full_ids

  

  def random_crop(self, hr, lr):
    lr_downscale = config.LR_DOWNSCALE
    hr_crop_size = config.HR_CROP_SIZE
    lr_crop_size = tf.cast((hr_crop_size / lr_downscale), tf.int32)

    # hr_image_shape = tf.shape(hr)[0 : 2]
    lr_image_shape = tf.shape(lr)[:2] # 0 - height, 1 - width

    # rng = tf.random.Generator.from_non_deterministic_state() # random number generator
    lr_w_start = tf.random.uniform((), minval = 0, maxval = lr_image_shape[1] - lr_crop_size + 1, dtype = tf.dtypes.int32)
    lr_h_start = tf.random.uniform((), minval = 0, maxval = lr_image_shape[0] - lr_crop_size + 1, dtype = tf.dtypes.int32)

    hr_w_start = lr_w_start * lr_downscale
    hr_h_start = lr_h_start * lr_downscale

    lr_crop = lr[lr_h_start : (lr_h_start + lr_crop_size), lr_w_start : (lr_w_start + lr_crop_size)]
    hr_crop = hr[hr_h_start : (hr_h_start + hr_crop_size), hr_w_start : (hr_w_start + hr_crop_size)]

    return lr_crop, hr_crop



  def random_flip(self,hr, lr):
    # rng = tf.random.Generator.from_non_deterministic_state() # random number generator
    num = tf.random.uniform((), minval = 0, maxval = 2, dtype = tf.dtypes.int32)

    return tf.cond(num == 0, 
                   lambda : (tf.image.flip_left_right(lr),
                             tf.image.flip_left_right(hr)),
                   lambda : (lr, hr))

  
  

  def get_dataset(self, image_ids):
    dataset = tf.data.Dataset.from_tensor_slices(image_ids)
    dataset = dataset.map(lambda image_id : tf.io.read_file(image_id))
    dataset = dataset.map(lambda png_image : tf.io.decode_png(png_image, channels = 3), num_parallel_calls = AUTOTUNE)
    return dataset
  


  def get_final_dataset(self, batch_size = 16, dataset_type = 'train'):
    if dataset_type == 'train':
      lr_ds = self.get_dataset(self.train_lr_names)
      hr_ds = self.get_dataset(self.train_hr_names)
    
    if dataset_type == 'valid':
      hr_ds = self.get_dataset(self.valid_hr_names)
      lr_ds = self.get_dataset(self.valid_lr_names)

    dataset = tf.data.Dataset.zip((hr_ds, lr_ds))
    dataset = dataset.map(self.random_crop, num_parallel_calls = AUTOTUNE)
    dataset = dataset.map(self.random_flip, num_parallel_calls = AUTOTUNE)
    dataset = dataset.batch(batch_size, drop_remainder = True)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(buffer_size = AUTOTUNE)
    return dataset


data_loader = DatasetLoader_Preprocessing(config.HR_TRAIN_PATH, config.HR_VALID_PATH, config.LR_TRAIN_PATH, config.LR_VALID_PATH, 4)
dataset = data_loader.get_final_dataset(16, 'valid')

In [13]:
import glob
import cv2

DIV_MEAN = np.array([0.4488, 0.4371, 0.4040]) * 255

image_id = glob.glob('/content/drive/My Drive/temp_data/lr_train/*')
img = cv2.imread(image_id[0])
img = (img - DIV_MEAN) / 127.5
print(np.min(img), np.max(img))

-0.8976 1.192
