In [1]:
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image, ImageFilter
import math
from os import listdir
import random
import cv2
import pandas as pd

import pywt

import keras
from keras import backend as K
from keras.models import model_from_json
from keras.models import Model, Sequential
from keras.layers import Input, Convolution2D, Conv2D, MaxPooling2D, Dense, Dropout, Flatten, Conv2D, BatchNormalization, GlobalAveragePooling2D
from keras.callbacks import LearningRateScheduler
from keras.optimizers import SGD, Adam
from keras import regularizers, optimizers, losses, initializers

import tensorflow as tf
from tensorflow.python.client import device_lib

import skimage
from skimage import measure

from PIL import Image

Using TensorFlow backend.


In [2]:
from srcnn import *
from utility import *

%matplotlib inline

In [3]:
#Division by 4
window_size= 64
scale = 4
stride = 16

In [4]:
def calcPSNR(original,prediction):
    return measure.compare_psnr(original, prediction,data_range=255)

In [5]:
def calcSSIM(orig,pred):
    return measure.compare_ssim(orig,pred,data_range=255, multichannel=True)

In [6]:
def calcRMSE(orig,pred):
    return np.sqrt(((orig-pred)**2).mean())

In [7]:
#Returns original, bicubic and model prediction
def getImages(filename):
    im = Image.open("set14/"+filename)

    img = np.asarray(im.convert('YCbCr'))

    img = img[:,:,0]
    org, bic = get_input_images(img, 4)

    h, w = retSize(org)

    sub_img = getSubImages(bic)
    
    low_w = []
    for i in range(len(sub_img)):
        dwt_w = pywt.dwt2(sub_img[i], 'haar')

        dwt_w = np.asfarray([dwt_w[0], dwt_w[1][0], dwt_w[1][1], dwt_w[1][2]])

        low_w.append(dwt_w)
        
    low_wavelet = np.asarray(low_w)
    
    pred_w = model.predict(low_wavelet)
    
    res = low_wavelet + pred_w

    sub_imgs = []
    for i in range(len(res)) :
        patch = res[i]
        dwt = (patch[0], (patch[1], patch[2], patch[3]))
        wavelet = pywt.idwt2(dwt, 'haar')
        sub_imgs.append(wavelet)
        
    im = patch_to_image(sub_imgs, h, w)
    return org,bic,im

In [8]:
def convertToYCbCr(x):
    return [image.convert('YCbCr') for image in x]
	
	#Display images
def print_luminance(img, ch=1) :
    if (ch == 1) :
        plt.imshow(img, cmap=plt.get_cmap('gray'), interpolation='nearest')
    else :
        plt.imshow(img[:,:,0], cmap=plt.get_cmap('gray'), interpolation='nearest')
		
