In [None]:
import numpy as np
from PIL import Image
import random
from scipy import ndarray, ndimage
import skimage.io as io
import skimage as sk
from skimage import transform
from skimage import util
import cv2
import shutil
import os
import json
from PIL import ImageDraw
from IPython.display import display
import glob
import time
import threading
from IPython.display import clear_output


def rotate(image_array: ndarray, angle):
    return ndimage.rotate(image_array, angle, reshape=True, order=0)


def brightness(img, brightness):
    return sk.exposure.adjust_gamma(img, brightness).astype(np.uint8)



def rotate_box(corners, angle,  cx, cy, h, w):
    
    """
    Rotate the bounding box.
    
    
    Parameters
    ----------
    
    corners : numpy.ndarray
        Numpy array of shape `N x 8` containing N bounding boxes each described by their 
        corner co-ordinates `x1 y1 x2 y2 x3 y3 x4 y4`
    
    angle : float
        angle by which the image is to be rotated
        
    cx : int
        x coordinate of the center of image (about which the box will be rotated)
        
    cy : int
        y coordinate of the center of image (about which the box will be rotated)
        
    h : int 
        height of the image
        
    w : int 
        width of the image
    
    Returns
    -------
    
    numpy.ndarray
        Numpy array of shape `N x 8` containing N rotated bounding boxes each described by their 
        corner co-ordinates `x1 y1 x2 y2 x3 y3 x4 y4`
    """

    corners = corners.reshape(-1,2)
    corners = np.hstack((corners, np.ones((corners.shape[0],1), dtype = type(corners[0][0]))))
    
    M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
    
    
    cos = np.abs(M[0, 0])
    sin = np.abs(M[0, 1])
    
    nW = int((h * sin) + (w * cos))
    nH = int((h * cos) + (w * sin))
    # adjust the rotation matrix to take into account translation
    M[0, 2] += (nW / 2) - cx
    M[1, 2] += (nH / 2) - cy
    # Prepare the vector to be transformed
    calculated = np.dot(M,corners.T).T
    
    calculated = calculated.reshape(-1,8)
    
    if len(np.where(calculated[:, [0,2,4,6]] > w)) == 0 or len(np.where(calculated[:, [1,3,5,7]] > h)) == 0:
        return None
    
    return calculated

            
        
def draw_labels(img_path, polygons): #For debugging
    img = Image.open(img_path).convert("RGB")
    draw = ImageDraw.Draw(img)
    for polygon in polygons:
        polygon = polygons[polygon]
        polygon = [ int(_) for _ in polygon]
        draw.polygon( polygon , outline="red")
    return img

def merge(dict1, dict2): 
    res = {**dict1, **dict2} 
    return res 



def save_mrcnn_labels(augmented_images, json_file_path, classes):
    existing_data = dict()
    
    if os.path.exists(json_file_path):
        with open(json_file_path) as json_file:
            existing_data = json.load(json_file)
    mrcnn_data = dict()
    
    a1 = augmented_images
    for augmented_image in a1:
        parts = augmented_image.split(".")
        image_name, extension = ".".join(parts[:-1]), parts[-1]
        key = image_name + ".json"
        mrcnn_data[key] = dict()
        mrcnn_data[key]['filename'] = image_name
        mrcnn_data[key]['file_attributes'] = dict()
        mrcnn_data[key]['regions'] = []
        for label in a1[augmented_image]:
            region = dict()
            region["region_attributes"] = dict({"class": classes.index(label) + 1})
            region["shape_attributes"] = dict()
            region["shape_attributes"]["name"] = "polygon"
            region["shape_attributes"]["all_points_y"] = list(np.take(a1[augmented_image][label], [1,3,5,7]))
            region["shape_attributes"]["all_points_x"] = list(np.take(a1[augmented_image][label], [0,2,4,6]))
            mrcnn_data[key]['regions'] += [ region  ]
    new_labels = merge(existing_data, mrcnn_data)
    
    with open(json_file_path, 'w') as f:
        json.dump(new_labels, f)



def parse_via_json(json_file_path, images_dir_path):
    labels = json.load(open(json_file_path))['_via_img_metadata']
    
    keys = list(labels.keys())
    
    images = dict()
    for key in keys:
        regions = labels[key]["regions"]
        image_name = labels[key]['filename']

        annotations = dict()
        for region in regions:
            try:
                x1 = region['shape_attributes']['x']
                x2 = region['shape_attributes']['x'] + region['shape_attributes']['width']
                x3 = x2
                x4 = x1
                
                y1 = region['shape_attributes']['y']
                y2 = y1
                y3 = region['shape_attributes']['y'] + region['shape_attributes']['height']
                y4 = y3
                

                region_name = list(region['region_attributes']['class'].keys())[0]
                region_name = region_name.replace(" ", "_")
                
                annotations[region_name] = [x1,y1,x2,y2,x3,y3,x4,y4]
                
            except:
                pass
            
        images[images_dir_path + image_name] = annotations
                
    return images

