In [1]:
# pip install -U albumentations

In [1]:
import cv2
import matplotlib.pyplot as plt
import albumentations as A
import random
import os

random.seed(42)

In [2]:
class DataReader:
    """
    A helper class to read image and mask data from specified paths.
    """
    @staticmethod
    def read_data(image_path, mask_path):
        """
        Reads an image and its corresponding mask from the given file paths.
        Args:
            image_path (str): Path to the image file.
            mask_path (str): Path to the mask file.
        Returns:
            tuple: A tuple containing the original image and mask as numpy arrays.
        """
        image = cv2.imread(image_path)
        original_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        original_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        return original_image, original_mask
    
    @staticmethod
    def read_mask(image_path, mask_path):
        """
        Reads a mask from the given file paths.
        Args:
            mask_path (str): Path to the mask file.
        Returns:
            the original mask as numpy arrays.
        """
        original_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        return original_mask


class DirectoryManager:
    """
    A helper class to manage directory operations such as ensuring directories exist.
    """
    @staticmethod
    def ensure_directory_exists(file_path):
        """
        Ensures that the directory for the given file path exists. 
        Creates the directory if it does not exist.
        Args:
            file_path (str): The file path for which the directory needs to be checked/created.
        """
        directory = os.path.dirname(file_path)
        if not os.path.exists(directory):
            os.makedirs(directory)



class Visualizer:
    """
    A helper class for visualizing images and masks.
    """
    @staticmethod
    def visualize(image, mask, original_image=None, original_mask=None, title=None):
        """
        Visualizes the original and transformed images and masks.
        Args:
            image (numpy.ndarray): Transformed image to be visualized.
            mask (numpy.ndarray): Transformed mask to be visualized.
            original_image (numpy.ndarray, optional): Original image for comparison.
            original_mask (numpy.ndarray, optional): Original mask for comparison.
            title (str, optional): Title for the visualization plot.
        """
        fontsize = 18
        if original_image is None and original_mask is None:
            f, ax = plt.subplots(2, 1, figsize=(8, 8))
            ax[0].imshow(image)
            ax[1].imshow(mask)
        else:
            f, ax = plt.subplots(2, 2, figsize=(8, 8))
            f.suptitle(title, fontsize=fontsize)
            ax[0, 0].imshow(original_image)
            ax[0, 0].set_title('Original image', fontsize=fontsize)
            ax[1, 0].imshow(original_mask)
            ax[1, 0].set_title('Original mask', fontsize=fontsize)
            ax[0, 1].imshow(image)
            ax[0, 1].set_title('Transformed image', fontsize=fontsize)
            ax[1, 1].imshow(mask)
            ax[1, 1].set_title('Transformed mask', fontsize=fontsize)



class ImageSaver:
    """
    A helper class for saving images to the disk.
    """
    @staticmethod
    def save_image(image, file_path):
        """
        Saves an image to the specified file path.
        Args:
            image (numpy.ndarray): Image to be saved.
            file_path (str): Path where the image will be saved.
        """
        DirectoryManager.ensure_directory_exists(file_path)
        cv2.imwrite(file_path, image)


