In [1]:
import tensorflow as tf
import numpy as np




In [2]:
# residual block for iperator
def residual_block(ip):
    residual_model = tf.keras.layers.Conv2D(64, (3,3), padding='same')(ip)
    residual_model = tf.keras.layers.BatchNormalization(momentum= 0.5)(residual_model)
    residual_model = tf.keras.layers.PReLU(shared_axes=[1,2])(residual_model)
    residual_model = tf.keras.layers.Conv2D(64, (3,3), padding='same')(residual_model)
    residual_model = tf.keras.layers.BatchNormalization(momentum= 0.5)(residual_model)
    
    return tf.keras.layers.Add([ip, residual_model])
    

In [3]:
def pixel_shuffler_block(ip):
    ps_model = tf.keras.layers.Conv2d(256, (3,3), padding='same')(ip)
    ps_model = tf.keras.layers.Upsampling2D(size = 2)(ps_model)
    ps_model = tf.keras.layers.PReLU(shared_axes=[1,2])(ps_model)
    
    return ps_model

In [4]:
def create_generator(generator, no_res_blocks = 16, no_pixel_shuffler_blocks = 2):
    model = tf.keras.layers.Conv2D(64, (9,9), padding='same')(generator)
    model = tf.keras.layers.PReLU(shared_axes=[1,2])(model)
    
    temp = model
    
    for i in range(no_res_blocks):
        model = residual_block(model)
    
    model = tf.keras.layers.Conv2D(64, (3,3), padding='same')(model)
    model = tf.keras.layers.BatchNormalization(momentum= 0.5)(model)
    model = tf.keras.layers.Add([model, temp])
    
    for i in range(no_pixel_shuffler_blocks):
        model = pixel_shuffler_block(model)
    
    output = tf.keras.layers.Conv2D(3, (9,9), padding='same')(model)

    return tf.keras.Model(inputs= generator, outputs= output)
    

In [5]:
def create_disriminator(discriminator):
    filters = [128, 256, 512]
    strides = [1, 2]
    model = tf.keras.layers.Conv2D(64, (3,3), strides= 1, padding='same')(discriminator)
    model = tf.keras.layers.LeakyReLU(alpha=0.2)(model)
    
    model = tf.keras.layers.Conv2D(64, (3,3), strides= 2, padding='same')(model)
    model = tf.keras.layers.BatchNormalization(momentum = 0.8)(model)
    model = tf.keras.layers.LeakyReLU(alpha=0.2)(model)
    
    for filter in filters:
        for stride in strides:       
            model = tf.keras.layers.Conv2D(filter, (3,3), strides= stride, padding='same')(model)
            model = tf.keras.layers.BatchNormalization(momentum = 0.8)(model)
            model = tf.keras.layers.LeakyReLU(alpha=0.2)(model)
    
    model = tf.kera.layers.Flatten()(model)
    model = tf.keras.layers.Dense(1024)(model)
    model = tf.keras.layers.LeakyReLU(alpha= 0.2)(model)
    model = tf.keras.layers.Dense(1, activation='sigmoid')(model)
    
    return tf.keras.Model(inputs= discriminator, outputs=model)

In [6]:
from keras.applications import VGG19

In [7]:
def create_vgg(hr_shape):
    vgg = VGG19(weights='imagenet', include_top=False, input_shape=hr_shape)
    
    return tf.keras.Model(inputs= vgg.inputs, output= vgg.layers[10].output)

In [8]:
def final_model(generator, discriminator, vgg, lr_ip, hr_ip):
    gen_img = generator(lr_ip)
    
    gen_features = vgg(gen_img)
    
    discriminator.trainable = False
    disc_result = discriminator(gen_img)
    
    return tf.keras.Model(inputs=[lr_ip, hr_ip], outputs= [disc_result, gen_features])

In [14]:
import os
import cv2
import numpy as np
import random
import matplotlib.pyplot as plt

In [None]:
n = 5000 # total number of images used for training
lr_path = "data/lr_images"
hr_path = "data/hr_images"
lr_list = os.listdir(lr_path)[:n]

lr_images = []
for lr in lr_list:
    img = cv2.imread(lr_path + lr)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    lr_images.append(img)

hr_list = os.listdir(hr_path)[:n]

hr_images = []
for hr in hr_list:
    img = cv2.imread(lr_path + hr)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    hr_images.append(img)

