In [None]:
import torch
import torch
import torch.utils.data as data
import numpy as np
from PIL import Image
import glob
import random
import os
from torch.utils.data import DataLoader

In [None]:
random.seed(1143)

In [None]:
def populate_train_list(orig_images_path, hazy_images_path):

    train_list = []
    val_list = []

    image_list_haze = glob.glob(os.path.join(hazy_images_path, "*.png"))
    image_list_haze += glob.glob(os.path.join(hazy_images_path, "*.jpg"))
    
    ##############################################################

    tmp_dict = {}
    for image_path in image_list_haze:
        key = image_path.split("\\")[-1]
        tmp_dict[key] = []
        tmp_dict[key].append(image_path)
    
    ##############################################################

    train_keys = []
    len_keys = int(0.7*len(tmp_dict.keys()))
    for i in range(len_keys):
        train_keys.append(list(tmp_dict.keys())[i])

    for key in list(tmp_dict.keys()):
        if key in train_keys:
            ori_path = os.path.join(orig_images_path, key)
            hazy_path = os.path.join(orig_images_path, key)
            # print(ori_path)
            train_list.append((ori_path, hazy_path))
        else:
            val_list.append((os.path.join(orig_images_path, key), os.path.join(orig_images_path, key)))

    random.shuffle(train_list)
    random.shuffle(val_list)

    return train_list, val_list


class dehazing_loader(data.Dataset):
    def __init__(self, orig_images_path, hazy_images_path, mode='train'):
        self.train_list, self.val_list = populate_train_list(orig_images_path, hazy_images_path)
        if mode == 'train':
            self.data_list = self.train_list
            print("Total training examples:", len(self.train_list))
        else:
            self.data_list = self.val_list
            print("Total validation examples:", len(self.val_list))

    def __getitem__(self, index):
        data_orig_path, data_hazy_path = self.data_list[index]

        data_orig = Image.open(data_orig_path)
        data_hazy = Image.open(data_hazy_path)

        data_orig = data_orig.resize((480, 640), Image.ANTIALIAS)
        data_hazy = data_hazy.resize((480, 640), Image.ANTIALIAS)

        data_orig = (np.asarray(data_orig) / 255.0)
        data_hazy = (np.asarray(data_hazy) / 255.0)

        data_orig = torch.from_numpy(data_orig).float()
        data_hazy = torch.from_numpy(data_hazy).float()

        return data_orig.permute(2, 0, 1), data_hazy.permute(2, 0, 1)

    def __len__(self):
        return len(self.data_list)

## Cautions
若不在main中执行，num_workers要设置成0

In [None]:
num_epochs = 10
train_batch_size = 8
val_batch_size = 8
num_workers = 0
ori_path = r'E:\workspace\work2\UIEB\train\GT'
hazy_path = r'E:\workspace\work2\UIEB\train\hazy'

In [None]:
train_dataset = dehazing_loader(
                    orig_images_path=ori_path, 
                    hazy_images_path=hazy_path
            )

In [None]:
val_dataset = dehazing_loader(
                    orig_images_path=ori_path, 
                    hazy_images_path=hazy_path,
                    mode='val'
            )

In [None]:
train_loader = DataLoader(
                train_dataset,
                batch_size = train_batch_size,
                shuffle = False,
                num_workers = num_workers,
                # pin_memory = True
            )

val_loader = DataLoader(
                val_dataset,
                batch_size = val_batch_size,
                shuffle = False,
                num_workers = num_workers,
                # pin_memory = True
            )

In [None]:
for epoch in range(1):
    print('epoch {}'.format(epoch))
    for iteration, img_train in enumerate(train_loader):
        # print(iteration)
        img_ori, img_hazy = img_train
        print('At {} train stage, ori_size:{}, hazy_size{}'.format(iteration, img_ori.shape, img_hazy.shape))
        
    for iteration, img_val in enumerate(val_loader):
        img_ori, img_hazy = img_val
        print('At {} val stage, ori_size:{}, hazy_size{}'.format(iteration, img_ori.shape, img_hazy.shape))
