In [75]:
import h5py
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchmetrics
import matplotlib.pyplot as plt
import cv2

from PIL import Image
from torchvision.transforms.functional import crop, pad
from tqdm import tqdm
import warnings
import math
warnings.filterwarnings("ignore")

#### path define

In [50]:
base_path = r"D:\programming\dataset\DIV2K" # base dir
# original high resolution training data folder
train_hr = os.path.join(base_path, "DIV2K_train_HR") 
# original low resolution training label folder
train_lr = os.path.join(base_path, "DIV2K_train_LR_bicubic_X2")
# original high resolution test data folder
val_hr = os.path.join(base_path, "DIV2K_valid_HR")
# original low resolution test label folder
val_lr = os.path.join(base_path, "DIV2K_valid_LR_bicubic_X2")

#### prepare for data set

This dataset use to test how the patches size influence the training, Set LR patches size = 16, 32, 64, 128, 256, corresponding HR patches size = 32, 64, 128, 256, 512, use padding with 0

In [76]:
def create_Y_channel(hr_path, lr_path, h5_folder, patch_size):
    """
    create train data by the data set folder created by crop_image_LR_to_4_patches
    and crop_image_HR_to_4_patches
    """
    if not os.path.exists(h5_folder):
        os.makedirs(h5_folder)
    length_list = list()
    for number in range(6):
        image_number_range = range(number * 100 + 1, number * 100 + 101)
        h5path = os.path.join(h5_folder, str(number) + ".hdf5")
        length = create_Y_channel_core(hr_path, lr_path, h5path, image_number_range, patch_size)
        length_list.append(length)
    index_file = os.path.join(h5_folder, "index.txt")
    with open(index_file, "w") as file:
        for length in length_list:
            file.write(str(length) + "\n")
         
        
def create_Y_channel_core(hr_path, lr_path, h5path, image_number_range, patch_size):
    """
    called by create_train_data_Y_channel, read 100 numbers of LR and HR images,
    save to hdf5 file
    """
    hrPatches = list()
    lrPatches = list()
    with tqdm(total=len(image_number_range)) as t:
        for image_number in image_number_range:
            hrp = os.path.join(hr_path, "{:0>4}".format(image_number) + ".png")
            lrp = os.path.join(lr_path, "{:0>4}".format(image_number) + "x2.png")
            if not os.path.exists(hrp):
                t.update(1)
                continue
            hr = cv2.imread(hrp)
            hr = cv2.cvtColor(hr, cv2.COLOR_BGR2YCR_CB) # for hr, change to ycbcr
            lr = cv2.imread(lrp)
            lr = cv2.cvtColor(lr, cv2.COLOR_BGR2YCR_CB) # for hr, change to ycbcr
                
            hr = np.array(hr).astype(np.float32)[:, :, 0] / 255  # get Y channel
            lr = np.array(lr).astype(np.float32)[:, :, 0] / 255  # get Y channel
            
            height, width = lr.shape
            for i in range(0, height, patch_size):
                for j in range(0, width, patch_size):
                    l_patch_temp = lr[i:i + patch_size, j:j + patch_size]
                    w, h = l_patch_temp.shape
                    l_patch = np.zeros((patch_size, patch_size), dtype=np.float32)
                    l_patch[:w, :h] = l_patch_temp
                    l_patch = np.expand_dims(l_patch, axis=0)
                    
                    h_patch_temp = hr[i*2:i*2 + patch_size*2, j*2:j*2 + patch_size*2]
                    w, h = h_patch_temp.shape
                    h_patch = np.zeros((patch_size*2, patch_size*2), dtype=np.float32)
                    h_patch[:w, :h] = h_patch_temp
                    h_patch = np.expand_dims(h_patch, axis=0)
                    
                    lrPatches.append(l_patch)
                    hrPatches.append(h_patch)
            t.update(1)
            
    length = len(hrPatches)
    h5_file = h5py.File(h5path, 'w')
    h5_file.create_dataset('lr', data=np.array(lrPatches))
    h5_file.create_dataset('hr', data=np.array(hrPatches))
    h5_file.close()
    return length


