In [1]:
import numpy as np
import cv2
import os
import gc
import math

from PIL import Image
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn
import torch.optim
from torch.utils.data import DataLoader, Dataset

import tqdm

In [2]:
os.makedirs("kaggle/working/model/")

In [3]:
def CFA(pic: np.ndarray):
    #h, w -> height,width(cv2-format)
    h, w, _ = pic.shape
    RGGB = np.array([[[1, 0, 0], [0, 1, 0]], [[0, 1, 0], [0, 0, 1]]])
    #模板平铺的倍数
    time_h = int(np.ceil(h / 2))
    time_w = int(np.ceil(w / 2))
    #平铺模板
    CFA = np.tile(RGGB, (time_h, time_w, 1))
    CFA = CFA[:h, :w, :]
    #CFA模板滤波
    processed = pic * CFA
    return processed

In [4]:
#对四维Tensor进行CFA滤波，pattern为滤波模式字符串（RGGB, BGGR等）
def CFA_d4(pic: torch.Tensor, pattern: str):
    #b, c, h, w -> batch_size, channel, height, width
    b, c, h, w = pic.shape
    pic = pic.cuda()
    
    processed = torch.zeros(pic.shape)
    processed = processed.cuda()
    
    RGGB = np.array([[[1, 0, 0], [0, 1, 0]], [[0, 1, 0], [0, 0, 1]]])
    
    #模板平铺的倍数
    time_h = int(np.ceil(h / 2))
    time_w = int(np.ceil(w / 2))
    
    #tiled RGGB
    CFA = np.tile(RGGB, (time_h, time_w, 1))
    processed2 = torch.clone(pic)
    processed2 = processed2.cuda()
   
    CFA = CFA[:h, :w, :]
    #CFA -> h*w*3,RGB
    CFA3 = CFA.transpose((2,0,1))
    #CFA3 -> 3*h*w,RGB
    # numpy 转 tensor
    CFA3 = torch.from_numpy(CFA3)
    CFA3 = CFA3.cuda()
    
    for i in range(b):
        if pattern == "BGGR":
            #translate BGGR -> RGGB
            processed2[i] = torch.roll(processed2[i],shifts=(-1, -1), dims = (1,2))
        elif pattern == "GBRG":
            #translate GBRG -> RGGB
            processed2[i] = torch.roll(processed2[i],shifts=(-1, 0), dims = (1,2))
        elif pattern == "GRBG":
            #translate GRBG -> RGGB
            processed2[i] = torch.roll(processed2[i],shifts=(0, -1), dims = (1,2))
    for i in range(b):
        processed2[i] = processed2[i] * CFA3
        
    return processed2


In [5]:
# img_o为输入，fil为卷积核，进行卷积，实现双线性插值
def my_fil(pic):
    h, w, _ = pic.shape
    #RGB -> BGR, to split
    pic = pic[:,:,::-1]
    #解决内存不连续问题
    pic = pic.copy()
    [B, G, R] = cv2.split(pic)
    filter = np.array([[[0.25, 0., 0.25],
                        [0.5, 0.25, 0.5],
                        [0.25, 0., 0.25]],

                       [[0.5, 0.25, 0.5],
                        [1., 1., 1.],
                        [0.5, 0.25, 0.5]],

                       [[0.25, 0., 0.25],
                        [0.5, 0.25, 0.5],
                        [0.25, 0., 0.25]]])

    B_fil = cv2.filter2D(B, -1, kernel=filter[:, :, 2])
    G_fil = cv2.filter2D(G, -1, kernel=filter[:, :, 1])
    R_fil = cv2.filter2D(R, -1, kernel=filter[:, :, 0])

    pic_new = cv2.merge([B_fil, G_fil, R_fil])
    #BGR -> RGB
    pic_new = pic_new[:,:,::-1]
    #解决内存不连续问题
    pic_new = pic_new.copy()
    return pic_new

