In [1]:
from network.srcnn import srcnn
import tensorflow as tf

import os
import numpy
import math
import cv2

from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


### Inference config

In [2]:
#IMG_NAME = "./imageValidate/yayoi_first_087.png"
#IMG_NAME = "./imageValidate2/comic.bmp"
IMG_NAME = "./imageValidate2/butterfly_GT.bmp"


IMG_OPENCV_NAME = "opencv.png"
IMG_DNN_NAME = "srcnn.png"

INTERPOLATION = cv2.INTER_CUBIC
SCALE = 2
SIZE_CONV = 6

FILEPATH_MODEL = "./model/"
FILENAME_MODEL = "yayoi_srcnn_935_2x_model.h5"

PSNR = True

### OpenCV interpolation methods
INTER_NEAREST - a nearest-neighbor interpolation<br>
INTER_LINEAR - a bilinear interpolation (used by default)<br>
INTER_AREA - resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire’-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.<br>
INTER_CUBIC - a bicubic interpolation over 4x4 pixel neighborhood<br>
INTER_LANCZOS4 - a Lanczos interpolation over 8x8 pixel neighborhood<br>

In [3]:
model = tf.keras.models.load_model(FILEPATH_MODEL + FILENAME_MODEL, compile=False)
print(model.summary())

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_4 (Conv2D)            (None, None, None, 128)   10496     
_________________________________________________________________
conv2d_5 (Conv2D)            (None, None, None, 64)    73792     
_________________________________________________________________
conv2d_6 (Conv2D)            (None, None, None, 1)     1601      
Total params: 85,889
Trainable params: 85,889
Non-trainable params: 0
_________________________________________________________________
None


In [4]:
tick1 = cv2.getTickCount()

### Generate OpenCV resized image for reference

In [5]:
img = cv2.imread(IMG_NAME, cv2.IMREAD_COLOR)
shape = img.shape

if PSNR:
    img = cv2.resize(img, (int(shape[1] / SCALE), int(shape[0] / SCALE)), INTERPOLATION)
    img = cv2.resize(img, (shape[1], shape[0]), INTERPOLATION)
else:
    img = cv2.resize(img, (shape[1] * SCALE, shape[0] * SCALE), INTERPOLATION)

cv2.imwrite(IMG_OPENCV_NAME, img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])


True

### Generate super resolution image by SRCNN

In [6]:
img = cv2.imread(IMG_NAME, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
shape = img.shape

if PSNR:
    imgY = cv2.resize(img[:, :, 0], (int(shape[1] / SCALE), int(shape[0] / SCALE)), INTERPOLATION)
    imgY = cv2.resize(imgY, (shape[1], shape[0]), cv2.INTER_CUBIC)
    img[:, :, 0] = imgY
else:
    imgY = cv2.resize(img[:, :, 0], (int(shape[1] * SCALE), int(shape[0] * SCALE)), INTERPOLATION)
    img[:, :, 0] = imgY

tensorY = numpy.zeros((1, shape[0], shape[1], 1), dtype=float)
tensorY[0, :, :, 0] = imgY.astype(float) / 255.

tensorOutput = model.predict(tensorY, batch_size=1) * 255.
tensorOutput[tensorOutput[:] > 255] = 255
tensorOutput[tensorOutput[:] < 0] = 0
tensorOutput = tensorOutput.astype(numpy.uint8)

img[SIZE_CONV: -SIZE_CONV, SIZE_CONV: -SIZE_CONV, 0] = tensorOutput[0, :, :, 0]

img = cv2.cvtColor(img, cv2.COLOR_YCrCb2BGR)
cv2.imwrite(IMG_DNN_NAME, img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])

True

### PSNR calculation

In [7]:
im1 = cv2.imread(IMG_NAME, cv2.IMREAD_COLOR)
im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2RGB)
im2 = cv2.imread(IMG_OPENCV_NAME, cv2.IMREAD_COLOR)
im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2RGB)
im3 = cv2.imread(IMG_DNN_NAME, cv2.IMREAD_COLOR)
im3 = cv2.cvtColor(im3, cv2.COLOR_BGR2RGB)

if PSNR:
    print("opencv:")
    print(cv2.PSNR(im1, im2))
    print("srcnn:")
    print(cv2.PSNR(im1, im3))

opencv:
24.782076560337416
srcnn:
29.741241339454547


### display image

In [8]:
pltShow = False

if pltShow:    
    plt.figure(num='comparison',figsize=(16,16))

    plt.subplot(1,3,1)
    plt.title('origin image')
    plt.imshow(im1)

    plt.subplot(1,3,2)
    plt.title('OpenCV')
    plt.imshow(im2)

    plt.subplot(1,3,3)
    plt.title("srcnn")
    plt.imshow(im3)

In [9]:
tick2 = cv2.getTickCount()
tick = math.floor( ((tick2 - tick1) * 1000) / cv2.getTickFrequency())

if tick >= 60000:
    mins = math.floor(tick / 60000)
    secs = math.floor((tick - mins * 60000) / 1000)
    msec = tick - mins * 60000 - secs * 1000
    print("Inference processed time: " + str(mins) + " mins " + str(secs) + " secs " + str(msec) + " ms")
elif tick >= 1000:
    secs = math.floor(tick / 1000)
    msec = tick - secs * 1000
    print("Inference processed time: " + str(secs) + " secs " + str(msec) + " ms")
else:
    print("Inference processed time: " + str(tick) + " ms")

Inference processed time: 903 ms
