Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7143321
Add test for case when there are preds but no GT for a class
dhananjaisharma10 Nov 27, 2022
1621983
Make Precision 0 when no GT but preds are made; -1 if no preds
dhananjaisharma10 Nov 27, 2022
8a12f5f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2022
9929553
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Dec 19, 2022
96cee42
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Dec 23, 2022
ad753bd
Merge branch 'master' into issue_1184/fix_mean_ap
justusschock Jan 9, 2023
105ebd8
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Feb 6, 2023
d6f538e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2023
5c1361b
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Feb 7, 2023
ebbd78e
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Feb 20, 2023
062f0e8
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Feb 21, 2023
79fd401
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Feb 22, 2023
79dbe19
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Feb 27, 2023
6b460c4
Merge branch 'master' into issue_1184/fix_mean_ap
dhananjaisharma10 Mar 2, 2023
401b1de
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Mar 4, 2023
25fb35f
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Mar 31, 2023
b9cfa71
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Apr 17, 2023
cbc43b3
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Apr 17, 2023
48f19f8
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Apr 17, 2023
42b3cab
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Apr 28, 2023
99724c3
Merge branch 'master' into issue_1184/fix_mean_ap
Borda May 2, 2023
39ae08c
Merge branch 'master' into issue_1184/fix_mean_ap
Borda May 9, 2023
37be226
Merge branch 'master' into issue_1184/fix_mean_ap
Borda May 15, 2023
8abd2d4
Merge branch 'master' into issue_1184/fix_mean_ap
Borda May 17, 2023
0b6766e
Merge branch 'master' into issue_1184/fix_mean_ap
Borda Jun 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,21 +794,25 @@ def __calculate_recall_precision_scores(
return recall, precision, scores

det_scores = torch.cat([e["dtScores"][:max_det] for e in img_eval_cls_bbox])

# different sorting method generates slightly different results.
# mergesort is used to be consistent as Matlab implementation.
# Sort in PyTorch does not support bool types on CUDA (yet, 1.11.0)
dtype = torch.uint8 if det_scores.is_cuda and det_scores.dtype is torch.bool else det_scores.dtype
# Explicitly cast to uint8 to avoid error for bool inputs on CUDA to argsort
inds = torch.argsort(det_scores.to(dtype), descending=True)
det_scores_sorted = det_scores[inds]

det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds]
det_ignore = torch.cat([e["dtIgnore"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds]
gt_ignore = torch.cat([e["gtIgnore"] for e in img_eval_cls_bbox])

npig = torch.count_nonzero(gt_ignore == False) # noqa: E712
if npig == 0:
# If there are any predictions, make Precision 0; otherwise, -1.
npreds = torch.count_nonzero(det_ignore == False) # noqa: E712
if npreds != 0:
precision[:, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = 0.0
return recall, precision, scores

det_scores_sorted = det_scores[inds]
det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds]
tps = torch.logical_and(det_matches, torch.logical_not(det_ignore))
fps = torch.logical_and(torch.logical_not(det_matches), torch.logical_not(det_ignore))

Expand Down
46 changes: 46 additions & 0 deletions tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,52 @@ def test_missing_gt():
assert result["map"] < 1, "MAP cannot be 1, as there is an image with no ground truth, but some predictions."


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_class_metrics_with_missing_gt():
"""Checks MAP for each class when there are 4 detections, each for a different class.

But there are targets for only 2 classes. Hence, MAP should be lower than 1. MAP for classes with targets should be
1 and 0 for the others.
"""
# Example source: Issue https://github.com/Lightning-AI/metrics/issues/1184
preds = [
{
"boxes": torch.Tensor(
[
[0, 0, 20, 20],
[30, 30, 50, 50],
[70, 70, 90, 90], # FP
[100, 100, 120, 120], # FP
]
),
"scores": torch.Tensor([0.6, 0.6, 0.6, 0.6]),
"labels": torch.IntTensor([0, 1, 2, 3]),
}
]

targets = [
{
"boxes": torch.Tensor([[0, 0, 20, 20], [30, 30, 50, 50]]),
"labels": torch.IntTensor([0, 1]),
}
]

metric = MeanAveragePrecision(class_metrics=True)
metric.update(preds, targets)
result = metric.compute()

assert result["map"] < 1, "MAP cannot be 1, as for class 2 and 3, there are some predictions, but not targets."

result_map_per_class = result.get("map_per_class", None)
assert result_map_per_class is not None, "map_per_class must be present in results."
assert isinstance(result_map_per_class, Tensor), "map_per_class must be a tensor"
assert len(result_map_per_class) == 4, "map_per_class must be of length 4, same as the number of classes."
assert result_map_per_class[0].item() == 1.0, "map for class 0 must be 1."
assert result_map_per_class[1].item() == 1.0, "map for class 1 must be 1."
assert result_map_per_class[2].item() == 0.0, "map for class 2 must be 0."
assert result_map_per_class[3].item() == 0.0, "map for class 3 must be 0."


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_segm_iou_empty_gt_mask():
"""Test empty ground truths."""
Expand Down