# Using The Super Resolution Convolutional Neural Network for Image Restoration
Welcome to this tutorial on single-image super-resolution. The goal of super-resolution (SR) is to recover a high resolution image from a low resolution input, or as they might say on any modern crime show, enhance!

To accomplish this goal, we will be deploying the super-resolution convolution neural network (SRCNN) using Keras. This network was published in the paper, "Image Super-Resolution Using Deep Convolutional Networks" by Chao Dong, et al. in 2014. You can read the full paper at https://arxiv.org/abs/1501.00092.

As the title suggests, the SRCNN is a deep convolutional neural network that learns end-to-end mapping of low resolution to high resolution images. As a result, we can use it to improve the image quality of low resolution images. To evaluate the performance of this network, we will be using three image quality metrics: peak signal to noise ratio (PSNR), mean squared error (MSE), and the structural similarity (SSIM) index.

Furthermore, we will be using OpenCV, the Open Source Computer Vision Library. OpenCV was originally developed by Intel and is used for many real-time computer vision applications. In this particular project, we will be using it to pre and post process our images. As you will see later, we will frequently be converting our images back and forth between the RGB, BGR, and YCrCb color spaces. This is necessary because the SRCNN network was trained on the luminance (Y) channel in the YCrCb color space.

During this project, you will learn how to:

use the PSNR, MSE, and SSIM image quality metrics,
process images using OpenCV,
convert between the RGB, BGR, and YCrCb color spaces,
build deep neural networks in Keras,
deploy and evaluate the SRCNN network

In [22]:
# check package versions
import sys
import keras
import cv2
import numpy
import matplotlib
import skimage

print('Python: {}'.format(sys.version))
print('Keras: {}'.format(keras.__version__))
print('OpenCV: {}'.format(cv2.__version__))
print('NumPy: {}'.format(numpy.__version__))
print('Matplotlib: {}'.format(matplotlib.__version__))
print('Scikit-Image: {}'.format(skimage.__version__))

Python: 3.11.7 | packaged by Anaconda, Inc. | (main, Dec 15 2023, 18:05:47) [MSC v.1916 64 bit (AMD64)]
Keras: 3.6.0
OpenCV: 4.10.0
NumPy: 1.26.4
Matplotlib: 3.8.0
Scikit-Image: 0.22.0


In [30]:
# import the necessary packages
import streamlit as st
import cv2
import numpy as np
from PIL import Image
from io import BytesIO
import math
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.optimizers import Adam
from skimage.metrics import structural_similarity as ssim
from matplotlib import pyplot as plt
import cv2
import numpy as np
import math
import os
# python magic function, displays pyplot figures in the notebook

In [32]:
# define a function for peak signal-to-noise ratio (PSNR)
def psnr(target, ref):
         
    # assume RGB/BGR image
    target_data = target.astype(float)
    ref_data = ref.astype(float)

    diff = ref_data - target_data
    diff = diff.flatten('C')

    rmse = np.sqrt(np.mean(diff ** 2.))

    return 20 * math.log10(255. / rmse)

# define function for mean squared error (MSE)
def mse(target, ref):
    # the MSE between the two images is the sum of the squared difference between the two images
    err = np.sum((target.astype('float') - ref.astype('float')) ** 2)
    err /= float(target.shape[0] * target.shape[1])
    
    return err

# define function that combines all three image quality metrics
def compare_images(target, ref, win_size=5):
    # Mean Squared Error (MSE)
    mse_value = mse(target, ref)
    
    # Peak Signal-to-Noise Ratio (PSNR)
    if mse_value == 0:  # identical images
        psnr_value = 100
    else:
        max_pixel = 255.0
        psnr_value = 20 * np.log10(max_pixel / np.sqrt(mse_value))
    
    # Structural Similarity Index (SSIM)
    ssim_value, _ = ssim(target, ref, full=True, multichannel=True, win_size=win_size, channel_axis=2)
    
    return psnr_value, mse_value, ssim_value


 http://mmlab.ie.cuhk.edu.hk/projects/SRCNN.html

In [36]:
# prepare degraded images by introducing quality distortions via resizing

