In [1]:
import sys
import os
sys.path.append("../")

In [2]:
import os
import cv2
import numpy as np
import albumentations as alb
from scripts.preprocessing import *
from helpers.model_handler import *
import matplotlib.pyplot as plt
import random

In [3]:
IMAGE_DIR = "../dataset/training/images/"
GT_DIR = "../dataset/training/groundtruth/"
IMAGE_DIR_AUGMENTED = "../dataset/training/augmented_images/"
GT_DIR_AUGMENTED = "../dataset/training/augmented_groundtruth/"
FILES = sorted(os.listdir(IMAGE_DIR))

In [4]:
def save_augmented(images, masks):
    assert len(images) == len(masks)
    
    if not os.path.isdir(IMAGE_DIR_AUGMENTED):
        os.mkdir(IMAGE_DIR_AUGMENTED)
    if not os.path.isdir(GT_DIR_AUGMENTED):
        os.mkdir(GT_DIR_AUGMENTED)
        
    transforms = [
        # Each yields 100 new images
        rotate_shift(rotate=1, degree=90),
        rotate_shift(rotate=1, degree=-90),
        rotate_shift(rotate=1, degree=180),
        rotate_shift(h_flip=1),
        rotate_shift(v_flip=1),
        rotate_shift(rotate=1, h_flip=1, degree=90),
        rotate_shift(rotate=1, v_flip=1, degree=90),
        distortion(),
        blur(),
        alb.GaussNoise((50, 100), p=1),
    ]

    for tsfm in transforms:
        existing = len(os.listdir(IMAGE_DIR_AUGMENTED))
        
        for i in range(len(images)): 

            transform = alb.Compose([tsfm, alb.Resize(416, 416)])
    
            image = img_float_to_uint8(images[i])
            mask = img_float_to_uint8(masks[i])
            aug = transform(image=image, mask=mask)
            img_aug = aug['image']
            msk_aug = aug['mask']

            if random.random() < 0.1:
                img_aug, msk_aug = rotate_inner(img_aug, msk_aug)

            img_name = f"satImage_{str(existing+i+1).zfill(3)}.png"
            msk_name = f"satImage_{str(existing+i+1).zfill(3)}.png"

            r1 = cv2.imwrite(os.path.join(IMAGE_DIR_AUGMENTED, img_name), img_aug)
            r2 = cv2.imwrite(os.path.join(GT_DIR_AUGMENTED, msk_name), msk_aug)
            if not r1:
                print("Error saving image: ", img_name)
            if not r2:
                print("Error saving mask: ", msk_name)



def rotate_shift(rotate=0, h_flip=0, v_flip=0, degree=0):
    return alb.Compose([
        alb.Rotate(limit=(degree, degree), p=rotate),
        alb.HorizontalFlip(p=h_flip),
        alb.VerticalFlip(p=v_flip)
    ])

def distortion():
    return alb.OneOf([alb.OpticalDistortion(0.3, 0.3), alb.GridDistortion(5, 0.3)], p=1)

def blur():
    return alb.OneOf([alb.MedianBlur(), alb.GaussianBlur()], p=1)

def rotate_inner(image, groundtruth):

    rgood = False
    while not rgood:
        x = random.randint(0, image.shape[1] - 1)
        y = random.randint(0, image.shape[0] - 1)
        try:
            r = random.randint(40, min(x, y, image.shape[1] - x, image.shape[0] - y))
            rgood = True
        except:
            pass
            
    # Create a circular mask
    mask = np.zeros_like(image)
    cv2.circle(mask, (x, y), r, (255, 255, 255), thickness=-1)

    # Randomly generate angle theta in [0, 2*pi]
    theta = random.uniform(0, 2 * np.pi)

    # Rotate the circular area within Cx, Cy, r in image I with angle theta
    M = cv2.getRotationMatrix2D((x, y), np.degrees(theta), 1)
    rotated_area = cv2.warpAffine(image, M, (image.shape[1], image.shape[0]))

    rotated_area_gt = cv2.warpAffine(groundtruth, M, (groundtruth.shape[1], groundtruth.shape[0]))

    # Combine the rotated area with the original image
    I_augmented = cv2.bitwise_and(rotated_area, mask) + cv2.bitwise_and(image, cv2.bitwise_not(mask))
    G_augmented = cv2.bitwise_and(rotated_area_gt, mask[:, :, 0]) + cv2.bitwise_and(groundtruth, cv2.bitwise_not(mask[:, :, 0]))

    return I_augmented, G_augmented

In [5]:
images = [load_image(IMAGE_DIR + FILES[i]) for i in range(len(FILES))]
masks = [load_image(GT_DIR + FILES[i]) for i in range(len(FILES))]
print(len(images), len(masks))

100 100


In [6]:
save_augmented(images, masks)