## Image super-resolution using GAN 

Generative Adversarial Networks is a Deep Neural Networks architecture based on a game-theoretic approach, where two components of the model, namely a generator and discriminator, try to compete with each other. 

Here, generator is trained to generate high resolution images from low resolution images. Discriminator is trained to identify original high resolution image and generated high resolution images. This helps the generator to generate better super-resolution images.







## Importing libraries 



In [None]:
import tensorflow as tf
from keras import Input
from keras.applications import VGG19, InceptionResNetV2
from keras.callbacks import TensorBoard
from keras.layers import BatchNormalization, Activation, LeakyReLU, Add, Dense
from keras.layers.convolutional import Conv2D, UpSampling2D
from keras.models import Model
from keras.optimizers import Adam

import glob
import time
import os
import cv2
import base64
import imageio
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random

from imageio import imread
from skimage.transform import resize as imresize
from copy import deepcopy
from tqdm import tqdm
from pprint import pprint
from PIL import Image
from sklearn.model_selection import train_test_split

 ### Defining hyperparameters



In [None]:
epochs = 100

batch_size = 8

low_resolution_shape = (64, 64, 3)

high_resolution_shape = (256, 256, 3)

common_optimizer = Adam(0.0002, 0.5)

In [None]:
data_dir = "/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba/*.*"

# Generator



In [None]:
def residual_block(x):

    filters = [64, 64]
    #filters = [128, 128]
    kernel_size = 3
    strides = 1
    padding = "same"
    momentum = 0.8
    activation = "relu"

    res = Conv2D(filters=filters[0], kernel_size=kernel_size, strides=strides, padding=padding)(x)
    res = Activation(activation=activation)(res)
    res = BatchNormalization(momentum=momentum)(res)

    res = Conv2D(filters=filters[1], kernel_size=kernel_size, strides=strides, padding=padding)(res)
    res = BatchNormalization(momentum=momentum)(res)

    res = Add()([res, x])
    return res

In [None]:
def build_generator():
    
    # Using 16 residual blocks ingenerator
    residual_blocks = 16
    momentum = 0.8
    
    #Input to generator is Low Resolution image
    input_shape = (64, 64, 3)
    
    # Defining Input layer
    input_layer = Input(shape=input_shape)
    
    # Pre-residual block (Convolution layer-n64s1, ReLu)
    gen1 = Conv2D(filters=64, kernel_size=9, strides=1, padding='same', activation='relu')(input_layer)
    
    # Adding 16 residual blocks
    res = residual_block(gen1)
    for i in range(residual_blocks - 1):
        res = residual_block(res)
    
    # Post-residual block: Convolutional layer and batchnorm layer
    gen2 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(res)
    gen2 = BatchNormalization(momentum=momentum)(gen2)
    
    # Adding pre-residual block(gen1) and the post-residual block(gen2)
    gen3 = Add()([gen2, gen1])
    
    # UpSampling
    gen4 = UpSampling2D(size=2)(gen3)
    gen4 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen4)
    gen4 = Activation('relu')(gen4)
    
    # UpSampling
    gen5 = UpSampling2D(size=2)(gen4)
    gen5 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen5)
    gen5 = Activation('relu')(gen5)
    
    # Final convolutional layer after upsampling
    gen6 = Conv2D(filters=3, kernel_size=9, strides=1, padding='same')(gen5)
    output = Activation('tanh')(gen6)
    
    # Model 
    model = Model(inputs=[input_layer], outputs=[output], name='generator')
    return model

# Discriminator



In [None]:
def build_discriminator():
    
    leakyrelu_alpha = 0.2
    momentum = 0.8
    
    #Input is High Resolution image
    input_shape = (256, 256, 3)
    
    #Defining input layer
    input_layer = Input(shape=input_shape)
    
    #8 Convolutional layers with batch normalization  
    dis1 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(input_layer)
    dis1 = LeakyReLU(alpha=leakyrelu_alpha)(dis1)

    dis2 = Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(dis1)
    dis2 = LeakyReLU(alpha=leakyrelu_alpha)(dis2)
    dis2 = BatchNormalization(momentum=momentum)(dis2)

    dis3 = Conv2D(filters=128, kernel_size=3, strides=1, padding='same')(dis2)
    dis3 = LeakyReLU(alpha=leakyrelu_alpha)(dis3)
    dis3 = BatchNormalization(momentum=momentum)(dis3)

    dis4 = Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(dis3)
    dis4 = LeakyReLU(alpha=leakyrelu_alpha)(dis4)
    dis4 = BatchNormalization(momentum=0.8)(dis4)

    dis5 = Conv2D(256, kernel_size=3, strides=1, padding='same')(dis4)
    dis5 = LeakyReLU(alpha=leakyrelu_alpha)(dis5)
    dis5 = BatchNormalization(momentum=momentum)(dis5)

    dis6 = Conv2D(filters=256, kernel_size=3, strides=2, padding='same')(dis5)
    dis6 = LeakyReLU(alpha=leakyrelu_alpha)(dis6)
    dis6 = BatchNormalization(momentum=momentum)(dis6)

    dis7 = Conv2D(filters=512, kernel_size=3, strides=1, padding='same')(dis6)
    dis7 = LeakyReLU(alpha=leakyrelu_alpha)(dis7)
    dis7 = BatchNormalization(momentum=momentum)(dis7)

    dis8 = Conv2D(filters=512, kernel_size=3, strides=2, padding='same')(dis7)
    dis8 = LeakyReLU(alpha=leakyrelu_alpha)(dis8)
    dis8 = BatchNormalization(momentum=momentum)(dis8)
    
    # Fully connected layer 
    dis9 = Dense(units=1024)(dis8)
    dis9 = LeakyReLU(alpha=0.2)(dis9)
    
    # Final fully connected layer for classification
    output = Dense(units=1, activation='sigmoid')(dis9)
    
    
    model = Model(inputs=[input_layer], outputs=[output], name='discriminator')
    return model

