In [76]:
import cv2
import os
import matplotlib.pyplot as plt
import data_loader
import SRCNN_model
import torch
import math
from tqdm import tqdm
import torch.nn as nn
from criteria import SSIM, PSNR
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np

base_path = r"D:\programming\dataset\DIV2K"
valid_hr = os.path.join(base_path, "DIV2K_valid_HR")
valid_lr = os.path.join(base_path, "DIV2K_valid_LR_bicubic_X2")

In [106]:
def get_SRCNN_model(model_save_path=None):
    """
    load SRCNN model, if no model_save_path, then init a new model
    """
    model = SRCNN_model.SRCNN(padding=True)
    if (model_save_path == None) or (not os.path.exists(model_save_path)):
        print("init new model parameter")
        model.init_weights()
        current_epoch = 0
    else:
        para = torch.load(model_save_path)
        model.load_state_dict(para["state_dict"])
        current_epoch = para["epoch"]
        print("load model parameter")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print("use device: ", device)
    return model, device, current_epoch


def SSNR_train(model, device, train_dataloader, val_dataloader,
               current_epoch, model_save_path, max_epoch=10):
    lr_begin = 0.01
    criterion = nn.MSELoss()

    with tqdm(total=len(train_dataloader) * max_epoch) as t:
        t.update(len(train_dataloader) * current_epoch)  # update to current state
        while current_epoch < max_epoch:
            lr = math.pow(0.95, current_epoch) * lr_begin
            optimizer = torch.optim.SGD([
                {'params': model.conv1.parameters()},
                {'params': model.conv2.parameters()},
                {'params': model.conv3.parameters(), 'lr': lr * 0.1}
            ], lr=lr, momentum=0.9)

            for hr, lr in train_dataloader:
                hr, lr = hr.to(device), lr.to(device)
                lr_after = model(lr)
                loss = criterion(lr_after, hr)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                t.update(1)
            # test result
            current_epoch += 1
            torch.save({"epoch": current_epoch, "state_dict": model.state_dict()}, model_save_path)
            
            
def script(h5_train_path, model_save_path, batch_size):
    h5_train_path = os.path.join(base_path, h5_train_path)
    model_save_path = os.path.join(base_path, model_save_path)

    train_dataloader, val_dataloader = data_loader.create_train_val_data_loader(
        h5_train_path, valid_hr, valid_lr, batch_size=batch_size)
    
    model, device, current_epoch = get_SRCNN_model(
        model_save_path=model_save_path)
    
    SSNR_train(model, device, train_dataloader, val_dataloader,
                   current_epoch, model_save_path, max_epoch=20)
    
script("train_scale2_channel", "SRCNN_X2.model", 64)
torch.cuda.empty_cache()
script("train_scale3_channel", "SRCNN_X3.model", 64)
torch.cuda.empty_cache()
script("train_scale4_channel", "SRCNN_X4.model", 64)
torch.cuda.empty_cache()
script("train_scale5_channel", "SRCNN_X5.model", 64)
torch.cuda.empty_cache()

load model parameter
use device:  cuda:0


100%|████████████████████████████████████████████████████████████████████████████| 30320/30320 [41:44<00:00, 12.11it/s]


load model parameter
use device:  cuda:0


100%|████████████████████████████████████████████████████████████████████████████| 30320/30320 [41:33<00:00, 12.16it/s]


load model parameter
use device:  cuda:0


100%|████████████████████████████████████████████████████████████████████████████| 30320/30320 [41:58<00:00, 12.04it/s]


load model parameter
use device:  cuda:0


100%|████████████████████████████████████████████████████████████████████████████| 30320/30320 [41:50<00:00, 12.08it/s]


In [80]:
class ValDataset_scale(Dataset):
    def __init__(self, valid_hr, scale):
        super(ValDataset_scale, self).__init__()
        self.valid_hr = valid_hr
        self.scale = scale

    def __getitem__(self, idx):
        idx = idx + 551
        hrp = os.path.join(self.valid_hr, "{:0>4}".format(idx) + ".png")  # open the image of
        scale = self.scale
        hr = cv2.imread(hrp)
        
        height, width, _ = hr.shape
        lr_height, lr_width = height//scale, width//scale
        lr = cv2.resize(hr, (lr_width, lr_height), interpolation=cv2.INTER_CUBIC)  # interpolation
        lr = cv2.resize(lr, (width, height), interpolation=cv2.INTER_CUBIC)  # interpolation
        
        hr = cv2.cvtColor(hr, cv2.COLOR_BGR2YCR_CB)
        lr = cv2.cvtColor(lr, cv2.COLOR_BGR2YCR_CB)
        hr = np.expand_dims(hr[:,:,0], [0])
        lr = np.expand_dims(lr[:,:,0], [0])
             
        hr = np.array(hr).astype(np.float32) / 255
        lr = np.array(lr).astype(np.float32) / 255
        return hr, lr

    def __len__(self):
        return 250
    
