In [10]:
import numpy as np
import cv2
from copy import deepcopy
import os
from PIL import Image

In [66]:
class ObjVar:
    def __init__(self, position, bbox, grids, class_name, area):
        self.position = position
        self.bbox = bbox
        self.grids = grids
        self.class_name = class_name
        self.area = area


class ObjDetector(object):
    def __init__(self):
        self.labels_dict = {
        0:  {'name': 'unlabeled',      'color': (0,   0,   0),   'ct_scape_id': 0,  'train_id': 255},
        1:  {'name': 'road',           'color': (128, 64,  128), 'ct_scape_id': 7,  'train_id': 1},
        2:  {'name': 'sidewalk',       'color': (244, 35,  232), 'ct_scape_id': 8,  'train_id': 2},
        3:  {'name': 'building',       'color': (70,  70,  70),  'ct_scape_id': 11, 'train_id': 3},
        4:  {'name': 'wall',           'color': (102, 102, 156), 'ct_scape_id': 12, 'train_id': 4},
        5:  {'name': 'fence',          'color': (190, 153, 153), 'ct_scape_id': 13, 'train_id': 5},
        6:  {'name': 'pole',           'color': (153, 153, 153), 'ct_scape_id': 17, 'train_id': 6},
        7:  {'name': 'traffic light',   'color': (250, 170, 30),  'ct_scape_id': 19, 'train_id': 7},
        8:  {'name': 'traffic sign',    'color': (220, 220, 0),   'ct_scape_id': 20, 'train_id': 8},
        9:  {'name': 'vegetation',     'color': (107, 142, 35),  'ct_scape_id': 21, 'train_id': 9},
        10: {'name': 'terrain',        'color': (152, 251, 152), 'ct_scape_id': 22, 'train_id': 10},
        11: {'name': 'sky',            'color': (70,  130, 180), 'ct_scape_id': 23, 'train_id': 11},
        12: {'name': 'person',         'color': (220, 20,  60),  'ct_scape_id': 24, 'train_id': 12},
        13: {'name': 'rider',          'color': (255, 0,   0),   'ct_scape_id': 25, 'train_id': 13},
        14: {'name': 'car',            'color': (0,   0,   142), 'ct_scape_id': 26, 'train_id': 14},
        15: {'name': 'truck',          'color': (0,   0,   70),  'ct_scape_id': 27, 'train_id': 15},
        16: {'name': 'bus',            'color': (0,   60,  100), 'ct_scape_id': 28, 'train_id': 16},
        17: {'name': 'train',          'color': (0,   80,  100), 'ct_scape_id': 31, 'train_id': 17},
        18: {'name': 'motorcycle',     'color': (0,   0,   230), 'ct_scape_id': 32, 'train_id': 18},
        19: {'name': 'bicycle',        'color': (119, 11,  32),  'ct_scape_id': 33, 'train_id': 19},
        20: {'name': 'dynamic',        'color': (111, 74,  0),   'ct_scape_id': 5,  'train_id': 20},
        21: {'name': 'ground',         'color': (81,  0,   81),  'ct_scape_id': 6,  'train_id': 21},
        22: {'name': 'parking',        'color': (250, 170, 160), 'ct_scape_id': 9,  'train_id': 22},
        23: {'name': 'rail track',     'color': (230, 150, 140), 'ct_scape_id': 10, 'train_id': 23},
        24: {'name': 'guard rail',     'color': (180, 165, 180), 'ct_scape_id': 14, 'train_id': 24},
        25: {'name': 'bridge',         'color': (150, 100, 100), 'ct_scape_id': 15, 'train_id': 25},
        26: {'name': 'tunnel',         'color': (150, 120, 90),  'ct_scape_id': 16, 'train_id': 26},
        27: {'name': 'polegroup',      'color': (153, 153, 153), 'ct_scape_id': 18, 'train_id': 27},
        28: {'name': 'caravan',        'color': (0,   0,   90),  'ct_scape_id': 29, 'train_id': 28},
        29: {'name': 'trailer',        'color': (0,   0,   110), 'ct_scape_id': 30, 'train_id': 29}
    }
        
    def get_grid_span(self, x, y, w, h, h_g_l, v_g_l):
        g_x = []
        g_y = []
        
        i = 1
        while not (x >= h_g_l[i - 1] and x < h_g_l[i]):
            i += 1
        g_x.append(i)

        while not (x+w >= h_g_l[i - 1] and x+w < h_g_l[i]):
            i += 1  
        if not i in g_x:
            g_x.append(i)
        if len(g_x) > 1:
            for ii in range(g_x[0]+1, g_x[1]):
                g_x.append(ii)
                
        g_x.sort()
        
        j = 1
        while not (y >= v_g_l[j - 1] and y < v_g_l[j]):
            j += 1 
        g_y.append(j)

        while not (y+h >= v_g_l[j - 1] and y+h < v_g_l[j]):
            j += 1  
        if not j in g_x:
            g_y.append(j)
        if len(g_y) > 1:
            for jj in range(g_y[0]+1, g_y[1]):
                g_y.append(jj)      
        g_y.sort()
            
        final_g = []
        for x in g_x:
            for y in g_y:
                final_g.append((y-1)*3 + x)
                
        return final_g

    
    def get_localization(self, img_lbl, detect_obj=[], grid_count=3):
        obj_dict = {}
        for key in self.labels_dict.keys():
            obj_id = self.labels_dict[key]['train_id']
            temp_img = deepcopy(img_lbl)
            temp_img[temp_img != obj_id] = 0
            temp_img[temp_img == obj_id] = 255
            
            temp_img = cv2.GaussianBlur(temp_img, (45, 45), 0)
            temp_img = cv2.dilate(temp_img, np.ones((35, 35), np.uint8), iterations=1)
            temp_img = cv2.erode(temp_img, np.ones((25, 25), np.uint8), iterations=1)

            th = int(np.max(temp_img) * 0.6)
            temp_img[temp_img > th] = 255
            temp_img[temp_img <= th] = 0
            
            contours, hierarchy = cv2.findContours(
                temp_img.astype(np.uint8), 
                cv2.RETR_EXTERNAL, 
                cv2.CHAIN_APPROX_SIMPLE
            )
            
            horizontal_grid_lines = [(i) * (img_lbl.shape[1] // 3) for i in range(grid_count + 1)]
            vertical_grid_lines =  [(i) * (img_lbl.shape[0] // 3) for i in range(grid_count + 1)]
            
            horizontal_grid_lines[-1] = horizontal_grid_lines[-1] + 1
            vertical_grid_lines[-1] = vertical_grid_lines[-1] + 1
       
            for cnt in contours:
                area = cv2.contourArea(cnt)
                if area < 500:
                    continue

                x,y,w,h = cv2.boundingRect(cnt)
                
                if x >= img_lbl.shape[1]:
                    x = x - 1
                
                if x+w >= img_lbl.shape[1]:
                    w = w - 1
                
                if y >= img_lbl.shape[0]:
                    y = y - 1
                
                if y+h >= img_lbl.shape[0]:
                    h = h - 1 

                if h>20 and w>20:
                    if key not in obj_dict.keys():
                        obj_dict[key] = []
                        
                    cur_obj = ObjVar(
                        position=[(y + (h//2)), (x + (w//2))],
                        bbox=[y, x, y+h, x+w],
                        grids=self.get_grid_span(x, y, w, h, horizontal_grid_lines, vertical_grid_lines),
                        class_name=self.labels_dict[key]['name'],
                        area=area
                    )
                    
                    obj_dict[key].append(cur_obj)

        return obj_dict

In [67]:
image_dir = os.path.join(os.getcwd(), '../des_gen_img')

source_images = [os.path.join(image_dir, img) for img in os.listdir(image_dir) if '.png' in img]
source_images.sort()


In [68]:
img = np.asarray(Image.open(source_images[0]))

obj_detector = ObjDetector()

obj_det_dict = obj_detector.get_localization(img_lbl=img)

obj_det_dict

{1: [<__main__.ObjVar at 0x23520aadc70>],
 2: [<__main__.ObjVar at 0x23520aadeb0>,
  <__main__.ObjVar at 0x23520aadac0>,
  <__main__.ObjVar at 0x23520acbdc0>],
 3: [<__main__.ObjVar at 0x23520acbf40>,
  <__main__.ObjVar at 0x2351ee680d0>,
  <__main__.ObjVar at 0x23520b84c10>,
  <__main__.ObjVar at 0x23520b84880>,
  <__main__.ObjVar at 0x23520b84130>],
 4: [<__main__.ObjVar at 0x23520b84d60>],
 5: [<__main__.ObjVar at 0x23520b5b5e0>],
 7: [<__main__.ObjVar at 0x23520b5b4f0>],
 9: [<__main__.ObjVar at 0x23520b5bd00>,
  <__main__.ObjVar at 0x23520b5b790>,
  <__main__.ObjVar at 0x23520b5b250>,
  <__main__.ObjVar at 0x23520b5b820>],
 10: [<__main__.ObjVar at 0x23520bd80a0>],
 11: [<__main__.ObjVar at 0x23520bd81f0>],
 12: [<__main__.ObjVar at 0x23520bd8220>],
 14: [<__main__.ObjVar at 0x23520bd8280>,
  <__main__.ObjVar at 0x23520bd82e0>,
  <__main__.ObjVar at 0x23520bd8160>],
 15: [<__main__.ObjVar at 0x23520bd8340>],
 16: [<__main__.ObjVar at 0x23520bd8430>],
 20: [<__main__.ObjVar at 0x23