lr_images = np.array(lr_images)
hr_images = np.array(hr_images)

In [None]:
#Displaying random images
test_img_no = random.randint(0, n-1)
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.imshow(np.reshape(lr_images[test_img_no]), [32, 32, 3])
plt.subplot(122)
plt.imshow(np.reshape(hr_images[test_img_no]), [128, 128, 3])
plt.show()

In [15]:
from sklearn.model_selection import train_test_split

In [None]:
#normalizing the image arrays for training 
lr_images /= 255
hr_images /= 255

#train and test split
lr_train, lr_test, hr_train, hr_test = train_test_split(lr_images, hr_images, test_size= 0.2, random_state = 30)


In [None]:
#getting the size of splits
hr_shape = (hr_train.shape[1], hr_train.shape[2], hr_train.shape[3])
lr_shape = (lr_train.shape[1], lr_train.shape[2], lr_train.shape[3])


In [None]:
#creating the discriminator, generator and vgg models
lr_ip =  tf.keras.layers.Input(shape = lr_shape)
hr_ip = tf.keras.layers.Input(shape= hr_shape)

generator = create_generator(lr_ip, no_res_blocks=16, no_pixel_shuffler_blocks=2)
generator.summary()

discriminator = create_disriminator(hr_ip)
discriminator.compile(loss='binary_crossentropy', optimizer= 'adam', metrics = ['accuracy'])
discriminator.summary()

vgg = create_vgg((128, 128, 3))
vgg.summary()
vgg.trainable = False


In [None]:
#Combining the models to a GAN

gan_model = final_model(generator, discriminator, vgg, lr_ip, hr_ip)
gan_model.compile(loss= ['binary_crossentropy', 'mse'], loss_weghts= [1e-3, 1], optimizer= 'adam')
gan_model.summary()

In [None]:
batch_size = 1
train_lr_batches = []
train_hr_batches = []
for i in range (int(hr_train.shape[0])/batch_size):
    start_index = i * batch_size
    end_index = start_index + batch_size
    train_hr_batches(hr_train[start_index:end_index])
    train_lr_batches(lr_train[start_index:end_index])
    

In [1]:
from tqdm import tqdm

In [None]:
epochs = 10

for e in range(epochs):
    fake_labels = np.zeroes((batch_size, 1))
    real_labels = np.ones((batch_size, 1))
    
    gen_losses = []
    dis_losses = []
    
    for batch in tqdm(range(len(train_hr_batches))):
        lr_images = train_lr_batches[batch]
        hr_images = train_hr_batches[batch]
        
        fake_images = generator.predict_on_batch(lr_images)
        
        discriminator.trainable = True
        d_loss_gen = discriminator.train_on_batches(fake_images, fake_labels)
        d_loss_real = discriminator.train_on_bataches(hr_images, real_labels)
        
        discriminator.trainbale = False
        
        d_loss = 0.5 * np.add(d_loss_gen, d_loss_real)
        
        image_features = vgg.predict(hr_images)
        
        g_loss, _, _ = gan_model.train_on_batch([lr_images, hr_images], [real_labels, image_features])
        
        dis_losses.append(d_loss)
        gen_losses.append(g_loss)

gen_losses = np.array(gen_losses)
dis_losses = np.array(dis_losses)

avg_g_loss = np.sum(gen_losses, axis= 0)/ len(gen_losses)
avg_d_loss = np.sum(dis_losses, axis= 0)/ len(dis_losses)

print("epoch: ", e+1, "gen_loss: ", avg_g_loss, "dis_loss: ", avg_d_loss)

if (e + 1) % 5 == 0:
    generator.save("gen_e_"+ str(e+1) + ".h5")

    

In [None]:
generator = tf.keras.load_model('gen_e_10.h5', compile= False)

[X1, X2] = [lr_test, hr_test]

ix = random.randint(0, len(X1), 1)
src_image, tar_image = X1[ix], X2[ix]

gen_image = generator.predict(src_image)

plt.figure(figsize=(16,8))
plt.subplots(231)
plt.title("LR Image")
plt.imshow(src_image[0, :, :, :])
plt.subplots(232)
plt.title("SR")
plt.imshow(gen_image[0, :, :, :])
plt.subplot(233)
plt.title("HR Images")
plt.imshow(tar_image[0, :, :, :])

plt.show()
