In [None]:
import os.path as osp
import glob
import cv2
import numpy as np
import torch
import models.RRDBNet_arch as arch
from math import log10, sqrt

In [None]:
raw_img_folder = 'raw_pictures/*'

In [None]:
id = 0
print("Downscaling Raw Pictures:")
for path in glob.glob(raw_img_folder):
    id += 1
    base = osp.splitext(osp.basename(path))[0]
    print(id, base)
    # read images
    raw = cv2.imread(path)
    new_height = int(raw.shape[0] / 4)
    new_width = int(raw.shape[1] / 4)
    dimension = (new_width, new_height)
    lowres = cv2.resize(raw,dimension)
    cv2.imwrite('lowres/{:s}.png'.format(base), lowres)

In [None]:
model_path = 'models/RRDB_ESRGAN_x4.pth'  # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
device = torch.device('cuda')  # if you want to run on CPU, change 'cuda' -> cpu

In [None]:
test_img_folder = 'lowres/*'

In [None]:
model = arch.RRDBNet(3, 3, 64, 23, gc=32)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)

In [None]:
idx = 0
print('Model path {:s}. \nTesting...'.format(model_path))
for path in glob.glob(test_img_folder):
    idx += 1
    base = osp.splitext(osp.basename(path))[0]
    print(idx, base)
    # read images
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    img = img * 1.0 / 255
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
    img_LR = img.unsqueeze(0)
    img_LR = img_LR.to(device)

    with torch.no_grad():
        output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
    output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
    output = (output * 255.0).round()
    cv2.imwrite('results/{:s}.png'.format(base), output)

In [None]:
results_folder = 'results/*'
def PSNR(original, sr):
    mse = np.mean((original - sr) ** 2)
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                  # Therefore PSNR have no importance.
        return 100
    max_pixel = 255.0
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr

idy = 0
print("Calculating SNRs:")
for path in glob.glob(results_folder):
    idy += 1
    base = osp.splitext(osp.basename(path))[0]
    print(idy, base)
    raw = cv2.imread("raw_pictures/{}.png" .format(base))
    sr = cv2.imread("results/{}.png" .format(base))
    value =str(round(PSNR(raw, sr),2))
    print("SNR: "+ value +" dB")