In [61]:
import numpy as np
from pathlib import Path
from PIL import Image
from matplotlib import pyplot as plt

from torch.utils.data.dataset import Dataset

In [20]:
TOTAL_NUM_IMAGES = 5000

TRAIN_SPLIT = 4000
VALIDATION_SPLIT = 500

TOTAL_NUM_CLASSES = 13

In [4]:
data_path = Path('../data')
img_path = data_path / 'images'
msk_path = data_path / 'masks'

images_list = list(img_path.glob('*.png'))
masks_list = list(msk_path.glob('*.png'))

if len(images_list) != len(masks_list) and len(images_list) != TOTAL_NUM_IMAGES:
    raise ValueError('Invalid data')

images_list = np.array(images_list)
masks_list = np.array(masks_list)
    
np.random.seed(1)
shuffle_idx = np.random.permutation(range(TOTAL_NUM_IMAGES))
images_list = images_list[shuffle_idx]
masks_list = masks_list[shuffle_idx]

train_images = images_list[:TRAIN_SPLIT]
train_masks = masks_list[:TRAIN_SPLIT]

validation_images = images_list[TRAIN_SPLIT:TRAIN_SPLIT + VALIDATION_SPLIT]
validation_masks = masks_list[TRAIN_SPLIT:TRAIN_SPLIT + VALIDATION_SPLIT]

test_images = images_list[TRAIN_SPLIT + VALIDATION_SPLIT:]
test_masks = masks_list[TRAIN_SPLIT + VALIDATION_SPLIT:]

In [58]:
class SegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        
    def __len__(self):
        return len(self.images_path)
    
    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]))
        image = image / 255
        mask = np.array(Image.open(self.mask_paths[idx]), dtype='int')[:, :, 0]
        mask = self.one_hot_encode(mask, TOTAL_NUM_CLASSES)
        return image, mask
        
    def one_hot_encode(self, mask, num_classes):
        y = mask.ravel()
        one_hot = np.zeros((y.shape[0], num_classes))
        one_hot[np.arange(y.shape[0]), y] = 1
        return np.reshape(one_hot, mask.shape + (num_classes,))

In [59]:
train_dataset = SegmentationDataset(train_images, train_masks)
validation_dataset = SegmentationDataset(validation_images, validation_masks)
test_dataset = SegmentationDataset(test_images, test_masks)