In [77]:
# create Y channel size = 16
h5_folder = os.path.join(base_path, "train_Ychannel_size_64_128") 
create_Y_channel(train_hr, train_lr, h5_folder, 64)

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:12<00:00,  7.94it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.09it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.83it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.00it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.69it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.78it/s]


This dataset use to test how the training channel influence the training. Set R, G, B, and RGB three types of channels. patch size = 64

In [73]:
def create_BGR_channel(hr_path, lr_path, h5_folder, patch_size):
    """
    create train data by the data set folder created by crop_image_LR_to_4_patches
    and crop_image_HR_to_4_patches
    """
    if not os.path.exists(h5_folder):
        os.makedirs(h5_folder)
    length_list = list()
    for number in range(6):
        image_number_range = range(number * 100 + 1, number * 100 + 101)
        h5path = os.path.join(h5_folder, str(number) + ".hdf5")
        length = create_BGR_channel_core(hr_path, lr_path, h5path, image_number_range, patch_size)
        length_list.append(length)
    index_file = os.path.join(h5_folder, "index.txt")
    with open(index_file, "w") as file:
        for length in length_list:
            file.write(str(length) + "\n")
         
        
def create_BGR_channel_core(hr_path, lr_path, h5path, image_number_range, patch_size):
    """
    called by create_train_data_Y_channel, read 100 numbers of LR and HR images,
    save to hdf5 file
    """
    hrPatches = list()
    lrPatches = list()
    with tqdm(total=len(image_number_range)) as t:
        for image_number in image_number_range:
            hrp = os.path.join(hr_path, "{:0>4}".format(image_number) + ".png")
            lrp = os.path.join(lr_path, "{:0>4}".format(image_number) + "x2.png")
            if not os.path.exists(hrp):
                t.update(1)
                continue
            hr = cv2.imread(hrp)
            lr = cv2.imread(lrp)                
            hr = np.array(hr).astype(np.float32) / 255  # get BGR channel
            lr = np.array(lr).astype(np.float32) / 255  # get BGR channel
            height, width, _ = lr.shape
            for i in range(0, height, patch_size):
                for j in range(0, width, patch_size):
                    l_patch_temp = lr[i:i + patch_size, j:j + patch_size]
                    w, h, _ = l_patch_temp.shape
                    l_patch = np.zeros((patch_size, patch_size, 3), dtype=np.float32)
                    l_patch[:w, :h] = l_patch_temp
                    l_patch = np.transpose(l_patch, [2,0,1])
                    
                    
                    h_patch_temp = hr[i*2:i*2 + patch_size*2, j*2:j*2 + patch_size*2]
                    w, h, _ = h_patch_temp.shape
                    h_patch = np.zeros((patch_size*2, patch_size*2, 3), dtype=np.float32)
                    h_patch[:w, :h] = h_patch_temp
                    h_patch = np.transpose(h_patch, [2,0,1])
                    
                    lrPatches.append(l_patch)
                    hrPatches.append(h_patch)
            t.update(1)
    length = len(hrPatches)
    h5_file = h5py.File(h5path, 'w')
    h5_file.create_dataset('lr', data=np.array(lrPatches))
    h5_file.create_dataset('hr', data=np.array(hrPatches))
    h5_file.close()
    return length


In [74]:
h5_folder = os.path.join(base_path, "train_BGR_channel") 
create_BGR_channel(train_hr, train_lr, h5_folder, 64)

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:12<00:00,  8.14it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.91it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.55it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.15it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:11<00:00,  9.08it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 18.57it/s]


This dataset use to test how the scale influence the pre-upsampling performance

