In [4]:
from glob import glob
from torch.utils.data import Dataset
import cv2
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import albumentations as albu

In [5]:
# 数据增强
def get_transforms(size):
    aug_fn = albu.Compose([albu.HorizontalFlip(),  # 水平翻转
                           albu.VerticalFlip(),  # 垂直翻转
                           albu.RandomRotate90()  # 随机旋转90度
                           ])
    crop_fn = albu.RandomCrop(size, size, always_apply=True) # 裁剪输入的随机部分
    pipeline = albu.Compose([aug_fn, crop_fn], additional_targets={'target': 'image'})

    def process(a, b):
        r = pipeline(image=a, target=b)
        return r['image'], r['target']

    return process

def get_normalize():
    transform = transforms.ToTensor()

    def process(a, b):
        image = transform(a).permute(1, 2, 0) - 0.5
        target = transform(b).permute(1, 2, 0) - 0.5
        return image, target

    return process

In [6]:
transform_fn = get_transforms(256)
normalize_fn = get_normalize()

In [7]:
train_blur = sorted(glob("./datasets/GoPro/train/blur/**/*.png", recursive=True))
train_sharp = sorted(glob("./datasets/GoPro/train/sharp/**/*.png", recursive=True))
val_blur = sorted(glob("./datasets/GoPro/test/blur/**/*.png", recursive=True))
val_sharp = sorted(glob("./datasets/GoPro/test/blur/**/*.png", recursive=True))

In [8]:
# 打包sharp和blur
class PairedDataset(Dataset):
    def __init__(self, files_a, files_b, transform_fn, normalize_fn):
        """files_a:blur, files_b:sharp"""
        self.data_a = files_a
        self.data_b = files_b
        self.transform_fn = transform_fn
        self.normalize_fn = normalize_fn
        
    def preprocess(self, img, res):
        def transpose(x):
            return np.transpose(x, (2, 0, 1))

        return map(transpose, self.normalize_fn(img, res))
    
    def __len__(self):
        return len(self.data_a)
    def __getitem__(self, idx):
        a, b = self.data_a[idx], self.data_b[idx]
        a = cv2.imread(a)
        b = cv2.imread(b)
        a, b = self.transform_fn(a, b)
        a, b = self.preprocess(a, b)
        return {'a': a, 'b': b}

In [9]:
train_datasets = PairedDataset(train_blur, train_sharp, transform_fn, normalize_fn)
val_datasets = PairedDataset(val_blur, val_sharp, transform_fn, normalize_fn)

train = DataLoader(train_datasets, batch_size=8, num_workers=0, shuffle=True)
val = DataLoader(val_datasets, batch_size=8, num_workers=0, shuffle=True)

In [47]:
# 验证过程
print(len(train_datasets))
print(len(val_datasets))

print(len(train))
print(len(val))

print(train_datasets.__getitem__(0)['a'].shape)
print(train_datasets.__getitem__(0)['b'].shape)
print(val_datasets.__getitem__(0)['a'].shape)
print(val_datasets.__getitem__(0)['b'].shape)

train_test = next(iter(train))
print(train_test['a'].shape)

2103
1111
263
139
torch.Size([3, 256, 256])
torch.Size([3, 256, 256])
torch.Size([3, 256, 256])
torch.Size([3, 256, 256])
torch.Size([8, 3, 256, 256])
