In [1]:
import numpy as np

from PIL import Image
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

import torch
from torchvision.transforms import ToTensor

from net.model import Generator

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# ===========================================================
# model import & setting
# ===========================================================

filepath='/home/guozy/BISHE/MyNet/result/2023-04-25_18:58:17/checkpoints/156_checkpoint.pkl'
checkpoint = torch.load(filepath, map_location='cuda:0')

model = Generator(n_residual_blocks=16, upsample_factor=4, base_filter=64, num_channel=3).to("cuda:0")
model.load_state_dict(checkpoint['G_state_dict'])
model.eval()

In [4]:
# ===========================================================
# compare origin with upsample in resolve
# ===========================================================

image = Image.open("/home/guozy/BISHE/dataset/Set5/butterfly.png").convert('RGB')
image_width = image.width  * 4
image_height = image.height * 4
origin_to_upsample_by_Bicubic = image.resize((image_width, image_height), resample=Image.BICUBIC)
origin_to_upsample_by_Bicubic.save('/home/guozy/BISHE/MyNet/rebuild/origin_to_upsample_by_Bicubic.jpg')

image = Image.open("/home/guozy/BISHE/dataset/Set5/butterfly.png").convert('RGB')
x = ToTensor()(image) 
x = x.to('cuda:0').unsqueeze(0)
out = model(x)
out = out.detach().squeeze(0)
out = out.permute(1,2,0).cpu().numpy() * 255.0
origin_to_upsample_by_NN = Image.fromarray(out.astype(np.uint8))
origin_to_upsample_by_NN.save('/home/guozy/BISHE/MyNet/rebuild/origin_to_upsample_by_NN.jpg')


  origin_to_upsample_by_Bicubic = image.resize((image_width, image_height), resample=Image.BICUBIC)


In [5]:
# ===========================================================
# compare origin with downsample in one image
# ===========================================================
image = Image.open("/home/guozy/BISHE/dataset/Set5/butterfly.png").convert('RGB')
image_width = (image.width // 4) * 4
image_height = (image.height // 4) * 4
if image_height != image.height or image_width != image.width:
    image = image.resize((image_width, image_height), resample=Image.BICUBIC)
image.save('/home/guozy/BISHE/MyNet/rebuild/origin.jpg')

downsample=image.resize((image.width // 4, image.height // 4), resample=Image.BICUBIC)
downsample.save('/home/guozy/BISHE/MyNet/rebuild/downsample.jpg')
downsample_to_origin_by_Bicubic=downsample.resize((image.width, image.height), resample=Image.BICUBIC)
downsample_to_origin_by_Bicubic.save('/home/guozy/BISHE/MyNet/rebuild/downsample_to_origin_by_Bicubic.jpg')

x = (ToTensor()(downsample))
x = x.to('cuda:0').unsqueeze(0)
out = model(x).squeeze(0)
out = out.detach().permute(1,2,0).cpu().numpy() * 255.0
downsample_to_origin_by_NN = Image.fromarray(out.astype(np.uint8))
downsample_to_origin_by_NN.save('/home/guozy/BISHE/MyNet/rebuild/downsample_to_origin_by_NN.jpg')

  downsample=image.resize((image.width // 4, image.height // 4), resample=Image.BICUBIC)
  downsample_to_origin_by_Bicubic=downsample.resize((image.width, image.height), resample=Image.BICUBIC)


In [6]:
# ===========================================================
# compare origin with downsample in whole dataset
# ===========================================================
from os import listdir
from os.path import join

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
    
# image_dir = '/home/guozy/BISHE/dataset/Set5/'
# image_dir = '/home/guozy/BISHE/dataset/Set14/'
image_dir = '/home/guozy/BISHE/dataset/BSD100/'

image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]

avg_psnr_NN = 0
avg_ssim_NN = 0
avg_mse_NN = 0

for image_filename in image_filenames:

    # part1
    image = Image.open(image_filename).convert('RGB')
    image_width = (image.width // 4) * 4
    image_height = (image.height // 4) * 4
    if image_height != image.height or image_width != image.width:
        image = image.resize((image_width, image_height), resample=Image.BICUBIC)
        
    downsample = image.resize((image.width // 4, image.height // 4), resample=Image.Resampling.BICUBIC)

    # part2
    x = (ToTensor()(downsample)).to('cuda:0').unsqueeze(0)
    with torch.no_grad():
        out = model(x)
    out = out.mul(255.0).squeeze(0).permute(1,2,0).cpu().numpy()

    # part3
    image = np.array(image, dtype=np.uint8)
    out = out.astype(np.uint8)
    
    p2 = psnr(out, image)
    s2 = ssim(out, image, channel_axis=2)
    m2= mse(out, image)
    avg_psnr_NN += p2
    avg_ssim_NN +=s2
    avg_mse_NN += m2

avg_psnr_NN /= len(image_filenames)
avg_ssim_NN /= len(image_filenames)
avg_mse_NN /= len(image_filenames)

print('NN: psnr:{} , ssim:{}, mse:{}\n'.format(avg_psnr_NN,avg_ssim_NN,avg_mse_NN))

NN: psnr:25.23191410507125 , ssim:0.7213611305163192, mse:254.76432784288187

