In [None]:
import os, shutil

data_dir = "<Path to the dataset>"
model_dir = "<Path to the folder to store the model>"
output_dir = "<Path for the generated images>"

### Initialize Hyperparameters

In [None]:
#Image
down_scale = 4             # downsampling scale
cropped_width = 204        # high res image (width)
cropped_height = 64     # high res image (height)
image_shape = (64, 204, 3)# high res image (shape)

# TRAIN
num_images = 2364          # total number of images (train & test)
split_ratio = 0.8
epochs = 1000
batch_size = 32
learning_rate = 1e-4
epsilon = 1e-8
sample_every = 1
save_every = 5

### Preprocess Images

In [None]:
import numpy as np
from numpy import random
from PIL import Image
import matplotlib.pyplot as plt
import glob

%matplotlib inline

def down_sample(img, scale=down_scale):
    
    new_h = img.shape[0]//scale
    new_w = img.shape[1]//scale
    lr_image = np.asarray(Image.fromarray(np.uint8(img)).resize((new_w, new_h), Image.BICUBIC))
    return lr_image


def normalize(img):
    n_img = np.divide(img.astype(np.float32), 127.5) - np.ones_like(img, dtype=np.float32)
    return n_img

In [None]:
import cv2
hr_images = []
lr_images = []

def get_data(data_dir = data_dir):
    
    for img_name in glob.glob(data_dir+"*.png"):
        hr_img = cv2.imread(os.path.join(img_name), cv2.IMREAD_COLOR)
        lr_img = down_sample(hr_img)
        hr_images.append(hr_img)
        lr_images.append(lr_img)
        
        
get_data()
hr_images = np.array(hr_images)
lr_images = np.array(lr_images)
print("High resolution image dataset shape: ", np.shape(hr_images))
print("Low resolution image dataset shape: ", np.shape(lr_images))

### Utilty Function for Training

In [None]:
def load_train_data(data_dir = data_dir, num_img=num_images,
                   split_ration=split_ratio,
                   hr_images=hr_images, lr_images=lr_images):
    
    num_train = int(num_img*split_ratio)
    
    hr_images = [normalize(img) for img in hr_images]
    lr_images = [normalize(img) for img in lr_images]
    
    hr_images = np.array(hr_images)
    lr_images = np.array(lr_images)
    hr_train = np.array(hr_images[:num_train,:,:,:])
    hr_test = np.array(hr_images[num_train:,:,:,:])
    lr_train = np.array(lr_images[:num_train,:,:,:])
    lr_test = np.array(lr_images[num_train:,:,:,:])
    
    return hr_train, hr_test, lr_train, lr_test

hr_train, hr_test, lr_train, lr_test = load_train_data()
print("HR images training dataset shape: ", np.shape(hr_train), "\t")
print("LR images training dataset shape: ", np.shape(lr_train), "\t")
print("HR images test dataset shape: ", np.shape(hr_test), "\t")
print("LR images test dataset shape: ", np.shape(lr_test), "\t")

In [None]:
from keras.layers.core import Activation, Flatten
from keras.layers import Input, add, LeakyReLU, PReLU
from keras.layers import BatchNormalization, Conv2D, UpSampling2D, Dense
from keras.models import Model, load_model
from keras.optimizers import Adam

In [None]:
def residual_block(model, kernel, filters, strides):
    """Residual block inspired by SRResNet.
       In -> Conv -> BN -> PReLU -> Conv -> BN -> add -> Out
       |___________________________________________^ 
    """
    prev = model
    
    model = Conv2D(filters=filters, kernel_size=kernel, strides=strides, padding="same")(model)
    model = BatchNormalization(momentum=0.5)(model)
    model = PReLU(alpha_initializer="zeros", alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model)
    model = Conv2D(filters=filters, kernel_size=kernel, strides=strides, padding="same")(model)
    model = BatchNormalization(momentum=0.5)(model)
    
    model = add([prev, model])
    
    return model

In [None]:
def up_sample_block(model, kernel, filters, strides):
    """Up sampling block (can be replaced by Conv2DTranspose layer).
       In -> Conv -> UpSample -> LReLU -> Out
    """
    
    model = Conv2D(filters=filters, kernel_size = kernel, strides = strides, padding="same")(model)
    model = UpSampling2D()(model)
    model = LeakyReLU(0.2)(model)
    
    return model

