https://arxiv.org/pdf/1609.04802.pdf

In [1]:

from google.colab import drive
drive.mount('/content/drive')
"""
Change directory to where this file is located
"""
%cd /content/drive/MyDrive/

MessageError: ignored

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import requests
import tensorflow_datasets as tfds
import tensorflow_hub as hub
import tqdm
import os
import shutil
import re
import cv2
import time
import logging
from PIL import Image

In [None]:
class TqdmLoggingHandler(logging.Handler):
    def __init__(self, level=logging.NOTSET):
        super().__init__(level)

    def emit(self, record):
        try:
            msg = self.format(record)
            tqdm.tqdm.write(msg)
            self.flush()
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            self.handleError(record)  
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
log.addHandler(TqdmLoggingHandler())

In [None]:

data=tfds.load('aflw2k3d')

[1mDownloading and preparing dataset aflw2k3d/1.0.0 (download: 83.36 MiB, generated: Unknown size, total: 83.36 MiB) to /root/tensorflow_datasets/aflw2k3d/1.0.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

KeyboardInterrupt: ignored

In [None]:
data=tfds.load('aflw2k3d')
train_data=data['train'].skip(600)
test_data=data['train'].take(600)
tqdm.tqdm(train_data)

In [None]:
@tf.function
def build_data(data):
  hr=data['image']
  hr=tf.image.resize(hr,(512,512))
  # data = tf.keras.layers.GaussianNoise()(data)
  cropped=tf.dtypes.cast(hr / 255,tf.float32)

  lr=tf.image.resize(cropped,(64,64))
  #lr=tf.image.resize(lr,(128,128),method=tf.image.ResizeMethod.BICUBIC)
  return lr, hr

In [None]:
for x in train_data.take(1):
  plt.imshow(x['image'])
  plt.axis('off')
  plt.show()


In [None]:
def bicubic_interpolate(image,shape):
  img_resized=cv2.resize(image,shape, interpolation=cv2.INTER_CUBIC)
  return img_resized

In [None]:
for x in train_data.take(5):
  lr,hr=build_data(x)

  plt.imshow(lr)
  plt.axis('off')
  plt.show()


In [None]:
#Generator 
def residual_block_gen(ch=64,k_s=3,st=1):
  model=tf.keras.Sequential([
    tf.keras.layers.Conv2D(ch,k_s,strides=(st,st),padding='same'),
    tf.keras.layers.LeakyReLU(),
  ])
  return model

def Upsample_block(x, ch=256, k_s=3, st=1):
  x = tf.keras.layers.Conv2D(ch,k_s, strides=(st,st), padding='same')(x)
  x = tf.nn.depth_to_space(x, 2) # Subpixel pixelshuffler
  x = tf.keras.layers.LeakyReLU()(x)
  return x

input_lr=tf.keras.layers.Input(shape=(None,None,3))
input_conv=tf.keras.layers.Conv2D(64,9,padding='same')(input_lr)
input_conv=tf.keras.layers.LeakyReLU()(input_conv)

block_1=residual_block_gen()(input_conv)
block_1=tf.keras.layers.Add()([block_1, input_conv])
block_2=residual_block_gen()(block_1)

block_2=tf.keras.layers.Add()([block_1, input_conv, block_2])
block_3=residual_block_gen()(block_2)

block_3=tf.keras.layers.Add()([block_1, input_conv, block_2, block_3])
block_4=residual_block_gen()(block_3)

block_4=tf.keras.layers.Add()([block_1, input_conv, block_2, block_4])

SRRes=tf.keras.layers.Conv2D(64,9,padding='same')(block_4)
SRRes=tf.keras.layers.Add()([SRRes,input_conv])

SRRes=Upsample_block(SRRes)
SRRes=Upsample_block(SRRes)
SRRes=Upsample_block(SRRes)
output_sr=tf.keras.layers.Conv2D(3,9,activation='tanh',padding='same')(SRRes)

SRResnet=tf.keras.models.Model(input_lr,output_sr)

In [None]:
tf.keras.utils.plot_model(SRResnet,show_shapes=True)

In [None]:
def residual_block_disc(ch=64,k_s=3,st=1):
  model=tf.keras.Sequential([
    tf.keras.layers.Conv2D(ch,k_s,strides=(st,st),padding='same'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(),
  ])
  return model

input_lr=tf.keras.layers.Input(shape=(512,512,3))
input_conv=tf.keras.layers.Conv2D(64,3,padding='same')(input_lr)
input_conv=tf.keras.layers.LeakyReLU()(input_conv)

channel_nums=[64,128,128,256,256,512,512]
stride_sizes=[2,1,2,1,2,1,2]

disc=input_conv
for x in range(7):
  disc=residual_block_disc(ch=channel_nums[x],st=stride_sizes[x])(disc)
  
disc=tf.keras.layers.Flatten()(disc)

disc=tf.keras.layers.Dense(1024)(disc)
disc=tf.keras.layers.LeakyReLU()(disc)

disc_output=tf.keras.layers.Dense(1,activation='sigmoid')(disc)

discriminator=tf.keras.models.Model(input_lr,disc_output)

In [None]:
tf.keras.utils.plot_model(discriminator,show_shapes=True)

In [None]:
def PSNR(y_true,y_pred):
  mse=tf.reduce_mean( (y_true - y_pred) ** 2 )
  return 20 * log10(1 / (mse ** 0.5))

def log10(x):
  numerator = tf.math.log(x)
  denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
  return numerator / denominator

def pixel_MSE(y_true,y_pred):
  return tf.reduce_mean( (y_true - y_pred) ** 2 )

In [None]:
VGG19=tf.keras.applications.VGG19(weights='imagenet',include_top=False,input_shape=(128,128,3))
VGG19.layers

In [None]:
VGG19=tf.keras.applications.VGG19(weights='imagenet',include_top=False,input_shape=(128,128,3))

VGG_i,VGG_j=2,2
def VGG_loss(y_hr,y_sr,i_m=2,j_m=2):
  i,j=0,0
  accumulated_loss=0.0
  for l in VGG19.layers:
    cl_name=l.__class__.__name__
    if cl_name=='Conv2D':
      j+=1
    if cl_name=='MaxPooling2D':
      i+=1
      j=0
    if i==i_m and j==j_m:
      break
    
    y_hr=l(y_hr)
    y_sr=l(y_sr)
    if cl_name=='Conv2D':
      accumulated_loss+=tf.reduce_mean((y_hr-y_sr)**2) * 0.006

  return accumulated_loss
def VGG_loss_intuitive(y_true,y_pred):
  accumulated_loss=0.0
  for l in VGG19.layers:
    y_true=l(y_true)
    y_pred=l(y_pred)
    accumulated_loss+=tf.reduce_mean((y_true-y_pred)**2) * 0.006
  return accumulated_loss

In [None]:
generator_optimizer=tf.keras.optimizers.SGD(0.0001)
discriminator_optimizer=tf.keras.optimizers.SGD(0.0001)

adv_ratio=0.001
evaluate=['PSNR']
# # MSE
# loss_func,adv_learning = pixel_MSE,False
# # VGG2.2
# loss_func,adv_learning = lambda y_hr,y_sr:VGG_loss(y_hr,y_sr,i_m=2,j_m=2),False
# # VGG 5.4
# loss_func,adv_learning = lambda y_hr,y_sr:VGG_loss(y_hr,y_sr,i_m=5,j_m=4),False
# # SRGAN-MSE
# loss_func,adv_learning = pixel_MSE,True
# # SRGAN-VGG 2.2
# loss_func,adv_learning = lambda y_hr,y_sr:VGG_loss(y_hr,y_sr,i_m=2,j_m=2),True
# # SRGAN-VGG 5.4
# loss_func,adv_learning = lambda y_hr,y_sr:VGG_loss(y_hr,y_sr,i_m=5,j_m=4),True

#Real loss
loss_func,adv_learning = lambda y_hr,y_sr:VGG_loss(y_hr,y_sr,i_m=5,j_m=4),True

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy()
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

@tf.function()
def train_step(data,loss_func=pixel_MSE,adv_learning=True,evaluate=['PSNR'],adv_ratio=0.001):
  logs={}
  gen_loss,disc_loss=0,0

  low_resolution,high_resolution=data
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    super_resolution = SRResnet(low_resolution, training=True)
    gen_loss=loss_func(high_resolution,super_resolution)

    logs['reconstruction']=gen_loss

    if adv_learning:
      real_output = discriminator(high_resolution, training=True)
      fake_output = discriminator(super_resolution, training=True)
      
      adv_loss_g = generator_loss(fake_output) * adv_ratio
      gen_loss += adv_loss_g

      disc_loss = discriminator_loss(real_output, fake_output)
      
      logs['adv_g']=adv_loss_g
      logs['adv_d']=disc_loss

  gradients_of_generator = gen_tape.gradient(gen_loss, SRResnet.trainable_variables)
  generator_optimizer.apply_gradients(zip(gradients_of_generator, SRResnet.trainable_variables))

  if adv_learning:
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
  for x in evaluate:
    if x=='PSNR':
      logs[x]=PSNR(high_resolution,super_resolution)

  return logs

In [None]:
SRResnet=tf.keras.models.load_model('SRResNet-generato.h5')
discriminator=tf.keras.models.load_model('SRResNet-discriminator.h5')

In [None]:
for x in range(50):
  train_dataset_mapped = train_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE).batch(32)
  val_dataset_mapped = test_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE).batch(32)
  
  for image_batch in tqdm.tqdm(train_dataset_mapped, position=0, leave=True):
    logs=train_step(image_batch,loss_func,adv_learning,evaluate,adv_ratio)
    for k in logs.keys():
      print(k,':',logs[k],end='  ')
    print()

