# SRGAN 简单测试

### 设置参数

In [7]:
import random

import numpy as np
import torch
from torch.backends import cudnn

# 为了让结果可复现，不太重要
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
# 定义GPU
device = torch.device("cuda", 0)
cudnn.benchmark = True
# 只在y通道测试
only_test_y_channel = True
# 生成器的类型
g_arch_name = "srresnet_x4"
# 模型参数，输入输出通道
in_channels = 3
out_channels = 3
channels = 64
num_rcb = 16
# 上采样率
upscale_factor = 4
mode = "test"
# 测试时进行标记，方便对结果分类管理
exp_name = "SRGAN_x4-Set5"

if mode == "test":
    # Test data address
    lr_dir = f"./data/Set5/LRbicx{upscale_factor}"
    sr_dir = f"./results/test/{exp_name}"
    gt_dir = f"./data/Set5/GTmod12"

    g_model_weights_path = f"./pretrained_models/SRGAN_x4-ImageNet-8c4a7569.pth.tar"

    


### 计算PSNR

In [12]:
from torchvision import transforms
from PIL import Image
def calc_psnr(img1, img2):
    # 需要先变成pil，然后再转换tensor
    img1 = Image.fromarray(img1)
    img2 = Image.fromarray(img2)
    img1 = transforms.ToTensor()(img1).unsqueeze(0)
    img2 = transforms.ToTensor()(img2).unsqueeze(0)
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))

In [13]:
import os
import cv2
import torch
import model
from natsort import natsorted
from numpy import ndarray

def image_to_tensor(image: ndarray, range_norm: bool, half: bool):
    # 转化成tensor
    tensor = torch.from_numpy(np.ascontiguousarray(image)).permute(2, 0, 1).float()

    # 范围从[0, 1] 转化到 [-1, 1]
    if range_norm:
        tensor = tensor.mul(2.0).sub(1.0)

    # 半精度
    if half:
        tensor = tensor.half()

    return tensor

def tensor_to_image(tensor, range_norm: bool, half: bool):
    '''
    tensor转化成图像
    '''
    if range_norm:
        tensor = tensor.add(1.0).div(2.0)
    if half:
        tensor = tensor.half()

    image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("uint8")

    return image


def preprocess_one_image(image_path: str, device: torch.device):
    image = cv2.imread(image_path).astype(np.float32) / 255.0

    # 色彩空间转换
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # 预处理
    tensor = image_to_tensor(image, False, False).unsqueeze_(0)

    # 加载到GPU
    tensor = tensor.to(device=device, memory_format=torch.channels_last, non_blocking=True)

    return tensor

model_names = sorted(
    name for name in model.__dict__ if
    name.islower() and not name.startswith("__") and callable(model.__dict__[name]))


def main():
    # 初始化
    g_model = model.__dict__[g_arch_name](in_channels=in_channels,
                                                       out_channels=out_channels,
                                                       channels=channels,
                                                       num_rcb=num_rcb)
    g_model = g_model.to(device=device)
    print(f"Build `{g_arch_name}` model successfully.")

    # 加载预训练模型
    checkpoint = torch.load(g_model_weights_path, map_location=lambda storage, loc: storage)
    g_model.load_state_dict(checkpoint["state_dict"])
    print(f"Load `{g_arch_name}` model weights "
          f"`{os.path.abspath(g_model_weights_path)}` successfully.")
    # 测试模式
    g_model.eval()

    # 按照文件名的正常语序进行排序
    file_names = natsorted(os.listdir(lr_dir))
    # Get the number of test image files.
    total_files = len(file_names)

    for index in range(total_files):
        # 低分辨率图像，输入模型
        lr_image_path = os.path.join(lr_dir, file_names[index])
        # 模型结果的保存路径
        sr_image_path = os.path.join(sr_dir, file_names[index])
        # 标签（即原图）的路径
        gt_image_path = os.path.join(gt_dir, file_names[index])

        print(f"Processing `{os.path.abspath(lr_image_path)}`...")
        lr_tensor = preprocess_one_image(lr_image_path, device)
        gt_tensor = preprocess_one_image(gt_image_path, device)

        # 只重建Y通道
        with torch.no_grad():
            sr_tensor = g_model(lr_tensor)

        # 保存
        sr_image = tensor_to_image(sr_tensor, False, False)
        sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR)
        print('save path:', sr_image_path)
        cv2.imwrite(sr_image_path, sr_image)

        # 计算PSNR
        gt = cv2.imread(gt_image_path)
        pred = sr_image
        print('image:', gt_image_path, 'PSNR:', calc_psnr(gt, pred))

if __name__ == "__main__":
    main()


Build `srresnet_x4` model successfully.
Load `srresnet_x4` model weights `d:\Code\CodeLearning\cvtest-git\cvtest\计算机视觉实践-练习3\SRGAN-test\pretrained_models\SRGAN_x4-ImageNet-8c4a7569.pth.tar` successfully.
Processing `d:\Code\CodeLearning\cvtest-git\cvtest\计算机视觉实践-练习3\SRGAN-test\data\Set5\LRbicx4\baby.png`...
save path: ./results/test/SRGAN_x4-Set5\baby.png
image: ./data/Set5/GTmod12\baby.png PSNR: tensor(30.6421)
Processing `d:\Code\CodeLearning\cvtest-git\cvtest\计算机视觉实践-练习3\SRGAN-test\data\Set5\LRbicx4\bird.png`...
save path: ./results/test/SRGAN_x4-Set5\bird.png
image: ./data/Set5/GTmod12\bird.png PSNR: tensor(29.8084)
Processing `d:\Code\CodeLearning\cvtest-git\cvtest\计算机视觉实践-练习3\SRGAN-test\data\Set5\LRbicx4\butterfly.png`...
save path: ./results/test/SRGAN_x4-Set5\butterfly.png
image: ./data/Set5/GTmod12\butterfly.png PSNR: tensor(25.2556)
Processing `d:\Code\CodeLearning\cvtest-git\cvtest\计算机视觉实践-练习3\SRGAN-test\data\Set5\LRbicx4\head.png`...
save path: ./results/test/SRGAN_x4-Set5\