In [1]:

import utils
path_pairs = utils.get_file_paths('Data/haze', 'Data/clear')


custom_dataset = utils.HeyZee(path_pairs)

(I0, I1, I2, I3), (O0, O1, O2, O3) = custom_dataset[0]

In [2]:
I0.shape, I1.shape, I2.shape, I3.shape, O0.shape, O1.shape, O2.shape, O3.shape

(torch.Size([3, 4000, 6000]),
 torch.Size([3, 2000, 3000]),
 torch.Size([3, 1000, 1500]),
 torch.Size([3, 500, 750]),
 torch.Size([3, 4000, 6000]),
 torch.Size([3, 2000, 3000]),
 torch.Size([3, 1000, 1500]),
 torch.Size([3, 500, 750]))

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

# do gamma correction
#img_gamma1 = np.power(img, gamma).clip(0,255).astype(np.uint8)

# gamma correction
def gammaCorrection(src, gamma):
    invGamma = 1 / gamma

    table = [((i / 255) ** invGamma) * 255 for i in range(256)]
    table = np.array(table, np.uint8)

    return cv2.LUT(src, table)


def do_gamma_correction(folder, gamma_B, gamma_G, gamma_R):
    for filename in os.listdir(folder):
        img = cv2.imread(os.path.join(folder,filename))
        # do gamma correction
        gammaImg_B = gammaCorrection(img[:,:, 0], gamma_B)
        gammaImg_G = gammaCorrection(img[:,:, 1], gamma_G)
        gammaImg_R = gammaCorrection(img[:,:, 2], gamma_R)
        gammaImg = np.dstack((gammaImg_B,gammaImg_G, gammaImg_R))

        # save to gamma folder
        cv2.imwrite(os.path.join(savepath  , filename), gammaImg)


readpath = "./dehaze_dataset/train/NTIRE2020/hazy"
savepath = "./data_hazy_gc/20_hazyRGB_gamma"
#do_gamma_correction(readpath, gamma_B=1.9, gamma_G=1.6, gamma_R=1.24)
# for 2020 GT images, gamma values should be: R(1.07), G(1.17), B(1.05), please verify the adjusted mean and variance

readpath = "./dehaze_dataset/train/NTIRE2021/hazy"
savepath = "./data_hazy_gc/21_hazyRGB_gamma"
do_gamma_correction(readpath, gamma_B=1, gamma_G=0.85, gamma_R=0.72)
# for 2021 GT images, gamma values should be: R(0.65), G(0.79), B(0.92), please verify the adjusted mean and variance

In [None]:
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision
import torchvision.transforms.functional as TF
import random
import os

#data augmentation for image rotate
def augment(hazy, clean):
    augmentation_method = random.choice([0, 1, 2, 3, 4, 5])
    rotate_degree = random.choice([90, 180, 270])
    '''Rotate'''
    if augmentation_method == 0:
        hazy = transforms.functional.rotate(hazy, rotate_degree)
        clean = transforms.functional.rotate(clean, rotate_degree)
        return hazy, clean
    '''Vertical'''
    if augmentation_method == 1:
        vertical_flip = torchvision.transforms.RandomVerticalFlip(p=1)
        hazy = vertical_flip(hazy)
        clean = vertical_flip(clean)
        return hazy, clean
    '''Horizontal'''
    if augmentation_method == 2:
        horizontal_flip = torchvision.transforms.RandomHorizontalFlip(p=1)
        hazy = horizontal_flip(hazy)
        clean = horizontal_flip(clean)
        return hazy, clean
    '''no change'''
    if augmentation_method == 3 or augmentation_method == 4 or augmentation_method == 5:
        return hazy, clean


class dehaze_train_dataset(Dataset):
    def __init__(self, train_dir):
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.list_train=[]
        for line in open(os.path.join(train_dir, 'train.txt')):
            line = line.strip('\n')
            if line!='':
                self.list_train.append(line)
        #'./NTIRE2021_Train_Hazy/'
        self.root_hazy=os.path.join(train_dir, 'hazy/')
        self.root_clean =os.path.join(train_dir, 'clean/')
        self.file_len = len(self.list_train)

    def __getitem__(self, index, is_train = True):
        if is_train:
            hazy = Image.open(self.root_hazy + self.list_train[index])
            clean=Image.open(self.root_clean + self.list_train[index])
            #crop a patch
            i,j,h,w = transforms.RandomCrop.get_params(hazy, output_size = (256,256))
            hazy_ = TF.crop(hazy, i, j, h, w)
            clean_ = TF.crop(clean, i, j, h, w)

            #data argumentation
            hazy_arg, clean_arg = augment(hazy_, clean_)
        hazy = self.transform(hazy_arg)
        clean = self.transform(clean_arg)
        return hazy,clean

    def __len__(self):
        return self.file_len