In [5]:
import os
import sys
import tensorflow as tf
import numpy as np
import math
import skimage
import cv2

from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, Input
from tensorflow.keras import Model
from skimage.measure import compare_ssim as ssim

%matplotlib inline

In [6]:
# helper functions

# function to calculate the peak signal to noise ratio of low resolution and high resolution
def psnr(l_res, h_res):
    
    # convert the image data to floats
    l_resData = l_res.astype(float)
    h_resData = h_res.astype(float)
    
    # calculate the difference
    diff = h_resData - l_resData
    diff = diff.flatten('C')
    
    # calculate the root mean square difference
    rmsd = math.sqrt(np.mean(diff ** 2.))
    
    # calculate the psnr
    psnr = 20 * math.log10(255. / rmsd)
    
    return psnr

# function for mean squared error
def mse(l_res, h_res):
    
    # sum of squared differences of two images
    error = np.sum((l_res.astype(float) - h_res.astype(float)) ** 2)
    
    # divide by total number of pixels
    error /= float(l_res.shape[0] * h_res.shape[1])
    return error

# compare the qulity of low-res and high-res images
def compare_images(l_res, h_res):
    
    results = []
    results.append(psnr(l_res, h_res))
    results.append(mse(l_res, h_res))
    results.append(ssim(l_res, h_res, multichannel=True))
    
    return results


# degrade images

def degrade_images(path, value):
    
    # for all the files in the given path
    for file in os.listdir(path):
        
        # read the file using cv2
        img = cv2.imread(path + '/' + file)
        
        # find the old and new image dimensions
        h, w, c = img.shape
        new_h = int(h / value)
        new_w = int(h / value)
        
        # downsize the image
        img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        
        # upsize the image
        img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR)
        
        # save the image
        print('Saving {}'.format(file))
        cv2.imwrite('Test_degrded/{}'.format(file), img) 
        
        
# image pre processing
def size_mod(img, factor):
    temp_size = img.shape
    size = temp_size[0:2]
    size = size - np.mod(size, factor)
    return img[0:size[0], 1:size[1]]


def crop(img, edge):
    return img[edge:-edge, edge:-edge]        
        
    
    
def test(test_path, model):
    
    #model = srcnn_model()
    #model.load_weights('srcnn_weights.h5')
    
    # load high res and and low res images
    path, file = os.path.split(test_path)
    lr = cv2.imread(test_path)
    hr = cv2.imread('org_images/{}'.format(file))
    
    # take the mode of the images
    lr = size_mod(lr, 3)
    hr = size_mod(hr, 3)
    
    # convert the images to YCrCb color space
    ycrcb = cv2.cvtColor(lr, cv2.COLOR_BGR2YCrCb)
    print(ycrcb.shape)
    # extract the Y (luminance) channel from YCrCb space
    Y = np.zeros((1, ycrcb.shape[0], ycrcb.shape[1], 1), dtype=float)
    Y[0, :, :, 0] = ycrcb[:, :, 0].astype(float) / 255
    
    # make a prediction using trained model
    prediction = model.predict(Y, batch_size=1)
    
    # post procces the images
    prediction *= 255
    prediction[prediction > 255] = 255
    prediction[prediction < 0] = 0
    prediction = prediction.astype(np.uint8)
    print(prediction.shape)
    # reconstruct the image in BGR space
    # note the predicted image lost the 4 pixels on each side therefore we need the crop
    # the image with a factor of 6
    ycrcb = crop(ycrcb, 4)
    ycrcb[:, :, 0] = prediction[0, :, :, 0] 
    recon_image = cv2.cvtColor(ycrcb, cv2.COLOR_YCrCb2BGR)
    
    # remove the border of the lr and hr image for comparison
    lr = crop(lr.astype(np.uint8), 4)
    hr = crop(hr.astype(np.uint8), 4)
    
    # image comparison
    metrics = []
    metrics.append(compare_images(lr, hr))
    metrics.append(compare_images(recon_image, hr))
    
    # return hr, lr, reconstructed image and metrics
    return hr, lr, recon_image, metrics

In [13]:
# image pre processing, split to f*f patches, extract Luminance channel and prepare train data labels

PATCH_SIZE = 32
STRIDE = 14
FACTOR = 2

