In [None]:
import xml.etree.ElementTree as ET
import numpy as np

In [None]:
grid_size = 5

class_names = ['__background__', 'aeroplane', 'bicycle', 'bird', 'boat',
    'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
    'dog', 'horse', 'motorbike', 'person', 'pottedplant',
    'sheep', 'sofa', 'train', 'tvmonitor'
]
name_to_ix = {k:i for i,k in enumerate(class_names)}
name_to_ix

In [None]:
tree = ET.parse('test.xml')
root = tree.getroot()

In [None]:
# Exploring the tree
for x in root:
    print(x)

In [None]:
# Fetching the dimensions of the image to later use for grid cell classification
dims = []
for x in root.find('size'):
    dims.append(int(x.text))

In [None]:
width , height , channels = dims

In [None]:
# An array to store all the objects to later fetch data
objects_array = []
for x in root.findall('object'):
    objects_array.append(x)
objects_array

In [None]:
# An array to store all the bounding box coordinates for every image in the image
name_arr = []
bndbox_array = []
for x in objects_array:
    # Bounding box logic
    temp_arr = []
    for y in x.find('bndbox'):
        temp_arr.append(int(y.text))
    bndbox_array.append(temp_arr)
    name_arr.append(x.find('name').text)    

bndbox_array

index_arr = [name_to_ix[name] for name in name_arr]

index_arr


In [None]:
def object_in_cell(obj_coords , cell_coords):
    obj_xmin ,obj_ymin, obj_xmax , obj_ymax = obj_coords
    cell_xmin ,  cell_ymin , cell_xmax , cell_ymax = cell_coords
    obj_x_mid = ( obj_xmin + obj_xmax )/ 2
    obj_y_mid = ( obj_ymin + obj_ymax )/ 2
    if((cell_xmin <= obj_x_mid <= cell_xmax) and (cell_ymin <= obj_y_mid <= cell_ymax)):
        return True
    return False

In [None]:
grid_placements_index = np.full((grid_size,grid_size),-1)

grid_horiz_start = 0
grid_horiz_end = width/grid_size
grid_vert_start = 0
grid_vert_end = height/grid_size
for x in range(grid_size):
    for y in range(grid_size):
        cell_coords = [grid_horiz_start , grid_vert_start , grid_horiz_end , grid_vert_end]
        for i,z in enumerate(bndbox_array):
            if object_in_cell(z , cell_coords):
                grid_placements_index[x][y] =  i
                break 
        grid_horiz_start += width/grid_size 
        grid_horiz_end += width/grid_size        
    grid_horiz_start = 0 
    grid_horiz_end = width/grid_size
    grid_vert_start += height/grid_size 
    grid_vert_end += height/grid_size        

grid_placements_index

array([[-1, -1, -1, -1, -1],
       [-1,  1, -1, -1, -1],
       [-1, -1,  0, -1, -1],
       [-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1]])

In [None]:
num_classes = len(class_names)-1
default_y_label = [0] + [0]*4 + [0]*num_classes

y_label_arr = np.zeros((grid_size, grid_size, len(default_y_label)), dtype='float')


# Initialize each cell with the default label
for i in range(grid_size):
    for j in range(grid_size):
        y_label_arr[i][j] = default_y_label.copy()

def label_array(grid_placements_index):
    for x in range(grid_size):
        for y in range(grid_size):
            if(grid_placements_index[x][y] >= 0):
                encode_class = [0]*num_classes
                encode_class[index_arr[grid_placements_index[x][y]]] = 1
                xmin ,ymin, xmax , ymax = bndbox_array[grid_placements_index[x][y]] 
                cx = ((xmin + xmax) / 2) / width     # center x normalized
                cy = ((ymin + ymax) / 2) / height     # center y normalized
                bw = (xmax - xmin) / width           # box width normalized
                bh = (ymax - ymin) / height 
                constructed_y_label = [1, cx, cy, bw, bh] + encode_class
                y_label_arr[x][y][:] = constructed_y_label
    return y_label_arr
y_true = label_array(grid_placements_index)

array([1.        , 0.251     , 0.36666667, 0.494     , 0.392     ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 1.        , 0.        ])