In [None]:
import cv2
import numpy as np
import os
import segmentation_models as sm
import natsort
import random
from tqdm import tqdm
import configparser
from libtiff import TIFF
import imgaug.augmenters as iaa

if os.environ["SM_FRAMEWORK"] == "tf.keras":
    from tensorflow.keras.utils import Sequence
else:
    from keras.utils import Sequence

In [None]:
# Code adapted partially from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly


class MultiLabelGenerator(Sequence):

    """MultiLabelGenerator for semantic segmentation architectures"""

    def __init__(
        self,
        file_ids,
        image_path,
        mask_path,
        batch_size,
        dim,
        n_channels,
        n_classes,
        augment_flag,
        tiff_flag,
        to_fit=True,
        shuffle=True,
    ):

        """Initialization
        :param file_ids: list of file names - for this dataset, the mapping between image to corresponding label file is based on common filename but different extension
        :param image_path: path to images location
        :param mask_path: path to masks location
        :param to_fit: True to return X and y, False to return X only
        :param batch_size: batch size at each iteration
        :param dim: tuple indicating image dimension
        :param n_channels: number of image channels
        :param n_classes: number of output masks
        :param augment_flag: boolean flag to indicate if augmentation should be enabled
        :param shuffle: True to shuffle label indexes after every epoch
        :param tiff_flag: Flag to indicate if labels should be parsed as TIFF files. Else it is assumed that they are in .npy format
        """

        self.file_ids = file_ids
        self.image_path = image_path
        self.mask_path = mask_path
        self.to_fit = to_fit
        self.batch_size = batch_size
        self.dim = dim
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.tiff_flag = tiff_flag
        self.augment_flag = augment_flag
        if augment_flag:
            self.augmentor_dict = {
                "1": iaa.OneOf([iaa.Fliplr(1)]),
                "2": iaa.OneOf([iaa.Affine(scale={"x": (0.6, 1.2)}, order=[0])]),
                "3": iaa.OneOf([iaa.Affine(rotate=(-20, 20), order=[0])]),
                "4": iaa.OneOf([iaa.PerspectiveTransform(scale=(0.01, 0.2))]),
                "5": iaa.OneOf([iaa.CropAndPad(percent=(-0.30, 0.30))]),
                "6": iaa.OneOf([iaa.ElasticTransformation(alpha=(0.5, 2), sigma=0.25)]),
                "7": iaa.OneOf(
                    [
                        iaa.AdditiveGaussianNoise(
                            loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5
                        )
                    ]
                ),
                "8": iaa.OneOf([iaa.AddToHueAndSaturation((-10, 10))]),
                "9": iaa.OneOf([iaa.GaussianBlur((0, 1.0))]),
                "10": iaa.OneOf([iaa.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5))]),
                "11": iaa.OneOf([iaa.LinearContrast((0.5, 1.0), per_channel=0.5)]),
            }
        self.on_epoch_end()

    def __len__(self):
        """Denotes the number of batches per epoch
        :return: number of batches per epoch
        """

        return int(np.floor(len(self.file_ids) / self.batch_size))

    def __getitem__(self, index):
        """Generate one batch of data
        :param index: index of the batch
        :return: X and y when fitting. X only when predicting
        """
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]

        # Find list of IDs
        file_ids_batch = [self.file_ids[k] for k in indexes]

        # Generate data
        X = self._generate_X(file_ids_batch)

        if self.to_fit:
            y = self._generate_y(file_ids_batch)
            if self.augment_flag:
                X, y = self.augment_data(X, y)
            return X, y

        else:
            return X

    def augment_data(self, X, y):
        """Augments a batch of data by randomly chosing one
        of the augmentors defined in augment_dict.Returns
        the transformed images and segmentation maps"""

        random_choice = random.randint(1, len(self.augmentor_dict))
        X_aug, y_aug = self.augmentor_dict[str(random_choice)](
            images=X, segmentation_maps=y
        )

        return X_aug, y_aug

    def on_epoch_end(self):
        """Updates indexes after each epoch"""

        self.indexes = np.arange(len(self.file_ids))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def _generate_X(self, batch_file_ids):

        """Generates data containing batch_size images
        :param list_IDs_temp: list of label ids to load
        :return: batch of images
        """
        # Initialization
        X = np.empty((self.batch_size, *self.dim, self.n_channels), dtype=np.uint8)

        # Generate data
        for i, file_id_name in enumerate(batch_file_ids):
            # Store sample
            X[i,] = self._load_image(os.path.join(self.image_path, file_id_name + ".jpg"))

        return X

    def _generate_y(self, batch_file_ids):

        """Generates data containing batch_size masks
        :param list_IDs_temp: list of label ids to load
        :return: batch if masks
        """
        y = np.empty((self.batch_size, *self.dim, self.n_classes), dtype=np.int8)

        # Generate data
        if self.tiff_flag:
            for i, file_id_name in enumerate(batch_file_ids):
                # Store sample
                y[i,] = self._load_tif(os.path.join(self.mask_path, file_id_name + ".tif"))

            return y
        else:
            for i, file_id_name in enumerate(batch_file_ids):
                # Store sample
                y[i,] = self._load_npy(os.path.join(self.mask_path, file_id_name + ".npy"))

            return y

    def _load_image(self, image_path):

        """Loads image
        :param image_path: path to image to load
        :return: image
        """
        if (self.n_channels == 1):
            img = cv2.imread(image_path,cv2.COLOR_BGR2GRAY)
            
        else:
            img = cv2.imread(image_path)
            
        img = cv2.resize(img, self.dim)
        return img

    def _load_tif(self, mask_path):

        """Load TIF masks after applying
        preprocessing to convert the masks
        :param mask_path: path to mask to load
        :return: mask image"""

        mask_tif = TIFF.open(mask_path, mode="r")
        mask_annot = mask_tif.iter_images()
        temp_list = []

        for mask_images in mask_annot:

            mask_images[mask_images <= 15] = 0  # just intended as a sanity check; practically the mask has either 0 or 255, but this range is set to handle any encoding issues between formats that may cause some stray pixels to take up values a little greater than 0
            mask_images[mask_images > 15] = 1
            temp_list.append(
                cv2.resize(mask_images, self.dim, interpolation=cv2.INTER_NEAREST)
            )

        mask_image = cv2.merge(tuple(temp_list))

        return mask_image

    def _load_npy(self, mask_path_npy):

        """Load TIF masks after applying
        preprocessing to convert the masks
        :param mask_path: path to mask to load
        :return: mask image"""

        temp_list = []
        mask = np.load(mask_path_npy)
        for i in range(self.n_classes):
            temp_list.append(
                cv2.resize(mask[:, :, i], self.dim, interpolation=cv2.INTER_NEAREST)
            )

        mask_image = cv2.merge(tuple(temp_list))

        return mask_image

