<a href="https://colab.research.google.com/github/Seowon-Ji/Multi-target/blob/master/dataloader.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from PIL import Image
from skimage import io, transform
import os
import numpy as np
import random
import cv2


class goprodataset(Dataset):
    def __init__(self,blur_dir, sharp_dir, crop=False, crop_size=256, multi_scale=False,
                 rotation=False, color_augment=False, tranform=None):

        self.blur_dir = blur_dir
        self.sharp_dir = sharp_dir
        self.blur_file_list = os.listdir(blur_dir)
        self.sharp_file_list = os.listdir(sharp_dir)

        self.transform = transform
        self.crop = crop
        self.crop_size = crop_size
        self.multi_scale = multi_scale
        self.rotation = rotation
        self.color_augment = color_augment
        self.rotate90 = transforms.RandomRotation(90)
        self.rotate45 = transforms.RandomRotation(45)
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.blur_file_list)

    def __getitem__(self, idx):
        img_name = self.blur_file_list[idx].split('.')[0]
        blur_img_path = os.path.join(self.blur_dir, img_name + '.png')
        sharp_img_path = os.path.join(self.sharp_dir, img_name + '.png')
        blur_image = Image.open(blur_img_path)
        sharp_image = Image.open(sharp_img_path)


        if self.rotation:
            degree = random.choice([90, 180, 270])
            blur_image = transforms.functional.rotate(blur_image, degree)
            sharp_image = transforms.functional.rotate(sharp_image, degree)

        if self.color_augment:
            #contrast_factor = 1 + (0.2 - 0.4*np.random.rand())
            #blur_image = transforms.functional.adjust_contrast(blur_image, contrast_factor)
            #sharp_image = transforms.functional.adjust_contrast(sharp_image, contrast_factor)
            blur_image = transforms.functional.adjust_gamma(blur_image, 1)
            sharp_image = transforms.functional.adjust_gamma(sharp_image, 1)
            sat_factor = 1 + (0.2 - 0.4*np.random.rand())
            blur_image = transforms.functional.adjust_saturation(blur_image, sat_factor)
            sharp_image = transforms.functional.adjust_saturation(sharp_image, sat_factor)

        # if self.transform:
        #     blur_image = self.transform(blur_image)
        #     sharp_image = self.transform(sharp_image)

        blur_image = self.to_tensor(blur_image)
        sharp_image = self.to_tensor(sharp_image)

        if self.crop:
            W = blur_image.size()[1]
            H = blur_image.size()[2]

            Ws = np.random.randint(0, W - self.crop_size - 1, 1)[0]
            Hs = np.random.randint(0, H - self.crop_size - 1, 1)[0]

            blur_image = blur_image[:, Ws:Ws + self.crop_size, Hs:Hs + self.crop_size]
            sharp_image = sharp_image[:, Ws:Ws + self.crop_size, Hs:Hs + self.crop_size]


        if self.multi_scale:
            H = sharp_image.size()[1]
            W = sharp_image.size()[2]
            blur_image_s1 = transforms.ToPILImage()(blur_image)
            sharp_image_s1 = transforms.ToPILImage()(sharp_image)
            blur_image_s2 = transforms.ToTensor()(transforms.Resize([H/2, W/2])(blur_image_s1))
            sharp_image_s2 = transforms.ToTensor()(transforms.Resize([H/2, W/2])(sharp_image_s1))
            blur_image_s3 = transforms.ToTensor()(transforms.Resize([H/4, W/4])(blur_image_s1))
            sharp_image_s3 = transforms.ToTensor()(transforms.Resize([H/4, W/4])(sharp_image_s1))
            blur_image_s1 = transforms.ToTensor()(blur_image_s1)
            sharp_image_s1 = transforms.ToTensor()(sharp_image_s1)
            return {'blur_image_s1': blur_image_s1, 'blur_image_s2': blur_image_s2, 'blur_image_s3': blur_image_s3, 'sharp_image_s1': sharp_image_s1, 'sharp_image_s2': sharp_image_s2, 'sharp_image_s3': sharp_image_s3}
        else:
            return {'blur_image': blur_image, 'sharp_image': sharp_image}
