In [None]:
!nvidia-smi

### Import Relevant Modules

In [None]:
import tensorflow as tf
import keras
import time
import matplotlib.pyplot as plt
from keras.applications.vgg19 import VGG19
from keras.applications.vgg19 import preprocess_input

### Preparing the data

In [None]:
Path = ["../input/div2k/DIV2K_train_HR/DIV2K_train_HR/*.png",
        "../input/llsrdatasets/My_datasets/Train_data/BSDS200/*.png"]
height = 224
width = 224

In [None]:
def read_file(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img,channels=3)
    img = tf.cast(img,tf.float32)
    img = tf.image.random_crop(img, [height,width,3])
    return img

In [None]:
def reshape_normalize(img):
    #img = tf.image.resize(img,[height,width],method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    img = img/127 - 1
    return img

In [None]:
def load_img(path):
    hrimg= read_file(path)
    hrimg = reshape_normalize(hrimg)
    return hrimg

In [None]:
dataset = tf.data.Dataset.list_files(Path)
dataset = dataset.map(load_img)
dataset = dataset.batch(8)

### Creating the generator

In [None]:
def dense_block(input):
    initializer = tf.random_normal_initializer(0.0,0.02)
    c1 = tf.keras.layers.Conv2D(64, kernel_size = 3, strides = 1, padding = 'same',kernel_initializer = initializer)(input)
    a1 = tf.keras.layers.LeakyReLU()(c1)
    a1 = tf.keras.layers.Concatenate()([input,a1])
    
    c2 = tf.keras.layers.Conv2D(64, kernel_size = 3, strides = 1, padding = 'same',kernel_initializer = initializer)(a1)
    a2 = tf.keras.layers.LeakyReLU()(c2)
    a2 = tf.keras.layers.Concatenate()([input,a1,a2])
    
    c3 = tf.keras.layers.Conv2D(64, kernel_size = 3, strides = 1, padding = 'same',kernel_initializer = initializer)(a2)
    a3 = tf.keras.layers.LeakyReLU()(c3)
    a3 = tf.keras.layers.Concatenate()([input,a1,a2,a3])
    
    c4 = tf.keras.layers.Conv2D(64, kernel_size = 3, strides = 1, padding = 'same',kernel_initializer = initializer)(a3)
    a4 = tf.keras.layers.LeakyReLU()(c4)
    a4 = tf.keras.layers.Concatenate()([input,a1,a2,a3,a4])
    
    c5 = tf.keras.layers.Conv2D(64, kernel_size = 3, strides = 1, padding = 'same',kernel_initializer = initializer)(a4)
    c5 = c5*0.2
    layer = tf.keras.layers.Add()([c5,input])
    return layer

In [None]:
def rddb_block(input):
    l = dense_block(input)
    l = dense_block(l)
    l = dense_block(l)
    l = l*0.2
    l = tf.keras.layers.Add()([input,l])
    return l

In [None]:
def build_generator():
    initializer = tf.random_normal_initializer(0.0,0.02)
    input = tf.keras.layers.Input(shape=[None, None, 3])
    m = tf.keras.layers.Conv2D(64,kernel_size=3, strides = 1,kernel_initializer = initializer,padding = 'same')(input)
    model = rddb_block(m)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = rddb_block(model)
    model = tf.keras.layers.Conv2D(128, kernel_size = 3, strides = 1, padding = 'same')(model)
    model = tf.keras.layers.Concatenate()([m,model])
    model = tf.nn.depth_to_space(model,2)
    model = tf.keras.layers.Conv2D(128, kernel_size = 3, strides = 1,padding = 'same')(model)
    model = tf.nn.depth_to_space(model,2)
    #model = tf.keras.layers.UpSampling2D(size = 2)(model)
    model = tf.keras.layers.Conv2D(256,kernel_size = 3, strides = 1,padding = 'same')(model)
    model = tf.keras.layers.Conv2D(3,kernel_size = 9,strides = 1,kernel_initializer = initializer,padding = 'same')(model)
    model = tf.keras.Model(inputs = input,outputs = model)
    return model
generator = build_generator()

### Creating the discriminator

