In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import copy
import os
import random

In [2]:
class B_it_bots_Segmentation_Dataset():
    
    def __init__(self, label_def, IMAGE_DIMENSION = [128, 128],
                     IMAGE_PATH='./objects/image', 
                     LABEL_PATH='./objects/label', IMG_TYPE='.jpg',
                    BACKGROUNDS_PATH ='./backgrounds/training/',
                    NUM_OF_BACKGROUNDS=20, NUM_OF_SCALES=20,
                    EXCLUDE_CLASSES=['motor', 'bearing_box_ax01', 
                                     'em_01', 'em_02', 'r20'],
                    SET_ZOOM_RANGE=[1.2,1.5]):
        
        self.label_def = label_def
        self.image_dimension = IMAGE_DIMENSION
        self.image_path = IMAGE_PATH
        self.label_path = LABEL_PATH
        self.img_type = IMG_TYPE
        self.num_of_backgrounds = NUM_OF_BACKGROUNDS
        self.num_of_scales = NUM_OF_SCALES
        self.exclude_classes = EXCLUDE_CLASSES
        self.zoom_range = SET_ZOOM_RANGE
        
        self.resize_images = lambda img_list: [cv2.resize(img, tuple(self.image_dimension))
                                    for img in img_list]
        
        self.object_paths = self.fetch_image_gt_paths()
        self.background_images = self.get_background_images(
                                        BACKGROUNDS_PATH, self.num_of_backgrounds)
        self.objects_to_details = self.get_synthetic_objects_list()
    
    def fetch_image_gt_paths(self):
        
        object_paths = dict()
        for clsses in os.listdir(self.label_path):
            if not any(clsses == exclude
                   for exclude in self.exclude_classes):
                cls_path = os.path.join(self.label_path, clsses)
                obj_files = list()
                for files in sorted(os.listdir(cls_path)):
                    obj_files.append([os.path.join(self.image_path, clsses, 
                            files.split('.')[0]+self.img_type), 
                             os.path.join(self.label_path, clsses, files)])
                object_paths[clsses] = obj_files.copy()
            
        return object_paths
    
    
    def get_background_images(self, backgrounds_path,
                             number_of_backgrounds):
        
        np.random.seed(1)
        background_files = os.listdir(backgrounds_path)
        background_files = [os.path.join(backgrounds_path, file) 
                            for file in background_files]
        np.random.shuffle(background_files)
        background_files = background_files[0:number_of_backgrounds]
        background_images = list()
        for file in background_files:
            background_images.append(cv2.imread(file))
        
        background_images = self.resize_images(background_images)
        return background_images
    
    
    def find_obj_loc_and_vals(self, image, label, label_value, obj_name):
        
        obj_loc = np.argwhere(label==label_value)
        obj_vals = [image[tuple(loc)] for loc in obj_loc]
        obj_vals = np.array(obj_vals)
        label_vals = np.ones(len(obj_loc)) * label_value
        rect_points = [min(obj_loc[:,0]), min(obj_loc[:,1]), max(obj_loc[:,0]), max(obj_loc[:,1])]
        obj_area = (rect_points[2] - rect_points[0]) * (rect_points[3] - rect_points[1])
        
        return {'obj_loc': obj_loc, 'obj_vals': obj_vals, 'label_vals': label_vals, 
                'obj_name': obj_name, 'rect_points': rect_points, 'obj_area': obj_area}
    
    def get_different_scales(self, image, image_label, label_value, obj_name):
        
        scales = np.linspace(self.zoom_range[0], self.zoom_range[1], 
                             num= self.num_of_scales*2)
        np.random.shuffle(scales)
        scales = scales[0: self.num_of_scales]
            
        scaled_objects = list()

        for i in range(0, self.num_of_scales):
            resized_img = cv2.resize(image, (0,0), fx=scales[i], fy=scales[i])
            resized_label = cv2.resize(image_label, (0,0), fx=scales[i], fy=scales[i])
            
            if not np.any(image_label==label_value):
                raise ValueError('Object {} with label {} found in image'.format(
                                obj_name, label_value))
            
            if not np.any(resized_label==label_value):
                print ('Scaled object {} lost...Taking original scale...'.format(
                        obj_name))
                resized_img, resized_label = image, image_label
                
            scaled_objects.append(self.find_obj_loc_and_vals(
                    resized_img, resized_label, 
                    label_value, obj_name))

        return scaled_objects
    
    def get_synthetic_objects_list(self):
        
        objects = list()
        objects_to_details = dict()
    
        for key in self.object_paths.keys():
            path_list = np.array(self.object_paths[key]).T
            images_in_cls = [cv2.imread(path, 1) for path in path_list[0]]
            labels_in_cls = [cv2.imread(path, 0) for path in path_list[1]]
            images_in_cls = self.resize_images(images_in_cls)
            labels_in_cls = self.resize_images(labels_in_cls)
            for img, label in zip(images_in_cls, labels_in_cls):
                objects += self.get_different_scales(
                        img, label, label_def[key], key)
            objects_to_details[key] = objects.copy()
            objects.clear()

        return objects_to_details
    
    def get_visual_image(self, label):
        colormap = np.asarray([[128, 64, 128], [244, 35, 232], [70, 70, 70], 
                               [102, 102, 156], [190, 153, 153], [153, 153, 153], 
                               [250, 170, 30], [220, 220, 0], [107, 142, 35], 
                               [152, 251, 152], [70, 130, 180], [220, 20, 60], 
                               [255, 0, 0], [0, 0, 142], [0, 0, 70], 
                                [0, 60, 100], [0, 80, 100],[0, 0, 230],
                                  [119, 11, 32], [0, 0, 0]])
        
        return colormap[np.array(label, dtype=np.uint8)]
    
    def do_random_blur(self, image, blur_prob=0.1, kernel_size=5):
        
        do_blur = np.random.rand(1) < blur_prob
        if do_blur:
            image = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
        
        return image
            
    def do_random_flip(self, image, label, flip_prob=0.1, axis=0):
        
        do_flip = np.random.rand(1) < flip_prob
        if do_flip:
            image = cv2.flip(image, axis)
            label = cv2.flip(label, axis)
            
        return image, label
    
    def do_random_rotation(self, image, label, key, rotation_prob=0.1, angle='RANDOM'):
        
        do_rotate = np.random.rand(1) < rotation_prob
        if do_rotate:
            if angle == 'RANDOM':
                angle = np.random.randint(10, 90)
            cols, rows = tuple(self.image_dimension)
            M = cv2.getRotationMatrix2D((cols/2,rows/2), angle, 1)
            image = cv2.warpAffine(image, M, (cols, rows))
            label = cv2.warpAffine(label, M, (cols, rows),  
                                   borderMode=cv2.BORDER_TRANSPARENT)
            label[label==0] = self.label_def['background']
            label[label!=19] = self.label_def[key]
            
        return image, label
        
            
    def perform_augmentation(self, SAVE_DIR='./b_it/',_IMAGE_FORMAT='%04d', 
                            _MASK_FORMAT='%04d', _RAW_FORMAT='%04d',
                            DO_FLIP=True, FLIP_PROB=0.1, FLIP_AXIS=0,
                            DO_BLUR=True, BLUR_PROB=0.1, KERNEL_SIZE=5,
                            DO_ROTATE=True, ROTATE_PROB=0.1, ANGLE='RANDOM'):
        
        background_label = np.ones(tuple(self.image_dimension)) * self.label_def['background']
        
        for key in self.object_paths.keys():
            clss_objects = self.objects_to_details[key]
            for img_number in range(len(clss_objects)):
                augmented_image = self.background_images[img_number%self.num_of_backgrounds].copy()
                augmented_label = background_label.copy()
                obj_details = copy.deepcopy(clss_objects[img_number])
                location_offsets = [0.25* dim for dim in self.image_dimension]
                location = [random.randrange(0, self.image_dimension[0] - location_offsets[0], 
                                             location_offsets[0]), 
                            random.randrange(0, self.image_dimension[1] - location_offsets[1],
                                             location_offsets[1])]

                row_shift = min(obj_details['obj_loc'][:,0]) - location[0]
                col_shift = min(obj_details['obj_loc'][:,1]) - location[1]
                obj_details['obj_loc'][:,0] -= row_shift
                obj_details['obj_loc'][:,1] -= col_shift

                for index,loc in enumerate(obj_details['obj_loc']):
                    if 0 < loc[0] < self.image_dimension[0] and 0 < loc[1] < self.image_dimension[1]:
                        augmented_image[tuple(loc)] = obj_details['obj_vals'][index]
                        augmented_label[tuple(loc)] = obj_details['label_vals'][index]
                        
                if DO_BLUR:
                    augmented_image = self.do_random_blur(augmented_image,
                                                         blur_prob=BLUR_PROB, 
                                                         kernel_size=KERNEL_SIZE)
                if DO_FLIP:
                    augmented_image, augmented_label = self.do_random_flip(augmented_image,
                                                                          augmented_label,
                                                                          flip_prob=FLIP_PROB,
                                                                          axis=FLIP_AXIS)
                if DO_ROTATE:
                    augmented_image, augmented_label = self.do_random_rotation(augmented_image,
                                                                              augmented_label,
                                                                               key,
                                                                              rotation_prob=ROTATE_PROB,
                                                                              angle=ANGLE)
                
                img_directory = os.path.join(SAVE_DIR, 'image', 
                                             obj_details['obj_name'])
                if not os.path.isdir(img_directory): 
                    os.makedirs(img_directory)
                raw_directory = os.path.join(SAVE_DIR, 'raw',
                                            obj_details['obj_name'])
                if not os.path.isdir(raw_directory): 
                    os.makedirs(raw_directory)
                visual_directory = os.path.join(SAVE_DIR, 'mask',
                                               obj_details['obj_name'])
                if not os.path.isdir(visual_directory): 
                    os.makedirs(visual_directory)

                cv2.imwrite(os.path.join(img_directory,
                                        _IMAGE_FORMAT % (img_number+1) + '.jpg'), 
                            augmented_image)
                cv2.imwrite(os.path.join(raw_directory,
                                        _RAW_FORMAT % (img_number+1) + '.png'), 
                            augmented_label)
                cv2.imwrite(os.path.join(visual_directory, 
                                        _MASK_FORMAT % (img_number+1) + '.jpg'), 
                            self.get_visual_image(augmented_label))

In [3]:
label_def = {'f20_20_B': 1, 's40_40_B': 2, 'f20_20_G': 3, 's40_40_G': 4,  'm20_100': 5, 
             'm20': 6, 'm30': 7, 'r20': 8, 'bearing_box_ax01': 9, 'bearing': 10, 'axis': 11, 
             'distance_tube': 12, 'motor': 13, 'container_box_blue': 14, 'container_box_red': 15, 
             'bearing_box_ax16': 16, 'em_01': 17, 'em_02': 18, 'background': 19}

b_it_bots_dataset = B_it_bots_Segmentation_Dataset(label_def)

Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Taking original scale...
Scaled object distance_tube lost...Takin

In [4]:
b_it_bots_dataset.perform_augmentation(DO_ROTATE=False)