In [1]:
import numpy as np
from os import listdir
from os.path import join
import math

from PIL import Image
from skimage.color import rgb2ycbcr
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

import torch
from torchvision.transforms import ToTensor

from net.model import Generator

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ===========================================================
# model import & setting
# ===========================================================

filepath='/home/guozy/BISHE/MyNet/result/weight_1501_2000/2000_checkpoint.pkl'
checkpoint = torch.load(filepath, map_location='cuda:0')

model = Generator(n_residual_blocks=16, upsample_factor=4, base_filter=64, num_channel=3).to("cuda:0")
model.load_state_dict(checkpoint['G_state_dict'])
model.eval()

Generator(
  (upsample): Upsample(scale_factor=4.0, mode=bicubic)
  (head): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (body): Sequential(
    (0): ResidualBlock(
      (leakyrelu): LeakyReLU(negative_slope=0.2)
      (conv3_1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv3_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv5_1): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (conv5_2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (confusion): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      (avg_pool): AdaptiveAvgPool2d(output_size=1)
      (avg_conv1): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1))
      (avg_conv2): Conv2d(4, 64, kernel_size=(1, 1), stride=(1, 1))
      (sigmoid): Sigmoid()
    )
    (1): ResidualBlock(
      (leakyrelu): LeakyReLU(negative_slope=0.2)
      (conv3_1): Conv2d(64, 64, kernel_size=

In [3]:
# ===========================================================
# compare origin with upsample in resolve
# ===========================================================

image = Image.open("/home/guozy/BISHE/dataset/Set5/butterfly.png").convert('RGB')
image_width = image.width  * 4
image_height = image.height * 4
origin_to_upsample_by_Bicubic = image.resize((image_width, image_height), resample=Image.Resampling.BICUBIC)
origin_to_upsample_by_Bicubic.save('/home/guozy/BISHE/MyNet/rebuild/origin_to_upsample_by_Bicubic.jpg')

image = Image.open("/home/guozy/BISHE/dataset/Set5/butterfly.png").convert('RGB')
x = ToTensor()(image) 
x = x.to('cuda:0').unsqueeze(0)
out = model(x)
out = out.detach().squeeze(0).clamp(0,1)
out = out.permute(1,2,0).cpu().numpy() * 255.0
origin_to_upsample_by_NN = Image.fromarray(out.astype(np.uint8))
origin_to_upsample_by_NN.save('/home/guozy/BISHE/MyNet/rebuild/origin_to_upsample_by_NN.jpg')


In [4]:
# ===========================================================
# compare origin with downsample in one image
# ===========================================================
image = Image.open("/home/guozy/BISHE/dataset/Set5/butterfly.png").convert('RGB')
image_width = (image.width // 4) * 4
image_height = (image.height // 4) * 4
if image_height != image.height or image_width != image.width:
    image = image.resize((image_width, image_height), resample=Image.Resampling.BICUBIC)
image.save('/home/guozy/BISHE/MyNet/rebuild/origin.jpg')

downsample=image.resize((image.width // 4, image.height // 4), resample=Image.Resampling.BICUBIC)
downsample.save('/home/guozy/BISHE/MyNet/rebuild/downsample.jpg')
downsample_to_origin_by_Bicubic=downsample.resize((image.width, image.height), resample=Image.Resampling.BICUBIC)
downsample_to_origin_by_Bicubic.save('/home/guozy/BISHE/MyNet/rebuild/downsample_to_origin_by_Bicubic.jpg')

x = (ToTensor()(downsample))
x = x.to('cuda:0').unsqueeze(0)
out = model(x).squeeze(0).clamp(0,1)
out = out.detach().permute(1,2,0).cpu().numpy() * 255.0
downsample_to_origin_by_NN = Image.fromarray(out.astype(np.uint8))
downsample_to_origin_by_NN.save('/home/guozy/BISHE/MyNet/rebuild/downsample_to_origin_by_NN.jpg')

In [7]:
def count_for_NN_index_in_Y(image_name):

    image_dir = '/home/guozy/BISHE/dataset/' + image_name
    image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]

    avg_psnr_NN = 0
    avg_ssim_NN = 0
    avg_mse_NN = 0

    for image_filename in image_filenames:

        # part1
        image = Image.open(image_filename).convert('RGB')
        image_width = (image.width // 4) * 4
        image_height = (image.height // 4) * 4
        if image_height != image.height or image_width != image.width:
            image = image.resize((image_width, image_height), resample=Image.Resampling.BICUBIC)
            
        downsample = image.resize((image.width // 4, image.height // 4), resample=Image.Resampling.BICUBIC)

        # part2
        x = (ToTensor()(downsample)).to('cuda:0').unsqueeze(0)
        with torch.no_grad():
            out = model(x).clamp(0,1)
        out = out.squeeze(0).permute(1,2,0).cpu().numpy()

        # part3
        image = np.array(image, dtype=np.float32) / 255.0
        image_y = rgb2ycbcr(image)[:,:,0]
        image_y =  image_y.astype(np.uint8)

        out = out.astype(np.float32)
        out_y = rgb2ycbcr(out)[:,:,0]
        out_y = out_y.astype(np.uint8)
        
        s2 = ssim(out_y, image_y, channel_axis=None)
        m2 = mse(out_y, image_y)
        avg_psnr_NN += 10 * math.log10(255*255/m2)
        avg_ssim_NN += s2
        avg_mse_NN += m2

    avg_psnr_NN /= len(image_filenames)
    avg_ssim_NN /= len(image_filenames)
    avg_mse_NN /= len(image_filenames)

    print(image_name + ': psnr:{} , ssim:{}, mse:{}\n'.format(avg_psnr_NN,avg_ssim_NN,avg_mse_NN))

image_names = ['Set5','Set14', 'BSD100']
for image_name in image_names:
    count_for_NN_index_in_Y(image_name)

Set5: psnr:29.566317575733137 , ssim:0.8782743945598469, mse:89.29434466270428

Set14: psnr:26.917976130585032 , ssim:0.7882531381409968, mse:161.76664359103862

BSD100: psnr:26.6489482993831 , ssim:0.7502482776230104, mse:185.18541354166666



In [10]:
def count_for_Bicubic_index_in_Y(image_name):

    image_dir = '/home/guozy/BISHE/dataset/' + image_name
    image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]

    avg_psnr_NN = 0
    avg_ssim_NN = 0
    avg_mse_NN = 0

    for image_filename in image_filenames:

        # part1
        image = Image.open(image_filename).convert('RGB')
        image_width = (image.width // 4) * 4
        image_height = (image.height // 4) * 4
        if image_height != image.height or image_width != image.width:
            image = image.resize((image_width, image_height), resample=Image.Resampling.BICUBIC)
            
        downsample = image.resize((image.width // 4, image.height // 4), resample=Image.Resampling.BICUBIC)

        # part2
        out = downsample.resize((image.width, image.height), resample=Image.Resampling.BICUBIC)

        # part3
        image = np.array(image, dtype=np.float32) / 255.0
        image_y = rgb2ycbcr(image)[:,:,0]
        image_y =  image_y.astype(np.uint8)

        out = np.array(out, dtype=np.float32) / 255.0
        out_y = rgb2ycbcr(out)[:,:,0]
        out_y = out_y.astype(np.uint8)
        
        s2 = ssim(out_y, image_y, channel_axis=None)
        m2 = mse(out_y, image_y)
        avg_psnr_NN += 10 * math.log10(255*255/m2)
        avg_ssim_NN += s2
        avg_mse_NN += m2

    avg_psnr_NN /= len(image_filenames)
    avg_ssim_NN /= len(image_filenames)
    avg_mse_NN /= len(image_filenames)

    print(image_name + ': psnr:{} , ssim:{}, mse:{}\n'.format(avg_psnr_NN,avg_ssim_NN,avg_mse_NN))

image_names = ['Set5','Set14', 'BSD100']
for image_name in image_names:
    count_for_Bicubic_index_in_Y(image_name)

Set5: psnr:28.422888781587876 , ssim:0.8222543106870681, mse:139.30414506010658

Set14: psnr:25.952405865950247 , ssim:0.724638411100365, mse:198.75281232847402

BSD100: psnr:25.973030472953692 , ssim:0.6861715205320509, mse:210.16578886718756

