In [1]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

from tensorflow import keras
from keras.models import Sequential
from keras import layers, Model
from google.colab import drive

In [2]:
db_dir = "/content/drive/MyDrive/Colab/dataset"

In [4]:
def get_train_data():
    X = []
    Y = []
    for x in os.listdir(db_dir + "/train/lr"):
        img_x = cv2.imread(db_dir + "/train/lr/" + x, cv2.IMREAD_COLOR)
        X.append(img_x)

    X = np.array(X) / 255

    for y in os.listdir(db_dir + "/train/hr"):
        img_y = cv2.imread(db_dir + "/train/hr/" + y, cv2.IMREAD_COLOR)
        Y.append(img_y)
    Y = np.array(Y) / 255

    return X,Y


train_lr, train_hr = get_train_data()

In [None]:
def get_test_data():
    X = []
    Y = []
    for x in os.listdir(db_dir + "/test/lr"):
        img_x = cv2.imread(db_dir + "/test/lr/" + x, cv2.IMREAD_COLOR)
        X.append(img_x)

    X = np.array(X) / 255

    for y in os.listdir(db_dir + "/test/hr"):
        img_y = cv2.imread(db_dir + "/test/hr/" + y, cv2.IMREAD_COLOR)
        Y.append(img_y)

    Y = np.array(Y) / 255

    return X,Y


test_lr, test_hr = get_test_data()

In [None]:
num_res_block = 16
hr_shape = (train_hr.shape[1], train_hr.shape[2], train_hr.shape[3])
lr_shape = (train_lr.shape[1], train_lr.shape[2], train_lr.shape[3])


Conv2D = layers.Conv2D
BatchNormalization = layers.BatchNormalization
PReLU = layers.PReLU
UpSampling2D = layers.UpSampling2D
Dense = layers.Dense
add = layers.add
LeakyReLU = layers.LeakyReLU
Input = layers.Input
Flatten = layers.Flatten


lr_ip = Input(shape=lr_shape)
hr_ip = Input(shape=hr_shape)

print(lr_ip), print(hr_ip)

## Generator

In [None]:
def res_block(ip):

    res_model = Conv2D(64, (3,3), padding = "same")(ip)
    res_model = BatchNormalization(momentum = 0.5)(res_model)
    res_model = PReLU(shared_axes = [1,2])(res_model)

    res_model = Conv2D(64, (3,3), padding = "same")(res_model)
    res_model = BatchNormalization(momentum = 0.5)(res_model)

    return add([ip,res_model])

def upscale_block(ip):

    up_model = Conv2D(256, (3,3), padding="same")(ip)
    up_model = UpSampling2D( size = 2 )(up_model)
    up_model = PReLU(shared_axes=[1,2])(up_model)

    return up_model

## Discriminator

In [None]:
def discriminator_block(ip, filters, strides=1, bn=True):

    disc_model = Conv2D(filters, (3,3), strides = strides, padding="same")(ip)
    disc_model = LeakyReLU( alpha=0.2 )(disc_model)
    if bn:
        disc_model = BatchNormalization( momentum=0.8 )(disc_model)


    return disc_model

## Generator model

In [None]:
def create_gen(gen_ip):
    layers = Conv2D(64, (9,9), padding="same")(gen_ip)
    layers = PReLU(shared_axes=[1,2])(layers)

    temp = layers

    for i in range(num_res_block):
        layers = res_block(layers)

    layers = Conv2D(64, (3,3), padding="same")(layers)
    layers = BatchNormalization(momentum=0.5)(layers)
    layers = add([layers,temp])

    layers = upscale_block(layers)
    layers = upscale_block(layers)

    op = Conv2D(3, (9,9), padding="same")(layers)

    return Model(inputs=gen_ip, outputs=op)

## Discriminator model

In [None]:
def create_disc(disc_ip):

    df = 64

    d1 = discriminator_block(disc_ip, df, bn=False)
    d2 = discriminator_block(d1, df, strides=2)
    d3 = discriminator_block(d2, df*2)
    d4 = discriminator_block(d3, df*2, strides=2)
    d5 = discriminator_block(d4, df*4)
    d6 = discriminator_block(d5, df*4, strides=2)
    d7 = discriminator_block(d6, df*8)
    d8 = discriminator_block(d7, df*8, strides=2)

    d8_5 = Flatten()(d8)
    d9 = Dense(df*16)(d8_5)
    d10 = LeakyReLU(alpha=0.2)(d9)
    validity = Dense(1, activation='sigmoid')(d10)

    return Model(disc_ip, validity)

