# SRCNN在Set5上的测试

In [1]:
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore",category=DeprecationWarning)


### 定义模型

In [2]:
class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
        self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

### 加载模型

In [3]:
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SRCNN().to(device)
model.load_state_dict(torch.load('model/srcnn_x3.pth'))
model.eval()

SRCNN(
  (conv1): Conv2d(1, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (conv2): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(32, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (relu): ReLU(inplace=True)
)

### 加载测试数据

In [4]:
import h5py
# 仅用于测试
test_imgs = []
with h5py.File('Set5/Set5_x3.h5','r') as f:
    # for fkey in f.keys():
        # print(f[fkey], fkey)
    hr_group = f["hr"]
    lr_group = f["lr"]
    for k in hr_group.keys():
        img = hr_group[k][:]
        test_imgs.append(img)


### 读取图像函数

In [5]:
def convert_rgb_to_ycbcr(img):
    '''
    色彩空间转换：RGB转换到YCbCr空间
    '''
    if type(img) == np.ndarray:
        # 根据已有的转换关系进行转换
        y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
        cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
        cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
        return np.array([y, cb, cr]).transpose([1, 2, 0])
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
        cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
        cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
        return torch.cat([y, cb, cr], 0).permute(1, 2, 0)

def convert_ycbcr_to_rgb(img):
    '''
    YCbCr空间转换到RGB空间
    '''
    if type(img) == np.ndarray:
        r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
        g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
        b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
        return np.array([r, g, b]).transpose([1, 2, 0])
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
        g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
        b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
        return torch.cat([r, g, b], 0).permute(1, 2, 0)
       
def readImg2ycbcr(img_path, scale=3):
    '''
    读取图像，返回用BICUBIC下采样后的图像，作为测试图像。
    同时返回原图，即gt，作为标签
    '''
    image = pil_image.open(img_path).convert('RGB')
    gt = image

    image_width = (image.width // scale) * scale
    image_height = (image.height // scale) * scale
    image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    image = image.resize((image.width // scale, image.height // scale), resample=pil_image.BICUBIC)
    image = image.resize((image.width * scale, image.height * scale), resample=pil_image.BICUBIC)
    image.save(img_path.replace('.', '_bicubic_x{}.'.format(scale)))
    
    image = np.array(image).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(image)
    return ycbcr, gt

### 计算PSNR指标

In [6]:
from torchvision import transforms
def calc_psnr(img1, img2):
    img1 = transforms.ToTensor()(img1).unsqueeze(0)
    img2 = transforms.ToTensor()(img2).unsqueeze(0)
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))

In [7]:
# 选取三幅图像做测试
img_paths = ['test_imgs/Set5/head.png', 'test_imgs/Set5/woman.png', 'test_imgs/Set5/bird.png', 'test_imgs/Set5/baby.png', 'test_imgs/Set5/butterfly.png']
scale = 3
for ip in img_paths:
    ycbcr, gt = readImg2ycbcr(ip)
    # 得到y通道
    y = ycbcr[..., 0]
    # 归一化
    y /= 255.
    y = torch.from_numpy(y).to(device)

    # 扩展通道维度和batch维度
    y = y.unsqueeze(0).unsqueeze(0)
    
    # 测试过程不需要求导
    with torch.no_grad():
        # 输出归一化
        preds = model(y).clamp(0.0, 1.0)

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
    # 格式转换
    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
    # 转换回rgb
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    output = output.resize(gt.size, resample=pil_image.Resampling.BICUBIC)
    psnr = calc_psnr(output, gt)
    print('image:', ip, 'PSNR: {:.2f}'.format(psnr))

    output.save(ip.replace('.', '_srcnn_x{}.'.format(scale)))

image: test_imgs/Set5/head.png PSNR: 30.47
image: test_imgs/Set5/woman.png PSNR: 29.77
image: test_imgs/Set5/bird.png PSNR: 32.57
image: test_imgs/Set5/baby.png PSNR: 33.48
image: test_imgs/Set5/butterfly.png PSNR: 26.19
