In [103]:
import numpy as np
from pathlib import Path
from PIL import Image
from matplotlib import pyplot as plt

from torch.utils.data import Dataset, DataLoader

In [104]:
TOTAL_NUM_IMAGES = 5000

TRAIN_SPLIT = 4000
VALIDATION_SPLIT = 500

TOTAL_NUM_CLASSES = 13

In [105]:
data_path = Path('../data')
img_path = data_path / 'images'
msk_path = data_path / 'masks'

images_list = list(img_path.glob('*.png'))
masks_list = list(msk_path.glob('*.png'))

if len(images_list) != len(masks_list) and len(images_list) != TOTAL_NUM_IMAGES:
    raise ValueError('Invalid data')

images_list = np.array(images_list)
masks_list = np.array(masks_list)
    
np.random.seed(1)
shuffle_idx = np.random.permutation(range(TOTAL_NUM_IMAGES))
np.random.seed(None)
images_list = images_list[shuffle_idx]
masks_list = masks_list[shuffle_idx]

train_images = images_list[:TRAIN_SPLIT]
train_masks = masks_list[:TRAIN_SPLIT]

validation_images = images_list[TRAIN_SPLIT:TRAIN_SPLIT + VALIDATION_SPLIT]
validation_masks = masks_list[TRAIN_SPLIT:TRAIN_SPLIT + VALIDATION_SPLIT]

test_images = images_list[TRAIN_SPLIT + VALIDATION_SPLIT:]
test_masks = masks_list[TRAIN_SPLIT + VALIDATION_SPLIT:]

In [117]:
class SegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, size):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.size = size
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]).resize((self.size, self.size), resample=Image.LANCZOS))
        image = image / 255
        mask = np.array(Image.open(self.mask_paths[idx]).resize((self.size, self.size), resample=Image.LANCZOS), dtype='int')[:, :, 0]
        mask = self.one_hot_encode(mask, TOTAL_NUM_CLASSES)
        return np.moveaxis(image, -1, 0), np.moveaxis(mask, -1, 0)
        
    def one_hot_encode(self, mask, num_classes):
        y = mask.ravel()
        one_hot = np.zeros((y.shape[0], num_classes))
        one_hot[np.arange(y.shape[0]), y] = 1
        return np.reshape(one_hot, mask.shape + (num_classes,))

In [118]:
train_dataset = SegmentationDataset(train_images, train_masks, 224)
validation_dataset = SegmentationDataset(validation_images, validation_masks, 224)
test_dataset = SegmentationDataset(test_images, test_masks, 224)

In [119]:
train_loader = DataLoader(train_dataset, batch_size=64)
validation_loader = DataLoader(validation_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)