# SR-Gan (Super Resolution Gan)

This is an attempt to re-implement the paper SR-GAN

Paper: https://arxiv.org/pdf/1609.04802.pdf

Other Resources: 
* https://www.youtube.com/watch?v=fx-rXMcKlQc
* https://arxiv.org/pdf/1609.05158.pdf

In [1]:
import numpy as np
import tensorflow as tf

In [2]:
lr_img_size = 64
hr_img_size = lr_img_size * 4
channels = 3

In [3]:
class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self, filters = 64, kernel_size = (3, 3), strides = (1, 1), padding = 'same', **kwargs):
        super().__init__(**kwargs)
        self.conv1 = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, 
                                            padding = padding)
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
        self.prelu = tf.keras.layers.PReLU()
        
        self.conv2 = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, 
                                            padding = padding)
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
        
    def call(self, inputs):
        x = self.prelu(self.batch_norm1(self.conv1(inputs)))
        x = self.batch_norm2(self.conv2(x))
        return tf.add(x, inputs)

In [4]:
class PixelShufflerx2(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        pass
    
    def call(self, inputs):
        inp_shape = inputs.shape
        
        out = tf.reshape(inputs, (-1, inp_shape[1], inp_shape[2], inp_shape[3]//4, 4))
        out = tf.reshape(out, (-1, inp_shape[1], inp_shape[2], inp_shape[3]//4, 2, 2))
        out = tf.transpose(out, perm = [0, 1, 4, 2, 5, 3])
        out = tf.reshape(out, (-1, inp_shape[1]*2, inp_shape[2]*2, inp_shape[3]//4))
        return out

In [5]:
def generator(img_size = lr_img_size, channels = channels, num_res_block = 16):
    inp = tf.keras.layers.Input(shape = (img_size, img_size, channels), dtype = tf.float32, 
                                name = f'Generator_input_{img_size}x{img_size}x{channels}')
    
    x = tf.keras.layers.Conv2D(filters = 64, kernel_size = (9, 9), strides = (1, 1), padding = 'same')(inp)
    x = tf.keras.layers.PReLU()(x)
    
    r = x
    for _ in range(num_res_block):
        r = ResidualBlock()(r)
        
    out = tf.keras.layers.Conv2D(filters = 64, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(r)
    out = tf.keras.layers.BatchNormalization()(out)
    out = tf.keras.layers.add([out, x])
    
    for _ in range(2):
        out = tf.keras.layers.Conv2D(filters = 256, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(out)
        out = PixelShufflerx2()(out)
        out = tf.keras.layers.PReLU()(out)
        
    out = tf.keras.layers.Conv2D(filters = 3, kernel_size = (9, 9), strides = (1, 1), padding = 'same')(out)
    out = tf.keras.layers.Activation('tanh')(out)
    return tf.keras.models.Model(inp, out, name = 'Generator')

In [6]:
gen = generator()
# gen.summary()

In [7]:
def discriminator(img_size = hr_img_size, channels = channels):
    inp = tf.keras.layers.Input(shape = (img_size, img_size, channels), dtype = tf.float32, 
                                name = f'discriminator_input_{img_size}x{img_size}x{channels}')
    
    x = tf.keras.layers.Conv2D(filters = 64, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(inp)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 64, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 128, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 128, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 256, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 256, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 512, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 512, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(units = 1024)(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    x = tf.keras.layers.Dense(units = 1)(x)
    #x = tf.keras.layers.Activation('sigmoid')(x)
    
    return tf.keras.models.Model(inp, x, name = 'Discriminator')

In [8]:
disc = discriminator()
# disc.summary()

In [10]:
bce = tf.keras.losses.BinaryCrossentropy(from_logits = True)

def get_vgg_model(out_layer = 'block5_conv4'):
    vgg_model = tf.keras.applications.VGG19(include_top = False, weights = 'imagenet')
    vgg_model.trainable = False
    
    outputs = vgg_model.get_layer(out_layer).output
    return tf.keras.models.Model(vgg_model.inputs, outputs, name = 'VGG_model')
vgg_model = get_vgg_model()
vgg_model.trainable = False

def mse_loss(real, gen_img):
    return tf.math.reduce_mean(tf.math.square(real - gen_img))

def discriminator_loss(disc_real_out, disc_gen_out):
    real_loss = bce(tf.ones_like(disc_real_out), disc_real_out)
    gen_loss = bce(tf.zeros_like(disc_gen_out), disc_gen_out)
    return real_loss + gen_loss

def generator_loss(disc_gen_out):
    return bce(tf.ones_like(disc_gen_out), disc_gen_out)

def vgg_loss(real, gen_img, rescale_factor = 1/12.75):
    # real and gen_img, pixel range [-1, 1]
    real_preprocessed = tf.keras.applications.vgg19.preprocess_input((real + 1)*127.5) * rescale_factor
    gen_preprocessed = tf.keras.applications.vgg19.preprocess_input((gen_img + 1)*127.5) * rescale_factor
    
    real_out = vgg_model(real_preprocessed)
    gen_out = vgg_model(gen_preprocessed)
    
    return mse(real_out, gen_out)

In [15]:
gen_opt = tf.keras.optimizers.Adam(learning_rate = 1e-4, beta_1 = 0.9)
disc_opt = tf.keras.optimizers.Adam(learning_rate = 1e-4, beta_1 = 0.9)

In [16]:
@tf.function
def train_step(lr_img, hr_img):
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_out = gen(lr_img, training = True)
        
        disc_real_out = disc(hr_img, training = True)
        disc_gen_out = disc(gen_out, training = True)
        
        disc_loss = discriminator_loss(disc_real_out, disc_gen_out)
        gen_loss = 6e-3 * vgg_loss(hr_img, gen_out) + 3e-3 * generator_loss(disc_gen_out)
        
    gen_grads = gen_tape.gradient(gen_loss, gen.trainable_variables)
    gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables))
    
    disc_grads = disc_tape.gradient(disc_loss, disc.trainable_variables)
    disc_opt.apply_gradients(zip(disc_grads, disc.trainable_variables))
    
    return gen_loss, disc_loss

In [17]:
def train(data, epochs = 1):
    
    for e in range(epochs):
        print(f'Epoch: {e} Starts')
        for lr_img, hr_img in data:
            gen_loss, disc_loss = train_step(lr_img, hr_img)
            print('.', end='')
        
        print(f'\nGenerator Loss: {gen_loss} \t Discriminator Loss: {disc_loss}')
        print(f'Epoch: {e} Ends\n\n\n')