In [20]:
from __future__ import absolute_import, division, print_function
import json
import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from metric_utils import *
import torchmetrics

## Predictions Format

The model accepts predictions as JSON files in the following format
(The notation in caps with a dollar sign (`$IMG_NAME`) means that such key will be replaced by the actual value of the property)
```json
    {
        "$IMG_NAME": {
            "labels": [$CLASS_IDX: int, ..., $CLASS_IDX: int],
            "boxes": [[$X1, $X2, $Y1, $Y2], ..., [$X1, $X2, $Y1, $Y2]],
            "scores": [$CONFIDENCE_SCORE: float, ..., $CONFIDENCE_SCORE: float]
        },

        ...
        
        "$IMG_NAME": {
            "labels": [$CLASS_IDX: int, ..., $CLASS_IDX: int],
            "boxes": [[$X1, $X2, $Y1, $Y2], ..., [$X1, $X2, $Y1, $Y2]],
            "scores": [$CONFIDENCE_SCORE: float, ..., $CONFIDENCE_SCORE: float]
        }
    }
```

## Ground Truth Format

Compliant to [COCO format]()
```json
{
    "image": {
        "id": int, 
        "width": int, 
        "height": int, 
        "file_name": str
    },
    "annotations": {
        "id": int, 
        "image_id": int, 
        "category_id": int, 
        "segmentation": RLE or [polygon], 
        "area": float, 
        "bbox": [x,y,width,height], 
        "iscrowd": 0 or 1,
    },

    "categories": [{
        "id": int, 
        "name": str
    }]
}
```

## Visualization setup

In [21]:
sns.set_style('white')
sns.set_context('poster')

COLORS = [
    '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
    '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
    '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
    '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5']

## Loading gt and pred annotations

In [22]:
with open('test.json') as infile:
    annotations = json.load(infile)

with open('rg-boxes.json') as infile:
    pred_boxes = json.load(infile)
    
id_to_filename = {}
for img in annotations['images']:
    id_to_filename[img['id']] = img['file_name'].split('.')[0]

class_id_to_name = {}
for cl in annotations['categories']:
    class_id_to_name[cl['id']] = cl['name']

gt_boxes = {}
for ann in annotations['annotations']:
    img_id = ann['image_id']
    img_name = id_to_filename[img_id]

    w, h = annotations['images'][img_id]['width'], annotations['images'][img_id]['height']

    if img_name not in gt_boxes:
        gt_boxes[img_name] = {'boxes': [], 'labels': []}
    bbox = ann['bbox'].copy()

    bbox[2] += bbox[0]
    bbox[3] += bbox[1]

    class_id = ann['category_id']
    gt_boxes[img_name]['boxes'].append(bbox)
    gt_boxes[img_name]['labels'].append(class_id)

In [23]:
import torch
pred_list = []
gt_list = []
no_pred = []

for img in sorted(pred_boxes):
    for k, v in pred_boxes[img].items():
        if not isinstance(v, torch.Tensor):
            pred_boxes[img][k] = torch.tensor(v)

    for k, v in gt_boxes[img].items():
        gt_boxes[img][k] = torch.tensor(v)


    if not pred_boxes[img]:
        # Handle missing predictions
        continue
    pred_list.append(pred_boxes[img])
    gt_list.append(gt_boxes[img])

In [24]:
map50 = torchmetrics.MAP(class_metrics=True)
map50.update(pred_list, gt_list)
map50.compute()