In [18]:
import os
import cv2
import argparse
import numpy as np
import tensorflow as tf
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tensorflow import keras

In [19]:
PROJECT_DIR = os.getcwd()
CLASSES = ["ich"]

In [21]:
class Dataset(object):
    CLASSES = ["ich"]
    
    def __init__(
            self,
            images_dir,
            masks_dir,
            classes=None,
            augmentation=None,
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]

        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]

        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def get_basename(self, i):
        return os.path.basename(self.images_fps[i])

    def __getitem__(self, i):                
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)

        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')

        # add background if mask is not binary
        if mask.shape[-1] != 1:
            background = 1 - mask.sum(axis=-1, keepdims=True)
            mask = np.concatenate((mask, background), axis=-1)

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask

    def __len__(self):
        return len(self.ids)

In [20]:
class Dataloader(keras.utils.Sequence):
    """Load data from dataset and form batches

    Args:
        dataset: instance of Dataset class for image loading and preprocessing.
        batch_size: Integet number of images in batch.
        shuffle: Boolean, if `True` shuffle image indexes each epoch.
    """

    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(dataset))

        self.on_epoch_end()

    def __getitem__(self, i):

        # collect batch data
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            data.append(self.dataset[j])

        # transpose list of lists
        batch = [np.stack(samples, axis=0) for samples in zip(*data)]

        return batch

    def __len__(self):
        """Denotes the number of batches per epoch"""
        return len(self.indexes) // self.batch_size

    def on_epoch_end(self):
        """Callback function to shuffle indexes each epoch"""
        if self.shuffle:
            self.indexes = np.random.permutation(self.indexes)

In [43]:
def main(args):   
    x_train_dir = os.path.join(args.dataset, "train", "image")
    y_train_dir = os.path.join(args.dataset, "train", "mask")
    
    train_dataset = Dataset(
        x_train_dir,
        y_train_dir,
        classes=CLASSES
    )
    train_dataloader = Dataloader(train_dataset, 16, True)
    
    
    for idx, batch in enumerate(train_dataloader):
        if idx > 1:
            break
        print(batch[0].shape)
    
    # for idx, (image, mask) in enumerate(train_dataset):
    #     if idx > 5:
    #         break

    #     print(image.shape, mask.shape)
    #     print(image.dtype, mask.dtype)
    #     print(type(image), type(mask))
    

In [44]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        help="/path/to/dataset"
    )
    main(parser.parse_args([
        "--dataset", os.path.join(PROJECT_DIR, "datasets", "ICH_420", "export", "Positive")
    ]))


(16, 512, 512, 3)
(16, 512, 512, 3)