class Augmentation:
    

    def __init__(self, DIR_PATH):
        from multiprocessing.pool import ThreadPool
        self.pool = ThreadPool(12)
        self.augmented_images = dict()
        self.is_processing = False
        self.pending = 0
        self.DIR_PATH = DIR_PATH
        shutil.rmtree(self.DIR_PATH, True)

        if not os.path.exists(self.DIR_PATH):
            os.makedirs(self.DIR_PATH, mode=0o777)

        
        
    def __del__(self):
        self.pool.close()
        self.pool.join()

        
    def callback(self, response):
        aug_image_name, response = response[0]
        if aug_image_name:
            self.augmented_images[aug_image_name] = response
        self.pending -= 1
        
    def err(self, error):
        print("Error", error)
        
    def run(self, img_path, labels):



        img = sk.io.imread(img_path)[:,:,:3]
        w,h = img.shape[1], img.shape[0]
        
        coordinates = np.array(list(labels.values()))
        
        
        for angle in range(0, 360, 1):
            self.pending += 1
            sharpness  = random.randint(100, 200)/100.0
            aug_image_name = "{3}-{2}-{0}-{1}.jpg".format(angle, sharpness, random.randint(1111111,9999999), img_path.split("/")[-1])
            save_path = self.DIR_PATH + aug_image_name            
            self.r = self.pool.map_async(self.augment_image, [(img, coordinates, w, h, angle, sharpness, save_path, labels )], callback=self.callback, error_callback=self.err)

        
    def augment_image(self, args):
        pil_image, coordinates, width, height, angle, sharpness, save_path, labels = args
        cx, cy = width//2, height//2
        augmented = rotate(pil_image, angle)
        sharpness = random.randint(100, 200)/100.0
        augmented = brightness(augmented, sharpness)

        rotated_coordinates = list(rotate_box(coordinates, angle, cx, cy, height, width))
        response = None
        if rotated_coordinates[0] is not None:
            sk.io.imsave( save_path, augmented)
            response  = dict(zip(labels.keys(), rotated_coordinates))
            aug_image_name = save_path.split("/")[-1]            
            return (aug_image_name, response)
        else:
            return (False, False)

# Main Cell

In [None]:
DIR_PATH = "/media/sohaib/additional_/maskrcnn/images/train/"
doc_labels = "/media/sohaib/additional_/maskrcnn/images/train/via_project_17Apr2021_12h3m.json"


import multiprocessing
from time import sleep

images = parse_via_json(doc_labels, DIR_PATH)
classes = list(set([ __ for _ in images.values() for __ in list(_.keys())]))

# Augmenting Images in multi-processing
if len(images) <= 1:
    print("I need at least 2 images")
    exit()

# Splitting Training and testing data
total_training = int(len(images) * 0.9)
counter = 0
train = Augmentation(DIR_PATH + "/train/")
val = Augmentation(DIR_PATH + "/val/")

# Creating Images
print("Creating Images....")
for image_path in images:
    if  counter < total_training:
        train.run(image_path, images[image_path])
    else:
        val.run(image_path, images[image_path])
    counter += 1

# Waiting for all of the images to complete
while train.pending > 0 or val.pending > 0:
    print(f"Pending {train.pending + val.pending}")
    sleep(1)
    clear_output(wait=True)

    
# Creating Json of augmented images
print("Images created. Now saving labels....")
save_mrcnn_labels(train.augmented_images, DIR_PATH + "/train/via_region_data.json", classes)
save_mrcnn_labels(val.augmented_images, DIR_PATH + "/val/via_region_data.json", classes)
print("================================================================================================================")
print("Following are the class names used in this dataset. You need to pass these to MRCNN prediction script/notebook")
print("================================================================================================================")
print("\n")
print('"' + '","'.join(classes) + '"')
file1 = open(DIR_PATH+"labels.txt","w")
write='"' + '","'.join(classes) + '"'
file1.writelines(write)
file1.close()


# Preview annotated images

In [None]:
from IPython.display import display
images = parse_via_json(doc_labels, DIR_PATH)


# img = draw_labels('/root/data/shahid/thai2/train/download (9).png-1624966-30-1.76.jpg', train.augmented_images['download (9).png-1624966-30-1.76.jpg'])
# display(img)
for image_path in images:
    img = draw_labels(image_path, images[image_path])
    display(img)