In [4]:
import numpy as np
import cv2

from PIL import Image
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

import torch
from torchvision.transforms import ToTensor

from net.model import Generator

In [5]:
# ===========================================================
# model import & setting
# ===========================================================

filepath='/home/guozy/BISHE/MyNet/result/baseline/checkpoints/356_checkpoint.pkl'
checkpoint = torch.load(filepath, map_location='cuda:0')

model = Generator(n_residual_blocks=16, upsample_factor=4, base_filter=64, num_channel=3).to("cuda:0")
model.load_state_dict(checkpoint['G_state_dict'])
model.eval()

Generator(
  (prelu): PReLU(num_parameters=1)
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (residual_block1): ResidualBlock(
    (prelu): PReLU(num_parameters=1)
    (sigmoid): Sigmoid()
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (conv4): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (confusion): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (avg_conv1): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1))
    (avg_conv2): Conv2d(4, 64, kernel_size=(1, 1), stride=(1, 1))
  )
  (residual_block2): ResidualBlock(
    (prelu): PReLU(num_parameters=1)
    (sigmoid): Sigmoid()
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(64, 64, kernel

In [6]:
# ===========================================================
# compare with origin to upsample
# ===========================================================

img = cv2.imread("/home/guozy/BISHE/dataset/Set5/butterfly.png")
w=img.shape[1]  
h=img.shape[0]
upsample=cv2.resize(img,(w*4,h*4),interpolation=cv2.INTER_CUBIC)
cv2.imwrite('/home/guozy/BISHE/MyNet/rebuild/origin_to_upsample_by_Bicubic.jpg',upsample)

img = Image.open("/home/guozy/BISHE/dataset/Set5/butterfly.png")
data = ToTensor()(img) 
data = data.to('cuda:0').unsqueeze(0)
out = model(data)
out = out.detach().squeeze(0)
out = out.permute(1,2,0).cpu().numpy() * 255.0
out = Image.fromarray(out.astype(np.uint8))
out.save('/home/guozy/BISHE/MyNet/rebuild/origin_to_upsample_by_NN.jpg')


In [13]:
# ===========================================================
# compare with downsample to origin in psnr and ssim
# ===========================================================
img = cv2.imread("/home/guozy/BISHE/dataset/Set5/butterfly.png")
w=img.shape[1]  
h=img.shape[0]
downsample=cv2.resize(img,(w//4,h//4),interpolation=cv2.INTER_CUBIC)
downsample_to_origin=cv2.resize(downsample,(w,h),interpolation=cv2.INTER_CUBIC)
cv2.imwrite('/home/guozy/BISHE/MyNet/rebuild/downsample_to_origin_by_Bicubic.jpg',downsample_to_origin,[int(cv2.IMWRITE_JPEG_QUALITY), 100])

data = (ToTensor()(downsample))
data = data.to('cuda:0').unsqueeze(0)
out = model(data).squeeze(0)

out = out.detach().permute(1,2,0).cpu().numpy() * 255.0
result = Image.fromarray(out.astype(np.uint8))
result.save('/home/guozy/BISHE/MyNet/rebuild/downsample_to_origin_by_NN.jpg')

img = np.array(img, dtype=np.int8)
out = out.astype(np.int8)
downsample_to_origin = downsample_to_origin.astype(np.int8)

p1 = psnr(downsample_to_origin, img)
s1 = ssim(downsample_to_origin, img, channel_axis=2)  
m1 = mse(downsample_to_origin, img)
print('BiCubic下采样超分辨重建结果： psnr:{} , ssim:{}, mse:{}\n'.format(p1,s1,m1))

p2 = psnr(out, img)
s2 = ssim(out, img, channel_axis=2)
m2= mse(out, img)
print('NN下采样超分辨重建结果： psnr:{} , ssim:{}, mse:{}\n'.format(p2,s2,m2))

tensor(1.9313, device='cuda:0', grad_fn=<MaxBackward1>)
BiCubic下采样超分辨重建结果： psnr:13.7536805302118 , ssim:0.4348710171852829, mse:2739.7586568196616

NN下采样超分辨重建结果： psnr:8.66580031918092 , ssim:-0.01742854856585117, mse:8840.978388468424

