<a href="https://colab.research.google.com/github/aju22/Enhanced-SRGAN/blob/main/Enhanced_Super_Resolution_GAN_(ESRGAN).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from keras import layers
from tensorflow import keras

# Model Architecture

![](https://esrgan.readthedocs.io/en/latest/_images/architecture.png)

In [None]:
class ConvBlock(layers.Layer):
  
  def __init__(self, out_c, activation=True, **kwargs):
    super().__init__()

    self.cnn = layers.Conv2D(out_c, **kwargs)
    self.activation = layers.LeakyReLU(alpha=0.3) if activation else layers.Activation('linear')

  def call(self, x):
    return self.activation(self.cnn(x))   

In [None]:
class UpsampleBlock(layers.Layer):
  
  def __init__(self, in_c, scale_factor=2):
    super().__init__()

    self.upsample = layers.UpSampling2D(size=scale_factor)
    self.conv = layers.Conv2D(in_c, 3, 1, 'same')
    self.activation = layers.LeakyReLU(alpha=0.3)

  def call(self, x):

    return self.activation(self.conv(self.upsample(x)))  

In [None]:
class DenseResidualBlock(layers.Layer):
   def __init__(self, in_c, channels=32, res_beta=0.2):
     super().__init__()
     self.res_beta = res_beta

     self.blocks = []

     for i in range(5):
       
       self.blocks.append(
           
           ConvBlock(channels if i <= 3 else in_c,
                     activation=True if i<=3 else False,
                     kernel_size=3, 
                     strides=1, 
                     padding='same'
                    )
           )
       

   def call(self, x):

     new_inputs = x

     for block in self.blocks:
       out = block(new_inputs)
       
       new_inputs = tf.concat([new_inputs, out], axis=3)  
       

     return self.res_beta * out + x   

In [None]:
class RRDB(layers.Layer):
  def __init__(self, in_c, res_beta=0.2):
    super().__init__()
    self.res_beta = res_beta
    self.rrdb = keras.Sequential([
       DenseResidualBlock(in_c) for _ in range(3) 
    ])

  def call(self, x):
    
    return self.rrdb(x)*self.res_beta + x  

In [None]:
class Generator(tf.keras.Model):
  
  def __init__(self, in_c=3, num_c=32, num_blocks=23):
    super().__init__()
    
    self.conv1 = layers.Conv2D(num_c, kernel_size=3,
                              strides=1, padding='same')
    
    self.residuals = keras.Sequential([
        RRDB(num_c) for _ in range(num_blocks)
    ])

    self.conv2 = layers.Conv2D(num_c, kernel_size=3,
                               strides=1, padding='same')
    
    self.upsample = keras.Sequential([
        UpsampleBlock(num_c) for _ in range(2)
    ])

    self.out = keras.Sequential([
        layers.Conv2D(num_c, 3, 1, 'same'),
        layers.LeakyReLU(0.2),
        layers.Conv2D(in_c, 3, 1, 'same')
    ])

  def call(self, x):

    initial = self.conv1(x)

    x = self.conv2(self.residuals(initial)) + initial
    x = self.upsample(x)
    
    return self.out(x)

In [None]:
class Discriminator(tf.keras.Model):
  
  def __init__(self, in_c=3, features=[64,64,128,128,256,256,512,512]):
    super().__init__()
    blocks = []

    for idx, feature in enumerate(features):

      blocks.append(
          ConvBlock(feature,
                    activation=True, 
                    kernel_size= 3,
                    strides= 1 + idx%2,
                    padding='same',
                    )
      )

    self.blocks = keras.Sequential(blocks)

    self.classifier = keras.Sequential([
        
        layers.AveragePooling2D(pool_size=(6, 6)),
        layers.Flatten(),
        layers.Dense(1024),
        layers.LeakyReLU(0.2),
        layers.Dense(1)                      
    ])  

  def call(self, x):
   
    x = self.blocks(x)

    return self.classifier(x)

In [None]:
generator = Generator()
discriminator = Discriminator()

In [None]:
res = 128
x = tf.random.uniform((5, res, res, 3))

In [None]:
gen_out = genarator(x)
gen_out.shape

TensorShape([5, 512, 512, 3])

In [None]:
disc_out = discriminator(gen_out)
disc_out.shape

TensorShape([5, 1])

# Training and Loss

In [None]:
class VGGLoss(layers.Layer):
  
  def __init__(self, input_shape=(224,224)):
    
    self.vgg = tf.keras.applications.VGG19(input_shape=input_shape)
    self.vgg.trainable = Falses
    outputs = [self.vgg.get_layer(index = idx).output for idx in range(35)]
    self.vgg = tf.keras.Model([self.vgg.input], outputs)

    self.loss = keras.losses.MeanSquaredError()

  def call(self, inputs, targets):

    vgg_input = self.vgg(input)
    vgg_target = self.vgg(targets)

    return self.loss(vgg_input, vgg_target)  

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
@tf.function
def train_step(low_res, high_res, step):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    
    gen_output = generator(low_res, training=True)

    disc_real = discriminator(high_res, training=True)
    disc_fake = discriminator(gen_output, training=True)

    disc_loss = (-(tf.math.reduce_mean(disc_real)) - (tf.math.reduce_mean(disc_fake)))

    l1_loss = 1e-2*tf.keras.losses.MeanAbsoluteError()(gen_output, high_res)
    adversarial_loss = 5e-3*(-tf.math.reduce_mean(discriminator(gen_output)))
    vgg_loss = VGGLoss()(gen_output, high_res)
    
    gen_loss = l1_loss + vgg_loss + adversarial_loss

  generator_gradients = gen_tape.gradient(gen_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))