In [2]:
import cv2
import random

from pathlib import Path
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from albumentations import (
    HorizontalFlip,
    VerticalFlip,
    CenterCrop,
    Crop,
    Transpose,
    MedianBlur,
    RandomRotate90,
    ElasticTransform,
    GridDistortion,
    OpticalDistortion,
    RandomBrightnessContrast,
    RandomGamma,
    HueSaturationValue,
    RGBShift,
    MotionBlur,
    GaussianBlur,
    GaussNoise,
    ChannelShuffle,
    CoarseDropout
)

In [4]:
def creat_dir(file_path: Path):
    if not file_path.exists():
        file_path.mkdir(parents=True, exist_ok=True)

In [16]:
def deal_img(images: list[Path], masks: list[Path], save_path, augment=True):
    """ Performing data augmentation. """

    def read_data(img, mas):
        """ Read the image and mask from the given path. """
        image_ = cv2.imread(str(img), cv2.IMREAD_COLOR)
        mask_ = cv2.imread(str(mas), cv2.IMREAD_COLOR)
        return image_, mask_

    crop_size = (192-32, 256-32)
    size = (256, 192)
    temp_images = []
    temp_masks = []
    for image, mask in tqdm(zip(images, masks), total=len(images)):
        x, y = read_data(image, mask)

        if augment:
            ## Center Crop
            aug = CenterCrop(p=1, height=crop_size[0], width=crop_size[1])
            augmented = aug(image=x, mask=y)
            x1 = augmented['image']
            y1 = augmented['mask']

            ## Crop
            x_min = 0
            y_min = 0
            x_max = x_min + size[0]
            y_max = y_min + size[1]

            aug = Crop(p=1, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max)
            augmented = aug(image=x, mask=y)
            x2 = augmented['image']
            y2 = augmented['mask']

            ## Random Rotate 90 degree
            aug = RandomRotate90(p=1)
            augmented = aug(image=x, mask=y)
            x3 = augmented['image']
            y3 = augmented['mask']

            ## Transpose
            aug = Transpose(p=1)
            augmented = aug(image=x, mask=y)
            x4 = augmented['image']
            y4 = augmented['mask']

            ## ElasticTransform
            aug = ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03)
            augmented = aug(image=x, mask=y)
            x5 = augmented['image']
            y5 = augmented['mask']

            ## Grid Distortion
            aug = GridDistortion(p=1)
            augmented = aug(image=x, mask=y)
            x6 = augmented['image']
            y6 = augmented['mask']

            ## Optical Distortion
            aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
            augmented = aug(image=x, mask=y)
            x7 = augmented['image']
            y7 = augmented['mask']

            ## Vertical Flip
            aug = VerticalFlip(p=1)
            augmented = aug(image=x, mask=y)
            x8 = augmented['image']
            y8 = augmented['mask']

            ## Horizontal Flip
            aug = HorizontalFlip(p=1)
            augmented = aug(image=x, mask=y)
            x9 = augmented['image']
            y9 = augmented['mask']

            ## Grayscale
            x10 = cv2.cvtColor(x, cv2.COLOR_RGB2GRAY)
            y10 = y

            ## Grayscale Vertical Flip
            aug = VerticalFlip(p=1)
            augmented = aug(image=x10, mask=y10)
            x11 = augmented['image']
            y11 = augmented['mask']

            ## Grayscale Horizontal Flip
            aug = HorizontalFlip(p=1)
            augmented = aug(image=x10, mask=y10)
            x12 = augmented['image']
            y12 = augmented['mask']

            ## Grayscale Center Crop
            aug = CenterCrop(p=1, height=crop_size[0], width=crop_size[1])
            augmented = aug(image=x10, mask=y10)
            x13 = augmented['image']
            y13 = augmented['mask']

            ##
            aug = RandomBrightnessContrast(p=1)
            augmented = aug(image=x, mask=y)
            x14 = augmented['image']
            y14 = augmented['mask']

            aug = RandomGamma(p=1)
            augmented = aug(image=x, mask=y)
            x15 = augmented['image']
            y15 = augmented['mask']

            aug = HueSaturationValue(p=1)
            augmented = aug(image=x, mask=y)
            x16 = augmented['image']
            y16 = augmented['mask']

            aug = RGBShift(p=1)
            augmented = aug(image=x, mask=y)
            x17 = augmented['image']
            y17 = augmented['mask']

            aug = MotionBlur(p=1, blur_limit=7)
            augmented = aug(image=x, mask=y)
            x20 = augmented['image']
            y20 = augmented['mask']

            aug = MedianBlur(p=1, blur_limit=9)
            augmented = aug(image=x, mask=y)
            x21 = augmented['image']
            y21 = augmented['mask']

            aug = GaussianBlur(p=1)
            augmented = aug(image=x, mask=y)
            x22 = augmented['image']
            y22 = augmented['mask']

            aug = GaussNoise(p=1)
            augmented = aug(image=x, mask=y)
            x23 = augmented['image']
            y23 = augmented['mask']

            aug = ChannelShuffle(p=1)
            augmented = aug(image=x, mask=y)
            x24 = augmented['image']
            y24 = augmented['mask']

            aug = CoarseDropout(p=1, max_holes=8, max_height=32, max_width=32)
            augmented = aug(image=x, mask=y)
            x25 = augmented['image']
            y25 = augmented['mask']

            temp_images.extend([
                x, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10,
                x11, x12, x13, x14, x15, x16, x17, x21, x20,
                x22, x23, x24, x25
            ]) # ,
            temp_masks.extend([
                y, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10,
                y11, y12, y13, y14, y15, y16, y17, y21, y20,
                y22, y23, y24, y25
            ])

        else:
            temp_images = [x]
            temp_masks  = [y]

    for num, (i, m) in tqdm(enumerate(zip(temp_images, temp_masks), start=1),
                            total=len(temp_images)):
        i = cv2.resize(i, size)
        m = cv2.resize(m, size)

        image_path = Path(save_path) / Path("x") / Path(f"{num}.png")
        mask_path  = Path(save_path) / Path("y") / Path(f"{num}.png")

        cv2.imwrite(str(image_path), i)
        cv2.imwrite(str(mask_path), m)