# Pre-trained VGG19 

Pre-trained VGG19 will be used for feature extraction from real images and generated images



In [None]:
def build_vgg():
    
    # Dimension corresponding to High Resolution image
    input_shape = (256, 256, 3)
    
    # Using pre-trained vgg19 trained on 'Imagenet' dataset
    vgg = VGG19(weights="imagenet")
    
    # Taking output from 9th layer
    vgg.outputs = [vgg.layers[9].output]
    
    # Defining input layer
    input_layer = Input(shape=input_shape)
    
    # Extracting features 
    features = vgg(input_layer)
    
    # Model
    model = Model(inputs=[input_layer], outputs=[features])
    return model

## Sampling images

Implementing a function for sampling images

In [None]:
def sample_images(data_dir, batch_size, high_resolution_shape, low_resolution_shape):
    
    # Creating list of all images in data_dir
    all_images = glob.glob(data_dir)
    
    # Choosing a random batch of images
    images_batch = np.random.choice(all_images, size=batch_size)

    low_resolution_images = []
    high_resolution_images = []

    for img in images_batch:
        # Getting numpy ndarray of current image
        img1 = imread(img, as_gray=False, pilmode='RGB')
        img1 = img1.astype(np.float32)
        
        # Resizing image
        img1_high_resolution = imresize(img1, high_resolution_shape)
        img1_low_resolution = imresize(img1, low_resolution_shape)
        
        # Applying augmentation: random horizontal flip
        if np.random.random() < 0.5:
            img1_high_resolution = np.fliplr(img1_high_resolution)
            img1_low_resolution = np.fliplr(img1_low_resolution)

        high_resolution_images.append(img1_high_resolution)
        low_resolution_images.append(img1_low_resolution)
    
    # Convert lists to numpy ndarrays
    return np.array(high_resolution_images), np.array(low_resolution_images)

## Saving images

Implementing a function to save images

In [None]:
def save_images(low_resolution_image, original_image, generated_image, path):

    #Saving low resolution, high resolution and generated high resolution images in one picture
    fig = plt.figure()
    
    ax = fig.add_subplot(1, 3, 1)
    ax.imshow(original_image)
    ax.axis("off")
    ax.set_title("ORIGINAL")
    
    ax = fig.add_subplot(1, 3, 2)
    ax.imshow(low_resolution_image)
    ax.axis("off")
    ax.set_title("LOW_RESOLUTION")

    ax = fig.add_subplot(1, 3, 3)
    ax.imshow(generated_image)
    ax.axis("off")
    ax.set_title("GENERATED")

    plt.savefig(path)


## VGG19 compilation

Compiling the trained vgg19 network

In [None]:
vgg = build_vgg()
vgg.trainable = False
vgg.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])
vgg.summary()

## Discriminator compilation

Compiling discriminator network

In [None]:

discriminator = build_discriminator()
discriminator.trainable = True
discriminator.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])

## Generator build


Building generator

In [None]:
generator = build_generator()

## Adversarial model compilation

Compiling a adversarial model that includes generator, discriminator and a pre-trained VGG19 network

