Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve mAP performance #742

Merged
merged 28 commits into from
Jan 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2db054f
Simplify id generation
twsl Dec 27, 2021
34f673e
rework and speed up _find_best_gt_match
Dec 28, 2021
c66de8c
fix gpu test and move inputs to gpu
Dec 28, 2021
b926929
fix: boxes xywh format
OlofHarrysson Jan 3, 2022
ebff078
add: Refactor to avoid duplicate calculations
OlofHarrysson Jan 4, 2022
d260d81
precision,recall,scores on correct device (-20%)
twsl Jan 10, 2022
ee1d0a0
arguments to python lists
twsl Jan 10, 2022
4fd42ba
enumerate instead of range
twsl Jan 10, 2022
45806f7
compute on device
twsl Jan 10, 2022
ab91460
Remove exception
twsl Jan 10, 2022
a714acf
Replace prec score loop
twsl Jan 10, 2022
27783a0
Fix auc flattening
twsl Jan 10, 2022
d67808f
Merge branch 'PyTorchLightning:master' into fix/map-perf
twsl Jan 11, 2022
7274117
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2022
43a2261
Remove deprecated functions, and warnings - Text (#773)
ashutoshml Jan 18, 2022
d68aaaa
draft to run metric on cpu only
Jan 19, 2022
738369c
Merge branch 'master' into fix/map-perf
tkupek Jan 19, 2022
0ab73ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2022
48675de
move tensors to cpu on compute, need to be on GPU for multi GPU syncing
Jan 26, 2022
b3884c3
Remove performance test script
twsl Jan 27, 2022
38ed36a
Merge branch 'master' into fix/map-perf
twsl Jan 27, 2022
5bf856f
Remove unused imports
twsl Jan 27, 2022
f65ce31
changelog
SkafteNicki Jan 27, 2022
c154707
Merge branch 'master' into fix/map-perf
Borda Jan 27, 2022
789eaf5
Merge branch 'master' into fix/map-perf
SkafteNicki Jan 31, 2022
bf3f4f0
fix mypy
SkafteNicki Jan 31, 2022
5b2ff30
suggestions
SkafteNicki Jan 31, 2022
a04dbab
Update torchmetrics/detection/map.py
SkafteNicki Jan 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed Matthews correlation coefficient when the denominator is 0 ([#781](https://github.com/PyTorchLightning/metrics/pull/781))


- Improve mAP performance ([#742](https://github.com/PyTorchLightning/metrics/pull/742))


## [0.7.0] - 2022-01-17

### Added
Expand Down
13 changes: 12 additions & 1 deletion tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,24 @@ def test_empty_ground_truths():
_gpu_test_condition = not torch.cuda.is_available()


def _move_to_gpu(input):
for x in input:
for key in x.keys():
if torch.is_tensor(x[key]):
x[key] = x[key].to("cuda")
return input


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_gpu_test_condition, reason="test requires CUDA availability")
def test_map_gpu():
"""Test predictions on single gpu."""
metric = MeanAveragePrecision()
metric = metric.to("cuda")
metric.update(_inputs.preds[0], _inputs.target[0])
preds = _inputs.preds[0]
targets = _inputs.target[0]

metric.update(_move_to_gpu(preds), _move_to_gpu(targets))
metric.compute()


Expand Down
170 changes: 89 additions & 81 deletions torchmetrics/detection/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import torch
from torch import IntTensor, Size, Tensor
from torch import IntTensor, Tensor

from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_8
Expand Down Expand Up @@ -156,7 +156,7 @@ class MeanAveragePrecision(Metric):

Args:
box_format:
Input format of given boxes. Supported formats are [xyxy’, ‘xywh’, ‘cxcywh].
Input format of given boxes. Supported formats are [`xyxy`, `xywh`, `cxcywh`].
iou_thresholds:
IoU thresholds for evaluation. If set to `None` it corresponds to the stepped range `[0.5,...,0.95]`
with step `0.05`. Else provide a list of floats.
Expand Down Expand Up @@ -222,6 +222,12 @@ class MeanAveragePrecision(Metric):
If ``class_metrics`` is not a boolean
"""

detection_boxes: List[Tensor]
detection_scores: List[Tensor]
detection_labels: List[Tensor]
groundtruth_boxes: List[Tensor]
groundtruth_labels: List[Tensor]

def __init__(
self,
box_format: str = "xyxy",
Expand Down Expand Up @@ -251,10 +257,10 @@ def __init__(
if box_format not in allowed_box_formats:
raise ValueError(f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}")
self.box_format = box_format
self.iou_thresholds = Tensor(iou_thresholds or torch.linspace(0.5, 0.95, round((0.95 - 0.5) / 0.05) + 1))
self.rec_thresholds = Tensor(rec_thresholds or torch.linspace(0.0, 1.00, round(1.00 / 0.01) + 1))
self.max_detection_thresholds = IntTensor(max_detection_thresholds or [1, 10, 100])
self.max_detection_thresholds, _ = torch.sort(self.max_detection_thresholds)
self.iou_thresholds = iou_thresholds or torch.linspace(0.5, 0.95, round((0.95 - 0.5) / 0.05) + 1).tolist()
self.rec_thresholds = rec_thresholds or torch.linspace(0.0, 1.00, round(1.00 / 0.01) + 1).tolist()
max_det_thr, _ = torch.sort(IntTensor(max_detection_thresholds or [1, 10, 100]))
self.max_detection_thresholds = max_det_thr.tolist()
self.bbox_area_ranges = {
"all": (0 ** 2, int(1e5 ** 2)),
"small": (0 ** 2, 32 ** 2),
Expand Down Expand Up @@ -318,20 +324,16 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
_input_validator(preds, target)

for item in preds:
self.detection_boxes.append(
_fix_empty_tensors(box_convert(item["boxes"], in_fmt=self.box_format, out_fmt="xyxy"))
if item["boxes"].size() == Size([1, 4])
else _fix_empty_tensors(item["boxes"])
)
boxes = _fix_empty_tensors(item["boxes"])
boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy")
self.detection_boxes.append(boxes)
self.detection_labels.append(item["labels"])
self.detection_scores.append(item["scores"])

for item in target:
self.groundtruth_boxes.append(
_fix_empty_tensors(box_convert(item["boxes"], in_fmt=self.box_format, out_fmt="xyxy"))
if item["boxes"].size() == Size([1, 4])
else _fix_empty_tensors(item["boxes"])
)
boxes = _fix_empty_tensors(item["boxes"])
boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy")
self.groundtruth_boxes.append(boxes)
self.groundtruth_labels.append(item["labels"])

def _get_classes(self) -> List:
Expand Down Expand Up @@ -423,22 +425,22 @@ def _evaluate_image(
nb_iou_thrs = len(self.iou_thresholds)
nb_gt = len(gt)
nb_det = len(det)
gt_matches = torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device)
det_matches = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device)
gt_matches = torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool)
det_matches = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool)
gt_ignore = ignore_area_sorted
det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device)
det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

if torch.numel(ious) > 0:
for idx_iou, t in enumerate(self.iou_thresholds):
for idx_det in range(nb_det):
m = MeanAveragePrecision._find_best_gt_match(
t, nb_gt, gt_matches, idx_iou, gt_ignore, ious, idx_det
)
for idx_det, _ in enumerate(det):
m = MeanAveragePrecision._find_best_gt_match(t, gt_matches, idx_iou, gt_ignore, ious, idx_det)
if m != -1:
det_ignore[idx_iou, idx_det] = gt_ignore[m]
det_matches[idx_iou, idx_det] = True
gt_matches[idx_iou, m] = True
det_matches[idx_iou, idx_det] = 1
gt_matches[idx_iou, m] = 1

# set unmatched detections outside of area range to ignore
det_areas = box_area(det).to(self.device)
det_areas = box_area(det)
det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1])
ar = det_ignore_area.reshape((1, nb_det))
det_ignore = torch.logical_or(
Expand All @@ -454,7 +456,7 @@ def _evaluate_image(

@staticmethod
def _find_best_gt_match(
thr: int, nb_gt: int, gt_matches: Tensor, idx_iou: float, gt_ignore: Tensor, ious: Tensor, idx_det: int
thr: int, gt_matches: Tensor, idx_iou: float, gt_ignore: Tensor, ious: Tensor, idx_det: int
) -> int:
"""Return id of best ground truth match with current detection.

Expand All @@ -474,23 +476,14 @@ def _find_best_gt_match(
idx_det:
Id of current detection.
"""
# information about best match so far (m=-1 -> unmatched)
iou = min([thr, 1 - 1e-10])
match_id = -1
for idx_gt in range(nb_gt):
# if this gt already matched, and not a crowd, continue
if gt_matches[idx_iou, idx_gt]:
continue
# if dt matched to reg gt, and on ignore gt, stop
if match_id > -1 and not gt_ignore[match_id] and gt_ignore[idx_gt]:
break
# continue to next gt unless better match made
if ious[idx_det, idx_gt] < iou:
continue
# if match successful and best so far, store appropriately
iou = ious[idx_det, idx_gt]
match_id = idx_gt
return match_id
previously_matched = gt_matches[idx_iou]
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
# Remove previously matched or ignored gts
remove_mask = previously_matched | gt_ignore
gt_ious = ious[idx_det] * ~remove_mask
match_idx = gt_ious.argmax().item()
if gt_ious[match_idx] > thr:
return match_idx
return -1

def _summarize(
self,
Expand Down Expand Up @@ -521,28 +514,30 @@ def _summarize(
prec = results["precision"]
# IoU
if iou_threshold is not None:
thr = torch.where(iou_threshold == self.iou_thresholds)[0]
prec = prec[thr]
prec = prec[:, :, :, area_inds, mdet_inds]
thr = self.iou_thresholds.index(iou_threshold)
prec = prec[thr, :, :, area_inds, mdet_inds]
else:
prec = prec[:, :, :, area_inds, mdet_inds]
else:
# dimension of recall: [TxKxAxM]
prec = results["recall"]
if iou_threshold is not None:
thr = torch.where(iou_threshold == self.iou_thresholds)[0]
prec = prec[thr]
prec = prec[:, :, area_inds, mdet_inds]
thr = self.iou_thresholds.index(iou_threshold)
prec = prec[thr, :, :, area_inds, mdet_inds]
else:
prec = prec[:, :, area_inds, mdet_inds]

mean_prec = Tensor([-1]) if len(prec[prec > -1]) == 0 else torch.mean(prec[prec > -1])
return mean_prec

def _calculate(self, class_ids: List) -> Tuple[Dict, MAPMetricResults, MARMetricResults]:
"""Calculate the precision, recall and scores for all supplied label classes to calculate mAP/mAR.
def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResults]:
"""Calculate the precision and recall for all supplied classes to calculate mAP/mAR.

Args:
class_ids:
List of label class Ids.
"""
img_ids = torch.arange(len(self.groundtruth_boxes), dtype=torch.int).tolist()
img_ids = range(len(self.groundtruth_boxes))
max_detections = self.max_detection_thresholds[-1]
area_ranges = self.bbox_area_ranges.values()

Expand All @@ -568,12 +563,11 @@ def _calculate(self, class_ids: List) -> Tuple[Dict, MAPMetricResults, MARMetric
scores = -torch.ones((nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs))

# move tensors if necessary
self.max_detection_thresholds = self.max_detection_thresholds.to(self.device)
self.rec_thresholds = self.rec_thresholds.to(self.device)
rec_thresholds_tensor = Tensor(self.rec_thresholds)

# retrieve E at each category, area range, and max number of detections
for idx_cls in range(nb_classes):
for idx_bbox_area in range(nb_bbox_areas):
for idx_cls, _ in enumerate(class_ids):
for idx_bbox_area, _ in enumerate(self.bbox_area_ranges):
for idx_max_det_thrs, max_det in enumerate(self.max_detection_thresholds):
recall, precision, scores = MeanAveragePrecision.__calculate_recall_precision_scores(
recall,
Expand All @@ -583,19 +577,24 @@ def _calculate(self, class_ids: List) -> Tuple[Dict, MAPMetricResults, MARMetric
idx_bbox_area=idx_bbox_area,
idx_max_det_thrs=idx_max_det_thrs,
eval_imgs=eval_imgs,
rec_thresholds=self.rec_thresholds,
rec_thresholds=rec_thresholds_tensor,
max_det=max_det,
nb_imgs=nb_imgs,
nb_bbox_areas=nb_bbox_areas,
)

results = {
"dimensions": [nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs],
"precision": precision,
"recall": recall,
"scores": scores,
}
return precision, recall

def _summarize_results(self, precisions: Tensor, recalls: Tensor) -> Tuple[MAPMetricResults, MARMetricResults]:
"""Summarizes the precision and recall values to calculate mAP/mAR.

Args:
precisions:
Precision values for different thresholds
recalls:
Recall values for different thresholds
"""
results = dict(precision=precisions, recall=recalls)
map_metrics = MAPMetricResults()
map_metrics.map = self._summarize(results, True)
last_max_det_thr = self.max_detection_thresholds[-1]
Expand All @@ -612,7 +611,7 @@ def _calculate(self, class_ids: List) -> Tuple[Dict, MAPMetricResults, MARMetric
mar_metrics.mar_medium = self._summarize(results, False, area_range="medium", max_dets=last_max_det_thr)
mar_metrics.mar_large = self._summarize(results, False, area_range="large", max_dets=last_max_det_thr)

return results, map_metrics, mar_metrics
return map_metrics, mar_metrics

@staticmethod
def __calculate_recall_precision_scores(
Expand Down Expand Up @@ -664,19 +663,17 @@ def __calculate_recall_precision_scores(
recall[idx, idx_cls, idx_bbox_area, idx_max_det_thrs] = rc[-1] if nd else 0

# Remove zigzags for AUC
for i in range(nd - 1, 0, -1):
if pr[i] > pr[i - 1]:
pr[i - 1] = pr[i]
diff_zero = torch.zeros((1,))
diff = torch.ones((1,))
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
while not torch.all(diff == 0):
diff = torch.clamp(torch.cat((pr[1:] - pr[:-1], diff_zero), 0), min=0)
pr += diff

inds = torch.searchsorted(rc, rec_thresholds, right=False)
# TODO: optimize
try:
for ri, pi in enumerate(inds): # range(min(len(inds), len(pr))):
# pi = inds[ri]
prec[ri] = pr[pi]
score[ri] = det_scores_sorted[pi]
except Exception:
pass
num_inds = inds.argmax() if inds.max() >= nd else nb_rec_thrs
inds = inds[:num_inds]
prec[:num_inds] = pr[inds]
score[:num_inds] = det_scores_sorted[inds]
precision[idx, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = prec
scores[idx, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = score

Expand Down Expand Up @@ -709,18 +706,29 @@ def compute(self) -> dict:
- map_per_class: ``torch.Tensor`` (-1 if class metrics are disabled)
- mar_100_per_class: ``torch.Tensor`` (-1 if class metrics are disabled)
"""
overall, map, mar = self._calculate(self._get_classes())

map_per_class_values: Tensor = Tensor([-1])
mar_max_dets_per_class_values: Tensor = Tensor([-1])
# move everything to CPU, as we are faster here
self.detection_boxes = [box.cpu() for box in self.detection_boxes]
self.detection_labels = [label.cpu() for label in self.detection_labels]
self.detection_scores = [score.cpu() for score in self.detection_scores]
self.groundtruth_boxes = [box.cpu() for box in self.groundtruth_boxes]
self.groundtruth_labels = [label.cpu() for label in self.groundtruth_labels]

classes = self._get_classes()
precisions, recalls = self._calculate(classes)
map, mar = self._summarize_results(precisions, recalls)

# if class mode is enabled, evaluate metrics per class
map_per_class_values: Tensor = Tensor([-1])
mar_max_dets_per_class_values: Tensor = Tensor([-1])
if self.class_metrics:
map_per_class_list = []
mar_max_dets_per_class_list = []

for class_id in self._get_classes():
_, cls_map, cls_mar = self._calculate([class_id])
for class_idx, _ in enumerate(classes):
cls_precisions = precisions[:, :, class_idx].unsqueeze(dim=2)
cls_recalls = recalls[:, class_idx].unsqueeze(dim=1)
cls_map, cls_mar = self._summarize_results(cls_precisions, cls_recalls)
map_per_class_list.append(cls_map.map)
mar_max_dets_per_class_list.append(cls_mar[f"mar_{self.max_detection_thresholds[-1]}"])

Expand Down