def image_split(path):
    
    x_train = []
    y_train = []
    for i, file in enumerate(os.listdir(path)):
        
        # read the file using cv2
        hr = cv2.imread(path + '/' + file)
        
        # change the image color channel to YCrCb
        hr = cv2.cvtColor(hr, cv2.COLOR_BGR2YCrCb)
        
        # find the old and new image dimensions
        h, w, c = hr.shape
        
        # degrade the images by downsizing and upsizing
        new_h = int(h / FACTOR)
        new_w = int(h / FACTOR) 
        lr = cv2.resize(hr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        lr = cv2.resize(lr, (w, h), interpolation=cv2.INTER_LINEAR)
        
        # number of stride steps
        w_steps = int((w -(PATCH_SIZE - STRIDE)) / STRIDE)
        h_steps = int((h -(PATCH_SIZE - STRIDE)) / STRIDE)
        
        #print('w: {}'.format(w))
        #print('h: {}'.format(h))
        #print('w_steps: {}'.format(w_steps))
        #print('h_steps: {}'.format(h_steps))
        
        Y_hr = np.zeros((hr.shape[0], hr.shape[1], 1), dtype=float)
        Y_hr[:, :, 0] = hr[:, :, 0].astype(float) / 255
        
        Y_lr = np.zeros((lr.shape[0], lr.shape[1], 1), dtype=float)
        Y_lr[:, :, 0] = lr[:, :, 0].astype(float) / 255
        
        for i in range(w_steps - 1):
            for j in range(h_steps - 1):
                
                hr_patch = Y_hr[i * STRIDE: i * STRIDE + PATCH_SIZE , j * STRIDE: j * STRIDE + PATCH_SIZE]
                lr_patch = Y_lr[i * STRIDE: i * STRIDE + PATCH_SIZE , j * STRIDE: j * STRIDE + PATCH_SIZE]
                
                if hr_patch.shape[0] == hr_patch.shape[1]:
                    x_train.append(hr_patch)
                    y_train.append(crop(lr_patch, 4)) 
    return np.array(x_train, dtype=float), np.array(y_train, dtype=float)

In [8]:
degrade_images('Test/', 2)

Saving flowers.bmp
Saving baboon.bmp
Saving barbara.bmp
Saving bridge.bmp
Saving coastguard.bmp
Saving comic.bmp
Saving face.bmp
Saving foreman.bmp
Saving lenna.bmp
Saving man.bmp
Saving monarch.bmp
Saving pepper.bmp
Saving ppt3.bmp
Saving zebra.bmp
Saving baby_GT.bmp
Saving bird_GT.bmp
Saving butterfly_GT.bmp
Saving head_GT.bmp
Saving woman_GT.bmp


In [10]:
# compare the image quality metrics

for img in os.listdir('Test_degrded/'):
    
    lr = cv2.imread('Test_degrded/{}'.format(img))
    hr = cv2.imread('Test/{}'.format(img))
    
    # calculate the metrics
    metrics = compare_images(lr, hr)
    
    # print the results
    print('{}\nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(img, metrics[0], metrics[1], metrics[2]))



flowers.bmp
PSNR: 26.27259062868604
MSE: 460.1956961325967
SSIM: 0.8374632814805686

baboon.bmp
PSNR: 22.13160980892662
MSE: 1194.099825
SSIM: 0.6322050631319908

barbara.bmp
PSNR: 24.985291313715283
MSE: 618.9741102430555
SSIM: 0.7807014586623002

bridge.bmp
PSNR: 25.850528790115554
MSE: 507.1643714904785
SSIM: 0.7804245912255268

coastguard.bmp
PSNR: 27.129127410105276
MSE: 377.8234197443182
SSIM: 0.7491459914768033

comic.bmp
PSNR: 25.127913186306913
MSE: 598.9772077562327
SSIM: 0.8799566225711454

face.bmp
PSNR: 30.99220650287191
MSE: 155.23189718546524
SSIM: 0.8008439492289884

foreman.bmp
PSNR: 29.83350956793885
MSE: 202.69855784406565
SSIM: 0.9250699266756456

lenna.bmp
PSNR: 31.47349297867539
MSE: 138.94800567626953
SSIM: 0.8460989200521499

man.bmp
PSNR: 27.22646369798821
MSE: 369.4496383666992
SSIM: 0.8214950645456561

monarch.bmp
PSNR: 28.69128492283592
MSE: 263.6775309244792
SSIM: 0.9265469628688131

pepper.bmp
PSNR: 29.88947161686106
MSE: 200.1033935546875
SSIM: 0.83579375

In [11]:
# build the SRCNN model

def srcnn_model():
    inputs = tf.keras.Input(shape=(None, None, 1))
    x = Conv2D(128, (9, 9), padding='valid', activation='relu',
               kernel_initializer='glorot_uniform', use_bias=True)(inputs)
    
    x = Conv2D(64, (3, 3), padding='same', activation='relu',
               kernel_initializer='glorot_uniform', use_bias=True)(x)
    
    outputs = Conv2D(1, (5, 5), padding='same', activation='linear',
               kernel_initializer='glorot_uniform', use_bias=True)(x)
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name='SRCNN')
    
    # define loss and optimizer
    optimizer = tf.keras.optimizers.Adam(lr=0.0003)
    
    # compile the model
    model.compile(optimizer=optimizer, loss='mean_squared_error', metrics=['mean_squared_error'])
    
    return model
    

In [12]:
my_model = srcnn_model()
my_model.summary()

Model: "SRCNN"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None, None, 1)]   0         
_________________________________________________________________
conv2d (Conv2D)              (None, None, None, 128)   10496     
_________________________________________________________________
conv2d_1 (Conv2D)            (None, None, None, 64)    73792     
_________________________________________________________________
conv2d_2 (Conv2D)            (None, None, None, 1)     1601      
Total params: 85,889
Trainable params: 85,889
Non-trainable params: 0
_________________________________________________________________


