### Dependencies

In [None]:
import os
import cv2
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [None]:
from keras import layers, Model
from keras.models import Sequential
from keras.layers import Conv2D, PReLU, BatchNormalization, Flatten, UpSampling2D, LeakyReLU, Dense, Input, add
from keras.applications import VGG19
from tqdm import tqdm

### Generator Block

In [None]:
#Residual Block for Generator
def residual(ip):
    res_model = Conv2D(64, (3, 3), padding = "same")(ip)
    res_model = BatchNormalization(momentum = 0.5)(res_model)
    res_model = PReLU(shared_axes = [1, 2])(res_model)
    
    res_model = Conv2D(64, (3, 3), padding = "same")(res_model)
    res_model = BatchNormalization(momentum = 0.5)(res_model)
    return add([ip, res_model])

#Upscaling Block for generator
def upscale(ip):
    up_model = Conv2D(256, (3, 3), padding = "same")(ip)
    up_model = UpSampling2D(size = 2)(up_model)
    up_model = PReLU(shared_axes = [1, 2])(up_model)
    
    return up_model

### Generator Model

In [None]:
# GENERATOR MODEL
def generator_model(generator_ip, num_res_block):
    layers = Conv2D(64, (9, 9), padding = "same")(generator_ip)
    layers = PReLU(shared_axes = [1, 2])(layers)

    temp = layers

    for i in range(num_res_block):
        layers = residual(layers)
    
    layers = Conv2D(64, (3, 3), padding = "same")(layers)
    layers = BatchNormalization(momentum = 0.5)(layers)
    layers = add([layers, temp])

    layers = upscale(layers)
    layers = upscale(layers)

    op = Conv2D(3, (9, 9), padding = "same")(layers)
    return Model(inputs = generator_ip, outputs = op)

### Discriminator Block

In [None]:
def disc_block(ip, filters, strides=1, bn=True):
    disc_model = Conv2D(filters, (3, 3), strides = strides, padding = "same")(ip)
    if bn: 
        disc_model = BatchNormalization(momentum = 0.8)(disc_model)
    
    disc_model = LeakyReLU(alpha = 0.2)(disc_model)
    return disc_model

### Discriminator Model

In [None]:
def discriminator_model(disc_ip):
    df = 64
    d1 = disc_block(disc_ip, df, bn=False)
    d2 = disc_block(d1, df, strides=2)
    d3 = disc_block(d2, df*2)
    d4 = disc_block(d3, df*2, strides=2)
    d5 = disc_block(d4, df*4)
    d6 = disc_block(d5, df*4, strides=2)
    d7 = disc_block(d6, df*8)
    d8 = disc_block(d7, df*8, strides=2)
    d8_5 = Flatten()(d8)
    d9 = Dense(df*16)(d8_5)
    d10 = LeakyReLU(alpha = 0.2)(d9)

    validity = Dense(1, activation='sigmoid')(d10)

    return Model(disc_ip, validity)

### VGG19

In [None]:
def build_vgg(hr_shape):
    vgg = VGG19(weights="imagenet", include_top=False,input_shape=hr_shape)
    return Model(inputs=vgg.inputs, outputs=vgg.layers[10].output)

### Combined Model

In [None]:
def comb_model(gen, disc, vgg, lr_ip, hr_ip):
    gen_img = gen(lr_ip)
    gen_features = vgg(gen_img)
    disc.trainable = False
    validity = disc(gen_img)
    return Model(inputs=[lr_ip, hr_ip], outputs=[validity, gen_features]) 

### Loading Data

In [None]:
n = 5000
#Loading Low Resolution (Downscaled) images for training
lr_list = os.listdir("data/lr_images/")[:n]
lr_images = []
for img in lr_list:
    img_lr = cv2.imread("data/lr_images/" + img)
    img_lr = cv2.cvtColor(img_lr, cv2.COLOR_BGR2RGB)
    lr_images.append(img_lr)

#Loading High Resolution (Downscaled) images for training
hr_list = os.listdir("data/hr_images/")[:n]
hr_images = []
for img in hr_list: 
    img_hr = cv2.imread("data/hr_images/" + img)
    img_hr = cv2.cvtColor(img_hr, cv2.COLOR_BGR2RGB)
    hr_images.append(img_hr)

lr_images = np.array(lr_images)
hr_images = np.array(hr_images)

#### Sanity Check on Imported Data

