# ESR-GAN (Enhanced Super Resolution Gan)

This is an attempt to re-implement the paper ESR-GAN

Paper: https://arxiv.org/pdf/1809.00219.pdf

Other Resources: 
* https://www.youtube.com/watch?v=qwYOlXRdADI
* https://github.com/xinntao/ESRGAN

In [35]:
import tensorflow as tf

In [68]:
class ResidualDenseBlock(tf.keras.layers.Layer):
    def __init__(self, filters = 32, filters_l = 64, kernel_size = (3, 3), strides = (1, 1), 
                 padding = 'same', beta = 0.2, **kwargs):
        super().__init__(**kwargs)
        self.beta = beta
        
        self.conv0 = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, 
                                            padding = padding)
        self.lrelu0 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
        self.concat1 = tf.keras.layers.Concatenate()
        self.conv1 = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, 
                                            padding = padding)
        self.lrelu1 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
        self.concat2 = tf.keras.layers.Concatenate()
        self.conv2 = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, 
                                            padding = padding)
        self.lrelu2 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
        self.concat3 = tf.keras.layers.Concatenate()
        self.conv3 = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, 
                                            padding = padding)
        self.lrelu3 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
        self.concat4 = tf.keras.layers.Concatenate()
        self.conv4 = tf.keras.layers.Conv2D(filters = filters_l, kernel_size = kernel_size, strides = strides, 
                                            padding = padding)
    
    def call(self, inputs):
        
        x0 = self.lrelu0(self.conv0(inputs))
        x1 = self.lrelu1(self.conv1(self.concat1([inputs, x0])))
        x2 = self.lrelu2(self.conv2(self.concat2([inputs, x0, x1])))
        x3 = self.lrelu3(self.conv3(self.concat3([inputs, x0, x1, x2])))
        x4 = self.conv4(self.concat4([inputs, x0, x1, x2, x3]))
        
        return x4 * self.beta + inputs

In [69]:
class RRDB(tf.keras.layers.Layer):
    def __init__(self, filters = 32, filters_l = 64, beta = 0.2, **kwargs):
        super().__init__(**kwargs)
        
        self.RDB1 = ResidualDenseBlock(filters = filters, filters_l = filters_l)
        self.RDB2 = ResidualDenseBlock(filters = filters, filters_l = filters_l)
        self.RDB3 = ResidualDenseBlock(filters = filters, filters_l = filters_l)
        self.beta = beta
        
    def call(self, inputs):
        out = self.RDB1(inputs)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * self.beta + inputs