In [None]:
class Generator_NN():
    
    """The Generator Network.
        Input -> Conv -> PReLU -> Res x 16 -> Conv -> BN -> add -> UpSample x 2 -> Conv -> Tanh -> Output
                           |_________________________________^
    """
    
    def __init__(self, noise_shape):
        self.noise_shape = noise_shape
        
    def generator(self):
        
        g_input = Input(shape=self.noise_shape)
        model = Conv2D(filters=64, kernel_size=9, strides=1, padding="same")(g_input)
        model = PReLU(alpha_initializer="zeros", alpha_regularizer=None,
                     alpha_constraint=None, shared_axes=[1,2])(model)
        prev = model
        for i in range(16):
            model = residual_block(model, 3, 64, 1)
            
        model = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")(model)
        model = BatchNormalization(momentum=0.5)(model)
        model = add([prev, model])
        
        
        for i in range(2):
            model = up_sample_block(model, 3, 256, 1)
            
        model = Conv2D(filters=3, kernel_size=9, strides=1, padding="same")(model)
        model = Activation("tanh")(model)
        
        return Model(inputs=g_input, outputs=model)

In [None]:
def discriminator_block(model, filters, kernel, strides):
    """Discriminator block.
        In -> Conv -> BN -> LReLU -> Out
    """
    
    model = Conv2D(filters=filters, kernel_size=kernel, strides=strides, padding="same")(model)
    model = BatchNormalization(momentum=0.5)(model)
    model = LeakyReLU(alpha=0.2)(model)
    
    return model

In [None]:
class Discriminator_NN():
    """The Discriminator Network.
        Input -> Conv -> LReLU -> Dis x 7 -> Flatten -> Dense -> LReLU -> Dense -> Sigmoid -> Output
    """
    
    def __init__(self, image_shape):
        self.image_shape = image_shape
        
    def discriminator(self):
        
        d_input = Input(shape=self.image_shape)
        model = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")(d_input)
        model = LeakyReLU(alpha=0.2)(model)
        
        model = discriminator_block(model, 64, 3, 2)
        model = discriminator_block(model, 128, 3, 1)
        model = discriminator_block(model, 128, 3, 2)
        model = discriminator_block(model, 256, 3, 1)
        model = discriminator_block(model, 256, 3, 2)
        model = discriminator_block(model, 512, 3, 1)
        model = discriminator_block(model, 512, 3, 2)
        
        model = Flatten()(model)
        model = Dense(1024)(model)
        model = LeakyReLU(alpha=0.2)(model)
        
        model = Dense(1)(model)
        model = Activation("sigmoid")(model)
        
        return Model(inputs=d_input, outputs=model)

In [None]:
from keras.applications.vgg19 import VGG19
import keras.backend as K

class Content_loss():
    
    def __init__(self, image_shape):
        self.image_shape = image_shape
        
    def content_loss(self, y, y_pred):
        vgg19_model = VGG19(include_top=False, weights="imagenet",
                          input_shape=self.image_shape)
        vgg19_model.trainable = False
        for layer in vgg19_model.layers:
            layer.trainable = False
            
        model = Model(inputs=vgg19_model.input, outputs=vgg19_model.get_layer("block5_conv4").output)
        model.trainable = False
        
        return K.mean(K.square(model(y) - model(y_pred)))

In [None]:
def GAN_NN(g, d, shape, optimizer, content_loss):
    
    d.trainable = False
    
    gan_input = Input(shape=shape)
    fake = g(gan_input)
    gan_output = d(fake)
    
    gan_model = Model(inputs = gan_input, outputs=[fake, gan_output])
    gan_model.compile(loss=[content_loss, "binary_crossentropy"],
                     loss_weights=[1., 1e-3], optimizer=optimizer)
    
    return gan_model

### Sample Generator Outputs

In [None]:
def generate_image(e, g, hr_test, lr_test, output_dir=output_dir,
                  dim=(1, 3), figsize=(15, 5)):
    
    hr_batch = np.asarray(hr_test)
    lr_batch = np.asarray(lr_test)
    sr_batch = np.asarray(g.predict(lr_batch))
    
    # denormalize
    hr_batch = ((hr_batch + 1) * 127.5).astype(np.uint8)
    lr_batch = ((lr_batch + 1) * 127.5).astype(np.uint8)
    sr_batch = ((sr_batch + 1) * 127.5).astype(np.uint8)
    
    #ranom sample
    idx = random.randint(0, len(hr_test))
    
    plt.figure(figsize=figsize)
    plt.subplot(dim[0], dim[1], 1)
    plt.imshow(lr_batch[idx], interpolation="nearest")
    plt.axis("off")
    plt.subplot(dim[0], dim[1], 2)
    plt.imshow(sr_batch[idx], interpolation="nearest")
    plt.axis("off")
    plt.subplot(dim[0], dim[1], 3)
    plt.imshow(hr_batch[idx], interpolation="nearest")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(output_dir + "result_{}.png".format(e))
    plt.close()

