<a href="https://colab.research.google.com/github/Lusiji254/Pneumothorax-Segmentation-using-DeepLab/blob/main/Augmentor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

# this creates a symbolic link so that now the path /content/gdrive/My\ Drive/ is equal to /mydrive
!ln -s /content/gdrive/My\ Drive/ /mydrive
!ls /mydrive

In [None]:
import os
import cv2
from tqdm import tqdm
from glob import glob
from albumentations import RandomBrightnessContrast, OpticalDistortion, HorizontalFlip


def create_dir(path):
    """ Create a directory. """
    if not os.path.exists(path):
        os.makedirs(path)

def load_data(path):
    images = sorted(glob(os.path.join(path, "downsampled_train_images", "*.png")))
    masks= sorted(glob(os.path.join(path, "downsampled_train_masks", "*.png")))
    return images, masks

def augment_data(images, masks, save_path, augment=True):
    H = 256
    W = 256

    for x, y in tqdm(zip(images, masks), total=len(images)):
        image_name = x.split("/")[-1]
        """ Extracting the name and extension of the image and the mask. """
        

        mask_name = y.split("/")[-1]
        

        """ Reading image and mask. """
        x = cv2.imread(x, cv2.IMREAD_COLOR)
        y = cv2.imread(y, cv2.IMREAD_COLOR)

        """ Augmentation """
        if augment == True:
            aug = RandomBrightnessContrast(p=0.3)
            augmented = aug(image=x, mask=y)
            x1 = augmented["image"]
            y1 = augmented["mask"]

            aug = HorizontalFlip()
            augmented = aug(image=x, mask=y)
            x2 = augmented['image']
            y2 = augmented['mask']

            aug = OpticalDistortion(distort_limit=2, shift_limit=0.5, p=0.3)
            augmented = aug(image=x, mask=y)
            x3 = augmented['image']
            y3 = augmented['mask']

            save_images = [x, x1, x2, x3]
            save_masks =  [y, y1, y2, y3]

        else:
            save_images = [x]
            save_masks = [y]

        """ Saving the image and mask. """
        idx = 0
        for i, m in zip(save_images, save_masks):
            i = cv2.resize(i, (W, H))
            m = cv2.resize(m, (W, H))

            if len(images) == 1:
                tmp_img_name = f"{image_name}"
                tmp_mask_name = f"{mask_name}"
            else:
                tmp_img_name = f"{idx}_{image_name}"
                tmp_mask_name = f"{idx}_{mask_name}"

            image_path = os.path.join(save_path, "images", tmp_img_name)
            mask_path = os.path.join(save_path, "masks", tmp_mask_name)

            cv2.imwrite(image_path, i)
            cv2.imwrite(mask_path, m)

            idx += 1

if __name__ == "__main__":
    """ Loading original images and masks. """
    path = "/content/gdrive/MyDrive/Pneumothorax/"
    images, masks = load_data(path)
    print(f"Original Images: {len(images)} - Original Masks: {len(masks)}")

    """ Creating folders. """
    create_dir("/content/gdrive/MyDrive/Pneumothorax/classify_aug_data/images")
    create_dir("/content/gdrive/MyDrive/Pneumothorax/classify_aug_data/masks")

    """ Applying data augmentation. """
    augment_data(images, masks, "/content/gdrive/MyDrive/Pneumothorax/classify_aug_data/", augment=True)

    """ Loading augmented images and masks. """
    images, masks = load_data("/content/gdrive/MyDrive/Pneumothorax/classify_aug_data/")
    print(f"Augmented Images: {len(images)} - Augmented Masks: {len(masks)}")

 

Original Images: 4758 - Original Masks: 4758


100%|██████████| 4758/4758 [34:06<00:00,  2.32it/s]

Augmented Images: 0 - Augmented Masks: 0





In [None]:
def load_data(path):
    images = sorted(glob(os.path.join(path, "images", "*.png")))
    masks= sorted(glob(os.path.join(path, "masks", "*.png")))
    return images, masks

if __name__ == "__main__":   
    """ Loading augmented images and masks. """
    images, masks = load_data("/content/gdrive/MyDrive/Pneumothorax/classify_aug_data/")
    print(f"Augmented Images: {len(images)} - Augmented Masks: {len(masks)}")

Augmented Images: 19032 - Augmented Masks: 19032