def get_ValDataset_scale(valid_hr, scale):
    ValDataset_ = ValDataset_scale(valid_hr, scale)
    val_dataloader = DataLoader(dataset=ValDataset_, batch_size=1,)
    return val_dataloader


In [104]:
def test_PSNR_with_scale(model, model_name, valid_hr, scale, device):
    ValDataset = get_ValDataset_scale(valid_hr, scale)
    psnr_list = list()
    ssim_list = list()
    with tqdm(total=250) as t:
        for hr_y, lr_y in ValDataset:
            hr_y = hr_y.numpy().squeeze()
            hr_y = hr_y * 255

            lr_y = lr_y.to(device)  # put into model
            lr_y = model(lr_y)
            lr_y = lr_y.cpu().data.numpy().squeeze()  # get the model result
            lr_y[lr_y > 1] = 1  # cut Y channel, 16<=Y<=235
            lr_y[lr_y < 0] = 0
            
            lr_y = lr_y * 255
            psnr = PSNR(lr_y, hr_y, 255)  # test the result for RSCNN
            ssim = SSIM(lr_y, hr_y)
            psnr_list.append(psnr)
            ssim_list.append(ssim)
            t.update(1)
            
    print("start test: {} model, validation set scale: {}, PSNR: {}, SSIM: {}".\
          format(model_name, scale, np.average(psnr_list), np.average(ssim_list)))


In [109]:
for model_path in ("SRCNN_X2.model", "SRCNN_X3.model", "SRCNN_X4.model", "SRCNN_X5.model"):
    model, device, current_epoch = get_SRCNN_model(os.path.join(base_path, model_path))
    for factor in range(2,6):
        data_set = "SRCNN_X" + str(factor)
        test_PSNR_with_scale(model, model_path, valid_hr, factor, device)
    torch.cuda.empty_cache()


load model parameter
use device:  cuda:0


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:44<00:00,  1.52it/s]


start test: SRCNN_X2.model model, validation set scale: 2, PSNR: 31.213251188778457, SSIM: 0.8980322504792011


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:41<00:00,  1.55it/s]


start test: SRCNN_X2.model model, validation set scale: 3, PSNR: 27.339824494524954, SSIM: 0.8134522054615131


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:32<00:00,  1.63it/s]


start test: SRCNN_X2.model model, validation set scale: 4, PSNR: 25.827683196115792, SSIM: 0.7615179068023417


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:29<00:00,  1.67it/s]


start test: SRCNN_X2.model model, validation set scale: 5, PSNR: 24.523672858639323, SSIM: 0.7142607198620783
load model parameter
use device:  cuda:0


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:22<00:00,  1.75it/s]


start test: SRCNN_X3.model model, validation set scale: 2, PSNR: 29.853873357938717, SSIM: 0.8658124547457096


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:16<00:00,  1.83it/s]


start test: SRCNN_X3.model model, validation set scale: 3, PSNR: 27.35640663897105, SSIM: 0.8042295316262232


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:16<00:00,  1.83it/s]


start test: SRCNN_X3.model model, validation set scale: 4, PSNR: 25.938026219865655, SSIM: 0.7578590213403988


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:16<00:00,  1.83it/s]


start test: SRCNN_X3.model model, validation set scale: 5, PSNR: 24.64975121973537, SSIM: 0.71357619338587
load model parameter
use device:  cuda:0


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:19<00:00,  1.79it/s]


start test: SRCNN_X4.model model, validation set scale: 2, PSNR: 28.413295127315696, SSIM: 0.8217282228378349


100%|█████████████████████████████████████████████████████████████████████████████| 250/250 [9:19:00<00:00, 134.16s/it]


start test: SRCNN_X4.model model, validation set scale: 3, PSNR: 27.006977499858934, SSIM: 0.7842415907052646


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:10<00:00,  1.91it/s]


start test: SRCNN_X4.model model, validation set scale: 4, PSNR: 25.878885328413123, SSIM: 0.7491293055485209


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:09<00:00,  1.93it/s]


start test: SRCNN_X4.model model, validation set scale: 5, PSNR: 24.678980411306895, SSIM: 0.710641709337084
load model parameter
use device:  cuda:0


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:25<00:00,  1.72it/s]


start test: SRCNN_X5.model model, validation set scale: 2, PSNR: 27.07451490015635, SSIM: 0.7735019485307718


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:16<00:00,  1.84it/s]


start test: SRCNN_X5.model model, validation set scale: 3, PSNR: 26.364676763476446, SSIM: 0.7538444828616745


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:14<00:00,  1.85it/s]


start test: SRCNN_X5.model model, validation set scale: 4, PSNR: 25.635248545319925, SSIM: 0.7315539685470948


100%|████████████████████████████████████████████████████████████████████████████████| 250/250 [02:17<00:00,  1.82it/s]

start test: SRCNN_X5.model model, validation set scale: 5, PSNR: 24.66717191305021, SSIM: 0.7024253382029826