In [19]:
# get the training 
x_train, y_train = image_split('Train/')

In [20]:
print('Size:\n Training data: {}\n Training Labels: {}\n'.format(x_train.shape, y_train.shape))

Size:
 Training data: (15272, 32, 32, 1)
 Training Labels: (15272, 24, 24, 1)



In [21]:
history = my_model.fit(x_train, y_train, epochs=2, batch_size=16)

Train on 15272 samples
Epoch 1/2

KeyboardInterrupt: 

In [None]:
path = 'Train/'

data = []
labels = []
for i, file in enumerate(os.listdir(path)):
        
        # read the file using cv2
        img = cv2.imread(path + '/' + file)
        
        # find the old and new image dimensions
        h, w, c = img.shape
        
        ycrcb = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
        #print(ycrcb.shape)
        # extract the Y (luminance) channel from YCrCb space
        Y = np.zeros((ycrcb.shape[0], ycrcb.shape[1], 1), dtype=float)
        Y[:, :, 0] = ycrcb[:, :, 0].astype(float) / 255
        if i == 0:
            Y2 = Y
            
        print(Y.shape)
        data.append(Y2)
        ycrcb = crop(ycrcb, 4)
        Y = np.zeros((ycrcb.shape[0], ycrcb.shape[1], 1), dtype=float)
        Y[:, :, 0] = ycrcb[:, :, 0].astype(float) / 255
        if i == 0:
            Y3 = Y
        labels.append(Y3)
        print('{} h: {} w: {} c: {} \n'.format(file, h, w, c))
        
data = np.array(data, dtype=float)
labels = np.array(labels, dtype=float)

In [None]:
print(data.shape)
print(labels.shape)

In [None]:
model2 = srcnn_model()

model2.fit(data, labels, epochs=10, batch_size=1)

In [None]:
import gc
gc.collect()

In [None]:
model3 = srcnn_model()

model2.fit(x_train, y_train, epochs=2, batch_size=16)

In [None]:
hr, lr, result, metrics = test('degraded_images/flowers.bmp', model2)

# compare the quality of images
print('Degraded Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(metrics[0][0], metrics[0][1], metrics[0][2]))
print('Reconstructed Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(metrics[1][0], metrics[1][1], metrics[1][2]))

# display images side by side
fig, axs = plt.subplots(1, 3, figsize=(20, 8))
axs[0].imshow(cv2.cvtColor(hr, cv2.COLOR_BGR2RGB))
axs[0].set_title('Orignal')
axs[1].imshow(cv2.cvtColor(lr, cv2.COLOR_BGR2RGB))
axs[1].set_title('Degraded')
axs[2].imshow(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
axs[2].set_title('Reconstructed')

for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:

    #print(ycrcb.shape)
        # extract the Y (luminance) channel from YCrCb space
        Y = np.zeros((ycrcb.shape[0], ycrcb.shape[1], 1), dtype=float)
        Y[:, :, 0] = ycrcb[:, :, 0].astype(float) / 255
        if i == 0:
            Y2 = Y
            
        print(Y.shape)
        data.append(Y2)
        ycrcb = crop(ycrcb, 4)
        Y = np.zeros((ycrcb.shape[0], ycrcb.shape[1], 1), dtype=float)
        Y[:, :, 0] = ycrcb[:, :, 0].astype(float) / 255
        if i == 0:
            Y3 = Y
        labels.append(Y3)
        print('{} h: {} w: {} c: {} \n'.format(file, h, w, c))

In [None]:
hr, lr, result, metrics = test('degraded_images/flowers.bmp')

# compare the quality of images
print('Degraded Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(metrics[0][0], metrics[0][1], metrics[0][2]))
print('Reconstructed Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(metrics[1][0], metrics[1][1], metrics[1][2]))

# display images side by side
fig, axs = plt.subplots(1, 3, figsize=(20, 8))
axs[0].imshow(cv2.cvtColor(hr, cv2.COLOR_BGR2RGB))
axs[0].set_title('Orignal')
axs[1].imshow(cv2.cvtColor(lr, cv2.COLOR_BGR2RGB))
axs[1].set_title('Degraded')
axs[2].imshow(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
axs[2].set_title('Reconstructed')

for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])