In [40]:
import torch
from torch import nn
import h5py
import numpy as np
from torch.utils.data import Dataset
import argparse
import os
import cv2
import copy
 
from torch import Tensor
import torch.optim as optim
 
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

import PIL.Image as pil_image
from skimage.metrics import structural_similarity as compare_ssim

### 搭建SRCNN 3层卷积模型，Conve2d（输入层数，输出层数，卷积核大小，步长，填充层）

In [16]:
class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
        self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)
 
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

### 分别构建Train和Val的Dataloader

In [38]:
class TrainDataset(Dataset):
    def __init__(self, h5_file):
        super(TrainDataset, self).__init__()
        self.h5_file = h5_file
 
    def __getitem__(self, idx): #通过np.expand_dims方法得到组合的新数据
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)
 
    def __len__(self):   #得到数据大小
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])
 
# 与TrainDataset类似
class EvalDataset(Dataset):
    def __init__(self, h5_file):
        super(EvalDataset, self).__init__()
        self.h5_file = h5_file
 
    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)
 
    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

### 定义一些工具函数

In [18]:
"""
       只操作y通道
       因为我们感兴趣的不是颜色变化(存储在 CbCr 通道中的信息)而只是其亮度(Y 通道);
       根本原因在于相较于色差，人类视觉对亮度变化更为敏感。
"""
def convert_rgb_to_y(img):
    if type(img) == np.ndarray:
        return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
    else:
        raise Exception('Unknown Type', type(img))
 
"""
        RGB转YCBCR
        Y=0.257*R+0.564*G+0.098*B+16
        Cb=-0.148*R-0.291*G+0.439*B+128
        Cr=0.439*R-0.368*G-0.071*B+128
"""
def convert_rgb_to_ycbcr(img):
    if type(img) == np.ndarray:
        y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
        cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
        cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
        return np.array([y, cb, cr]).transpose([1, 2, 0])
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
        cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
        cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
        return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
    else:
        raise Exception('Unknown Type', type(img))
 
"""
        YCBCR转RGB
        R=1.164*(Y-16)+1.596*(Cr-128)
        G=1.164*(Y-16)-0.392*(Cb-128)-0.813*(Cr-128)
        B=1.164*(Y-16)+2.017*(Cb-128)
"""
def convert_ycbcr_to_rgb(img):
    if type(img) == np.ndarray:
        r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
        g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
        b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
        return np.array([r, g, b]).transpose([1, 2, 0])
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
        g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
        b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
        return torch.cat([r, g, b], 0).permute(1, 2, 0)
    else:
        raise Exception('Unknown Type', type(img))
 
# PSNR 计算
def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
 
# 计算 平均数，求和，长度
class AverageMeter(object):
    def __init__(self):
        self.reset()
 
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
 
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

### 开始训练

In [23]:
# 初始参数设定
train_file='./dataset/91-image_x2.h5'
eval_file='./dataset/Set5_x2.h5'
output_dir='./'
scale=2
lr=1e-4
batch_size=16
num_workers=0
num_epochs=400
seed=123

# 输出放入固定文件夹里
outputs_dir = os.path.join(output_dir, 'x{}'.format(scale))
print(outputs_dir)
# 没有该文件夹就新建一个文件夹
if not os.path.exists(output_dir):
    print(outputs_dir)
    print()
    os.makedirs(output_dir)

cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(seed)

model = SRCNN().to(device)

# 恢复训练，从之前结束的那个地方开始
# model.load_state_dict(torch.load('outputs/x3/epoch_173.pth'))

# 设置损失函数为MSE
criterion = nn.MSELoss()

# 优化函数Adam，lr代表学习率，
optimizer = optim.Adam([
    {'params': model.conv1.parameters()},
    {'params': model.conv2.parameters()},
    {'params': model.conv3.parameters(), 'lr': lr * 0.1}
], lr=lr)

# 预处理训练集
train_dataset = TrainDataset(train_file)
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=True)
# 预处理验证集
eval_dataset = EvalDataset(eval_file)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

# 拷贝权重
best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

lossLog = []
psnrLog = []

