In [1]:
import numpy as np
import cv2
import os
import gc
import math

from PIL import Image
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn
import torch.optim
from torch.utils.data import DataLoader, Dataset


In [2]:
def CFA(pic: np.ndarray):
    #h, w -> height,width(cv2-format)
    h, w, _ = pic.shape
    RGGB = np.array([[[1, 0, 0], [0, 1, 0]], [[0, 1, 0], [0, 0, 1]]])
    #模板平铺的倍数
    time_h = int(np.ceil(h / 2))
    time_w = int(np.ceil(w / 2))
    #平铺模板
    CFA = np.tile(RGGB, (time_h, time_w, 1))
    CFA = CFA[:h, :w, :]
    #CFA模板滤波
    processed = pic * CFA
    return processed

In [3]:
#对四维Tensor进行CFA滤波，pattern为滤波模式字符串（RGGB, BGGR等）
def CFA_d4(pic: torch.Tensor, pattern: str):
    #b, c, h, w -> batch_size, channel, height, width
    b, c, h, w = pic.shape
    pic = pic.cuda()
    
    processed = torch.zeros(pic.shape)
    processed = processed.cuda()
    
    RGGB = np.array([[[1, 0, 0], [0, 1, 0]], [[0, 1, 0], [0, 0, 1]]])
    
    #模板平铺的倍数
    time_h = int(np.ceil(h / 2))
    time_w = int(np.ceil(w / 2))
    
    #tiled RGGB
    CFA = np.tile(RGGB, (time_h, time_w, 1))
    processed2 = torch.clone(pic)
    processed2 = processed2.cuda()
   
    CFA = CFA[:h, :w, :]
    #CFA -> h*w*3,RGB
    CFA3 = CFA.transpose((2,0,1))
    #CFA3 -> 3*h*w,RGB
    # numpy 转 tensor
    CFA3 = torch.from_numpy(CFA3)
    CFA3 = CFA3.cuda()
    
    for i in range(b):
        if pattern == "BGGR":
            #translate BGGR -> RGGB
            processed2[i] = torch.roll(processed2[i],shifts=(-1, -1), dims = (1,2))
        elif pattern == "GBRG":
            #translate GBRG -> RGGB
            processed2[i] = torch.roll(processed2[i],shifts=(-1, 0), dims = (1,2))
        elif pattern == "GRBG":
            #translate GRBG -> RGGB
            processed2[i] = torch.roll(processed2[i],shifts=(0, -1), dims = (1,2))
    for i in range(b):
        processed2[i] = processed2[i] * CFA3
        
    return processed2


In [4]:
# img_o为输入，fil为卷积核，进行卷积，实现双线性插值
def my_fil(pic):
    h, w, _ = pic.shape
    #RGB -> BGR, to split
    pic = pic[:,:,::-1]
    #解决内存不连续问题
    pic = pic.copy()
    [B, G, R] = cv2.split(pic)
    filter = np.array([[[0.25, 0., 0.25],
                        [0.5, 0.25, 0.5],
                        [0.25, 0., 0.25]],

                       [[0.5, 0.25, 0.5],
                        [1., 1., 1.],
                        [0.5, 0.25, 0.5]],

                       [[0.25, 0., 0.25],
                        [0.5, 0.25, 0.5],
                        [0.25, 0., 0.25]]])

    B_fil = cv2.filter2D(B, -1, kernel=filter[:, :, 2])
    G_fil = cv2.filter2D(G, -1, kernel=filter[:, :, 1])
    R_fil = cv2.filter2D(R, -1, kernel=filter[:, :, 0])

    pic_new = cv2.merge([B_fil, G_fil, R_fil])
    #BGR -> RGB
    pic_new = pic_new[:,:,::-1]
    #解决内存不连续问题
    pic_new = pic_new.copy()
    return pic_new