def prepare_images(path, factor):
    
    # loop through the files in the directory
    for file in os.listdir(path):
        
        # open the file
        img = cv2.imread(path + '/' + file)
        
        # find old and new image dimensions
        h, w, c = img.shape
        new_height = int(h / factor)  # Convert to integer
        new_width = int(w / factor)   # Convert to integer
        
        # resize the image - down
        img = cv2.resize(img, (new_width, new_height), interpolation = cv2.INTER_LINEAR)
        
        # resize the image - up
        img = cv2.resize(img, (w, h), interpolation = cv2.INTER_LINEAR)
        
        # save the image
        print('Saving {}'.format(file))
        cv2.imwrite('images/{}'.format(file), img)


In [38]:
prepare_images('source1/', 2)

Saving baboon.bmp
Saving baby_GT.bmp
Saving barbara.bmp
Saving bird_GT.bmp
Saving butterfly_GT.bmp
Saving coastguard.bmp
Saving comic.bmp
Saving face.bmp
Saving flowers.bmp
Saving foreman.bmp
Saving head_GT.bmp
Saving lenna.bmp
Saving monarch.bmp
Saving pepper.bmp
Saving ppt3.bmp
Saving woman_GT.bmp
Saving zebra.bmp


In [40]:
# define the SRCNN model
def model():
    
    # define model type
    SRCNN = Sequential()
    
    # add model layers
    SRCNN.add(Conv2D(filters=128, kernel_size = (9, 9), kernel_initializer='glorot_uniform',
                     activation='relu', padding='valid', use_bias=True, input_shape=(None, None, 1)))
    SRCNN.add(Conv2D(filters=64, kernel_size = (3, 3), kernel_initializer='glorot_uniform',
                     activation='relu', padding='same', use_bias=True))
    SRCNN.add(Conv2D(filters=1, kernel_size = (5, 5), kernel_initializer='glorot_uniform',
                     activation='linear', padding='valid', use_bias=True))
    
    # define optimizer
    adam = Adam(learning_rate=0.0003)
    
    # compile model
    SRCNN.compile(optimizer=adam, loss='mean_squared_error', metrics=['mean_squared_error'])
    
    return SRCNN

https://github.com/MarkPrecursor/SRCNN-keras

In [43]:
# define necessary image processing functions

def modcrop(img, scale):
    tmpsz = img.shape
    sz = tmpsz[0:2]
    sz = sz - np.mod(sz, scale)
    img = img[0:sz[0], 1:sz[1]]
    return img


def shave(image, border):
    img = image[border: -border, border: -border]
    return img

In [57]:
# Define the main prediction function
def predict(image_path):
    try:
        # Load the srcnn model with weights
        srcnn = model()
        srcnn.load_weights('3051crop_weight_200.h5')

        # Load the degraded and reference images
        path, file = os.path.split(image_path)
        degraded = cv2.imread(image_path)
        ref = cv2.imread('source1/{}'.format(file))

        # Ensure images are loaded correctly
        if degraded is None or ref is None:
            print(f"Error: Image(s) not found for {file}")
            return None

        # Preprocess the image with modcrop
        ref = modcrop(ref, 3)
        degraded = modcrop(degraded, 3)

        # Convert the image to YCrCb (SRCNN is trained on Y channel)
        temp = cv2.cvtColor(degraded, cv2.COLOR_BGR2YCrCb)
        
        # Normalize and prepare input
        Y = np.zeros((1, temp.shape[0], temp.shape[1], 1), dtype=float)
        Y[0, :, :, 0] = temp[:, :, 0].astype(float) / 255

        # Perform super-resolution with srcnn
        pre = srcnn.predict(Y, batch_size=1)

        # Post-process output
        pre *= 255
        pre[pre[:] > 255] = 255
        pre[pre[:] < 0] = 0
        pre = pre.astype(np.uint8)

        # Copy Y channel back to image and convert to BGR
        temp = shave(temp, 6)
        temp[:, :, 0] = pre[0, :, :, 0]
        output = cv2.cvtColor(temp, cv2.COLOR_YCrCb2BGR)
        
        # Remove border from reference and degraded image
        ref = shave(ref.astype(np.uint8), 6)
        degraded = shave(degraded.astype(np.uint8), 6)

        # Image quality calculations
        scores = []
        scores.append(compare_images(degraded, ref))
        scores.append(compare_images(output, ref))

        # Return images and scores
        return ref, degraded, output, scores

    except Exception as e:
        print(f"Error in predict function: {str(e)}")
        return None

