# Augmentation des données

In [None]:
import warnings
warnings.simplefilter(action='ignore')
import numpy as np
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import imageio
import os

In [None]:
ia.seed(1)

def augment(image, image2, n_augment=4):

    sometimes = lambda aug: iaa.Sometimes(0.2, aug)

    seq = iaa.Sequential(
        [
            iaa.Fliplr(0.5), # horizontal flip

            iaa.SomeOf((1, 3),
                [
                    iaa.PerspectiveTransform(scale=(0.003, 0.008)), # add perspective
                    
                    iaa.Affine( # zoom
                        scale=(1.0, 1.4),
                    ),
                    
                    iaa.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5)), # change alpha / lightness
                    
                    iaa.OneOf([ # add blur / noise
                        iaa.GaussianBlur((0.0, 1.8)),
                        iaa.AdditiveGaussianNoise(
                            loc=0, scale=(0.0, 0.02*255)),
                    ]),
                    
                    iaa.OneOf([ # change brightess / contrast / saturation
                        iaa.AddToBrightness((-25, 25)),
                        iaa.AllChannelsCLAHE(clip_limit=(1)),
                        iaa.GammaContrast((0.75, 1.2)),
                        iaa.MultiplySaturation((0.65, 1.3)),
                    ]),
                    
                ],
                random_order=True
            ),
            
            sometimes(iaa.CoarseDropout((0.03, 0.06), size_percent=(0.07, 0.1))),
        ],
        random_order=False
    )

    segmap = SegmentationMapsOnImage(image2, shape=image.shape)
    images_aug = []
    segmaps_aug = []
    for _ in range(n_augment):
        images_aug_i, segmaps_aug_i = seq(image=image, segmentation_maps=segmap)
        images_aug.append(images_aug_i)
        segmaps_aug.append(segmaps_aug_i.get_arr())
    
    return images_aug, segmaps_aug

In [None]:
original_imgs_path = "./data/leftImg8bit"

cities = {"train":[dirs for root, dirs, files in os.walk("/".join([original_imgs_path, "train"]))][0]}

paths = []
tgt_paths = []
out_paths = []

for split in ["train"]:
    for city in cities[split]:
        tmp_path = "/".join([original_imgs_path, split, city])
        lbl_tmp_path = tmp_path.replace("leftImg8bit", "gtFine")
        for _, _, files in os.walk(tmp_path):
            for name in files:
                if name.endswith(("_leftImg8bit.png")):
                    paths.append("/".join([tmp_path, name]))
                    tgt_paths.append(f"./data/labels/{name}".replace("leftImg8bit", "gtFine_labelIds"))
                    out_paths.append(f"./data/augments/{name}")

In [None]:
for pth in range(len(paths)):
    img = imageio.imread(paths[pth])
    tgt = imageio.imread(tgt_paths[pth])
    img_t, tgt_t = augment(img, tgt)
    for n in range(len(img_t)):
        aug_img = img_t[n]
        aug_tgt = tgt_t[n]
        new_path = f"{out_paths[pth][:-4]}_augment_{n}.png"
        new_tgt_path = f"{tgt_paths[pth][:-4]}_augment_{n}.png"
        imageio.imwrite(new_path, aug_img)
        imageio.imwrite(new_tgt_path, aug_tgt)