In [1]:
from torch.utils.data import Dataset
import torch
from PIL import Image
from pathlib import Path
import numpy as np
import random
import sys

sys.path.insert(0, '/home/taylor/PycharmProjects/uav-classif')

In [2]:
from utils.dataset.transforms import transforms as t
import os
from torch.utils.data import DataLoader

train_data_dir = '/home/taylor/PycharmProjects/uav-classif/kelp_species/train_input/data/train'

In [3]:
%%timeit

class SegmentationDataset(Dataset):
    def __init__(self, ds_path, ext=".png", transform=None, target_transform=None):
        super().__init__()
        self._images = sorted(Path(ds_path).joinpath("x").glob("*" + ext))
        self._labels = sorted(Path(ds_path).joinpath("y").glob("*" + ext))
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(list(self._images))

    def __getitem__(self, idx):
        img = Image.open(self._images[idx]).convert('RGB')
        target = Image.open(self._labels[idx])

        seed = np.random.randint(2147483647)

        torch.manual_seed(seed)
        random.seed(seed)  # apply this seed to img transforms
        if self.transform is not None:
            img = self.transform(img)

        torch.manual_seed(seed)
        random.seed(seed)  # apply this seed to target transforms
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

ds_train = SegmentationDataset(train_data_dir, transform=t.train_transforms,
                                       target_transform=t.train_target_transforms)

dl_train = DataLoader(ds_train, shuffle=True, batch_size=32, pin_memory=True,
                          drop_last=True, num_workers=os.cpu_count())

198 ms ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
%%timeit

class SegmentationDataset(Dataset):
    def __init__(self, ds_path, ext=".png", transform=None, target_transform=None):
        super().__init__()
        self._images = list(Path(ds_path).joinpath("x").iterdir())
        self._labels = list(Path(ds_path).joinpath("y").iterdir())
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(list(self._images))

    def __getitem__(self, idx):        
        img = Image.open(self._images[idx]).convert('RGB')
        target = Image.open(self._labels[idx])

        seed = np.random.randint(2147483647)

        torch.manual_seed(seed)
        random.seed(seed)  # apply this seed to img transforms
        if self.transform is not None:
            img = self.transform(img)

        torch.manual_seed(seed)
        random.seed(seed)  # apply this seed to target transforms
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

ds_train = SegmentationDataset(train_data_dir, transform=t.train_transforms,
                                       target_transform=t.train_target_transforms)

dl_train = DataLoader(ds_train, shuffle=True, batch_size=32, pin_memory=True,
                          drop_last=True, num_workers=os.cpu_count())

30 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