In [None]:
def build_discriminator():
    initializer = tf.random_normal_initializer(0.0,0.02)
    inp = tf.keras.layers.Input(shape = [height,width,3])
    tar = tf.keras.layers.Input(shape=[height,width,3])
    input = tf.keras.layers.Concatenate()([inp,tar])
    input = tf.keras.layers.GaussianNoise(0.2)(input)
    model = tf.keras.layers.Conv2D(64, kernel_size = 3,strides = 1, kernel_initializer = initializer,padding = 'same')(input)
    model = tf.keras.layers.LeakyReLU()(model)
    
    model = tf.keras.layers.Conv2D(64, kernel_size = 3,strides = 2,kernel_initializer = initializer, padding = 'same')(model)
    model = tf.keras.layers.BatchNormalization()(model)
    model = tf.keras.layers.LeakyReLU()(model)
    
    model = tf.keras.layers.Conv2D(128, kernel_size = 3,strides = 1,kernel_initializer = initializer, padding = 'same')(model)
    #model = tf.keras.layers.GaussianNoise(0.03)(model)
    model = tf.keras.layers.BatchNormalization()(model)
    model = tf.keras.layers.LeakyReLU()(model)
    
    model = tf.keras.layers.Conv2D(128, kernel_size = 3,strides = 2,kernel_initializer = initializer, padding = 'same')(model)
    model = tf.keras.layers.BatchNormalization()(model)
    model = tf.keras.layers.LeakyReLU()(model)
    
    model = tf.keras.layers.Conv2D(256, kernel_size = 3,strides = 1,kernel_initializer = initializer, padding = 'same')(model)
    #model = tf.keras.layers.GaussianNoise(0.1)(model)
    model = tf.keras.layers.BatchNormalization()(model)
    model = tf.keras.layers.LeakyReLU()(model)
    
    model = tf.keras.layers.Conv2D(256, kernel_size = 3,strides = 2,kernel_initializer = initializer, padding = 'same')(model)
    model = tf.keras.layers.BatchNormalization()(model)
    model = tf.keras.layers.LeakyReLU()(model)
    
    model = tf.keras.layers.Conv2D(512, kernel_size = 3,strides = 2,kernel_initializer = initializer, padding = 'same')(model)
    model = tf.keras.layers.BatchNormalization()(model)
    model = tf.keras.layers.LeakyReLU()(model)
    
    model = tf.keras.layers.Dense(1024)(model)
    model = tf.keras.layers.LeakyReLU()(model)
    model = tf.keras.layers.Dense(1)(model)
    model = tf.keras.layers.Activation('sigmoid')(model)
    
    model = tf.keras.Model(inputs = [inp,tar], outputs = model)
    return model

In [None]:
discriminator = build_discriminator()

### Downloading the VGG19 classification model pre-trained on the Imagenet dataset (Used for the perceptual loss function)

In [None]:
v = VGG19(weights = 'imagenet')

### Cropping the VGG model

In [None]:
vgg = tf.keras.models.Model(inputs = v.input, outputs = v.get_layer('block4_conv4').output)

### Defining the loss functions

In [None]:
def generator_loss(disc_output_rf,disc_output_fr,gen_output,hr_image):
    lgra = (tf.keras.losses.binary_crossentropy(tf.ones_like(disc_output_fr),disc_output_fr)+
            tf.keras.losses.binary_crossentropy(tf.zeros_like(disc_output_rf),disc_output_rf))
    l1 = tf.reduce_mean(tf.abs(hr_image - gen_output))
    gen_feature = vgg(preprocess_input(hr_image))
    original_feature = vgg(preprocess_input(gen_output))
    percept_loss = tf.reduce_mean(tf.losses.mean_squared_error(gen_feature,original_feature))
    total_loss = percept_loss + (5e-3)*lgra + (1e-2)*l1
    return total_loss , l1, percept_loss,lgra

def discriminator_loss(disc_output_rf,disc_output_fr):
    ldra = (tf.keras.losses.binary_crossentropy(tf.ones_like(disc_output_rf),disc_output_rf)+
            tf.keras.losses.binary_crossentropy(tf.zeros_like(disc_output_fr),disc_output_fr))
    return ldra