for epoch in range(1, num_epochs + 1):
    # 模型训练入口
    model.train()

    # 变量更新，计算epoch平均损失
    epoch_losses = AverageMeter()

    # 进度条
    with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
        t.set_description('epoch:{}/{}'.format(epoch, num_epochs))

        # 每个batch计算一次
        for data in train_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)
            # 送入模型训练
            preds = model(inputs)
            # 获得损失
            loss = criterion(preds, labels)

            # 显示损失值与长度
            epoch_losses.update(loss.item(), len(inputs))

            # 梯度清零
            optimizer.zero_grad()

            # 反向传播
            loss.backward()

            # 更新参数
            optimizer.step()
            t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
            t.update(len(inputs))
    # 记录lossLog 方面画图
    lossLog.append(np.array(epoch_losses.avg))
    np.savetxt("lossLog.txt", lossLog)
    # 保存模型
    torch.save(model.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))

    model.eval()
    epoch_psnr = AverageMeter()

    for data in eval_dataloader:
        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            preds = model(inputs).clamp(0.0, 1.0)
        epoch_psnr.update(calc_psnr(preds, labels), len(inputs))
    print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

    # 记录psnr
    psnrLog.append(Tensor.cpu(epoch_psnr.avg))
    np.savetxt('psnrLog.txt', psnrLog)
    # 找到更好的权重参数，更新
    if epoch_psnr.avg > best_psnr:
        best_epoch = epoch
        best_psnr = epoch_psnr.avg
        best_weights = copy.deepcopy(model.state_dict())

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))

    torch.save(best_weights, os.path.join(outputs_dir, 'best.pth'))

print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))

torch.save(best_weights, os.path.join(outputs_dir, 'best.pth'))

./x2


epoch:1/400: 100%|███████████████████████████████████████████████| 21904/21904 [01:19<00:00, 276.13it/s, loss=0.003701]


eval psnr: 34.66
best epoch: 1, psnr: 34.66


epoch:2/400: 100%|███████████████████████████████████████████████| 21904/21904 [01:20<00:00, 271.10it/s, loss=0.000784]


eval psnr: 34.85
best epoch: 2, psnr: 34.85


epoch:3/400: 100%|███████████████████████████████████████████████| 21904/21904 [01:20<00:00, 271.79it/s, loss=0.000730]


eval psnr: 35.17
best epoch: 3, psnr: 35.17


epoch:4/400: 100%|███████████████████████████████████████████████| 21904/21904 [01:19<00:00, 275.93it/s, loss=0.000708]


eval psnr: 35.30
best epoch: 4, psnr: 35.30


epoch:5/400:  78%|████████████████████████████████████▍          | 17008/21904 [01:01<00:17, 274.45it/s, loss=0.000695]


KeyboardInterrupt: 

### 开始测试

In [46]:
# 设置权重参数目录，处理图像目录，放大倍数
weights_file="best.pth"
image_file="imgs/2.jpg"
scale=2

cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = SRCNN().to(device)

state_dict = model.state_dict()  # 通过 model.state_dict()得到模型有哪些 parameters and persistent buffers
for n, p in torch.load(weights_file, map_location=lambda storage, loc: storage).items():
    if n in state_dict.keys():
        state_dict[n].copy_(p)
    else:
        raise KeyError(n)

model.eval()

image = pil_image.open(image_file).convert('RGB')

# 经过一个插值操作，首先将原始图片重设尺寸，使之可以被放大倍数scale整除
# 得到低分辨率图像Lr，即三次插值后的图像，同时保存输出
image_width = (image.width // scale) * scale
image_height = (image.height // scale) * scale
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
image = image.resize((image.width // scale, image.height // scale), resample=pil_image.BICUBIC)
image = image.resize((image.width * scale, image.height * scale), resample=pil_image.BICUBIC)
image.save(image_file.replace('.', '_bicubic_x{}.'.format(scale)))
# 将图像转化为数组类型，同时图像转为ycbcr类型
image = np.array(image).astype(np.float32)
ycbcr = convert_rgb_to_ycbcr(image)
# 得到 ycbcr中的 y 通道
y = ycbcr[..., 0]
y /= 255.  # 归一化处理
y = torch.from_numpy(y).to(device)
y = y.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
    preds = model(y).clamp(0.0, 1.0)

psnr = calc_psnr(y, preds)   # 计算y通道的psnr值
print('PSNR: {:.2f}'.format(psnr))  # 格式化输出PSNR值

preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

# 将img的数据格式由（channels,imagesize,imagesize）转化为（imagesize,imagesize,channels）,进行格式的转换后方可进行显示。
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])

output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
output = pil_image.fromarray(output) 
output.save(image_file.replace('.', '_srcnn_x{}.'.format(scale))) 

imgPathList=[]
imgPathList.append(image_file.replace('.', '_bicubic_x{}.'.format(scale)))
imgPathList.append(image_file.replace('.', '_srcnn_x{}.'.format(scale)))


imgList=[]
for imgPath in imgPathList:
    img = np.fromfile(imgPath, dtype=np.uint8)
    img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE)
    imgList.append(img)

ssim = compare_ssim(imgList[0], imgList[1], multichannel=True)
print('SSIM: {:.2f}'.format(ssim))  # ssim

PSNR: 36.92
SSIM: 0.97