def my_fil_d4(pic):
    #b, c, h, w -> batch_size, channel, height, width
    b, c, h, w = pic.shape
    filter = np.array([[[0.25, 0., 0.25],
                        [0.5, 0.25, 0.5],
                        [0.25, 0., 0.25]],

                       [[0.5, 0.25, 0.5],
                        [1., 1., 1.],
                        [0.5, 0.25, 0.5]],

                       [[0.25, 0., 0.25],
                        [0.5, 0.25, 0.5],
                        [0.25, 0., 0.25]]])
    
    # pic转numpy
    pic = pic.detach().cpu().numpy() if pic.requires_grad else pic.cpu().numpy()
    pic3 = pic
    for i in range(b):
        pic2 = pic[i]
        #CHW -> HWC, to split
        pic2 = pic2.transpose((1, 2, 0))
        #RGB -> BGR, to split
        pic2 = pic2[:,:,::-1]
        #解决内存不连续问题
        pic2 = pic2.copy()
        
        [B, G, R] = cv2.split(pic2)
        
        B_fil = cv2.filter2D(B, -1, kernel=filter[:, :, 2])
        G_fil = cv2.filter2D(G, -1, kernel=filter[:, :, 1])
        R_fil = cv2.filter2D(R, -1, kernel=filter[:, :, 0])
        
        pic_new = cv2.merge([B_fil, G_fil, R_fil])
        #BGR -> RGB
        pic_new = pic_new[:,:,::-1]
        #解决内存不连续问题
        pic_new = pic_new.copy()
        #HWC -> CHW
        pic_new = pic_new.transpose((2, 0, 1))
        pic3[i] = pic_new
    pic3 = torch.from_numpy(pic3)
    pic3 = pic3.cuda()
    return pic3
        
        

In [5]:
class Datasets(Dataset):
    def __init__(self, data, label, transform=None):
        self.data = data
        self.label = label
        self.transform = transform

    def __getitem__(self, index):
        data = self.data[index]
        label = self.label[index]
        if self.transform is not None:
            data = self.transform(data)
            label = self.transform(label)
        return data, label

    def __len__(self):
        return len(self.data)

In [6]:
def train_data_get(data_path):
    img_list = []
    labels = []
    imgs = os.listdir(data_path)

    for i in range(len(imgs)):
        path = data_path + imgs[i]
        temp = cv2.imread(path)
        #BGR -> RGB
        temp = temp[:,:,::-1]
        #解决内存不连续问题
        temp = temp.copy()
        temp = np.float64(temp)
        h, w, _ = temp.shape

        temp = temp[0:(h - h%16), 0:(w-w%16), :]  # 避免上下采样时出现维度不匹配问题

        #归一化
        temp = temp / 255.0
        
        temp_label = temp
        temp = CFA(temp)
        temp = my_fil(temp)
        
        img_list.append(temp)
        labels.append(temp_label)

    train_list, train_label = img_list, labels
    print("Train data get complete.")
    return train_list, train_label

def val_data_get(data_path):
    img_list = []
    labels = []
    imgs = os.listdir(data_path)

    for i in range(len(imgs)):
        path = data_path + imgs[i]
        temp = cv2.imread(path)
        #BGR -> RGB
        temp = temp[:,:,::-1]
        #解决内存不连续问题
        temp = temp.copy()
        temp = np.float64(temp)
        h, w, _ = temp.shape

        temp = temp[0:(h - h%16), 0:(w-w%16), :]  # 避免上下采样时出现维度不匹配问题
        
        #归一化
        temp = temp / 255.0
        
        temp_label = temp
        temp = CFA(temp)
        temp = my_fil(temp)
        
        img_list.append(temp)
        labels.append(temp_label)
        
    val_list, val_label = img_list, labels
    print("Validation data get complete.")
    return val_list, val_label

def Gehler_Shi_test_data_get(data_path):
    img_list = []
    labels = []
    imgs = os.listdir(data_path)

    for i in range(len(imgs)):
        
        path = data_path + imgs[i]
        temp = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        #BGR -> RGB
        temp = temp[:,:,::-1]
        #解决内存不连续问题
        temp = temp.copy()
        temp = np.float64(temp)

        # 去除black level(对于CANON5D的图片)
        h, w, _ = temp.shape
        if h == 2193 or h == 1460:
            temp = np.maximum(0., temp - 129.)

        temp = temp / 4095.  # 归一化(参考代码中是除以最大值，或许可以考虑除以4095)
        h, w, _ = temp.shape

        #对Gehler_Shi测试集不进行切分
        #将图像的大小缩小为16的倍数以适应UNET
        #为了让数据大小统一，切一个统一的大小（设为1344*1344）
        temp = temp[0:1344, 0:1344, :]
        
        #label为原值，input经过CFA滤波并双线性插值
        temp_label = temp
        temp = np.float64(CFA(temp))
        temp = my_fil(temp)

        img_list.append(temp)
        labels.append(temp_label)

    test_list, test_label = img_list, labels
    print("Gehler-Shi test data get complete.")
    return test_list, test_label

In [7]:
transform = transforms.Compose([transforms.ToTensor()])
train_batch_size = 1
train_number_epoch = 2