### Defining the optimizers for the models

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1.25e-5,beta_1=0.9)
discriminator_optimizer = tf.keras.optimizers.Adam(5e-7,beta_1=0.9)

In [None]:
#generator = tf.keras.models.load_model("../input/image-super-resolution-gan/generator24")
#discriminator = tf.keras.models.load_model("../input/image-super-resolution-gan/discriminator24")

### Training the models

In [None]:
@tf.function
def train_step(target,input_image,epoch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image,training = True)
        psnr = tf.image.psnr(gen_output, target, max_val=1.0)
        disc_output_fr = discriminator([gen_output,target],training = True)
        disc_output_rf = discriminator([target,gen_output],training = True)
        gen_total_loss, l1, percept_loss ,lgra= generator_loss(disc_output_rf,disc_output_fr,gen_output,target)
        disc_loss = discriminator_loss(disc_output_rf,disc_output_fr)
    generator_gradients = gen_tape.gradient(gen_total_loss,generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(generator_gradients,generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,discriminator.trainable_variables))
    return gen_total_loss , l1, percept_loss,disc_loss,psnr,lgra

In [None]:
def fit(dataset,epochs):
    for epoch in range(epochs):
        sum_gloss = 0
        sum_dloss = 0
        sum_psnr = 0
        sum_l1 = 0
        sum_lgra = 0
        sum_perceptl = 0
        for hr_img in dataset.take(1):
            lr_img = tf.image.resize(hr_img,[width//4,height//4])
            pre_img = generator(lr_img, training = True)
            plt.figure(figsize = (32,32))
            psnr = tf.image.psnr(pre_img, hr_img, max_val=1.0)
            tf.print("PSNR = ",psnr)
            display_list= [lr_img[0],hr_img[0],pre_img[0]]
            title = ["4x Downscaled","Original HR","Upscaled"]
            plt.subplot(1,3,0+1)
            plt.title(title[0],fontsize = 25)
            plt.imshow(display_list[0]*0.5+0.5)
            plt.subplot(1,3,1+1)
            plt.title(title[1],fontsize = 25)
            plt.imshow(display_list[1]*0.5+0.5)
            plt.subplot(1,3,2+1)
            plt.title(title[2],fontsize = 25)
            plt.imshow(display_list[2]*0.5+0.5)
            plt.show()
        print("Epoch : ",epoch)
        for n,hr_image in dataset.enumerate():
            lr_image = tf.image.resize(hr_image,[width//4,height//4],method = 'bicubic')
            if n%20==0:
                gen_output = generator(lr_image,training = False)
                display_list= [lr_image[0],hr_image[0],gen_output[0]]
                title = ["4x Downscaled","Original HR","Upscaled"]
                plt.figure(figsize = (32,32))
                plt.subplot(1,3,0+1)
                plt.title(title[0],fontsize = 25)
                plt.imshow(display_list[0]*0.5+0.5)
                plt.subplot(1,3,1+1)
                plt.title(title[1],fontsize = 25)
                plt.imshow(display_list[1]*0.5+0.5)
                plt.subplot(1,3,2+1)
                plt.title(title[2],fontsize = 25)
                plt.imshow(display_list[2]*0.5+0.5)
                plt.show()
            gen_total_loss, l1, percept_loss, disc_loss, psnr,lgra= train_step(hr_image,lr_image,epoch)
            sum_gloss += gen_total_loss
            sum_dloss += disc_loss
            sum_psnr += psnr
            sum_l1 += l1
            sum_lgra += lgra
            sum_perceptl += percept_loss
        print("Perceptual :")
        print(tf.reduce_mean(sum_perceptl))
        print("PSNR :")
        print(tf.reduce_mean(sum_psnr))
        print("L1 :")
        print(sum_l1)
        print("Gloss :")
        print(tf.reduce_mean(sum_gloss))
        print("Dloss :")
        print(tf.reduce_mean(sum_dloss))
        print("Lgra :")
        print(tf.reduce_mean(sum_lgra))
        print()
        generator.save(("generator"+str(epoch)))
        discriminator.save(("discriminator"+str(epoch)))
        

In [None]:
fit(dataset,25)

In [None]:
#generator.save("generator final")

In [None]:
#discriminator.save("discriminator final")