In [133]:
"""
Image Super-Resolution using Convolutional
Autoencoders
Author: Amruth Karun M V
Date: 06-Nov-2021
"""

import cv2
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras import Model, Input, regularizers
from tensorflow.keras.layers import (
    Dense, Conv2D, MaxPool2D, 
    UpSampling2D, Add)
from tensorflow.keras.callbacks import EarlyStopping
from keras.preprocessing import image
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings;
warnings.filterwarnings('ignore')


INPUT_PATH = '../input/cifar10/cifar10_sample/'

def show_sample_image():
    """
    Displays original 32 x 32 images.
    Arguments: None
    Returns: Displays the image
    """
    
    cifar_sample = glob.glob(INPUT_PATH + '*.png')
    random_index = random.randint(0, 19)
    print("Image: ", random_index)
    img_path = cifar_sample[random_index]
    img = cv2.imread(img_path)
    plt.imshow(img)
        
def load_images():
    """
    Loads sample images from cifar10 dataset.
    2 images are taken from each class. Total 20 images.
    Arguments: None
    Returns: Train and val images
    """
    
    cifar_sample = glob.glob(INPUT_PATH + '*.png')
    print("Total images = ", len(cifar_sample))
    all_images = []
    for i in tqdm(cifar_sample):
        img = image.load_img(i, target_size=(32,32,3))
        img = image.img_to_array(img)
        img = img/255.
        all_images.append(img)
    all_images = np.array(all_images)
    train_x, val_x = train_test_split(all_images, random_state=32, test_size=0.2)
    return train_x, val_x

def pixalate_image(image, scale_percent = 50):
    """
    Lower the resolution of input image without
    reducing the size
    Arguments:
        image         -- input image
        scale_percent -- amount to be reduced
    Returns: Pixalated image
    """
    
    width = int(image.shape[1] * scale_percent / 100)
    height = int(image.shape[0] * scale_percent / 100)
    dim = (width, height)
    small_image = cv2.resize(image, dim, interpolation = cv2.INTER_AREA)
    width = int(small_image.shape[1] * 100 / scale_percent)
    height = int(small_image.shape[0] * 100 / scale_percent)
    dim = (width, height)
    low_res_image = cv2.resize(small_image, dim, interpolation = cv2.INTER_AREA)
    return low_res_image


def get_low_res_image(train_x, val_x):
    """
    Get low resolution images for train
    and validation set
    Arguments:
        train_x  -- train set
        val_x    -- validation set
    Returns: Low resolution data
    """
    
    # get low resolution images for the train set
    train_x_px = []
  
    for i in range(train_x.shape[0]):
        temp = pixalate_image(train_x[i,:,:,:])
        train_x_px.append(temp)
    train_x_px = np.array(train_x_px)   
    
    # get low resolution images for the validation set
    val_x_px = []
    for i in range(val_x.shape[0]):
        temp = pixalate_image(val_x[i,:,:,:])
        val_x_px.append(temp)
    val_x_px = np.array(val_x_px)     
    
    return train_x_px, val_x_px

def train_model(train_x_px, val_x_px):
    """
    Trains the  Autoencoder network
    Arguments: 
        train_x_px    -- low resolution train set
        val_x_px      -- low resolution validation set
    Returns: Autoencoder Network
    """
    
    Input_img = Input(shape=(32, 32, 3))  
    #encoding architecture
    x1 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(Input_img)
    x2 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x1)
    x3 = MaxPool2D(padding='same')(x2)
    x4 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x3)
    x5 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x4)
    x6 = MaxPool2D(padding='same')(x5)
    encoded = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x6)
    
    # decoding architecture
    x7 = UpSampling2D()(encoded)
    x8 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x7)
    x9 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x8)
    x10 = Add()([x5, x9])
    x11 = UpSampling2D()(x10)
    x12 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x11)
    x13 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x12)
    x14 = Add()([x2, x13])
    decoded = Conv2D(3, (3, 3), padding='same',activation='relu', kernel_regularizer=regularizers.l1(10e-10))(x14)
    autoencoder = Model(Input_img, decoded)
    autoencoder.compile(optimizer='adam', loss='mse', metrics=['accuracy'])

    autoencoder.summary()
    
    early_stopper = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=8, verbose=1, mode='auto')
    history = autoencoder.fit(train_x_px, train_x,
                              epochs=256,
                              batch_size=64,
                              shuffle=True,
                              validation_data=(val_x_px, val_x),
                              callbacks=[early_stopper])
    
    autoencoder.save_weights("autoencoder.h5")
    return autoencoder

def get_results(autoencoder, val_x, val_x_px):
    """
    Evaluate the autoencoder model using
    validation data and predict results
    Arguments:
        autoencoder    -- trained autoencoder model
        val_x          -- original validation data
        val_x_px       -- low resolution validation data
    """
    
    results = autoencoder.evaluate(val_x_px, val_x)
    print('val_loss = {}, val_accuracy = {}'.format(results[0], results[1]))
    
    predictions = autoencoder.predict(val_x_px)
    
    n = 4
    plt.figure(figsize= (20,10))
    for i in range(n):
        ax1 = plt.subplot(3, n, i+1)
        plt.imshow(val_x_px[i])
        ax1.get_xaxis().set_visible(False)
        ax1.get_yaxis().set_visible(False)
        ax1.title.set_text('Pixelated Image')
        
        ax2 = plt.subplot(3, n, i+1+n)
        plt.imshow(predictions[i])
        ax2.get_xaxis().set_visible(False)
        ax2.get_yaxis().set_visible(False)
        ax2.title.set_text('Predicted Image')
        
    plt.show()

In [134]:
# Show 1 random image from train set
# show_sample_image()
train_x, val_x = load_images()
train_x_px, val_x_px = get_low_res_image(train_x, val_x) 

# Train the model
autoencoder = train_model(train_x_px, val_x_px)
get_results(autoencoder, val_x, val_x_px)