def my_fil_d4(pic):
    #b, c, h, w -> batch_size, channel, height, width
    b, c, h, w = pic.shape
    filter = np.array([[[0.25, 0., 0.25],
                        [0.5, 0.25, 0.5],
                        [0.25, 0., 0.25]],

                       [[0.5, 0.25, 0.5],
                        [1., 1., 1.],
                        [0.5, 0.25, 0.5]],

                       [[0.25, 0., 0.25],
                        [0.5, 0.25, 0.5],
                        [0.25, 0., 0.25]]])
    
    # pic转numpy
    pic = pic.detach().cpu().numpy() if pic.requires_grad else pic.cpu().numpy()
    pic3 = pic
    for i in range(b):
        pic2 = pic[i]
        #CHW -> HWC, to split
        pic2 = pic2.transpose((1, 2, 0))
        #RGB -> BGR, to split
        pic2 = pic2[:,:,::-1]
        #解决内存不连续问题
        pic2 = pic2.copy()
        
        [B, G, R] = cv2.split(pic2)
        
        B_fil = cv2.filter2D(B, -1, kernel=filter[:, :, 2])
        G_fil = cv2.filter2D(G, -1, kernel=filter[:, :, 1])
        R_fil = cv2.filter2D(R, -1, kernel=filter[:, :, 0])
        
        pic_new = cv2.merge([B_fil, G_fil, R_fil])
        #BGR -> RGB
        pic_new = pic_new[:,:,::-1]
        #解决内存不连续问题
        pic_new = pic_new.copy()
        #HWC -> CHW
        pic_new = pic_new.transpose((2, 0, 1))
        pic3[i] = pic_new
    pic3 = torch.from_numpy(pic3)
    pic3 = pic3.cuda()
    return pic3
        
        

In [6]:
class MyDataset(Dataset):
    def __init__(self, filepath, transform=None):
        self.filepath = filepath
        self.transform = transform
        
    def __getitem__(self, index):
        imgs = os.listdir(filepath)
        path = filepath + imgs[index]
        
        temp = cv2.imread(path)
        #BGR -> RGB
        temp = temp[:,:,::-1]
        #解决内存不连续问题
        temp = temp.copy()
        temp = np.float64(temp)
        h, w, _ = temp.shape

        temp = temp[0:(h - h%16), 0:(w-w%16), :]  # 避免上下采样时出现维度不匹配问题

        #归一化
        temp = temp / 255.0
        
        label = temp
        data = my_fil(CFA(temp))
        
        if self.transform is not None:
            data = self.transform(data)
            label = self.transform(label)
            
        return data, label

    def __len__(self):
        return len(os.listdir(self.filepath))

In [7]:
def train_data_get(data_path):
    img_list = []
    labels = []
    imgs = os.listdir(data_path)

    for i in range(len(imgs)):
        path = data_path + imgs[i]
        temp = cv2.imread(path)
        #BGR -> RGB
        temp = temp[:,:,::-1]
        #解决内存不连续问题
        temp = temp.copy()
        temp = np.float64(temp)
        h, w, _ = temp.shape

        temp = temp[0:(h - h%16), 0:(w-w%16), :]  # 避免上下采样时出现维度不匹配问题

        #归一化
        temp = temp / 255.0
        
        temp_label = temp
        temp = CFA(temp)
        temp = my_fil(temp)
        
        img_list.append(temp)
        labels.append(temp_label)

    train_list, train_label = img_list, labels
    print("Train data get complete.")
    return train_list, train_label

def val_data_get(data_path):
    img_list = []
    labels = []
    imgs = os.listdir(data_path)

    for i in range(len(imgs)):
        path = data_path + imgs[i]
        temp = cv2.imread(path)
        #BGR -> RGB
        temp = temp[:,:,::-1]
        #解决内存不连续问题
        temp = temp.copy()
        temp = np.float64(temp)
        h, w, _ = temp.shape

        temp = temp[0:(h - h%16), 0:(w-w%16), :]  # 避免上下采样时出现维度不匹配问题
        
        #归一化
        temp = temp / 255.0
        
        temp_label = temp
        temp = CFA(temp)
        temp = my_fil(temp)
        
        img_list.append(temp)
        labels.append(temp_label)
        
    val_list, val_label = img_list, labels
    print("Validation data get complete.")
    return val_list, val_label

def Gehler_Shi_test_data_get(data_path):
    img_list = []
    labels = []
    imgs = os.listdir(data_path)

    for i in range(len(imgs)):
        path = data_path + imgs[i]
        temp = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        #BGR -> RGB
        temp = temp[:,:,::-1]
        #解决内存不连续问题
        temp = temp.copy()
        temp = np.float64(temp)

        # 去除black level(对于CANON5D的图片)
        h, w, _ = temp.shape
        if h == 2193 or h == 1460:
            temp = np.maximum(0., temp - 129.)

        temp = temp / 4095.  # 归一化(参考代码中是除以最大值，或许可以考虑除以4095)
        h, w, _ = temp.shape

        #对Gehler_Shi测试集不进行切分
        #将图像的大小缩小为16的倍数以适应UNET
        #为了让数据大小统一，切一个统一的大小（设为1344*1344）
        temp = temp[0:1344, 0:1344, :]
        
        #label为原值，input经过CFA滤波并双线性插值
        temp_label = temp
        temp = np.float64(CFA(temp))
        temp = my_fil(temp)

        img_list.append(temp)
        labels.append(temp_label)

    test_list, test_label = img_list, labels
    print("Gehler-Shi test data get complete.")
    return test_list, test_label