In [6]:
def file_sort(file_list: list):
        return sorted(file_list, key=lambda x: x.stem)

def get_data():
    creat_dir(Path("./data/train/x"))
    creat_dir(Path("./data/train/y"))
    creat_dir(Path("./data/valid/x"))
    creat_dir(Path("./data/valid/y"))

    x = Path("raw_data/train/Original").glob("*.png")
    y = Path("raw_data/train/Ground Truth").glob("*.png")
    x = file_sort([i for i in x])
    y = file_sort([i for i in y])
    train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.1)
    return  train_x, test_x, train_y, test_y

In [18]:
train_img, test_img, train_mask, test_mask = get_data()

In [19]:
deal_img(train_img, train_mask, save_path="data/train")
deal_img(test_img, test_mask, save_path="data/valid")

100%|██████████| 550/550 [00:55<00:00,  9.95it/s]
100%|██████████| 13200/13200 [01:53<00:00, 116.02it/s]
100%|██████████| 62/62 [00:07<00:00,  8.36it/s]
100%|██████████| 1488/1488 [00:08<00:00, 181.07it/s]


In [7]:
def deal_test_data():
    creat_dir(Path("./data/test/x"))
    creat_dir(Path("./data/test/y"))

    x_list = Path("raw_data/test/Original").glob("*.tif")
    y_list = Path("raw_data/test/Ground Truth").glob("*.tif")
    x_list = file_sort([i for i in x_list])
    y_list = file_sort([i for i in y_list])

    def convert_img(img_name: str, file_path: str):
        img_name = Path(img_name)
        img = cv2.resize(cv2.imread(str(img_name), cv2.IMREAD_COLOR), (256, 192))
        cv2.imwrite(f"{file_path}/{str(img_name.stem)}.png", img)
    temp_x = map(convert_img, x_list, ["data/test/x"] * len(x_list))
    _ = [i for i in temp_x]
    temp_y = map(convert_img, y_list, ["data/test/y"] * len(y_list))
    _ = [i for i in temp_y]

deal_test_data()