Skip to content

Commit

Permalink
feat(evaluator): support logging per class mAP during evaluation (#1026)
Browse files Browse the repository at this point in the history
* feat(evaluator): support logging per class mAP during evaluation

* make linter happy
  • Loading branch information
FateScript committed Dec 24, 2021
1 parent 12d0df1 commit d669fd3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
1 change: 1 addition & 0 deletions tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def main(exp, args, num_gpu):
logger.info("Model Structure:\n{}".format(str(model)))

evaluator = exp.get_evaluator(args.batch_size, is_distributed, args.test, args.legacy)
evaluator.per_class_mAP = True

torch.cuda.set_device(rank)
model.cuda(rank)
Expand Down
47 changes: 43 additions & 4 deletions yolox/evaluators/coco_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
import tempfile
import time
from loguru import logger
from tabulate import tabulate
from tqdm import tqdm

import numpy as np

import torch

from yolox.data.datasets import COCO_CLASSES
from yolox.utils import (
gather,
is_main_process,
Expand All @@ -23,30 +27,63 @@
)


def per_class_mAP_table(coco_eval, class_names=COCO_CLASSES, headers=["class", "AP"], colums=6):
per_class_mAP = {}
precisions = coco_eval.eval["precision"]
# precision has dims (iou, recall, cls, area range, max dets)
assert len(class_names) == precisions.shape[2]

for idx, name in enumerate(class_names):
# area range index 0: all area ranges
# max dets index -1: typically 100 per image
precision = precisions[:, :, idx, 0, -1]
precision = precision[precision > -1]
ap = np.mean(precision) if precision.size else float("nan")
per_class_mAP[name] = float(ap * 100)

num_cols = min(colums, len(per_class_mAP) * len(headers))
result_pair = [x for pair in per_class_mAP.items() for x in pair]
row_pair = itertools.zip_longest(*[result_pair[i::num_cols] for i in range(num_cols)])
table_headers = headers * (num_cols // len(headers))
table = tabulate(
row_pair, tablefmt="pipe", floatfmt=".3f", headers=table_headers, numalign="left",
)
return table


class COCOEvaluator:
"""
COCO AP Evaluation class. All the data in the val2017 dataset are processed
and evaluated by COCO API.
"""

def __init__(
self, dataloader, img_size, confthre, nmsthre, num_classes, testdev=False
self,
dataloader,
img_size: int,
confthre: float,
nmsthre: float,
num_classes: int,
testdev: bool = False,
per_class_mAP: bool = False,
):
"""
Args:
dataloader (Dataloader): evaluate dataloader.
img_size (int): image size after preprocess. images are resized
img_size: image size after preprocess. images are resized
to squares whose shape is (img_size, img_size).
confthre (float): confidence threshold ranging from 0 to 1, which
confthre: confidence threshold ranging from 0 to 1, which
is defined in the config file.
nmsthre (float): IoU threshold of non-max supression ranging from 0 to 1.
nmsthre: IoU threshold of non-max supression ranging from 0 to 1.
per_class_mAP: Show per class mAP during evalution or not. Default to False.
"""
self.dataloader = dataloader
self.img_size = img_size
self.confthre = confthre
self.nmsthre = nmsthre
self.num_classes = num_classes
self.testdev = testdev
self.per_class_mAP = per_class_mAP

def evaluate(
self,
Expand Down Expand Up @@ -216,6 +253,8 @@ def evaluate_prediction(self, data_dict, statistics):
with contextlib.redirect_stdout(redirect_string):
cocoEval.summarize()
info += redirect_string.getvalue()
if self.per_class_mAP:
info += "per class mAP:\n" + per_class_mAP_table(cocoEval)
return cocoEval.stats[0], cocoEval.stats[1], info
else:
return 0, 0, info

0 comments on commit d669fd3

Please sign in to comment.