In [8]:
transform = transforms.Compose([transforms.ToTensor()])
train_batch_size = 16
train_number_epoch = 50

train_dir1 = "/kaggle/input/div2k-dataset/DIV2K_train_HR/DIV2K_train_HR/"
val_dir = "/kaggle/input/div2k-dataset/DIV2K_valid_HR/DIV2K_valid_HR/"
#test_dir = "/kaggle/input/gehler-shi-test/"

#train_list1, train_label1 = train_data_get(train_dir1)
#val_list, val_label = val_data_get(val_dir)
#test_list, test_label = Gehler_Shi_test_data_get(val_dir)

trainset1 = MyDataset(train_dir1, transform=transform)
print("MyDataset Correct")
trainloader1 = DataLoader(trainset1, batch_size=train_batch_size, shuffle=True)
valset = MyDataset(val_dir, transform=transform)
valloader = DataLoader(valset, batch_size=1, shuffle=True)
#testset = Datasets(data=test_list, label=test_label, transform=transform)
#testloader = DataLoader(testset, batch_size=1, shuffle=True)

MyDataset Correct


In [9]:
# 基本卷积块
class Conv(nn.Module):
    def __init__(self, C_in, C_out):
        super(Conv, self).__init__()
        self.layer = nn.Sequential(

            nn.Conv2d(C_in, C_out, 3, 1, 1),
            nn.LeakyReLU(),

            nn.Conv2d(C_out, C_out, 3, 1, 1),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        return self.layer(x)


# 下采样模块
class DownSampling(nn.Module):
    def __init__(self, C):
        super(DownSampling, self).__init__()
        self.Down = nn.Sequential(
            # 使用卷积进行2倍的下采样，通道数不变
            nn.Conv2d(C, C, 3, 2, 1),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return self.Down(x)


# 上采样模块
class UpSampling(nn.Module):

    def __init__(self, C):
        super(UpSampling, self).__init__()
        # 特征图大小扩大2倍，通道数减半
        self.Up = nn.Conv2d(C, C // 2, 1, 1)

    def forward(self, x, r):
        # 使用邻近插值进行上采样
        up = F.interpolate(x, scale_factor=2, mode="nearest")
        x = self.Up(up)
        # 拼接，当前上采样的，和之前下采样过程中的
        return torch.cat((x, r), 1)


# 主干网络
class UNet(nn.Module):

    def __init__(self):
        super(UNet, self).__init__()

        self.C1 = Conv(3, 64)
        self.D1 = DownSampling(64)
        self.C2 = Conv(64, 128)
        self.D2 = DownSampling(128)
        self.C3 = Conv(128, 256)
        self.D3 = DownSampling(256)
        self.C4 = Conv(256, 512)
        self.D4 = DownSampling(512)
        self.C5 = Conv(512, 1024)

        # 4次上采样
        self.U1 = UpSampling(1024)
        self.C6 = Conv(1024, 512)
        self.U2 = UpSampling(512)
        self.C7 = Conv(512, 256)
        self.U3 = UpSampling(256)
        self.C8 = Conv(256, 128)
        self.U4 = UpSampling(128)
        self.C9 = Conv(128, 64)

        self.Th = torch.nn.Sigmoid()
        self.pred = torch.nn.Conv2d(64, 3, 3, 1, 1)

    def forward(self, x):
        # 下采样部分
        R1 = self.C1(x)
        R2 = self.C2(self.D1(R1))
        R3 = self.C3(self.D2(R2))
        R4 = self.C4(self.D3(R3))
        Y1 = self.C5(self.D4(R4))

        # 上采样部分
        # 上采样的时候需要拼接起来
        O1 = self.C6(self.U1(Y1, R4))
        O2 = self.C7(self.U2(O1, R3))
        O3 = self.C8(self.U3(O2, R2))
        O4 = self.C9(self.U4(O3, R1))
        return self.Th(self.pred(O4))
    


In [10]:
net = UNet().cuda()
loss_func = nn.MSELoss()
best_train_loss = float('inf')
best_val_loss = float('inf')
lr = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-4)
train_dataloader = trainloader1
val_dataloader = valloader
#test_dataloader = testloader
epochs = train_number_epoch

In [11]:
def Pre_train(trainloader, valloader, model, loss_func, optimizer):
    global best_train_loss
    global best_val_loss
    device = torch.device('cuda')
    model = model.to(device)
    for epoch in range(0, epochs):
        # train part
        total_train_loss = 0
        counter1 = 0
        for i, data in enumerate(trainloader, 0):
            input_pic, label = data
            #转移到GPU上，类型变为float以与模型参数类型匹配
            input_pic, label = input_pic.to(device, dtype=torch.float), label.to(device, dtype=torch.float)
            
            
            # output 已经双线性插值 进网络重建，得到三通道图片
            output = model(input_pic)

            optimizer.zero_grad()
            
            loss = loss_func(output, label)
            loss.requires_grad_(True)
            loss.backward()
            
            optimizer.step()
            total_train_loss += loss.item()
            counter1 += 1

        # val part
        model.eval()
        total_val_loss = 0
        
        #初始化三通道PSNR值与CPSNR值
        total_psnr_r = 0
        total_psnr_g = 0
        total_psnr_b = 0
        total_cpsnr = 0
        
        counter2 = 0
        
        with torch.no_grad():
            for j, data2 in enumerate(valloader, 0):
                input_pic2, label2 = data2
                
                #CFA滤波与双线性插值的过程已在val_data_get函数中完成
                input_pic2, label2 = input_pic2.to(device, dtype=torch.float), label2.to(device, dtype=torch.float)
                # 进网络
                
                output2 = model(input_pic2)
                
                # RGGB采样
                output2 = output2.cuda()
                
                loss2 = loss_func(output2, label2)
                
                total_val_loss += loss2.item()
                
                #为计算PSNR部分而进行MSE的计算
                
                #RGB三通道的MSE值
                
                mse_r = loss_func(output2[:, 0, :, :],
                                 label2[:, 0, :, :])
                mse_g = loss_func(output2[:, 1, :, :],
                                 label2[:, 1, :, :])
                mse_b = loss_func(output2[:, 2, :, :],
                                 label2[:, 2, :, :])
                #CPSNR，即三通道PSNR值的平均值
                cmse = loss_func(output2[:, :, :, :],
                                 label2[:, :, :, :])
                
                #计算所有图片RGB三通道的PSNR值及CPSNR之和
                if mse_r.item() < 1.0e-10:  # 均方误差小到过分了
                    total_psnr_r += 100
                else:
                    total_psnr_r += 20 * math.log10(1 / math.sqrt(mse_r.item()))
                
                if mse_g.item() < 1.0e-10:  # 均方误差小到过分了
                    total_psnr_g += 100
                else:
                    total_psnr_g += 20 * math.log10(1 / math.sqrt(mse_g.item()))
                
                if mse_b.item() < 1.0e-10:  # 均方误差小到过分了
                    total_psnr_b += 100
                else:
                    total_psnr_b += 20 * math.log10(1 / math.sqrt(mse_b.item()))
                
                if cmse.item() < 1.0e-10:  # 均方误差小到过分了
                    total_cpsnr += 100
                else:
                    total_cpsnr += 20 * math.log10(1 / math.sqrt(cmse.item()))
            
                counter2 += 1
        
                
            train_loss_avg = total_train_loss / counter1
            val_loss_avg = total_val_loss / counter2
            
            #计算所有图片RGB三通道的PSNR值及CPSNR之均值
            psnr_r_avg = total_psnr_r / counter2
            psnr_g_avg = total_psnr_g / counter2
            psnr_b_avg = total_psnr_b / counter2
            cpsnr_avg = total_cpsnr / counter2
            
            print("Epoch number: {} , Train loss: {:.4f}, Val loss: {:.4f}, \nPSNR_R: {:.4f}, PSNR_G: {:.4f}, PSNR_B: {:.4f}, CPSNR: {:.4f}\n".format(epoch, train_loss_avg, val_loss_avg, psnr_r_avg, psnr_g_avg, psnr_b_avg, cpsnr_avg))        

        if val_loss_avg < best_val_loss:
            best_val_loss = val_loss_avg
            torch.save(net.state_dict(), '/kaggle/working/model/best_pretrain_unet.pth')
            