In [None]:
img_no = random.randint(0, len(lr_images) - 1)
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.imshow(np.reshape(lr_images[img_no], (64, 64, 3)))
plt.title('Low Resolution Image (64x64)')
plt.subplot(122)
plt.imshow(np.reshape(hr_images[img_no], (256, 256, 3)))
plt.title('High Resolution Image (256x256)')
plt.show()

### Train-Test Split

In [None]:
#Scale Values
lr_images = lr_images / 255 # type: ignore
hr_images = hr_images / 255 # type: ignore


In [None]:
lr_train, lr_test, hr_train, hr_test = train_test_split(lr_images, hr_images, test_size=0.33, random_state=42)

In [None]:
#Getting shape for LR and HR images to pass to generator model
hr_shape = (hr_train.shape[1], hr_train.shape[2], hr_train.shape[3]) # type: ignore
lr_shape = (lr_train.shape[1], lr_train.shape[2], lr_train.shape[3]) # type: ignore

lr_ip = Input(shape = lr_shape)
hr_ip = Input(shape = hr_shape)


In [None]:
generator = generator_model(lr_ip, num_res_block=16)
generator.summary()

In [None]:
discriminator = discriminator_model(hr_ip)
discriminator.compile(loss="binary_crossentropy", optimizer="adam", metrics=['accuracy'])
discriminator.summary()

In [None]:
vgg = build_vgg((256, 256, 3))
print(vgg.summary())
vgg.trainable = False

In [None]:
gan_model = comb_model(generator, discriminator, vgg, lr_ip, hr_ip)


### 2 losses : Adversarial loss and Content (VGG) loss
* **Adversarial Loss**: is defined based on the probabilities of the discriminator over all training samples use binary_crossentropy

* **Content Loss**: feature map obtained by the j-th convolution (after activation) before the i-th maxpooling layer within the VGG19 network MSE between the feature representations of a reconstructed image and the reference image. 

In [None]:
gan_model.compile(loss=["binary_crossentropy", "mse"], loss_weights=[1e-3, 1], optimizer="adam")
gan_model.summary()

### Creating batches of images to be fetched during training

In [None]:
batch_size = 1
train_lr_batches = []
train_hr_batches = []
for i in range(int(hr_train.shape[0] / batch_size)):    # type: ignore
    start_index = i * batch_size
    end_index = start_index + batch_size
    train_hr_batches.append(hr_train[start_index:end_index])
    train_lr_batches.append(lr_train[start_index:end_index])
    

### Train and save model

In [None]:
epochs = 1

#Train over epochs:
for e in range(epochs): 
    #Assign label 0 to generated (fake) images
    fake_label = np.zeros((batch_size, 1))
    #Assign label 1 to real images
    real_label = np.ones((batch_size, 1))

    #Lists to populate with generator and discriminator losses
    g_losses, d_losses = [], []

    #enumerate training over batches
    for b in tqdm(range(len(train_hr_batches))):
        #Fetch a batch of Low Resolution images for training
        lr_imgs = train_lr_batches[b]
        #Fetch a batch of High Resolution images for training
        hr_imgs = train_hr_batches[b]

        fake_imgs = generator.predict_on_batch(lr_imgs)

        #train the discriminator on fake and real HR images to classify between real and fake HR images. 
        discriminator.trainable = True
        d_loss_gen = discriminator.train_on_batch(fake_imgs, fake_label)
        d_loss_real = discriminator.train_on_batch(hr_imgs, real_label)

        #Set discriminatornon-trainable to train the generator
        discriminator.trainable = False
        #Average discriminator loss
        d_loss = 0.5 * np.add(d_loss_gen, d_loss_real)

        #Extract VGG Features to calculate loss
        img_features = vgg.predict(hr_imgs)

        #Train the generator
        g_loss, _, _ = gan_model.train_on_batch([lr_imgs, hr_imgs], [real_label, img_features])

        d_losses.append(d_loss)
        g_losses.append(g_loss)
    
    g_losses = np.array(g_losses)
    d_losses = np.array(d_losses)

    #Average Losses for Generator and Discriminator
    g_loss = np.sum(g_losses, axis=0) / len(g_losses)
    d_loss = np.sum(d_losses, axis=0) / len(d_losses)

    print(f"Epoch: {e + 1}, Generator Loss: {g_loss}, Discriminator Loss: {d_loss}")

    # if (e + 1) % 5 == 0: 
    generator.save("gen_e_" + "_" +str(e + 1) + ".h5")