In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor

from PIL import Image
import os

class RescaleTransform(object):
    def __init__(self, scale):
        self.scale = scale

    def __call__(self, image):
        return image.resize((int(image.size[0] * self.scale), int(image.size[1] * self.scale)))


class HeyZee(Dataset):
    def __init__(self, folder1_path, folder2_path,):
        self.folder1_path = folder1_path
        self.folder2_path = folder2_path
        self.scale_transform = RescaleTransform(scale=0.5)
        self.file_paths = self._get_file_paths()
        self.tensor_transform = ToTensor()
    

    def _get_file_paths(self):
        file_paths = []
        folder1_files = os.listdir(self.folder1_path)
        folder2_files = os.listdir(self.folder2_path)
        for i in range(min(len(folder1_files), len(folder2_files))):
            file_paths.append((os.path.join(self.folder1_path, folder1_files[i]), os.path.join(self.folder2_path, folder2_files[i])))
        return file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path1, file_path2 = self.file_paths[idx]
        I0 = Image.open(file_path1)
        O0 = Image.open(file_path2)

        I1 = self.scale_transform(I0)
        I2 = self.scale_transform(I1)
        I3 = self.scale_transform(I2)

        O1 = self.scale_transform(O0)
        O2 = self.scale_transform(O1)
        O3 = self.scale_transform(O2)

        I0 = self.tensor_transform(I0)
        I1 = self.tensor_transform(I1)
        I2 = self.tensor_transform(I2)
        I3 = self.tensor_transform(I3)

        O0 = self.tensor_transform(O0)
        O1 = self.tensor_transform(O1)
        O2 = self.tensor_transform(O2)
        O3 = self.tensor_transform(O3)

        return (I0, I1, I2, I3), (O0, O1, O2, O3)




In [2]:


custom_dataset = HeyZee('Data/haze', 'Data/clear')
dataloader = DataLoader(custom_dataset, batch_size=4, shuffle=True)

(I0, I1, I2, I3), (O0, O1, O2, O3) = custom_dataset[0]

In [3]:
I0.shape, I1.shape, I2.shape, I3.shape, O0.shape, O1.shape, O2.shape, O3.shape

(torch.Size([3, 4000, 6000]),
 torch.Size([3, 2000, 3000]),
 torch.Size([3, 1000, 1500]),
 torch.Size([3, 500, 750]),
 torch.Size([3, 4000, 6000]),
 torch.Size([3, 2000, 3000]),
 torch.Size([3, 1000, 1500]),
 torch.Size([3, 500, 750]))