In [3]:
class DataAugmentationWorker:
    """
    A class to perform data augmentation for image and mask pairs.
    Attributes:
        pixel_transformations (dict): A dictionary of pixel-level transformations.
        image_path (str): Path to the original image.
        mask_path (str): Path to the original mask.
        original_image (numpy.ndarray): Original image read from image_path.
        original_mask (numpy.ndarray): Original mask read from mask_path.
        all_transformations (list): List of all transformation combinations.
        augmented_data_root_path (str): Root path to save augmented images and masks.
    """

    def __init__(self, image_path, mask_path, augmented_data_root_path):
        """
        Initialize the DataAugmentationWorker with image paths and transformations.
        Args:
            image_path (str): Path to the original image.
            mask_path (str): Path to the original mask.
            augmented_data_root_path (str): Root path to save augmented data.
        """
        self.pixel_transformations = {
            "ChannelShuffle": A.ChannelShuffle(p=1),
            "CLAHE": A.CLAHE(p=1),
            "Equalize": A.Equalize(p=1),
            "RandomBrightnessContrast": A.RandomBrightnessContrast(p=1)
        }
        self.image_path = image_path
        self.mask_path = mask_path
        self.original_image, self.original_mask = DataReader.read_data(image_path, mask_path)
        self.all_transformations = []
        self.augmented_data_root_path = augmented_data_root_path

    def create_transformations(self):
        """
        Create a list of transformations, combining rotation and pixel transformations.
        Returns:
            list: List of combined transformations.
        """
        all_transformations = []
        for i in range(0, 180, 30):
            random_interval = random.choice([[i, i + 30], [-i, -(i + 30)]])
            random_pixel_transformation = random.choice(
                list(self.pixel_transformations.keys()))
            transformation = A.Compose([
                A.Rotate(limit=random_interval, p=1),
                self.pixel_transformations[random_pixel_transformation]
            ])
            all_transformations.append(transformation)
        return all_transformations

    def apply_transformations(self):
        """
        Apply the created transformations to the original image and mask.
        Returns:
            list: List of tuples containing transformed images and masks.
        """
        all_transformed_data = []

        if not self.all_transformations:
            self.all_transformations = self.create_transformations()

        for transformation in self.all_transformations:
            transformed_data = transformation(
                image=self.original_image, 
                mask=self.original_mask
            )
            all_transformed_data.append(
                (transformed_data["image"], transformed_data["mask"]))
        return all_transformed_data

    def visualize_all_transformations(self):
        """
        Visualize all the transformations applied to the original image and mask.
        """
        for transformed_image, transformed_mask in self.apply_transformations():
            Visualizer.visualize(transformed_image, transformed_mask, 
                                 self.original_image, self.original_mask)

    def save_one_image_mask_couple(self, image, mask, image_path, mask_path):
        """
        Save a single image and mask pair to the specified paths.
        Args:
            image (numpy.ndarray): The image to be saved.
            mask (numpy.ndarray): The mask to be saved.
            image_path (str): Path to save the image.
            mask_path (str): Path to save the mask.
        """
        ImageSaver.save_image(image, image_path)
        ImageSaver.save_image(mask, mask_path)
        print(f"Saved {image_path}")
        print(f"Saved {mask_path}")

    def save_all_transformations(self):
        """
        Save all transformed image and mask pairs to the augmentation root path.
        """
        all_transformed_data = self.apply_transformations()
        self.save_one_image_mask_couple(
            cv2.cvtColor(self.original_image, cv2.COLOR_RGB2BGR), 
            self.original_mask, 
            self.image_path.replace("../data", self.augmented_data_root_path), 
            self.mask_path.replace("../data", self.augmented_data_root_path)
        )
        for i, ((transformed_image, transformed_mask), transformation) in \
                enumerate(zip(all_transformed_data, self.all_transformations)):
            
            transformations_names = list(map(
                lambda t: t['__class_fullname__'], 
                transformation.to_dict()['transform']['transforms']
            ))
            image_mask_name_suffix = "_".join(
                [transformations_names[0], str(30 * i), str(30 * i + 30)] + transformations_names[1:]
            ).lower()
            
            image_path = self.image_path.replace("../data", self.augmented_data_root_path)\
                                        .replace(".png", "_" + image_mask_name_suffix + ".png")
            mask_path = self.mask_path.replace("../data", self.augmented_data_root_path)\
                                      .replace("_mask.png", "_" + image_mask_name_suffix + "_mask.png")
            
            self.save_one_image_mask_couple(
                transformed_image, transformed_mask, image_path, mask_path
            )


In [4]:
# test = DataAugmentationWorker(image_path="../data/images/image_max_1.png",
#                               mask_path="../data/masks/image_max_1_mask.png",
#                               augmented_data_root_path='./to_del')
# test.save_all_transformations()

In [5]:
def apply_data_augmentation(image_path, mask_path,augmented_data_root_path):
    worker = DataAugmentationWorker(
        image_path=image_path,
        mask_path=mask_path,
        augmented_data_root_path=augmented_data_root_path)
    worker.save_all_transformations()

### Apply DataAugmentation to training dataset

In [6]:
from tqdm import tqdm
from joblib import Parallel, delayed

training_image_root_path = '../data/train/images/'
training_mask_root_path = '../data/train/masks/'

_ = Parallel(n_jobs=-1)(delayed(apply_data_augmentation)(
                                                    os.path.join(training_image_root_path, image_name),
                                                    os.path.join(training_mask_root_path, image_name.replace('.png', '_mask.png')),
                                                    '../data'
                                                    )
                        for image_name in tqdm(os.listdir("../data/train/images")))

100%|████████████████████████████████████████████████████████████████████████████████| 328/328 [00:09<00:00, 34.71it/s]


### Apply DataAugmentation to validation dataset

In [7]:
from tqdm import tqdm
from joblib import Parallel, delayed

val_image_root_path = '../data/val/images/'
val_mask_root_path = '../data/val/masks/'

_ = Parallel(n_jobs=-1)(delayed(apply_data_augmentation)(
                                                    os.path.join(val_image_root_path, image_name),
                                                    os.path.join(val_mask_root_path, image_name.replace('.png', '_mask.png')),
                                                    '../data'
                                                    )
                        for image_name in tqdm(os.listdir("../data/val/images")))

100%|█████████████████████████████████████████████████████████████████████████████████| 72/72 [00:00<00:00, 123.24it/s]