In [121]:
def create_Y_channel_bicubic(hr_path, h5_folder, patch_size, scale=2):
    """
    create train data by the data set folder created by crop_image_LR_to_4_patches
    and crop_image_HR_to_4_patches
    """
    if not os.path.exists(h5_folder):
        os.makedirs(h5_folder)
    length_list = list()
    for number in range(6):
        image_number_range = range(number * 100 + 1, number * 100 + 101)
        h5path = os.path.join(h5_folder, str(number) + ".hdf5")
        length = create_Y_channel_bicubic_core(hr_path, h5path, image_number_range, patch_size, scale)
        length_list.append(length)
    index_file = os.path.join(h5_folder, "index.txt")
    with open(index_file, "w") as file:
        for length in length_list:
            file.write(str(length) + "\n")
         
        
def create_Y_channel_bicubic_core(hr_path, h5path, image_number_range, patch_size, scale):
    """
    called by create_train_data_Y_channel, read 100 numbers of LR and HR images,
    save to hdf5 file
    """
    hrPatches = list()
    lrPatches = list()
    with tqdm(total=len(image_number_range)) as t:
        for image_number in image_number_range:
            hrp = os.path.join(hr_path, "{:0>4}".format(image_number) + ".png")
            if not os.path.exists(hrp):
                t.update(1)
                continue
            
            hr = cv2.imread(hrp)
            height, width, _ = hr.shape
            lr_height, lr_width = height//scale, width//scale
            lr = cv2.resize(hr, (lr_width, lr_height), interpolation=cv2.INTER_CUBIC)  # interpolation
            lr = cv2.resize(lr, (width, height), interpolation=cv2.INTER_CUBIC)  # interpolation
            
            hr = cv2.cvtColor(hr, cv2.COLOR_BGR2YCR_CB) # for hr, change to ycbcr
            lr = cv2.cvtColor(lr, cv2.COLOR_BGR2YCR_CB) # for lr, change to ycbcr
            hr = np.array(hr[:,:,0]).astype(np.float32) / 255  # get Y channel
            lr = np.array(lr[:,:,0]).astype(np.float32) / 255  # get Y channel
            
        
            for i in range(0, height, patch_size):
                for j in range(0, width, patch_size):
                    l_patch_temp = lr[i:i + patch_size, j:j + patch_size]
                    w, h = l_patch_temp.shape
                    l_patch = np.zeros((patch_size, patch_size), dtype=np.float32)
                    l_patch[:w, :h] = l_patch_temp
                    l_patch = np.expand_dims(l_patch, [0])
                    
                    h_patch_temp = hr[i:i + patch_size, j:j + patch_size]
                    w, h = h_patch_temp.shape
                    h_patch = np.zeros((patch_size, patch_size), dtype=np.float32)
                    h_patch[:w, :h] = h_patch_temp
                    h_patch = np.expand_dims(h_patch, [0])
                    
                    lrPatches.append(l_patch)
                    hrPatches.append(h_patch)
            t.update(1)
    length = len(hrPatches)
    h5_file = h5py.File(h5path, 'w')
    h5_file.create_dataset('lr', data=np.array(lrPatches))
    h5_file.create_dataset('hr', data=np.array(hrPatches))
    h5_file.close()
    return length


In [123]:
h5_folder = os.path.join(base_path, "train_scale3_channel") 
create_Y_channel_bicubic(train_hr, h5_folder, 128, scale=3)
h5_folder = os.path.join(base_path, "train_scale4_channel") 
create_Y_channel_bicubic(train_hr, h5_folder, 128, scale=4)
h5_folder = os.path.join(base_path, "train_scale5_channel") 
create_Y_channel_bicubic(train_hr, h5_folder, 128, scale=5)

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.03it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.04it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.21it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.11it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.58it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.16it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 11.21it/s]
100%|███████████████████████████████████

#### history code

In [None]:
# previous code, no need anymore