train_dir1 = "/kaggle/input/moiretrain10000/"
val_dir = "/kaggle/input/moireval/"
test_dir = "/kaggle/input/gehlertest10/"

train_list1, train_label1 = train_data_get(train_dir1)
val_list, val_label = val_data_get(val_dir)
test_list, test_label = Gehler_Shi_test_data_get(val_dir)

trainset1 = Datasets(data=train_list1, label=train_label1, transform=transform)
trainloader1 = DataLoader(trainset1, batch_size=train_batch_size, shuffle=True)
valset = Datasets(data=val_list, label=val_label, transform=transform)
valloader = DataLoader(valset, batch_size=1, shuffle=True)
testset = Datasets(data=test_list, label=test_label, transform=transform)
testloader = DataLoader(testset, batch_size=1, shuffle=True)

Train data get complete.
Validation data get complete.
Gehler-Shi test data get complete.


In [8]:
# common

class BasicBlock(nn.Sequential):
    def __init__(
            self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,
            bn=False, act=nn.PReLU()):

        m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
        if bn:
            m.append(nn.BatchNorm2d(out_channels))
        if act is not None:
            m.append(act)

        super(BasicBlock, self).__init__(*m)


def default_conv(in_channels, out_channels, kernel_size, stride=1, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size // 2), stride=stride, bias=bias)


class MeanShift(nn.Conv2d):
    def __init__(
            self, rgb_range,
            rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False


class ResBlock(nn.Module):
    def __init__(
            self, conv, n_feats, kernel_size,
            bias=True, bn=False, act=nn.PReLU(), res_scale=1):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res


class Upsampler(nn.Sequential):
    def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):

        m = []
        if (scale & (scale - 1)) == 0:  # Is scale = 2^n?
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feats, 4 * n_feats, 3, bias))
                m.append(nn.PixelShuffle(2))
                if bn:
                    m.append(nn.BatchNorm2d(n_feats))
                if act == 'relu':
                    m.append(nn.ReLU(True))
                elif act == 'prelu':
                    m.append(nn.PReLU(n_feats))

        elif scale == 3:
            m.append(conv(n_feats, 9 * n_feats, 3, bias))
            m.append(nn.PixelShuffle(3))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if act == 'relu':
                m.append(nn.ReLU(True))
            elif act == 'prelu':
                m.append(nn.PReLU(n_feats))
        else:
            raise NotImplementedError

        super(Upsampler, self).__init__(*m)


In [9]:
# tools

def normalize(x):
    return x.mul_(2).add_(-1)

def same_padding(images, ksizes, strides, rates):
    assert len(images.size()) == 4
    batch_size, channel, rows, cols = images.size()
    out_rows = (rows + strides[0] - 1) // strides[0]
    out_cols = (cols + strides[1] - 1) // strides[1]
    effective_k_row = (ksizes[0] - 1) * rates[0] + 1
    effective_k_col = (ksizes[1] - 1) * rates[1] + 1
    padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
    padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
    # Pad the input
    padding_top = int(padding_rows / 2.)
    padding_left = int(padding_cols / 2.)
    padding_bottom = padding_rows - padding_top
    padding_right = padding_cols - padding_left
    paddings = (padding_left, padding_right, padding_top, padding_bottom)
    images = torch.nn.ZeroPad2d(paddings)(images)
    return images


def extract_image_patches(images, ksizes, strides, rates, padding='same'):
    """
    Extract patches from images and put them in the C output dimension.
    :param padding:
    :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
    :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
     each dimension of images
    :param strides: [stride_rows, stride_cols]
    :param rates: [dilation_rows, dilation_cols]
    :return: A Tensor
    """
    assert len(images.size()) == 4
    assert padding in ['same', 'valid']
    batch_size, channel, height, width = images.size()
    
    if padding == 'same':
        images = same_padding(images, ksizes, strides, rates)
    elif padding == 'valid':
        pass
    else:
        raise NotImplementedError('Unsupported padding type: {}.\
                Only "same" or "valid" are supported.'.format(padding))

    unfold = torch.nn.Unfold(kernel_size=ksizes,
                             dilation=rates,
                             padding=0,
                             stride=strides)
    patches = unfold(images)
    return patches  # [N, C*k*k, L], L is the total number of such blocks