In [None]:
def prepare_file_names(image_dir, masks_dir):
    
    """Generates file names present in image directory and corresponding mask directory
    :param image_dir: Path to images
    :param masks_dir: Path to masks
    :return: tuple of image file paths and mask file paths present in respective directories
    """
    
    image_file_paths_train = []
    mask_file_paths_train = []

    train_files = natsort.natsorted(os.listdir(image_dir))
    image_file_paths_train += [os.path.splitext(x)[0] for x in train_files]

    mask_files = natsort.natsorted(os.listdir(masks_dir))

    mask_file_paths_train += [os.path.splitext(x)[0] for x in mask_files]

    return image_file_paths_train, mask_file_paths_train

In [None]:
config = configparser.ConfigParser()
config.read("segmentation_training.ini")

TRAIN_IMAGES_PATH = config["IMAGE_FOLDERS"]["train_images_path"]
TRAIN_MASKS_PATH = config["IMAGE_FOLDERS"]["train_masks_path"]
VAL_IMAGES_PATH = config["IMAGE_FOLDERS"]["val_images_path"]
VAL_MASKS_PATH = config["IMAGE_FOLDERS"]["val_masks_path"]

AUGMENT_FLAG = config["DATA_GENERATOR_PARAMETERS"].getboolean("augment_flag")
TIFF_FLAG = config["DATA_GENERATOR_PARAMETERS"].getboolean("tiff_flag")
IMAGE_DIMENSIONS = tuple(
    map(
        lambda x: int(x),
        config["DATA_GENERATOR_PARAMETERS"]["image_dimensions"].split(","),
    )
)
NUM_CHANNELS = int(config["DATA_GENERATOR_PARAMETERS"]["num_channels"])
NUM_CLASSES = int(config["DATA_GENERATOR_PARAMETERS"]["num_classes"])
BATCH_SIZE = int(config["DATA_GENERATOR_PARAMETERS"]["batch_size"])