In [None]:
def build_adversarial_model(generator, discriminator, vgg):
    
    # Input layer for high-resolution images
    input_high_resolution = Input(shape=high_resolution_shape)

    # Input layer for low-resolution images
    input_low_resolution = Input(shape=low_resolution_shape)

    # Generating high-resolution images from low-resolution images
    generated_high_resolution_images = generator(input_low_resolution)

    # Extracting feature maps from generated images
    features = vgg(generated_high_resolution_images)
    
    # Making discriminator inside GAN untrainable
    # In an adversarial network, we don't train the discriminator while the generator is training.
    discriminator.trainable = False
    discriminator.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])

    # Discriminator will give probability of generated high-resolution image
    probs = discriminator(generated_high_resolution_images)

    # создадим и скомпилируем сотязательную модель
    adversarial_model = Model([input_low_resolution, input_high_resolution], [probs, features])
    adversarial_model.compile(loss=['binary_crossentropy', 'mse'], loss_weights=[1e-3, 1], optimizer=common_optimizer)
    return adversarial_model

In [None]:
adversarial_model = build_adversarial_model(generator, discriminator, vgg)


# Training loop on CelebA dataset

Training on CelebA Dataset



We will train SRGAN in 2 stages:

* In the first stage, we train the discriminator.
* In the second stage we train the adversarial network, inside which we are training the generator, but the discriminator is frozen.


In [None]:
for epoch in range(epochs):
    d_history = []
    g_history = []
    print("Epoch:{}".format(epoch))
    
    # Sampling batch of images
    high_resolution_images, low_resolution_images = sample_images(data_dir=data_dir, batch_size=batch_size,
                                                                          low_resolution_shape=low_resolution_shape,
                                                                          high_resolution_shape=high_resolution_shape)
    
    # Normalizing images
    high_resolution_images = high_resolution_images / 127.5 - 1.
    low_resolution_images = low_resolution_images / 127.5 - 1.
    
    # Generating high-resolution images from low-resolution images
    generated_high_resolution_images = generator.predict(low_resolution_images)
    
    # Generating a batch of real and fake tags
    real_labels = np.ones((batch_size, 16, 16, 1))
    fake_labels = np.zeros((batch_size, 16, 16, 1))
    
    # Train generator on real and fake images
    d_loss_real = discriminator.train_on_batch(high_resolution_images, real_labels)
    d_loss_real =  np.mean(d_loss_real)
    d_loss_fake = discriminator.train_on_batch(generated_high_resolution_images, fake_labels)
    d_loss_fake =  np.mean(d_loss_fake)
    # Calculating the total loss of the discriminator as the arithmetic average of losses on real and fake tags
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    d_history.append(d_loss)
    print("D_loss:", d_loss)
    
    
    # Training the generator
    
    # Sampling batch of images
    high_resolution_images, low_resolution_images = sample_images(data_dir=data_dir, batch_size=batch_size,
                                                                    low_resolution_shape=low_resolution_shape,
                                                                    high_resolution_shape=high_resolution_shape)
    
    #  Normalizing images
    high_resolution_images = high_resolution_images / 127.5 - 1.
    low_resolution_images = low_resolution_images / 127.5 - 1.
    
    # Extracting feature maps for true high resolution images
    image_features = vgg.predict(high_resolution_images)
    
    # Training the generator
    g_loss = adversarial_model.train_on_batch([low_resolution_images, high_resolution_images],
                                             [real_labels, image_features])
    g_history.append( 0.5 * (g_loss[1]) )
    print( "G_loss:", 0.5 * (g_loss[1]) )
    
    # Save and display image samples
    if epoch % 20 == 0:
        high_resolution_images, low_resolution_images = sample_images(data_dir=data_dir, batch_size=batch_size,
                                                                        low_resolution_shape=low_resolution_shape,
                                                                        high_resolution_shape=high_resolution_shape)
        
        # Normalizing images
        high_resolution_images = high_resolution_images / 127.5 - 1.
        low_resolution_images = low_resolution_images / 127.5 - 1.

        generated_images = generator.predict_on_batch(low_resolution_images)

        for index, img in enumerate(generated_images):
            save_images(low_resolution_images[index], high_resolution_images[index], img,
                        path="/kaggle/working/img_{}_{}".format(epoch, index))


## Save models weights

Saving model weights

In [None]:
generator.save_weights("/kaggle/working/generator.h5")
discriminator.save_weights("/kaggle/working/discriminator.h5")

# Evaluation mode



In [None]:
#discriminator = build_discriminator()
#generator = build_generator()

generator.load_weights("/kaggle/working/generator.h5")
discriminator.load_weights("/kaggle/working/discriminator.h5")

high_resolution_images, low_resolution_images = sample_images(data_dir=data_dir, batch_size=10,
                                                                      low_resolution_shape=low_resolution_shape,
                                                                      high_resolution_shape=high_resolution_shape)

high_resolution_images = high_resolution_images / 127.5 - 1.
low_resolution_images = low_resolution_images / 127.5 - 1.

generated_images = generator.predict_on_batch(low_resolution_images)

## Save images



In [None]:
for index, img in enumerate(generated_images):
    save_images(low_resolution_images[index], high_resolution_images[index], img,
                path="/kaggle/working/gen_{}".format(index))