def reduce_mean(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.mean(x, dim=i, keepdim=keepdim)
    return x


def reduce_std(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.std(x, dim=i, keepdim=keepdim)
    return x


def reduce_sum(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.sum(x, dim=i, keepdim=keepdim)
    return x



In [10]:
# attention

class PyramidAttention(nn.Module):
    def __init__(self, level=5, res_scale=1, channel=64, reduction=2, ksize=3, stride=1, softmax_scale=10, average=True,
                 conv=default_conv):
        super(PyramidAttention, self).__init__()
        self.ksize = ksize
        self.stride = stride
        self.res_scale = res_scale
        self.softmax_scale = softmax_scale
        self.scale = [1 - i / 10 for i in range(level)]
        self.average = average
        escape_NaN = torch.FloatTensor([1e-4])
        self.register_buffer('escape_NaN', escape_NaN)
        self.conv_match_L_base = BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU())
        self.conv_match = BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU())
        self.conv_assembly = BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU())

    def forward(self, input):
        res = input
        # theta
        match_base = self.conv_match_L_base(input)
        shape_base = list(res.size())
        input_groups = torch.split(match_base, 1, dim=0)
        # patch size for matching
        kernel = self.ksize
        # raw_w is for reconstruction
        raw_w = []
        # w is for matching
        w = []
        # build feature pyramid
        for i in range(len(self.scale)):
            ref = input
            if self.scale[i] != 1:
                ref = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic')
            # feature transformation function f
            base = self.conv_assembly(ref)
            shape_input = base.shape
            # sampling
            raw_w_i = extract_image_patches(base, ksizes=[kernel, kernel],
                                            strides=[self.stride, self.stride],
                                            rates=[1, 1],
                                            padding='same')  # [N, C*k*k, L]
            raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1)
            raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3)  # raw_shape: [N, L, C, k, k]
            raw_w_i_groups = torch.split(raw_w_i, 1, dim=0)
            raw_w.append(raw_w_i_groups)

            # feature transformation function g
            ref_i = self.conv_match(ref)
            shape_ref = ref_i.shape
            # sampling
            w_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize],
                                        strides=[self.stride, self.stride],
                                        rates=[1, 1],
                                        padding='same')
            w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)
            w_i = w_i.permute(0, 4, 1, 2, 3)  # w shape: [N, L, C, k, k]
            w_i_groups = torch.split(w_i, 1, dim=0)
            w.append(w_i_groups)

        y = []
        for idx, xi in enumerate(input_groups):
            # group in a filter
            wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))], dim=0)  # [L, C, k, k]
            # normalize
            max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
                                                     axis=[1, 2, 3],
                                                     keepdim=True)),
                               self.escape_NaN)
            wi_normed = wi / max_wi
            # matching
            xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1])  # xi: 1*c*H*W
            yi = F.conv2d(xi, wi_normed, stride=1)  # [1, L, H, W] L = shape_ref[2]*shape_ref[3]
            yi = yi.view(1, wi.shape[0], shape_base[2], shape_base[3])  # (B=1, C=32*32, H=32, W=32)
            # softmax matching score
            yi = F.softmax(yi * self.softmax_scale, dim=1)

            if not self.average:
                yi = (yi == yi.max(dim=1, keepdim=True)[0]).float()

            # deconv for patch pasting
            raw_wi = torch.cat([raw_w[i][idx][0] for i in range(len(self.scale))], dim=0)
            yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride, padding=1) / 4.
            y.append(yi)

        y = torch.cat(y, dim=0) + res * self.res_scale  # back to the mini-batch
        return y


In [11]:
#panet

def make_model():
    return PANET()

class PANET(nn.Module):
    def __init__(self, conv=default_conv):
        super(PANET, self).__init__()

        n_resblocks = 10
        res_scale = 1
        n_feats = 64
        kernel_size = 3
        n_colors = 3

        rgb_range = 1
        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std)
        msa = PyramidAttention()
        # define head module
        m_head = [conv(n_colors, n_feats, kernel_size)]

        # define body module
        m_body = [
            ResBlock(
                conv, n_feats, kernel_size, nn.PReLU(), res_scale=res_scale
            ) for _ in range(n_resblocks // 2)
        ]
        m_body.append(msa)
        for i in range(n_resblocks // 2):
            m_body.append(ResBlock(conv, n_feats, kernel_size, nn.PReLU(), res_scale=res_scale))

        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        # m_tail = [
        #    common.Upsampler(conv, scale, n_feats, act=False),
        #    conv(n_feats, args.n_colors, kernel_size)
        # ]
        m_tail = [
            conv(n_feats, n_colors, kernel_size)
        ]

        self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1)

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)

    def forward(self, x):
        # x = self.sub_mean(x)
        x = self.head(x)

        res = self.body(x)

        res += x

        x = self.tail(res)
        # x = self.add_mean(x)

        return x

    def load_state_dict(self, state_dict, strict=True):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') == -1:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))


