In [1]:
import utils as U
import dataset.cufed5 as c5

from tqdm import tqdm
from PIL import Image

import torch
import torch.nn.functional as F

import torchvision.transforms as T
import torchvision.transforms.functional as TF

import images as I

import msoj
import refsr

import metrics as M



In [2]:

pkg = c5.Package("test/raw/CUFED5")

def upscale_bicubic(t):
    if t.dim() == 3:
        t = t.unsqueeze(0)
    assert t.dim() == 4
    assert t.size(0) == 1
    
    return F.interpolate(t, scale_factor=4, mode="bicubic", align_corners=True)
    
is_cuda=True
refsr_model = refsr.get_default_sr_model(cuda=is_cuda)
vgg_model = refsr.get_default_vgg_model(cuda=is_cuda)
sr_model = upscale_bicubic
method = refsr.RefSR(sr_model, refsr_model, vgg_model)

ssim_base = []
psnr_base = []
ssim = []
psnr = []
for i in tqdm(range(len(pkg))):
    hr = U.data.ImageDataset(f"data/test/proc/CUFED5/{i:03}/hr/*.png", T.Compose([
        T.ToTensor(),
        T.Lambda(lambda x: x.cuda().unsqueeze(0) * 255)
    ]))
    lr = U.data.ImageDataset(f"data/test/proc/CUFED5/{i:03}/lr/*.png", T.Compose([
        T.ToTensor(),
        T.Lambda(lambda x: x.cuda().unsqueeze(0) * 255)
    ]))
    
    for j in range(len(hr)):
        try:
            x, y = lr[j], hr[j]
            y2 = sr_model(x)

            y, y2 = y / 255, y2 / 255
            
            psnr_base.append(M.pytorch_psnr(y, y2).item())
            ssim_base.append(M.pytorch_ssim(y, y2).item())
            
#             del x, y, y2
        except AssertionError as ae:
            print(f"BASE: There was an assertion error [{i}, {j}]")
            print(ae)
        except Exception as e:
            print(f"BASE: There was an error [{i}, {j}]")
            raise e

    refs = U.data.ImageDataset(f"data/test/proc/CUFED5/{i:03}/s1/*.png", T.Compose([
        T.ToTensor(),
        T.Lambda(lambda x: x.cuda().unsqueeze(0) * 255)
    ]))
    refs = list(refs)

    for j in range(len(hr)):
        try:
            x, y = lr[j], hr[j]
            y2 = method.upscale_with_ref(x, refs)

            y, y2 = y / 255, y2 / 255
            
            psnr.append(M.pytorch_psnr(y, y2).item())
            ssim.append(M.pytorch_ssim(y, y2).item())

            del x, y, y2
        except AssertionError as ae:
            print(f"NTT: There was an assertion error [{i}, {j}]")
            print(ae)
        except Exception as e:
            print(f"NTT: There was an error [{i}, {j}]")
            raise e
    del refs
    
    print("BASE", sum(ssim_base) / len(ssim_base), sum(psnr_base) / len(psnr_base))
    print("NTT", sum(ssim) / len(ssim), sum(psnr) / len(psnr))


  0%|                                                                                                                                                                  | 0/126 [00:00<?, ?it/s]

tensor(0.2610, device='cuda:0') tensor(0.2614, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.2808, device='cuda:0') tensor(0.2794, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.2780, device='cuda:0') tensor(0.2784, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.2117, device='cuda:0') tensor(0.2106, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.2001, device='cuda:0') tensor(0.1996, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3681, device='cuda:0') tensor(0.3710, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3492, device='cuda:0') tensor(0.3512, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3121, device='cuda:0') tensor(0.3168, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
BASE 0.5650530122220516 

  1%|█▏                                                                                                                                                        | 1/126 [00:07<16:38,  7.99s/it]

tensor(0.7155, device='cuda:0') tensor(0.7188, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.8015, device='cuda:0') tensor(0.8053, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.8289, device='cuda:0') tensor(0.8354, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.7746, device='cuda:0') tensor(0.7801, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.4493, device='cuda:0') tensor(0.4507, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.4839, device='cuda:0') tensor(0.4830, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.4955, device='cuda:0') tensor(0.4938, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.4817, device='cuda:0') tensor(0.4816, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
BASE 0.5897363442927599 

  2%|██▍                                                                                                                                                       | 2/126 [00:14<15:44,  7.62s/it]

tensor(0.2966, device='cuda:0') tensor(0.2927, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3014, device='cuda:0') tensor(0.2966, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3195, device='cuda:0') tensor(0.3176, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3602, device='cuda:0') tensor(0.3615, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.6239, device='cuda:0') tensor(0.6236, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.6554, device='cuda:0') tensor(0.6578, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.5457, device='cuda:0') tensor(0.5478, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.5699, device='cuda:0') tensor(0.5729, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
BASE 0.6032237447798252 

  2%|███▋                                                                                                                                                      | 3/126 [00:21<14:57,  7.29s/it]

tensor(0.3536, device='cuda:0') tensor(0.3541, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3662, device='cuda:0') tensor(0.3660, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3476, device='cuda:0') tensor(0.3462, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3353, device='cuda:0') tensor(0.3360, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3608, device='cuda:0') tensor(0.3618, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3784, device='cuda:0') tensor(0.3762, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3787, device='cuda:0') tensor(0.3820, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.2650, device='cuda:0') tensor(0.2660, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.4288, device='c

  3%|████▉                                                                                                                                                     | 4/126 [00:32<17:21,  8.54s/it]

tensor(0.6715, device='cuda:0') tensor(0.6735, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.7188, device='cuda:0') tensor(0.7166, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.6218, device='cuda:0') tensor(0.6175, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.5463, device='cuda:0') tensor(0.5431, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.1844, device='cuda:0') tensor(0.1836, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3135, device='cuda:0') tensor(0.3137, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.3242, device='cuda:0') tensor(0.3254, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])
tensor(0.2077, device='cuda:0') tensor(0.2065, device='cuda:0')
torch.Size([1, 3, 120, 120]) torch.Size([1, 3, 120, 120])


KeyboardInterrupt: 

In [None]:
M.pytorch_psnr(y, y2)

In [None]:
I.to_pil_image(y2 * 127)[0][0]