In [186]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import cv2

import math
import numpy as np

import os

In [150]:
def upscale(img, scale):
    height = img.shape[0]*scale
    width = img.shape[1]*scale
    out = [[0]*width for i in range(height)]
    
    for i in range(height):
        for j in range(width):
            out[i][j] = img[i//scale][j//scale]
    return out

In [153]:
def psnr(label, outputs, max_val=1.):
    """
    Compute Peak Signal to Noise Ratio (the higher the better).
    PSNR = 20 * log10(MAXp) - 10 * log10(MSE).
    https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio#Definition
    First we need to convert torch tensors to NumPy operable.
    """
    label = label
    outputs = outputs
    img_diff = outputs - label
    rmse = math.sqrt(np.mean((img_diff) ** 2))
    if rmse == 0:
        return 100
    else:
        PSNR = 20 * math.log10(max_val / rmse)
        return PSNR

In [199]:
def generate_patches(im_high, im_low, size, ct):
    xcount = im_high.shape[0]//size
    ycount = im_high.shape[1]//size
    
    for i in range(xcount):
        for j in range(ycount):
            ct = ct + 1
            patch_high = np.array(im_high)[i*size:(i+1)*size, j*size:(j+1)*size]
            cv2.imwrite("data/Set14/patches_high_res/%s.png" % str(ct), cv2.cvtColor(patch_high*255, cv2.COLOR_GRAY2BGR))
            patch_low = np.array(im_low)[i*size:(i+1)*size, j*size:(j+1)*size]
            cv2.imwrite("data/Set14/patches_low_res/%s.png" % str(ct), cv2.cvtColor(patch_low*255, cv2.COLOR_GRAY2BGR))
    
    return ct
    

In [204]:
def make_pair(filename, ct):
    img = mpimg.imread(filename)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    gray_down = cv2.resize(gray, (int(gray.shape[1]/2), int(gray.shape[0]/2)), interpolation=cv2.INTER_CUBIC)
    gray_down = upscale(gray_down, 2)
    
    ct = generate_patches(gray, gray_down, 40, ct)
    return ct

In [205]:
directory = "data/Set14"

In [206]:
ct = 0
for filename in os.listdir(directory):
    f = os.path.join(directory, filename)
    # checking if it is a file
    if os.path.isfile(f) and "DS_Store" not in f:
        ct = make_pair(str(f), ct)
        print(f + " %i" % ct)

data/Set14/monarch.png 228
data/Set14/flowers.png 336
data/Set14/bridge.png 480
data/Set14/ppt3.png 688
data/Set14/zebra.png 814
data/Set14/lenna.png 958
data/Set14/barbara.png 1210
data/Set14/face.png 1246
data/Set14/comic.png 1300
data/Set14/pepper.png 1444
data/Set14/man.png 1588
data/Set14/coastguard.png 1644
data/Set14/foreman.png 1700
data/Set14/baboon.png 1844