## VGG19

In [None]:
from keras.applications import VGG19

def build_vgg():
    vgg = VGG19(weights="imagenet",input_shape =(128,128,3),include_top=False)
    outputs = [vgg.layers[9].output]
    #plot_model(vgg, show_shapes = True)

    return Model(vgg.input, outputs)

## Combined Model

In [None]:
def create_comb(gen_model, disc_model, vgg, lr_ip, hr_ip):
    gen_img = gen_model(lr_ip)

    print(gen_img)

    gen_features = vgg(gen_img)

    disc_model.trainable = False
    print(gen_img.shape)
    validity = disc_model(gen_img)

    return Model(inputs=[lr_ip, hr_ip], outputs=[validity, gen_features])

In [None]:
generator = create_gen(lr_ip)
discriminator = create_disc(hr_ip)
discriminator.compile(loss="binary_crossentropy", optimizer="adam", metrics=['accuracy'])

vgg = build_vgg()
vgg.trainable = False


gan_model = create_comb(generator, discriminator, vgg, lr_ip, hr_ip)
gan_model.compile(loss=["binary_crossentropy","mse"], loss_weights=[1e-3, 1], optimizer="adam")

In [None]:
generator.summary()

In [None]:
discriminator.summary()

In [None]:
gan_model.summary()

In [None]:
batch_size = 20
train_lr_batches = []
train_hr_batches = []
for it in range(int(train_hr.shape[0] / batch_size)):
    start_idx = it * batch_size
    end_idx = start_idx + batch_size
    train_hr_batches.append(train_hr[start_idx:end_idx])
    train_lr_batches.append(train_lr[start_idx:end_idx])

In [None]:
epochs = 20
for e in range(epochs):
    gen_label = np.zeros((batch_size, 1))
    real_label = np.ones((batch_size, 1))
    g_losses = []
    d_losses = []
    for b in range(len(train_hr_batches)):
        lr_imgs = train_lr_batches[b]
        hr_imgs = train_hr_batches[b]
        gen_imgs = generator.predict_on_batch(lr_imgs)
        discriminator.trainable = True
        d_loss_gen = discriminator.train_on_batch(gen_imgs, gen_label)
        d_loss_real = discriminator.train_on_batch(hr_imgs, real_label)
        discriminator.trainable = False
        d_loss = 0.5 * np.add(d_loss_gen, d_loss_real)
        image_features = vgg.predict(hr_imgs)
        g_loss, _, _ = gan_model.train_on_batch([lr_imgs, hr_imgs], [real_label, image_features])
        d_losses.append(d_loss)
        g_losses.append(g_loss)
    g_losses = np.array(g_losses)
    d_losses = np.array(d_losses)
    g_loss = np.sum(g_losses, axis=0) / len(g_losses)
    d_loss = np.sum(d_losses, axis=0) / len(d_losses)

    print("epoch:", e+1 ,"g_loss:", g_loss, "d_loss:", d_loss)

In [None]:
result = generator.predict_on_batch(train_lr_batches[0])
result[0] = cv2.cvtColor(result[0], cv2.COLOR_BGR2RGB)
plt.imshow(result[0])
plt.show()

resized = cv2.resize(train_lr_batches[0][0], (128, 128))
resized = cv2.cvtColor(np.float32(resized), cv2.COLOR_BGR2RGB)
plt.imshow(resized)
plt.show()

train_lr_batches[0][0] = cv2.cvtColor(np.float32(train_lr_batches[0][0]), cv2.COLOR_BGR2RGB)
plt.imshow(train_lr_batches[0][0])
plt.show()

train_hr_batches[0][0] = cv2.cvtColor(np.float32(train_hr_batches[0][0]), cv2.COLOR_BGR2RGB)
plt.imshow(train_hr_batches[0][0])
plt.show()

In [None]:
test_label = np.ones((len(test_lr),1))
test_imgs_features = vgg.predict(test_hr)
test_res = gan_model.evaluate([test_lr, test_hr], [test_label, test_imgs_features],10 )
print("Test loss & test Accuracy:", test_res)

In [None]:
def psnr(ref, target):
    error = ref.astype(np.float32) - target.astype(np.float32)
    mse = np.mean(error**2)
    return 10 * np.log10((255**2)/mse)

In [None]:
print(psnr(result[0] * 255, train_hr_batches[0][0] * 255))
print(psnr(resized * 255, train_hr_batches[0][0] * 255))