In [None]:
%matplotlib inline

from cytools import bbox_overlaps
import os
from os.path import join as pjoin
from os.path import exists as pexists
import json
import math
import random
import time
from random import shuffle

import cv2
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

data_dir = "/home/toni/datasets/openimages"
random.seed(int(time.time()))
results_paths = [x.strip()[:-4] + "/results.json" for x in open("/tmp/sorted_4_oi_names.txt")]
shuffle(results_paths)
filenames_paths = [x.replace("/results.json", ".jpg") for x in results_paths]

centernet_path = "/opt/results/CenterNet-104_480000"
atss_path = "/opt/results/ATSS"
number_of_examples = 1000
CLASSES = {
    'airplane': 4,
    'apple': 47,
    'backpack': 24,
    'banana': 46,
    'baseballbat': 34,
    'baseballglove': 35,
    'bear': 21,
    'bed': 59,
    'bench': 13,
    'bicycle': 1,
    'bird': 14,
    'boat': 8,
    'book': 73,
    'bottle': 39,
    'bowl': 45,
    'broccoli': 50,
    'bus': 5,
    'cake': 55,
    'car': 2,
    'carrot': 51,
    'cat': 15,
    'cellphone': 67,
    'chair': 56,
    'clock': 74,
    'couch': 57,
    'cow': 19,
    'cup': 41,
    'diningtable': 60,
    'dog': 16,
    'donut': 54,
    'elephant': 20,
    'firehydrant': 10,
    'fork': 42,
    'frisbee': 29,
    'giraffe': 23,
    'hair_drier': 78,
    'handbag': 26,
    'horse': 17,
    'hotdog': 52,
    'keyboard': 66,
    'kite': 33,
    'knife': 43,
    'laptop': 63,
    'microwave': 68,
    'motorcycle': 3,
    'mouse': 64,
    'orange': 49,
    'oven': 69,
    'parkingmeter': 12,
    'person': 0,
    'pizza': 53,
    'pottedplant': 58,
    'refrigerator': 72,
    'remote': 65,
    'sandwich': 48,
    'scissors': 76,
    'sheep': 18,
    'sink': 71,
    'skateboard': 36,
    'skis': 30,
    'snowboard': 31,
    'spoon': 44,
    'sportsball': 32,
    'stopsign': 11,
    'suitcase': 28,
    'surfboard': 37,
    'teddybear': 77,
    'tennisracket': 38,
    'tie': 27,
    'toaster': 70,
    'toilet': 61,
    'toothbrush': 79,
    'trafficlight': 9,
    'train': 6,
    'truck': 7,
    'tv': 62,
    'umbrella': 25,
    'vase': 75,
    'wineglass': 40,
    'zebra': 22
}

def print_results(results, model, image, color):
    fig, ax = plt.subplots(figsize=(12, 12))
    fig = ax.imshow(image, aspect='equal')
    plt.axis('off')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

    for x in results:
        bbox = x[:4]

        xmin = bbox[0]
        ymin = bbox[1]
        xmax = bbox[2]
        ymax = bbox[3]

        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor=color,
                                   linewidth=4.0))

        ax.text(xmin+1, ymin-3, '{:s}'.format(model+"_"+CLASSES.keys()[x[-1:]]), bbox=dict(facecolor=color, ec='black', lw=2, alpha=0.5),
                fontsize=15, color='white', weight='bold')

    plt.show()
    plt.close()


def merge_bboxes(bboxA, bboxB, method="mean"):
    if method == "mean":
        return (bboxA+bboxB)/2
    

    
def combine(bboxesA, bboxesB, thr=0.7): # bboxes -> [x1, y1, x2, y2, score1, score2, ..., scoreN, category]
    
    cluster = np.array([])

    ious = bbox_overlaps(bboxesA[:, :4], bboxesB[:, :4])
    
    max_ious = ious.max(0)  # [0.9, 0.3, 0.7, 0.1, 0.6]
    
    oks = max_ious > thr
    # [True, False, True, False, True]
    
    rels = ious.argmax(0)
    # [2, 4, 1, 3, 5]
    
    for i, (ok, idx) in enumerate(zip(oks, rels)):
        if ok:
            new_bbox = merge_bboxes(bboxesA[i, :4], bboxesB[idx, :4], method='mean') # [x1, y1, x2, y2]
            new_score = merge_scores(bboxesA[i, 4:], bboxesB[idx, 4:-1], method='mean') # [score1, score2, ..., scoreN]
            
            cluster.append(np.concatenate(new_bbox, new_score))
        else:
            cluster.append(bboxesA[i])
            cluster.append(bboxesB[idx])

    
    # [x1, y1, x2, y2, score1, score2, ..., scoreN, category]
    return cluster
                           

def parse_teacher_results(results):
    arr = []
                           
    for i, result in enumerate(results):
        category = result["category_id"].replace("_", "").replace(" ", "")
        arr.append([*result["bbox"], result["score"], CLASSES[category]])
                        
    return arr
          
            
        
def cluster():
    count = 0
    while (True):
        count+=1

        # If both teachers have results of the file
        if pexists(pjoin(centernet_path, results_paths[count])) and pexists(pjoin(atss_path, results_paths[count])):
            print(pjoin(centernet_path, results_paths[count]))
            
            teachers = []
            # Read Centernet results
            with open(pjoin(centernet_path, results_paths[count])) as f:
                centernet_results = json.load(f)
            
            # Parse Centernet results
            for i, result in enumerate(centernet_results):
                centernet_results[i]["bbox"][2] += centernet_results[i]["bbox"][0]
                centernet_results[i]["bbox"][3] += centernet_results[i]["bbox"][1]
                           
            teachers.append(parse_teacher_results(centernet_results))
                            
            # Read ATSS results
            with open(pjoin(atss_path, results_paths[count])) as f:
                atss_results = json.load(f)
            
            teachers.append(parse_teacher_results(atss_results))
                           
            
            cluster_result = []
            for i in range(len(teachers)):
                if len(cluster_result) == 0:
                    cluster_result = teachers[i]
                else:
                    cluster_result = combine(np.array(cluster_result), np.array(teachers[i]))
                    

            # Get image
            image_file = pjoin(data_dir, filenames_paths[count])
            image = cv2.imread(image_file)[:, :, ::-1]
            
            # Print results
            print_results(centernet_results, "centernet", image, "green")
            print_results(atss_results, "atss", image, "red")
            print_results(cluster_result, "cluster", image, "yellow")
            
            yield count
            

In [None]:
for i, x in enumerate(cluster()):
    if i == number_of_examples:
        break
    continue
