In [25]:
import math
import os
from os import listdir
from os.path import join
import numpy as np
import torch
from torch import nn
from PIL import Image
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import torchvision.utils as utils
from torch.utils.data import DataLoader
from tqdm import tqdm
import pytorch_ssim
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
from torchvision.transforms import InterpolationMode

from torch.utils.data.dataset import Dataset
import torchvision.utils as vutils
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

In [26]:
device = torch.device('cuda:2')

In [27]:
class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()
    def forward(self, x):
        return self.prelu(self.pixel_shuffle(self.conv(x)))

class ResidualBlock(nn.Module):
    def __init__(self, channels=64):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):        
        return self.bn2(self.conv2(self.prelu(self.bn1(self.conv1(x))))) + x

class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        self.b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, padding=4), nn.PReLU())
        self.b2 = nn.Sequential(*[ResidualBlock(64) for _ in range(16)])
        self.b3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64))
        self.b4 = nn.Sequential(*[UpsampleBLock(64, 2) for _ in range(upsample_block_num)])
        self.tail = nn.Conv2d(64, 3, kernel_size=9, padding=4)

    def forward(self, x):
        start = self.b1(x)
        end = self.b4(self.b3(self.b2(start)) + start)
        return self.tail(end)

In [28]:
G_dict_path = 'epochs/G_stable.pth'
DATA_PATH = '../../SR_testing_datasets/'

In [29]:
datasets = [DATA_PATH + i for i in ['Set5', 'BSDS100', 'Set14']]
datasets

['../../SR_testing_datasets/Set5',
 '../../SR_testing_datasets/BSDS100',
 '../../SR_testing_datasets/Set14']

In [30]:
G = Generator(4)
G.load_state_dict(torch.load(G_dict_path))
G.to(device)
G.eval()
print()




In [51]:
for dataset in datasets:
    ssim = 0
    psnr = 0
    for img in os.listdir(dataset):
        hr = Image.open(dataset + '/' + img).convert("RGB")
        w = hr.size[1]
        h = hr.size[0]
        lr = transforms.Resize((w // 4, h // 4), InterpolationMode.BICUBIC)(hr)
        hr = transforms.ToTensor()(hr)
        hr = hr.unsqueeze(0).to(device)
        lr = transforms.ToTensor()(lr)
        lr = lr.unsqueeze(0).to(device)
        sr = G(lr)
        sr = transforms.Resize((w, h), InterpolationMode.BICUBIC)(sr)
        image = torch.dstack([transforms.Resize((w, h), InterpolationMode.BICUBIC)(lr)[0], torch.ones_like(sr)[:,:,:,:20][0], sr[0], torch.ones_like(sr)[:,:,:,:20][0], hr[0]])
        vutils.save_image(image, './results/' + dataset.split('/')[-1] + '/' + img.split('.')[0] + '.jpg')
        ssim += pytorch_ssim.ssim(sr, hr).item()
        psnr += 10 * math.log10((hr.max()**2) / ((sr - hr) ** 2).data.mean())
    print(dataset.split('/')[-1])
    print('SSIM: ', round(ssim / len(os.listdir(dataset)), 4))
    print('PSNR: ', round(psnr / len(os.listdir(dataset)), 4))
    print('______________')

Set5
SSIM:  0.7849
PSNR:  27.1157
______________
BSDS100
SSIM:  0.6411
PSNR:  23.9407
______________
Set14
SSIM:  0.6557
PSNR:  23.9995
______________


In [None]:
Set5
SSIM:  0.7824
PSNR:  27.0172
______________
BSDS100
SSIM:  0.6366
PSNR:  23.8347
______________
Set14
SSIM:  0.6533
PSNR:  23.9596
______________