In [1]:
from __future__ import absolute_import, division, print_function
import json
import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from metric_utils import *
import torchmetrics

## Predictions Format

The model accepts predictions as JSON files in the following format
(The notation in caps with a dollar sign (`$IMG_NAME`) means that such key will be replaced by the actual value of the property)
```json
    {
        "$IMG_NAME": {
            "labels": [$CLASS_IDX: int, ..., $CLASS_IDX: int],
            "boxes": [[$X1, $X2, $Y1, $Y2], ..., [$X1, $X2, $Y1, $Y2]],
            "scores": [$CONFIDENCE_SCORE: float, ..., $CONFIDENCE_SCORE: float]
        },

        ...
        
        "$IMG_NAME": {
            "labels": [$CLASS_IDX: int, ..., $CLASS_IDX: int],
            "boxes": [[$X1, $X2, $Y1, $Y2], ..., [$X1, $X2, $Y1, $Y2]],
            "scores": [$CONFIDENCE_SCORE: float, ..., $CONFIDENCE_SCORE: float]
        }
    }
```

## Ground Truth Format

Same as before, without the scores
```json
    {
        "$IMG_NAME": {
            "labels": [$CLASS_NAME: string, ..., $CLASS_NAME: string],
            "boxes": [[$X1, $X2, $Y1, $Y2], ..., [$X1, $X2, $Y1, $Y2]],
        },

        ...
        
        "$IMG_NAME": {
            "labels": [$CLASS_NAME: string, ..., $CLASS_NAME: string],
            "boxes": [[$X1, $X2, $Y1, $Y2], ..., [$X1, $X2, $Y1, $Y2]],
        }
    }
```

## Loading gt and pred annotations

In [2]:
# gt_file = 'gt_boxes.json'
# pred_file = 'rg-boxes@0.9cs.json.json'
gt_file = 'gt_less_5.json'
pred_file = 'pred_less_5.json'

with open(gt_file) as infile:
    gt_boxes = json.load(infile)

with open(pred_file) as infile:
    pred_boxes = json.load(infile)


class_name_to_id = {
    'sidelobe': 1,
    'source': 2,
    'galaxy': 3,
}

In [3]:
import torch
pred_list = []
gt_list = []
no_pred = []

for img in sorted(pred_boxes):
    for k, v in pred_boxes[img].items():
        if k == 'labels':
            pred_boxes[img]['labels'] = torch.tensor(list(map(lambda x: class_name_to_id[x], pred_boxes[img]['labels'])))
        else:
            pred_boxes[img][k] = torch.tensor(v)

    gt_boxes[img]['boxes'] = torch.tensor(gt_boxes[img]['boxes'])
    gt_boxes[img]['labels'] = torch.tensor(list(map(lambda x: class_name_to_id[x], gt_boxes[img]['labels'])))


    if not pred_boxes[img]:
        # Handle missing predictions
        continue
    pred_list.append(pred_boxes[img])
    gt_list.append(gt_boxes[img])

In [4]:
map50 = torchmetrics.MAP(class_metrics=True)
map50.update(pred_list, gt_list)
map50.compute()

  rank_zero_deprecation(


{'map': tensor([0.4054]),
 'map_50': tensor([0.7310]),
 'map_75': tensor([0.3741]),
 'map_small': tensor([0.4054]),
 'map_medium': tensor([-1.]),
 'map_large': tensor([-1.]),
 'mar_1': tensor([0.3429]),
 'mar_10': tensor([0.4441]),
 'mar_100': tensor([0.4441]),
 'mar_small': tensor([0.4441]),
 'mar_medium': tensor([-1.]),
 'mar_large': tensor([-1.]),
 'map_per_class': tensor([0.3172, 0.4936]),
 'mar_100_per_class': tensor([0.3200, 0.5682])}

In [5]:
pred_list

[{'labels': tensor([2, 2, 1, 2, 1, 2, 2, 2]),
  'boxes': tensor([[100.6596,   2.9818, 105.0881,   7.7063],
          [ 37.9099,  23.5687,  41.3657,  27.0133],
          [ 27.6616,  24.9992, 102.2815, 104.7608],
          [121.9637,  67.8532, 127.9469,  72.9755],
          [ 28.8867,  75.8943,  59.8381, 103.4442],
          [ 77.1975, 114.7743,  83.5611, 122.3947],
          [113.6003,  86.0253, 118.7156,  90.2984],
          [ 55.2425,  19.2636,  62.3436,  25.4450]]),
  'scores': tensor([0.9948, 0.9282, 0.9249, 0.9982, 0.9550, 0.9953, 0.9979, 0.9982])},
 {'labels': tensor([2, 2, 2]),
  'boxes': tensor([[ 9.6838e+01,  3.6944e+01,  1.0195e+02,  4.1980e+01],
          [ 2.7586e+01, -5.0275e-03,  4.2092e+01,  6.6433e+00],
          [ 6.2486e+01,  6.2937e+01,  6.7974e+01,  6.7900e+01]]),
  'scores': tensor([0.9779, 0.9953, 0.9992])},
 {'labels': tensor([2]),
  'boxes': tensor([[63.1644, 61.7706, 68.4250, 69.3829]]),
  'scores': tensor([0.9981])},
 {'labels': tensor([2]),
  'boxes': tensor([