In [4]:
from glob import glob
import xml.etree.ElementTree as ET
from tqdm import tqdm

import numpy as np

ANNOTATIONS_PATH = "./VOC2012/Annotations"
CLUSTERS = 5

In [7]:
def load_dataset(path):
    dataset = []
    for xml_file in tqdm(glob("{}/*xml".format(path))):
        tree = ET.parse(xml_file)

        height = int(tree.findtext("./size/height"))
        width = int(tree.findtext("./size/width"))

        for obj in tree.iter("object"):
            xmin = int(float(obj.findtext("bndbox/xmin"))) / width
            ymin = int(float(obj.findtext("bndbox/ymin"))) / height
            xmax = int(float(obj.findtext("bndbox/xmax"))) / width
            ymax = int(float(obj.findtext("bndbox/ymax"))) / height

            dataset.append([xmax - xmin, ymax - ymin])

    return np.array(dataset)

In [8]:
data = load_dataset(ANNOTATIONS_PATH)

100%|██████████| 17125/17125 [00:02<00:00, 6100.72it/s]


In [22]:
def iou(box, clusters):
    """[summary]
	Calculate the intersection over union (iou) between a box and k clusters
 
	Args:
		box ([np.array]): single array (w, h)
		clusters ([np.array]): numpy array of shpae (k, 2)
	"""
    
    # print(f"width clusters[:, 0]: {clusters[:, 0]}, box[0]: {box[0]}")
    # print(f"width clusters[:, 1]: {clusters[:, 1]}, box[1]: {box[1]}")
    
    x = np.minimum(clusters[:, 0], box[0]) # broadcasting
    y = np.minimum(clusters[:, 1], box[1])
    
    if np.count_nonzero(x==0) > 0 or np.count_nonzero(y==0) > 0:
        raise ValueError("Box has no area")
    
    intersection = x*y
    box_area = box[0] * box[1]
    cluster_area = clusters[:, 0] * clusters[:, 1]
    
    iou_ = intersection / (box_area + cluster_area - intersection)
    
    return iou_

In [25]:
def avg_iou(boxes, clusters):
    return np.mean([np.max(iou(boxes[i], clusters)) for i in range(boxes.shape[0])])

In [26]:
def kmeans(boxes, k, dist=np.median):
    """[summary]
    Calculate k-means clustering with the intersection over union (IoU) metric

    Args:
        boxes ([np.array]): numpy array of shape (r, 2), where r is the number of rows
        k ([int]): number of clusters
        dist: distance function. Defaults to np.median.
    Return:
        clusters: (k, 2)
    """
    rows = boxes.shape[0]

    distances = np.empty((rows, k))
    last_clusters = np.zeros((rows,))

    np.random.seed()

    clusters = boxes[np.random.choice(rows, k, replace=False)]
    
    while True:
        for row in tqdm(range(rows)):
            distances[row] = 1 - iou(boxes[row], clusters)
            
        nearest_clusters = np.argmin(distances, axis=1)
        
        if (last_clusters == nearest_clusters).all():
            break
        
        for i in range(k):
            clusters[i] = dist(boxes[nearest_clusters == i], axis=0)
        
        last_clusters = nearest_clusters
        
    return clusters

In [27]:
out = kmeans(data, CLUSTERS)

100%|██████████| 40138/40138 [00:00<00:00, 70699.31it/s]
100%|██████████| 40138/40138 [00:00<00:00, 70240.45it/s]
100%|██████████| 40138/40138 [00:00<00:00, 71738.60it/s]
100%|██████████| 40138/40138 [00:00<00:00, 69983.00it/s]
100%|██████████| 40138/40138 [00:00<00:00, 70235.91it/s]
100%|██████████| 40138/40138 [00:00<00:00, 68439.22it/s]
100%|██████████| 40138/40138 [00:00<00:00, 71358.05it/s]
100%|██████████| 40138/40138 [00:00<00:00, 70482.03it/s]
100%|██████████| 40138/40138 [00:00<00:00, 70236.20it/s]
100%|██████████| 40138/40138 [00:00<00:00, 69749.31it/s]
100%|██████████| 40138/40138 [00:00<00:00, 70363.07it/s]
100%|██████████| 40138/40138 [00:00<00:00, 69093.54it/s]
100%|██████████| 40138/40138 [00:00<00:00, 70979.45it/s]
100%|██████████| 40138/40138 [00:00<00:00, 69628.58it/s]
100%|██████████| 40138/40138 [00:00<00:00, 70358.93it/s]
100%|██████████| 40138/40138 [00:00<00:00, 71230.37it/s]
100%|██████████| 40138/40138 [00:00<00:00, 70605.82it/s]
100%|██████████| 40138/40138 [0

In [29]:
print("Accuracy: {:.2f}%".format(avg_iou(data, out) * 100))
print("Boxes:\n {}".format(out))

ratios = np.around(out[:, 0] / out[:, 1], decimals=2).tolist()
inverse_ratios = np.around(out[:, 1] / out[:, 0], decimals=2).tolist()

print("Ratios:\n {}".format(sorted(ratios)))
print("Inverse Ratios:\n {}".format(sorted(inverse_ratios)))

Accuracy: 61.24%
Boxes:
 [[0.812      0.82933333]
 [0.402      0.608     ]
 [0.042      0.07207207]
 [0.194      0.37866667]
 [0.1        0.17066667]]
Ratios:
 [0.51, 0.58, 0.59, 0.66, 0.98]
Inverse Ratios:
 [1.02, 1.51, 1.71, 1.72, 1.95]