In [12]:
net = PANET().cuda()
loss_func = nn.MSELoss()
best_train_loss = float('inf')
best_val_loss = float('inf')
lr = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-4)
train_dataloader = trainloader1
val_dataloader = valloader
test_dataloader = testloader
epochs = train_number_epoch

In [13]:
def Pre_train(trainloader, valloader, model, loss_func, optimizer):
    global best_train_loss
    global best_val_loss
    device = torch.device('cuda')
    model = model.to(device)
    for epoch in range(0, epochs):
        # train part
        total_train_loss = 0
        counter1 = 0
        for i, data in enumerate(trainloader, 0):
            input_pic, label = data
            #转移到GPU上，类型变为float以与模型参数类型匹配
            input_pic, label = input_pic.to(device, dtype=torch.float), label.to(device, dtype=torch.float)
            
            
            # output 已经双线性插值 进网络重建，得到三通道图片
            output = model(input_pic)

            optimizer.zero_grad()
            
            loss = loss_func(output, label)
            loss.requires_grad_(True)
            loss.backward()
            
            optimizer.step()
            total_train_loss += loss.item()
            counter1 += 1

        # val part
        model.eval()
        total_val_loss = 0
        
        #初始化三通道PSNR值与CPSNR值
        total_psnr_r = 0
        total_psnr_g = 0
        total_psnr_b = 0
        total_cpsnr = 0
        
        counter2 = 0
        
        with torch.no_grad():
            for j, data2 in enumerate(valloader, 0):
                input_pic2, label2 = data2
                
                #CFA滤波与双线性插值的过程已在val_data_get函数中完成
                input_pic2, label2 = input_pic2.to(device, dtype=torch.float), label2.to(device, dtype=torch.float)
                # 进网络
                
                output2 = model(input_pic2)
                
                # RGGB采样
                output2 = output2.cuda()
                
                loss2 = loss_func(output2, label2)
                
                total_val_loss += loss2.item()
                
                #为计算PSNR部分而进行MSE的计算
                
                #RGB三通道的MSE值
                
                mse_r = loss_func(output2[:, 0, :, :],
                                 label2[:, 0, :, :])
                mse_g = loss_func(output2[:, 1, :, :],
                                 label2[:, 1, :, :])
                mse_b = loss_func(output2[:, 2, :, :],
                                 label2[:, 2, :, :])
                #CPSNR，即三通道PSNR值的平均值
                cmse = loss_func(output2[:, :, :, :],
                                 label2[:, :, :, :])
                
                #计算所有图片RGB三通道的PSNR值及CPSNR之和
                if mse_r.item() < 1.0e-10:  # 均方误差小到过分了
                    total_psnr_r += 100
                else:
                    total_psnr_r += 20 * math.log10(1 / math.sqrt(mse_r.item()))
                
                if mse_g.item() < 1.0e-10:  # 均方误差小到过分了
                    total_psnr_g += 100
                else:
                    total_psnr_g += 20 * math.log10(1 / math.sqrt(mse_g.item()))
                
                if mse_b.item() < 1.0e-10:  # 均方误差小到过分了
                    total_psnr_b += 100
                else:
                    total_psnr_b += 20 * math.log10(1 / math.sqrt(mse_b.item()))
                
                if cmse.item() < 1.0e-10:  # 均方误差小到过分了
                    total_cpsnr += 100
                else:
                    total_cpsnr += 20 * math.log10(1 / math.sqrt(cmse.item()))
            
                counter2 += 1
        
                
            train_loss_avg = total_train_loss / counter1
            val_loss_avg = total_val_loss / counter2
            
            #计算所有图片RGB三通道的PSNR值及CPSNR之均值
            psnr_r_avg = total_psnr_r / counter2
            psnr_g_avg = total_psnr_g / counter2
            psnr_b_avg = total_psnr_b / counter2
            cpsnr_avg = total_cpsnr / counter2
            
            print("Epoch number: {} , Train loss: {:.4f}, Val loss: {:.4f}, \nPSNR_R: {:.4f}, PSNR_G: {:.4f}, PSNR_B: {:.4f}, CPSNR: {:.4f}\n".format(epoch, train_loss_avg, val_loss_avg, psnr_r_avg, psnr_g_avg, psnr_b_avg, cpsnr_avg))        

        if val_loss_avg < best_val_loss:
            best_val_loss = val_loss_avg
            torch.save(net.state_dict(), '/kaggle/working/best_pretrain_panet.pth')
            