In [None]:
from tqdm import tqdm

def setup_training(data_dir = data_dir):
    if os.path.isdir(model_dir): shutil.rmtree(model_dir)
    if os.path.isdir(output_dir): shutil.rmtree(output_dir)
    if os.path.isdir(model_dir+"loss.txt"): os.remove(model_dir + "loss.txt")
        
    os.mkdir(model_dir)
    os.mkdir(output_dir)
    loss_file = open(model_dir+"loss.txt", "w+")
    loss_file.close()
    
    
setup_training()

### Training

In [None]:
def SRGAN(epochs=epochs, batch_size=batch_size, split_ratio=split_ratio,
         sample_every = sample_every, save_every=save_every,
         shape=image_shape, scale=down_scale, num_imgs=num_images,
         lr=learning_rate, epsilon=epsilon):
    
    hr_train, hr_test, lr_train, lr_test = load_train_data()
    
    num_batches = int(len(hr_train)//batch_size)
    shape_small = (shape[0]//scale, shape[1]//scale, shape[2])
    
    g_loss = Content_loss(shape)
    optimizer = Adam(lr=learning_rate, epsilon=epsilon)
    
    g = Generator_NN(shape_small).generator()
    d = Discriminator_NN(shape).discriminator()
    
    g.compile(loss=g_loss.content_loss, optimizer=optimizer)
    d.compile(loss="binary_crossentropy", optimizer=optimizer)
    
    gan = GAN_NN(g, d, shape_small, optimizer, g_loss.content_loss)
    
    for e in range(1, epochs+1):
        for _ in tqdm(range(num_batches)):
            idxs = random.randint(0, len(hr_train), size=batch_size)
            hr_batch = []
            lr_batch = []
            hr_batch = [np.array(hr_train[i]) for i in idxs]
            hr_batch = np.array(hr_batch)
            #hr_batch = np.asarray(hr_batch).reshape(batch_size, shape[0], shape[1], shape[2])
            lr_batch = [np.array(lr_train[i]) for i in idxs]
            lr_batch = np.array(lr_batch)
            #lr_batch = np.asarray(lr_batch).reshape(batch_size, small_shape[0], small_shape[1], small_shape[2])
            #print(lr_batch.shape)
            sr_batch = g.predict(lr_batch)
            print(sr_batch.shape)
            # std = 0.05, mean =0.9
            real_label = 0.05*random.randn(batch_size) + 0.9
            # std = 0.05, mean = 0.1
            fake_label = 0.05*random.randn(batch_size) + 0.1
            
            d.trainable = True
            d_loss_real = d.train_on_batch(hr_batch, real_label)
            d_loss_fake = d.train_on_batch(sr_batch, fake_label)
            d_loss = np.add(d_loss_real, d_loss_fake) / 2.0
            d.trainable = False
            
            idxs = random.randint(0, len(hr_train), size=batch_size)
            hr_batch = []
            lr_batch = []
            hr_batch = [hr_train[i] for i in idxs]
            hr_batch = np.asarray(hr_batch)
            lr_batch = [lr_train[i] for i in idxs]
            lr_batch = np.asarray(lr_batch)
            sr_batch = g.predict(lr_batch)
            
            # std = 0.05, mean = 0.9
            gan_label = 0.05*random.randn(batch_size) + 0.9
            
            gan_loss = gan.train_on_batch(lr_batch, [hr_batch, gan_label])
            
        print("EPOCH {}\td_loss {}\tgan_loss {}".format(e, d_loss, gan_loss))
        
        loss_file = open(model_dir + "loss.txt", "a")
        loss_file.write("EPOCH {}\td_loss {}\tgan_loss {}\n".format(e, d_loss, str(gan_loss)))
        loss_file.close()
        
        if (e==1) or (e%sample_every==0):
            generate_image(e, g, hr_test, lr_test)
        if (e%save_every ==0):
            g.save(model_dir + "g_model{}.h5".format(e))
            d.save(model_dir + "d_model{}.h5".format(e))

In [None]:
SRGAN()

In [None]:
def visualize_result(output_dir = output_dir):
    
    images_shown = []
    for file_name in os.listdir(output_dir):
        name, ext = file_name.split(".")
        if(int(name)/25 == 0):
            image = np.asarray(Image.open(file_name))
            images_shown.appned(image)
            
        fig, ax = plt.subplots(1, len(images_shown))
        for img in range(len(image_shown)):
            ax[img].set_title("sample "+str(img+1))
            ax[img].imshow(images_shown[img])
            
        
visualize_result()