In [79]:
def generator(inp_shape, filters = 32, filters_l = 64, num_blocks = 23):
    inp = tf.keras.layers.Input(shape = inp_shape, dtype = tf.float32, name = 'generator_input')
    
    t = tf.keras.layers.Conv2D(filters = filters_l, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(inp)
    
    x = t
    for _ in range(num_blocks):
        x = RRDB(filters = filters, filters_l = filters_l)(x)
    
    x = tf.keras.layers.Conv2D(filters = filters_l, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
    x = tf.keras.layers.Add()([x, t])
    
    for _ in range(2):
        x = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'nearest')(x)
        x = tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
    x = tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    x = tf.keras.layers.Conv2D(filters = inp_shape[-1], kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
    
    return tf.keras.models.Model(inp, x, name = 'Generator')

In [1]:
# gen = generator((64, 64, 3))
# gen.summary()

In [84]:
def discriminator(inp_shape):
    inp = tf.keras.layers.Input(shape = inp_shape, dtype = tf.float32, name = f'discriminator_input_{inp_shape}')
    
    x = tf.keras.layers.Conv2D(filters = 64, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(inp)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 64, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 128, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 128, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 256, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 256, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 512, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 512, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(units = 1024)(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    x = tf.keras.layers.Dense(units = 1)(x)
    #x = tf.keras.layers.Activation('sigmoid')(x)
    
    return tf.keras.models.Model(inp, x, name = 'Discriminator')

In [2]:
# disc = discriminator(gen.output_shape[1:])
# disc.summary()

In [126]:
def get_vgg_model(out_layer = 'block5_conv4', act_off = True):
    vgg_model = tf.keras.applications.VGG19(include_top = False, weights = 'imagenet')
    vgg_model.trainable = False
    
    out = vgg_model.get_layer(out_layer).output

    
    model = tf.keras.models.Model(vgg_model.inputs, out, name = 'vgg_model')
    if act_off:
        model.layers[-1].activation = None
        return model
    return model
vgg_model = get_vgg_model(act_off = True)
vgg_model.trainable = False

In [131]:
def vgg_loss(real, gen, rescale_factor = 1/12.75):
    # real and gen_img, pixel range [-1, 1]
    preprocessed_real = tf.keras.applications.vgg19.preprocess_input((real + 1)*127.5) * rescale_factor
    preprocessed_gen = tf.keras.applications.vgg19.preprocess_input((gen + 1)*127.5) * rescale_factor
    
    vgg_real = vgg_model(preprocessed_real)
    vgg_gen = vgg_model(preprocessed_gen)
    
    return tf.math.reduce_mean(tf.math.square(vgg_real - vgg_gen))

def l1_loss(real, gen):
    return tf.math.reduce_mean(tf.abs(real - gen))

bce = tf.keras.losses.BinaryCrossentropy(from_logits = True)
def relativistic_generator_loss(disc_real_out, disc_gen_out):
    real_r = tf.nn.sigmoid(disc_real_out - tf.math.reduce_mean(disc_gen_out, axis = 0, keepdims = True))
    real_loss = bce(tf.zeros_like(real_r), real_r)
    
    gen_r = tf.nn.sigmoid(disc_gen_out - tf.math.reduce_mean(disc_real_out, axis = 0, keepdims = True))
    gen_loss = bce(tf.ones_like(gen_r), gen_r)
    return real_loss + gen_loss
    
def relativistic_discriminator_loss(disc_real_out, disc_gen_out):
    real_r = tf.nn.sigmoid(disc_real_out - tf.math.reduce_mean(disc_gen_out, axis = 0, keepdims = True))
    real_loss = bce(tf.ones_like(real_r), real_r)
    
    gen_r = tf.nn.sigmoid(disc_gen_out - tf.math.reduce_mean(disc_real_out, axis = 0, keepdims = True))
    gen_loss = bce(tf.zeros_like(gen_r), gen_r)
    return real_loss + gen_loss

In [130]:
gen_opt = tf.keras.optimizers.Adam(learning_rate = 1e-4, beta_1 = 0.9, beta_2 = 0.999)
disc_opt = tf.keras.optimizers.Adam(learning_rate = 1e-4, beta_1 = 0.9, beta_2 = 0.999)

In [135]:
@tf.function
def train_step(lr_img, hr_img):
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_out = gen(lr_img, training = True)
        
        disc_real_out = disc(hr_img, training = True)
        disc_gen_out = disc(gen_out, training = True)
        
        disc_loss = relativistic_discriminator_loss(disc_real_out, disc_gen_out)
        gen_loss = (relativistic_generator_loss(disc_real_out, disc_gen_out) * 5e-3 + n * l1_loss(hr_img, gen_out) * 1e-2 
                    + vgg_loss(hr_img, gen_out))
        
    gen_grads = gen_tape.gradient(gen_loss, gen.trainable_variables)
    gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables))
    
    disc_grads = disc_tape.gradient(disc_loss, dics.trainable_variables)
    disc_opt.apply_gradients(zip(disc_grads, disc.trainable_variables))
    
    return gen_loss, disc_loss

In [136]:
def train(data, epochs = 1):
    for e in range(epochs):
        print(f'Epoch: {e} Starts')
        for lr_img, hr_img in data:
            gen_loss, disc_loss = train_step(lr_img, hr_img)
            print('.', end='')
            
        print(f'\nGenerator Loss: {gen_loss} \t Discriminator Loss: {disc_loss}')
        print(f'Epoch: {e} Ends\n\n\n')