In [14]:
# To restore Panet to the last checkpoint
net.load_state_dict(torch.load('/kaggle/input/best-pretrain-panet-3epoch/best_pretrain_panet_3epoch.pth'))

In [15]:
Pre_train(train_dataloader,val_dataloader, net, loss_func, optimizer)

Epoch number: 0 , Train loss: 0.0011, Val loss: 0.0014, 
PSNR_R: 29.2178, PSNR_G: 33.1509, PSNR_B: 29.5872, CPSNR: 30.1326

Epoch number: 1 , Train loss: 0.0012, Val loss: 0.0014, 
PSNR_R: 29.2142, PSNR_G: 31.8072, PSNR_B: 29.5512, CPSNR: 29.8695



In [16]:
def Pre_test(testloader, model, loss_func):
    global best_test_loss
    device = torch.device('cuda')
    model = model.to(device)

    model.eval()
    total_test_loss = 0

    #初始化三通道PSNR值与CPSNR值
    total_psnr_r = 0
    total_psnr_g = 0
    total_psnr_b = 0
    total_cpsnr = 0

    counter = 0

    with torch.no_grad():
        for j, data in enumerate(testloader, 0):
            input_pic, label = data

            #CFA滤波与双线性插值的过程已在Gehler_Shi_test_data_get函数中完成
            input_pic, label = input_pic.to(device, dtype=torch.float), label.to(device, dtype=torch.float)
            # 进网络

            output = model(input_pic)

            # RGGB采样
            output = output.cuda()

            loss = loss_func(output, label)

            total_test_loss += loss.item()

            #为计算PSNR部分而进行MSE的计算

            #设置去除边界的宽度
            bound = 1

            #RGB三通道的MSE值

            mse_r = loss_func(output[:, 0, bound:-bound,bound:-bound],
                             label[:, 0, bound:-bound,bound:-bound])
            mse_g = loss_func(output[:, 1, bound:-bound,bound:-bound],
                             label[:, 1, bound:-bound,bound:-bound])
            mse_b = loss_func(output[:, 2, bound:-bound,bound:-bound],
                             label[:, 2, bound:-bound,bound:-bound])
            #CPSNR，即三通道PSNR值的平均值
            cmse = loss_func(output[:, :, bound:-bound,bound:-bound],
                             label[:, :, bound:-bound,bound:-bound])

            #计算所有图片RGB三通道的PSNR值及CPSNR之和
            if mse_r.item() < 1.0e-10:  # 均方误差小到过分了
                total_psnr_r += 100
            else:
                total_psnr_r += 20 * math.log10(1 / math.sqrt(mse_r.item()))

            if mse_g.item() < 1.0e-10:  # 均方误差小到过分了
                total_psnr_g += 100
            else:
                total_psnr_g += 20 * math.log10(1 / math.sqrt(mse_g.item()))

            if mse_b.item() < 1.0e-10:  # 均方误差小到过分了
                total_psnr_b += 100
            else:
                total_psnr_b += 20 * math.log10(1 / math.sqrt(mse_b.item()))

            if cmse.item() < 1.0e-10:  # 均方误差小到过分了
                total_cpsnr += 100
            else:
                total_cpsnr += 20 * math.log10(1 / math.sqrt(cmse.item()))

            counter += 1

        test_loss_avg = total_test_loss / counter

        #计算所有图片RGB三通道的PSNR值及CPSNR之均值
        psnr_r_avg = total_psnr_r / counter
        psnr_g_avg = total_psnr_g / counter
        psnr_b_avg = total_psnr_b / counter
        cpsnr_avg = total_cpsnr / counter

        print("Test loss: {:.4f}, \nPSNR_R: {:.4f}, PSNR_G: {:.4f}, PSNR_B: {:.4f}, CPSNR: {:.4f}\n".format(test_loss_avg, psnr_r_avg, psnr_g_avg, psnr_b_avg, cpsnr_avg))        


In [17]:
Pre_test(test_dataloader, net, loss_func)

Test loss: 0.0001, 
PSNR_R: 42.4004, PSNR_G: 39.8077, PSNR_B: 42.6564, CPSNR: 41.4118

