In [1]:
import numpy as np
import cv2

from PIL import Image
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

import torch
import torch.backends.cudnn as cudnn
from torchvision.transforms import ToTensor

from net.model import Generator, Discriminator, VGG19
from net.solver import MyNetTrainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ===========================================================
# model import & setting
# ===========================================================

filepath='/home/guozy/BISHE/MyNet/result/2023-04-24_00:48:44/checkpoints/epoch_1_G.pt'
# checkpoint = torch.load(filepath)
# model = Generator(n_residual_blocks=16, upsample_factor=4, base_filter=64, num_channel=3)
# model.load_state_dict(checkpoint['model_state_dict'], strict=True)
# state_dict = model.state_dict()
# for n, p in checkpoint['model_state_dict'].items():
#     if n in state_dict.keys():
#         state_dict[n].copy_(p)
#     else:
#         raise KeyError(n)
model = torch.load(filepath)
model.to('cuda:0').eval()

Generator(
  (prelu): PReLU(num_parameters=1)
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (residual_block1): ResidualBlock(
    (prelu): PReLU(num_parameters=1)
    (sigmoid): Sigmoid()
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (conv4): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (confusion): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (avg_conv1): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1))
    (avg_conv2): Conv2d(4, 64, kernel_size=(1, 1), stride=(1, 1))
  )
  (residual_block2): ResidualBlock(
    (prelu): PReLU(num_parameters=1)
    (sigmoid): Sigmoid()
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(64, 64, kernel

In [3]:
# ===========================================================
# compare with origin to upsample
# ===========================================================

img = cv2.imread("/home/guozy/BISHE/dataset/Set5/butterfly.png")
w=img.shape[1]  
h=img.shape[0]
upsample=cv2.resize(img,(w*4,h*4),interpolation=cv2.INTER_CUBIC)
cv2.imwrite('/home/guozy/BISHE/MyNet/rebuild/origin_to_upsample_by_Bicubic.jpg',upsample,[int(cv2.IMWRITE_JPEG_QUALITY), 100])

img = Image.open("/home/guozy/BISHE/dataset/Set5/butterfly.png")
data = ToTensor()(img) 
data = data.to('cuda:0').unsqueeze(0)
out = model(data)
out = out.detach().squeeze(0).clip(0, 255)
print(out)
out = out.permute(1,2,0).cpu().numpy()
out = Image.fromarray(out, mode="RGB")
out.save('/home/guozy/BISHE/MyNet/rebuild/origin_to_upsample_by_NN.jpg')


tensor([[[4.2559e-02, 2.0149e-01, 0.0000e+00,  ..., 1.7049e-01,
          5.1804e-02, 1.9847e-01],
         [6.4296e-02, 1.8473e-01, 1.3561e-01,  ..., 2.8621e-01,
          2.1096e-01, 9.2691e-02],
         [2.1567e-01, 5.1997e-02, 2.7934e-01,  ..., 2.0908e-01,
          8.3931e-02, 0.0000e+00],
         ...,
         [1.4871e-02, 0.0000e+00, 1.4482e-01,  ..., 9.8711e-02,
          1.1434e-01, 1.1403e-02],
         [0.0000e+00, 0.0000e+00, 5.4809e-02,  ..., 2.2520e-01,
          9.1170e-02, 0.0000e+00],
         [1.1716e-01, 0.0000e+00, 2.9692e-02,  ..., 9.3137e-02,
          1.4883e-01, 1.6313e-02]],

        [[4.6350e-02, 0.0000e+00, 0.0000e+00,  ..., 9.9560e-02,
          0.0000e+00, 1.6428e-01],
         [2.1281e-01, 0.0000e+00, 0.0000e+00,  ..., 1.3751e-01,
          6.4249e-02, 0.0000e+00],
         [8.5572e-02, 0.0000e+00, 2.1689e-01,  ..., 2.4251e-01,
          0.0000e+00, 1.0089e-01],
         ...,
         [9.9592e-02, 6.5644e-02, 5.2636e-02,  ..., 2.1097e-01,
          0.000

In [8]:
# ===========================================================
# compare with downsample to origin in psnr and ssim
# ===========================================================
img = cv2.imread("/home/guozy/BISHE/dataset/Set5/butterfly.png")
w=img.shape[1]  
h=img.shape[0]
downsample=cv2.resize(img,(w//4,h//4),interpolation=cv2.INTER_CUBIC)
downsample_to_origin=cv2.resize(downsample,(w,h),interpolation=cv2.INTER_CUBIC)
cv2.imwrite('/home/guozy/BISHE/MyNet/rebuild/downsample_to_origin_by_Bicubic.jpg',downsample_to_origin,[int(cv2.IMWRITE_JPEG_QUALITY), 100])

data = (ToTensor()(downsample))
data = data.to(device).unsqueeze(0)
out = model(data).squeeze(0)
out = out.detach().permute(1,2,0).cpu().numpy()
out = out.clip(0, 255)
print(out)
result = Image.fromarray(out, mode="RGB")
result.save('/home/guozy/BISHE/MyNet/rebuild/downsample_to_origin_by_NN.jpg')

img = np.array(img, dtype=np.int8)
out = out.astype(np.int8)
downsample_to_origin = downsample_to_origin.astype(np.int8)

p1 = psnr(downsample_to_origin, img)
s1 = ssim(downsample_to_origin, img, channel_axis=2)  
m1 = mse(downsample_to_origin, img)
print('BiCubic下采样超分辨重建结果： psnr:{} , ssim:{}, mse:{}\n'.format(p1,s1,m1))

p2 = psnr(out, img)
s2 = ssim(out, img, channel_axis=2)
m2= mse(out, img)
print('NN下采样超分辨重建结果： psnr:{} , ssim:{}, mse:{}\n'.format(p2,s2,m2))

[[[0.14517902 0.20639068 0.15247548]
  [0.1977173  0.19617209 0.24282092]
  [0.18045627 0.14961956 0.33742842]
  ...
  [0.21964769 0.24105231 0.24710858]
  [0.2553778  0.24047641 0.23833653]
  [0.24186611 0.17041352 0.1908362 ]]

 [[0.10942198 0.27130547 0.19052558]
  [0.11737037 0.12715462 0.1499329 ]
  [0.14469834 0.17699537 0.25030607]
  ...
  [0.319872   0.27019176 0.23707563]
  [0.21117492 0.28546563 0.28278354]
  [0.22077927 0.25930727 0.17926832]]

 [[0.16722828 0.2057205  0.21603441]
  [0.07372984 0.10948735 0.17512372]
  [0.16333146 0.20140061 0.21301165]
  ...
  [0.24573794 0.2492166  0.22074775]
  [0.25535467 0.2500736  0.21436815]
  [0.14055714 0.21458486 0.25274965]]

 ...

 [[0.17674953 0.18559152 0.3585888 ]
  [0.1675743  0.2111363  0.34699383]
  [0.14112453 0.22935608 0.38471603]
  ...
  [0.33552754 0.4430482  0.5547945 ]
  [0.36924222 0.3616642  0.44250605]
  [0.35424885 0.3458019  0.46039036]]

 [[0.16249882 0.19822973 0.32187292]
  [0.18673658 0.2814146  0.3407562 ]
