In [None]:
class GOOSE_SemanticDataset(Dataset):
    """
    Example Pytorch Dataset Module for semantic tasks with GOOSE.
    """

    def __init__(self, dataset_dict: List[Dict], crop: bool = True, resize_size: Iterable[int] = None):
        '''
        Parameters:
            dataset_dict  [Iter]    : List of  Dicts with the images information generated by *goose_create_dataDict*

            crop          [Bool]    : Whether to make a square crop of the images or not

            resize_size   [Iter]    : List with the target resize size of the images (After the crop if crop == True)
        '''
        self.dataset_dict   = dataset_dict
        self.transforms     = transforms.Compose([
            transforms.ToTensor(),
            ])
        self.resize_size    = resize_size
        self.crop           = crop

    def preprocess(self, image):
        if image is None:
            return None

        if self.crop:
            # Square-Crop in the center
            s = min([image.width , image.height])
            image = transforms.CenterCrop((s,s)).forward(image)

        if self.resize_size is not None:
            # Resize to given size
            image = image.resize(self.resize_size, resample=Image.NEAREST)

        return image


    def __getitem__(self, i):
        '''
        Parameter:
            i   [int]                   : Index of the image to get

        Returns:
            image_tensor [torch.Tensor] : 3 x H x W Tensor

            label_tensor [torch.Tensor] : H x W Tensor as semantic map
        '''
        image = Image.open(self.dataset_dict[i]['img_path']).convert('RGB')
        label = Image.open(self.dataset_dict[i]['semantic_path']).convert('L')

        image = self.preprocess(image)
        label = self.preprocess(label)

        image_tensor = self.transforms(image)
        label_tensor = torch.from_numpy(np.array(label)).long()

        return image_tensor, label_tensor

    def __len__(self):
        return len(self.dataset_dict)