In [None]:
SRResnet.save('SRResNet-generato.h5')
discriminator.save('SRResNet-discriminator.h5')

In [None]:
train_dataset_mapped = train_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE)
for x in train_dataset_mapped.take(8):
  plt.figure()
  plt.imshow(x[0].numpy())
  plt.show()
  plt.imshow(bicubic_interpolate(x[0].numpy(),(512,512)))
  plt.show()
  plt.imshow(x[1].numpy())
  plt.show()

  pred=SRResnet(x[0].numpy().reshape(1,128,128,3))
  plt.imshow(pred[0])
  plt.show()

In [None]:
start = time.time()
train_dataset_mapped = test_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE)
for x in train_data.take(10):
  im=x['image']
  lr_image=downscale_image(x['image'])
  esr=tf.squeeze(tf.cast(model(lr_image), tf.uint8))
  

In [None]:
train_dataset_mapped = test_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE)
for x in train_data.take(10):
  imgg=x['image']
  lr_image=downscale_image(imgg)

  fig=plt.figure(figsize=(20,5))
  im=x['image'].numpy()
  plt.subplot(1,4,1)
  plt.imshow(im)
  plt.axis('off')

  plt.subplot(1,4,4)
  lr=tf.squeeze(tf.cast(model(lr_image), tf.uint8))
  plt.imshow(lr)  
  plt.axis('off')

  plt.subplot(1,4,2)
  pred=SRResnet(np.array([x['image']])/255)
  plt.imshow(pred[0].numpy())
  plt.axis('off')

  plt.subplot(1,4,3)
  bic=bicubic_interpolate(im,(im.shape[1]//4,im.shape[0]//4))
  plt.imshow(bic)
  plt.axis('off')

  plt.show()