In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import copy
import os
import random
import csv

In [2]:
class DataAugmentation():
    
    def __init__(self, label_def, IMAGE_DIMENSION= [480, 640], NUM_OF_SCALES= 2, 
                 GET_OBJ_DET_LABEL= True, NUM_OF_IMAGES= 50, IMAGE_PATH= './objects/image', 
                 LABEL_PATH= './objects/label', LABEL_IMG_TYPE= '.png', MAX_OBJECTS_PER_IMAGE= 6):
        
        self.label_def = label_def
        self.image_dimension = IMAGE_DIMENSION
        self.num_of_scales = NUM_OF_SCALES
        self.get_obj_det_label = GET_OBJ_DET_LABEL
        self.num_of_images = NUM_OF_IMAGES
        self.image_path = IMAGE_PATH
        self.label_path = LABEL_PATH
        self.label_img_type = LABEL_IMG_TYPE
        self.max_objects_per_image = MAX_OBJECTS_PER_IMAGE
        
        self.files_count, object_paths = self.fetch_image_gt_paths()
        self.background_images = self.get_background_images()
        self.objects = self.get_synthetic_objects_list(object_paths)
    
    def fetch_image_gt_paths(self):
        object_paths = dict()
        files_count = 0
        for obj_dir in os.listdir(self.image_path):
            obj_files = list()
            for files in sorted(os.listdir(os.path.join(self.image_path, obj_dir))):
                files_count += 1
                obj_files.append([os.path.join(self.image_path, obj_dir, files), 
                         os.path.join(self.label_path, obj_dir, files.split('.')[0]+self.label_img_type)])
            object_paths[obj_dir] = obj_files.copy()
            
        return files_count, object_paths
    
    def get_background_images(self):

        background_files = os.listdir('./backgrounds')
        background_files = [os.path.join('./backgrounds', file) for file in background_files]
        background_images = list()
        for file in background_files:
            background_images.append(cv2.imread(file))
            
        return background_images
    
    def create_augment_vector(self, augment_vector_length= None):
        
        num_of_images = self.num_of_images if augment_vector_length is None else augment_vector_length
        num_objects = self.num_of_scales * self.files_count
        objects_index = np.arange(0, num_objects)
        self.augment_vector = list()

        for i in range(num_of_images):
            num_objects_to_place = np.random.randint(1, high= self.max_objects_per_image)
            what_objects = [objects_index[i] for i in range(num_objects_to_place)]
    
            if i % len(self.background_images) == 0:
                np.random.shuffle(self.background_images)
        
            #if i % num_objects == 0:
            np.random.shuffle(objects_index)
        
            self.augment_vector.append({'background_image': self.background_images[
                                                            i % len(self.background_images)], 
                       'num_objects_to_place': num_objects_to_place, 
                       'what_objects': what_objects,
                      'locations': self.get_random_locations(num_objects_to_place)})
    
    def get_random_locations(self, num_objects_to_place):
        location = [[random.randrange(0, 440, 30), random.randrange(0, 600, 30)]
               for _ in range(num_objects_to_place)]
        
        return np.array(location)
    
    def plot_img_and_label(self, image, label, img_num, obj_det_label= None):
        
        if self.preview_data or self.save_data_preview:
            label = label.copy()
            if obj_det_label is not None:
                for l in obj_det_label:
                    for i in range(l[1], l[3]+1):
                        if i < self.image_dimension[0]:
                            label[i, l[2]:l[2]+3] = len(label_def) + 1
                            label[i, l[4]-3:l[4]] = len(label_def) + 1

                    for i in range(l[2], l[4]+1):
                        if i < self.image_dimension[1]:
                            label[l[1]:l[1]+3, i] = len(label_def) + 1
                            label[l[3]-3:l[3], i] = len(label_def) + 1

            figure = plt.figure()
            figure.set_figheight(15)
            figure.set_figwidth(15)
            figure.add_subplot(1, 2, 1)
            plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            figure.add_subplot(1, 2, 2)
            plt.imshow(label)
            if self.save_data_preview:
                plt.savefig('./data_augmentation_results/image_and_gt/'+str(img_num)+'.png')
                result = cv2.imread('./data_augmentation_results/image_and_gt/'+str(img_num)+'.png', 1)
                result = result[320:750,50:1030,:]
                cv2.imwrite('./data_augmentation_results/image_and_gt/'+str(img_num)+'.png', result)
            plt.show() if self.preview_data else plt.close(figure)
        
        
    def find_obj_loc_and_vals(self, image, label, label_value, obj_name):
        obj_loc = np.argwhere(label==label_value)
        obj_vals = [image[tuple(loc)] for loc in obj_loc]
        obj_vals = np.array(obj_vals)
        label_vals = np.ones(len(obj_loc)) * label_value

        return {'obj_loc': obj_loc, 'obj_vals': obj_vals, 'label_vals': label_vals, 
                'obj_name': obj_name}
    
    def get_different_scales(self, image, image_label, label_value, obj_name):
        scale_difference = 1/self.num_of_scales
        scales = [i * scale_difference for i in range(1, self.num_of_scales+1)]
        scaled_objects = list()

        for i in range(0, self.num_of_scales):
            scaled_objects.append(self.find_obj_loc_and_vals(
                    cv2.resize(image, (0,0), fx=scales[i], fy=scales[i]), 
                    cv2.resize(image_label, (0,0), fx=scales[i], fy=scales[i]), 
                    label_value, obj_name))

        return scaled_objects
    
    def get_synthetic_objects_list(self, object_paths):
        objects = list()
        for key in self.label_def:
            if key is not 'background':
                data_list = object_paths[key]
                for data in data_list:
                    img = cv2.imread(data[0])
                    label = cv2.imread(data[1], 0)
                    objects += self.get_different_scales(img, label, label_def[key], key)
                    
        return objects
    
    def get_augmented_image(self, original_image, original_label, obj_vals, location):
        
        augmented_image = original_image.copy()
        augmented_label = original_label.copy()
        obj_vals_to_augment = copy.deepcopy(obj_vals)

        obj_vals_to_augment['obj_loc'][:,0] -= min(obj_vals_to_augment['obj_loc'][:,0]) - location[0]
        obj_vals_to_augment['obj_loc'][:,1] -= min(obj_vals_to_augment['obj_loc'][:,1]) - location[1]

        for index,loc in enumerate(obj_vals_to_augment['obj_loc']):
            if 0 < loc[0] < self.image_dimension[0] and 0 < loc[1] < self.image_dimension[1]:
                augmented_image[tuple(loc)] = obj_vals_to_augment['obj_vals'][index]
                augmented_label[tuple(loc)] = obj_vals_to_augment['label_vals'][index]

        if self.get_obj_det_label:
            rect_points = [min(obj_vals_to_augment['obj_loc'][:,0]), min(obj_vals_to_augment['obj_loc'][:,1]), 
               max(obj_vals_to_augment['obj_loc'][:,0]), max(obj_vals_to_augment['obj_loc'][:,1])]
            obj_det_label = [obj_vals_to_augment['obj_name']] + rect_points
            return augmented_image, augmented_label, obj_det_label

        return augmented_image, augmented_label
    
    def perform_augmentation(self, preview_data= False, save_data_preview= False):
        
        self.preview_data = preview_data
        self.save_data_preview = save_data_preview
        
        obj_det_label = list()
        background_label = np.ones(tuple(self.image_dimension)) * self.label_def['background']
        for index, vector in enumerate(self.augment_vector):
            augmented_image, augmented_label = vector['background_image'], background_label.copy()
            obj_det_label.clear()
            for i in range(vector['num_objects_to_place']):

                if self.get_obj_det_label:
                    augmented_image, augmented_label, rect_label = self.get_augmented_image(augmented_image, 
                                                        augmented_label, 
                                                           self.objects[vector['what_objects'][i]], 
                                                               vector['locations'][i])
                    obj_det_label.append(rect_label)
                else:
                    augmented_image, augmented_label = self.get_augmented_image(augmented_image, 
                                                        augmented_label, 
                                                           self.objects[vector['what_objects'][i]], 
                                                               vector['locations'][i])

            if self.get_obj_det_label:
                self.plot_img_and_label(augmented_image, augmented_label, index, obj_det_label)
                cv2.imwrite('./data_augmentation_results/image/'+str(index)+'.png', augmented_image)
                cv2.imwrite('./data_augmentation_results/ground_truth/'+str(index)+'.png', augmented_label)
                with open('./data_augmentation_results/obj_det/'+str(index)+'.csv','w') as f:
                    wr = csv.writer(f,delimiter=',')
                    [wr.writerow(l) for l in obj_det_label]

            else:
                plot_img_and_label(augmented_image, augmented_label, index)
                cv2.imwrite('./data_augmentation_results/image/'+str(index)+'.png', augmented_image)
                cv2.imwrite('./data_augmentation_results/ground_truth/'+str(index)+'.png', augmented_label)

In [3]:
label_def = {'axis': 11, 'bearing': 10, 'container_box_blue': 14, 
             'container_box_red': 15, 'em_01': 17, 'em_02': 18,
             'f20_20_B': 1, 'f20_20_G': 3, 'm20': 6, 'm20_100': 5, 'm30': 7,
             'motor': 13, 'r20': 8, 's40_40_B': 2, 'background': 19}
augmenter = DataAugmentation(label_def)
augmenter.create_augment_vector()
augmenter.perform_augmentation(save_data_preview= True)

In [4]:
augmenter.create_augment_vector()
augmenter.perform_augmentation()