# UNet++ Dataset

This notebook provides functions to load the custom dataset class for training UNet++ using Pytorch.

In [9]:
import torch
from skimage.io import imread
from torch.utils import data
import imagecodecs

class SegmentationDataSet(data.Dataset):
    
    def __init__(self, inputs: list, targets: list, transform=None):
            self.inputs = inputs
            self.targets = targets
            self.transform = transform
            self.inputs_dtype = torch.float32
            self.targets_dtype = torch.long
            
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, index: int):
        # select the sample
        input_ID = self.inputs[index]
        target_ID = self.targets[index]
        
        # load input and target
        x, y = imread(input_ID, as_gray=False), imread(target_ID, as_gray=True)
        
        # preprocessing
        if self.transform is not None:
            x, y = self.transform(x, y)
            
        # typecasting
        x, y = torch.from_numpy(x).type(self.inputs_dtype), torch.from_numpy(y).type(self.targets_dtype)
        
        return x, y

In [10]:
inputs = ['./dataset/images_train/data/0.png', './dataset/images_train/data/1.png']
targets = ['./dataset/images_train/labels/0.tif', './dataset/images_train/labels/1.tif']

training_dataset = SegmentationDataSet(inputs=inputs, targets=targets, transform=None)
training_dataloader = data.DataLoader(dataset=training_dataset, batch_size=2, shuffle=True)

x, y = next(iter(training_dataloader))

print(f'x = shape: {x.shape}; type: {x.dtype}')
print(f'x = min: {x.min()}; max: {x.max()}')
print(f'y = shape: {y.shape}; class: {y.unique()}; type: {y.dtype}')

x = shape: torch.Size([2, 892, 696, 3]); type: torch.float32
x = min: 0.0; max: 255.0
y = shape: torch.Size([2, 892, 696]); class: tensor([0, 1]); type: torch.int64