In [59]:
ref, degraded, output, scores = predict('images/flowers.bmp')

# print all scores for all images
print('Degraded Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(scores[0][0], scores[0][1], scores[0][2]))
print('Reconstructed Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(scores[1][0], scores[1][1], scores[1][2]))


# display images as subplots
fig, axs = plt.subplots(1, 3, figsize=(20, 8))
axs[0].imshow(cv2.cvtColor(ref, cv2.COLOR_BGR2RGB))
axs[0].set_title('Original')
axs[1].imshow(cv2.cvtColor(degraded, cv2.COLOR_BGR2RGB))
axs[1].set_title('Degraded')
axs[2].imshow(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
axs[2].set_title('SRCNN')

# remove the x and y ticks
for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])


Error in predict function: OpenCV(4.10.0) :-1: error: (-5:Bad argument) in function 'cvtColor'
> Overload resolution failed:
>  - src is not a numpy array, neither a scalar
>  - Expected Ptr<cv::UMat> for argument 'src'



TypeError: cannot unpack non-iterable NoneType object

In [52]:
for file in os.listdir('images'):
    try:
        # Perform super-resolution
        result = predict('images/{}'.format(file))
        
        if result is None:
            continue

        ref, degraded, output, scores = result
        
        # Display images as subplots
        fig, axs = plt.subplots(1, 3, figsize=(20, 8))
        axs[0].imshow(cv2.cvtColor(ref, cv2.COLOR_BGR2RGB))
        axs[0].set_title('Original')
        axs[1].imshow(cv2.cvtColor(degraded, cv2.COLOR_BGR2RGB))
        axs[1].set_title('Degraded')
        axs[1].set(xlabel='PSNR: {}\nMSE: {}\nSSIM: {}'.format(scores[0][0], scores[0][1], scores[0][2]))
        axs[2].imshow(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
        axs[2].set_title('SRCNN')
        axs[2].set(xlabel='PSNR: {}\nMSE: {}\nSSIM: {}'.format(scores[1][0], scores[1][1], scores[1][2]))

        # Remove x and y ticks
        for ax in axs:
            ax.set_xticks([])
            ax.set_yticks([])

        print('Saving {}'.format(file))
        fig.savefig('output/{}.png'.format(os.path.splitext(file)[0]))
        plt.close()

    except Exception as e:
        print(f"Error processing {file}: {str(e)}")

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 62s/step
Saving baboon.bmp
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 66s/step
Saving baby_GT.bmp
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m114s[0m 114s/step
Saving barbara.bmp
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 374ms/step
Saving bird_GT.bmp
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 277ms/step
Saving butterfly_GT.bmp
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 440ms/step
Saving coastguard.bmp
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 377ms/step
Saving comic.bmp
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 337ms/step
Saving face.bmp
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 847ms/step
Saving flowers.bmp
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 423ms/step
Saving foreman.bmp
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 363ms/step
Saving

In [55]:
# Streamlit app
def main():
    st.title("Super Resolution with SRCNN")
    
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "bmp"])
    
    if uploaded_file is not None:
        image = Image.open(uploaded_file)
        st.image(image, caption='Uploaded Image', use_column_width=True)
        
        if st.button('Enhance Image'):
            # Convert PIL image to OpenCV format
            opencv_image = np.array(image.convert('RGB'))
            opencv_image = opencv_image[:, :, ::-1]  # Convert RGB to BGR
            
            # Perform super-resolution
            output_image = predict(opencv_image)
            
            # Display the result
            st.image(output_image, caption='Enhanced Image', use_column_width=True)

if __name__ == "__main__":
    main()

2024-10-13 13:54:41.027 
  command:

    streamlit run F:\anaconda3\Lib\site-packages\ipykernel_launcher.py [ARGUMENTS]