def create_train_data(hr_path, lr_path, h5path, pSize=33, pStride=33, padding=6):
    h5_file = h5py.File(h5path, 'w')
    hrPatches = list()
    lrPatches = list()
    for i in range(301, 551):
        hrp = os.path.join(hr_path, "{:0>4}".format(i) + ".png")
        lrp = os.path.join(lr_path, "{:0>4}".format(i) + "x2.png")
        hr = cv2.imread(hrp)
        hr = cv2.cvtColor(hr, cv2.COLOR_BGR2YCR_CB) # for hr, change to ycbcr
        
        lr = cv2.imread(lrp)
        lr = cv2.cvtColor(lr, cv2.COLOR_BGR2RGB)
        lr = cv2.resize(lr, (hr.shape[1], hr.shape[0]), interpolation=cv2.INTER_CUBIC)
        lr = cv2.cvtColor(lr, cv2.COLOR_RGB2YCR_CB)
        
        hr = np.array(hr).astype(np.float32)[:,:,0] / 255 # get Y channel
        lr = np.array(lr).astype(np.float32)[:,:,0] / 255 # get Y channel
        height, width = hr.shape
        for i in range(0, height - pSize + 1, pStride):
            for j in range(0, width - pSize + 1, pStride):
                l = lr[i:i + pSize, j:j + pSize]
                l = np.reshape(l,newshape=(1, pSize, pSize))
                lrPatches.append(l)
                h = hr[i + padding:i - padding + pSize, j + padding:j + pSize - padding]
                h = np.reshape(h,newshape=(1, pSize-padding*2, pSize-padding*2))
                hrPatches.append(h)
    h5_file.create_dataset('lr', data=np.array(lrPatches))
    h5_file.create_dataset('hr', data=np.array(hrPatches))
    h5_file.close()
    

# create validation set with RGB

class ValDataset(Dataset):
    def __init__(self, padding):
        super(ValDataset, self).__init__()
        self.padding = padding
    
    def __getitem__(self, idx):
        idx = idx + 551
        hrp = os.path.join(train_hr, "{:0>4}".format(idx) + ".png") # open the image of 
        lrp = os.path.join(train_lr, "{:0>4}".format(idx) + "x2.png") # the val set
        hr = cv2.imread(hrp)
        hr = cv2.cvtColor(hr, cv2.COLOR_BGR2YCR_CB) # for hr, change to ycbcr
        
        lr = cv2.imread(lrp)
        lr = cv2.cvtColor(lr, cv2.COLOR_BGR2RGB)
        lr = cv2.resize(lr, (hr.shape[1], hr.shape[0]), interpolation=cv2.INTER_CUBIC)
        lr = cv2.cvtColor(lr, cv2.COLOR_RGB2YCR_CB)
        
        hr = hr[self.padding:-self.padding,self.padding:-self.padding]
        hr = np.array(hr).astype(np.float32) / 255
        lr = np.array(lr).astype(np.float32) / 255
        return hr, lr
    

def test_psnr_interpolation():
    padding = 6
    psnr_list = list()
    ssim_list = list()
    
    for hr, lr in val_dataloader:
        hr = hr.numpy() * 255
        hr = np.reshape(hr, (hr.shape[1], hr.shape[2], hr.shape[3]))
        hr = np.array(hr).astype(np.uint8)
        hr = cv2.cvtColor(hr, cv2.COLOR_YCR_CB2RGB)

        lr = lr.numpy() * 255
        lr = np.reshape(lr, (lr.shape[1], lr.shape[2], lr.shape[3]))
        lr = lr[padding:-padding,padding:-padding]
        lr = np.array(lr).astype(np.uint8)
        lr = cv2.cvtColor(lr, cv2.COLOR_YCR_CB2RGB)
        psnr = PSNR(hr, lr, 255)
        psnr_list.append(psnr)
        ssim = SSIM(hr, lr)
        ssim_list.append(ssim)
    return np.average(psnr_list), np.average(ssim_list)

# psnr, ssim = test_psnr_interpolation()
# print("psnr for interpolation is: {:.3f}".format(psnr))
# print("ssim for interpolation is: {:.3f}".format(ssim))