def getSubImage(img, height, width, centerX, centerY, len_interval, scale = 1) :
    window = window_size // scale
    return img[ height   : height//2 + window + centerX*len_interval,
                width // 2  + centerY*len_interval : width//2 + window + centerY*len_interval : ]
				
def retSize(x):
    return (x.shape[0], x.shape[1])
	
def getSubImages(image, stride = stride) :
    sub = []
    height, width = retSize(image)
    for h in range(0, height - window_size, stride) :
        for w in range (0, width - window_size, stride) : 
            sub.append(image[h : h + window_size, w : w + window_size])
    return sub

def appendSubbands(l1,l2,l3,l4, dwt):
    l1 = np.append(l1 , dwt[0])
    l2 = np.append(l2, dwt[1][0])
    l3 = np.append(l3, dwt[1][1])
    l4 = np.append(l4, dwt[1][2])
    return (l1,l2,l3,l4)

def reshape(matrix, dwt_shape, channel='channels_last', ch=1) :
    if (channel == 'channels_last') :
        return matrix.reshape(-1, dwt_shape[0], dwt_shape[1], ch)
    else :
        return matrix.reshape(ch, -1, dwt_shape[0], dwt_shape[1])

def equals(a, b) :
    t = a == b
    count = 0
    for a in range(0, len(t)) :
        for b in range(len(t[0])) :
            for c in range(len(t[1])) :
                if ( t[a, b, c] == False):
                    count = count + 1
    return count / (32*32*4) < 0.20
    

def getSubImages(image, stride = stride) :
    sub = []
    height, width = retSize(image)
    for h in range(0, height - window_size, stride) :
        for w in range (0, width - window_size, stride) : 
            sub.append(image[h : h + window_size, w : w + window_size])
    return sub

def plot4images(im1, im2, im3, im4, ch=1) :
    size = (10, 10)
    fig = plt.figure(figsize=size)
    
    plt.subplot(221)
    plt.title("Approximation")
    print_luminance(im1, ch )
    plt.subplot(222)
    plt.title("Horizontal")
    print_luminance(im2, ch)
    plt.subplot(223)
    plt.title("Vertical")
    print_luminance(im3, ch)
    
    plt.subplot(224)
    plt.title("Diagonal")
    print_luminance(im4, ch)

def plot_waveletTrans(wt, ch=1) :
    plot4images(wt[0], wt[1][0], wt[1][1], wt[1][2], ch)



In [9]:
PATH = ""
TRAIN_FILE = PATH + "set14/"
obj_files = sorted(listdir(TRAIN_FILE )) #the filenames  the call to sorted is actually necessary to ensure independence to filesystems

In [10]:
high_res=obj_files[0::2]
low_res=obj_files[1::2]

In [140]:
space_grey = load_model("sptial_deep_model_50_ycbcr")
wave_grey = load_model("SRCNN_ycbcr_50_residual")

In [227]:
def reconstructSpatialImage(model, bic):
    height, width = retSize(bic)
    padding = 100
    padded_img = np.pad(bic, pad_width = padding, mode='symmetric')

    space_img = np.zeros((height,width))
    stride = 4
    window_size = 32
    half_window = window_size//2
    for h in range(padding, padded_img.shape[0]-padding, stride):
        h_l = h - padding
        h_h = min(height, h-padding+1+stride)
        for w in range(padding, padded_img.shape[1]-padding, stride):
            d = np.reshape(padded_img[h-half_window:h+half_window, 
                                      w-half_window:w+half_window], (window_size,window_size,1))
            window = np.reshape(model.predict(np.expand_dims(d,axis=0)/255., 1),(window_size,window_size)) #+ padded_img[h-half_window:h+half_window, w-half_window:w+half_window]/255.
            w_h = min(width, w-padding+1+stride)
            w_l = w - padding
            space_img[h_l:h_h, w_l:w_h] = window[half_window:half_window+h_h-h_l,
                                                 half_window:half_window+w_h-w_l]
            del d
            del window
    del padded_img
    return space_img

In [228]:
def reconstructWaveletImage(model, bic):
    height, width = retSize(bic)
    padding = 100
    padded_img = np.pad(bic, pad_width = padding, mode='symmetric')

    wave_image = np.zeros((height,width))
    stride = 4
    window_size = 64
    half_window = window_size//2
    for h in range(padding, padded_img.shape[0]-padding, stride):
        h_l = h - padding
        h_h = min(height, h-padding+1+stride)
        for w in range(padding, padded_img.shape[1]-padding, stride):
            d = padded_img[h-half_window:h+half_window, 
                                      w-half_window:w+half_window]
            dwt_w = pywt.dwt2(d, 'haar')
            dwt_w = np.asfarray([dwt_w[0], dwt_w[1][0], dwt_w[1][1], dwt_w[1][2]])
            patch = model.predict(np.expand_dims(dwt_w,axis=0)/255., 1)
            dwt = (patch[0][0], (patch[0][1], patch[0][2], patch[0][3]))
            window = pywt.idwt2(dwt, 'haar')
            w_h = min(width, w-padding+1+stride)
            w_l = w - padding
            wave_image[h_l:h_h, w_l:w_h] = window[half_window:half_window+h_h-h_l,
                                                 half_window:half_window+w_h-w_l]
            del d
            del window
            del dwt_w
            del patch
            del dwt
    del padded_img
    return wave_image

In [229]:
def reconstructWaveAndSpaceImage(img, model_wave, model_space):
    org, bic = get_input_images(img, 4.)
    space = reconstructSpatialImage(model_space, bic)
    print("done with space")
    wave = reconstructWaveletImage(model_wave, bic)
    print("done with freq")
    return space, wave, org, bic

In [247]:
def getPredictions(path, model_wave, model_space):
    img = Image.open("set14/"+path)
    img = np.asarray(img.convert('YCbCr'))[:,:,0]
    return reconstructWaveAndSpaceImage(img, wave_grey, space_grey)

In [248]:
results=[]

for f in high_res:
    print(f)
    space_pred, wave_pred, org, bic = getPredictions(f, wave_grey, space_grey)
    results.append((org,bic,wave_pred, space_pred))
    
np.save("results", results)

img_001_SRF_4_HR.png
done with space
done with freq
img_002_SRF_4_HR.png
done with space
done with freq
img_003_SRF_4_HR.png
done with space
done with freq
img_004_SRF_4_HR.png
done with space
done with freq
img_005_SRF_4_HR.png
done with space
done with freq
img_006_SRF_4_HR.png
done with space
done with freq
img_007_SRF_4_HR.png
done with space
done with freq
img_008_SRF_4_HR.png
done with space
done with freq
img_009_SRF_4_HR.png
done with space
done with freq
img_010_SRF_4_HR.png
done with space
done with freq
img_011_SRF_4_HR.png
done with space
done with freq
img_012_SRF_4_HR.png
done with space
done with freq
img_013_SRF_4_HR.png
done with space
done with freq
img_014_SRF_4_HR.png
done with space
done with freq


In [249]:
results=np.load("results.npy")

### Methods to crunch numbers

In [250]:
def convertToPicture(img):
    if(len(img.shape) == 3):
        img = Image.fromarray(img, 'RGB')
    else:
        img = Image.fromarray(img)
    if img.mode != 'RGB':
        img = img.convert('RGB')
    return img

In [251]:
def plot_results(i):
    result=results[i]
    org=result[0].astype(float)
    bic=result[1].astype(float)
    sr=result[2].astype(float)
    spred=result[3].astype(float)
    
    psnr_bic=round(calcPSNR(org,bic),4)
    psnr_sr=round(calcPSNR(org,sr),4)
    psnr_sd=round(calcPSNR(org, spred),4)
    org,bic,sr, spred = ims[i]
    
    plot_images([org,bic,sr, spred], ["Original", "Bic (PSNR: "+str(psnr_bic)+")", "SR (PSNR: "+str(psnr_sr)+")", "Spatial (PSNR: "+str(psnr_sd)+")"], size= (20,20), ch=1)

In [252]:
def displayTab(t, cols, indexValues):
    df=pd.DataFrame(data=t)
    df.columns=cols
    df.index=indexValues
    return df

def waveletTransf(image):
    dwt_w = pywt.dwt2(image, wavelet='haar')
    return np.asfarray([dwt_w[0], dwt_w[1][0], dwt_w[1][1], dwt_w[1][2]])

def waveDomainMeasure(image, obj, measure):
    dwt_image = waveletTransf(image)
    dwt_obj = waveletTransf(obj)
    if measure == 'ssim':
        op = calcSSIM
    if measure == 'rmse':
        op = calcRMSE
    return (op(dwt_obj[0],dwt_image[0]),
            op(dwt_obj[1],dwt_image[1]),
            op(dwt_obj[2],dwt_image[2]),
            op(dwt_obj[3],dwt_image[3]))
        
def waveMeasureWaveletAndSpatial(bic, image_space, image_wave, obj, measure):
    return (waveDomainMeasure(bic, obj, measure),
            waveDomainMeasure(image_space, obj, measure), 
            waveDomainMeasure(image_wave, obj, measure))

"""Returns mean measure of the dataset result (assumed of the form [original, bicubic, wavelet, spatial])
for each wavelet frequency, between spatial image and original and wavelet with original.
Returns the result in the order space, wave"""
def averageWaveMeasure(results, measure, f):
    waveAndSpace = []
    for res in results:
        waveAndSpace.append(waveMeasureWaveletAndSpatial(res[1], 
                                                         f(res[3]), 
                                                         f(res[2]), 
                                                         res[0], measure)) # remember the order: original, bicubic, wavelet, spatial
    return np.mean(waveAndSpace, axis=0)

"""Returns mean measure of the dataset result (assumed of the form [original, bicubic, wavelet, spatial])
taken in space, between spatial net image and original and wavelet net with original.
Returns the result in the order space, wave"""
def averageSpaceMeasure(results, measure, f):
    waveAndSpace= []
    if measure == 'psnr':
        op = calcPSNR
    if measure == 'rmse':
        op = calcRMSE
    if measure == 'ssim':
        op = calcSSIM
    for res in results:
        waveAndSpace.append((op(res[0], res[1]),
                             op(res[0], f(res[3])), 
                             op(res[0], f(res[2]))))
    return np.mean(waveAndSpace, axis=0)

In [257]:
def displayMeasureResults(results,f):
    t1 = pd.concat([displayTab(averageWaveMeasure(results, 'ssim', f), 
                      ["mean LL Band SSIM", "mean LH Band SSIM", "mean HL Band SSIM", "mean HH Band SSIM"], 
           ["Bicubic","Spatial SRCNN", "Wavelet SRCNN"]),
           displayTab(averageWaveMeasure(results, 'rmse', f), 
                      ["mean LL Band RMSE", "mean LH Band RMSE", "mean HL Band RMSE", "mean HH Band RMSE"], 
           ["Bicubic","Spatial SRCNN", "Wavelet SRCNN"])], axis=1)
    t2 = pd.concat([displayTab(averageSpaceMeasure(results, 'psnr', f), ["mean PSNR"], ["Bicubic", "Spatial SRCNN", "Wavelet SRCNN"]),
          displayTab(averageSpaceMeasure(results, 'rmse',f), ["mean RMSE"], ["Bicubic","Spatial SRCNN", "Wavelet SRCNN"]),
          displayTab(averageSpaceMeasure(results, 'ssim',f), ["mean SSIM"], ["Bicubic","Spatial SRCNN", "Wavelet SRCNN"])],axis=1)
    return pd.concat([t1,t2], axis=1)

## Clipping impact on SSIM and RMSE

In [258]:
def convertNoClip(x):
    return np.uint8(x*255)

In [259]:
def convertClip(x):
    return np.uint8(np.clip(x,0,1)*255)

In [260]:
displayMeasureResults(results, convertNoClip)

Unnamed: 0,mean LL Band SSIM,mean LH Band SSIM,mean HL Band SSIM,mean HH Band SSIM,mean LL Band RMSE,mean LH Band RMSE,mean HL Band RMSE,mean HH Band RMSE,mean PSNR,mean RMSE,mean SSIM
Bicubic,0.70435,0.534244,0.490209,0.618344,29.99238,13.816186,13.45169,6.976369,23.389095,7.027991,0.667652
Spatial SRCNN,0.718573,0.55494,0.516398,0.628458,29.486667,14.585224,14.454077,8.132235,23.572869,6.786253,0.688161
Wavelet SRCNN,0.719035,0.554807,0.517135,0.625497,29.037812,14.70728,13.958787,8.458644,23.611924,6.759918,0.688812


In [261]:
displayMeasureResults(results, convertClip)

Unnamed: 0,mean LL Band SSIM,mean LH Band SSIM,mean HL Band SSIM,mean HH Band SSIM,mean LL Band RMSE,mean LH Band RMSE,mean HL Band RMSE,mean HH Band RMSE,mean PSNR,mean RMSE,mean SSIM
Bicubic,0.70435,0.534244,0.490209,0.618344,29.99238,13.816186,13.45169,6.976369,23.389095,7.027991,0.667652
Spatial SRCNN,0.723667,0.561043,0.521837,0.63387,27.099446,13.057959,12.841793,6.759755,24.150642,6.783948,0.691636
Wavelet SRCNN,0.723896,0.561452,0.522789,0.631493,26.964752,13.046753,12.675799,6.77725,24.198652,6.755819,0.692389


We can notice here that clipping actually has a positive impact on the quality of the produced images, on all metrics. In general, the wavelet outperforms the spatial by a small margin in every domain and metric. In turn, the spatial outperforms the bicubic.

In [262]:
i=0
ims = []
for res in results:
    im_o= convertToPicture(res[0])
    im_w = convertToPicture(convertClip(res[2]))
    im_b = convertToPicture(res[1])
    im_s = convertToPicture(convertClip(res[3]))
    im_o.save("srcnn_results/"+str(i)+"_original.png")
    im_w.save("srcnn_results/"+str(i)+"_wave.png")
    im_b.save("srcnn_results/"+str(i)+"_bic.png")
    im_s.save("srcnn_results/"+str(i)+"_space.png")
    i=i+1

The wavelet net outperforms the spatial net on all wavelet bands, which provide significant clue as to why images produced by the wavelet net actually seem better.

In [None]:
def calculate_measures(measure):
    bic_values=[]
    sr_values=[]
    spatial_values=[]

    indexValues=[]
    
    cols=["Bicubic","SRCNN","Spatial SRCNN"]
    cols=[x+" "+measure for x in cols]

    i=1
    for result in results:
        org=result[0].astype(float)
        bic=result[1].astype(float)
        sr=result[2].astype(float)
        spatial=result[3].astype(float)
        
        if measure=="psnr":
            bic_values.append(round(calcPSNR(org,bic),4))
            sr_values.append(round(calcPSNR(org,sr),4))
            spatial_values.append(round(calcPSNR(org,spatial),4))

        if measure=="rmse":
            bic_values.append(round(calcRMSE(org,bic),4))
            sr_values.append(round(calcRMSE(org,sr),4))
            spatial_values.append(round(calcRMSE(org,spatial),4))

        if measure=="ssim":
            bic_values.append(round(calcSSIM(org,bic),4))
            sr_values.append(round(calcSSIM(org,sr),4))
            spatial_values.append(round(calcSSIM(org,spatial),4))
            
        indexValues.append("Image "+str(i))
        i=i+1
        
    df=pd.DataFrame(data=[bic_values,sr_values,spatial_values]).transpose()
    df.columns=cols
    df.index=indexValues
    
    return df

In [None]:
def plotMeanStd(df,measure):
    mean_values=[]
    std_values=[]
    
    mean_values.append(df["Bicubic "+measure].mean())
    mean_values.append(df["SRCNN "+measure].mean())
    mean_values.append(df["Spatial SRCNN "+measure].mean())
    
    std_values.append(df["Bicubic "+measure].std())
    std_values.append(df["SRCNN "+measure].std())
    std_values.append(df["Spatial SRCNN "+measure].std())
    
    
    xticks=['Bic','SR','SR+Spatial']
    x = np.array([1, 2, 3])
    
    plt.xticks(x,xticks)
    plt.errorbar(x, mean_values, std_values, linestyle='None', ecolor="red",marker='o')
    plt.show()

In [None]:
#Will find values that are greater than 1 std away from mean
def find_outliers(df,col):
    mean=df[col].mean()
    std=df[col].std()
    return df[(df[col]>mean+std) | (df[col]<mean-std)][col]

<h2>PSNR</h2>

In [None]:
df_psnr=calculate_measures("psnr")

In [None]:
df_psnr

In [None]:
print(df_psnr.idxmin())
print("")
print(df_psnr.idxmax())

Lowest and highest PSNR was from image 1 and 6 respectively. Note that the lowest/highest value is consistent regardless
of method. We can find the images that either performed rather well or poorly by looking at how much they deviate from
mean (here looking at mean +- 1 std)

In [None]:
find_outliers(df_psnr,"Bicubic psnr")

In [None]:
find_outliers(df_psnr,"SRCNN psnr")

In [None]:
find_outliers(df_psnr,"Spatial SRCNN psnr")

For all 3 methods, it was the same set of images that deviates a lot.<br>
Low measure: Image 1, 5, 13 <br>
High measure: Image 6, 9

Plot of mean and standard deviation for all 3 methods 

In [None]:
plotMeanStd(df_psnr,"psnr")

<h2>RMSE</h2>

In [None]:
df_rmse=calculate_measures("rmse")

In [None]:
df_rmse

In [None]:
print(df_rmse.idxmin())
print("")
print(df_rmse.idxmax())

In [None]:
find_outliers(df_rmse,"Bicubic rmse")

In [None]:
find_outliers(df_rmse,"SRCNN rmse")

In [None]:
find_outliers(df_rmse,"Spatial SRCNN rmse")

In [None]:
plotMeanStd(df_rmse,"rmse")

<h2>SSIM</h2>

In [None]:
df_ssim=calculate_measures("ssim")

In [None]:
df_ssim

In [None]:
print(df_ssim.idxmin())
print("")
print(df_ssim.idxmax())

In [None]:
find_outliers(df_ssim,"Bicubic ssim")

In [None]:
find_outliers(df_ssim,"SRCNN ssim")

In [None]:
find_outliers(df_ssim,"Spatial SRCNN ssim")

In [None]:
plotMeanStd(df_ssim,"ssim")

PSNR and RMSE behave similarly in regards to which images where super resolved the best and worst and which ones. SSIM on the other hand has a different subset of images with measures far from mean.

<h2>Plots</h2>

In [None]:
for i in range(14):
    plot_results(i)