Skip to content

Commit

Permalink
Add Average Recall to OD metrics (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
ntlind committed Mar 1, 2024
1 parent 9744a32 commit c08a923
Show file tree
Hide file tree
Showing 15 changed files with 639 additions and 107 deletions.
36 changes: 36 additions & 0 deletions api/tests/functional-tests/backend/metrics/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,42 @@ def test__compute_detection_metrics(
},
# mAP METRICS AVERAGED OVER IOUS
{"ious": iou_thresholds, "value": 0.637},
# AR METRICS
{
"ious": iou_thresholds,
"value": 0.45,
"label": {"key": "class", "value": "2"},
},
{
"ious": iou_thresholds,
"value": -1,
"label": {"key": "class", "value": "3"},
},
{
"ious": iou_thresholds,
"value": 0.58,
"label": {"key": "class", "value": "49"},
},
{
"ious": iou_thresholds,
"value": 0.78,
"label": {"key": "class", "value": "0"},
},
{
"ious": iou_thresholds,
"value": 0.8,
"label": {"key": "class", "value": "1"},
},
{
"ious": iou_thresholds,
"value": 0.65,
"label": {"key": "class", "value": "4"},
},
# mAR METRICS
{
"ious": iou_thresholds,
"value": 0.652,
},
]

assert len(metrics) == len(expected)
Expand Down
17 changes: 16 additions & 1 deletion api/tests/functional-tests/crud/test_create_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,9 @@ def method_to_test(

assert set([m.type for m in metrics]) == {
"AP",
"AR",
"APAveragedOverIOUs",
"mAR",
"mAP",
"mAPAveragedOverIOUs",
}
Expand All @@ -1110,7 +1112,16 @@ def method_to_test(
) == {0.2}

# should be five labels (since thats how many are in groundtruth set)
assert len(set(m.label_id for m in metrics if m.label_id is not None)) == 5
assert (
len(
set(
m.label_id
for m in metrics
if m.label_id is not None and m.type != "AR"
)
)
== 5
)

# test getting metrics from evaluation settings id
pydantic_metrics = crud.get_evaluations(
Expand Down Expand Up @@ -1145,7 +1156,9 @@ def method_to_test(
for m in metrics_pydantic:
assert m.type in {
"AP",
"AR",
"APAveragedOverIOUs",
"mAR",
"mAP",
"mAPAveragedOverIOUs",
}
Expand All @@ -1166,7 +1179,9 @@ def method_to_test(
for m in metrics_pydantic:
assert m.type in {
"AP",
"AR",
"APAveragedOverIOUs",
"mAR",
"mAP",
"mAPAveragedOverIOUs",
}
Expand Down
81 changes: 59 additions & 22 deletions api/tests/unit-tests/backend/metrics/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from valor_api import schemas
from valor_api.backend.metrics.detection import (
RankedPair,
_ap,
_calculate_101_pt_interp,
_calculate_ap_and_ar,
_compute_mean_detection_metrics_from_aps,
)

Expand All @@ -23,7 +23,7 @@ def test__compute_mean_detection_metrics_from_aps():
assert _compute_mean_detection_metrics_from_aps([]) == list()


def test__ap():
def test__calculate_ap_and_ar():

pairs = {
"0": [
Expand Down Expand Up @@ -57,77 +57,114 @@ def test__ap():
iou_thresholds = [0.5, 0.75, 0.9]

# Calculated by hand
reference_metrics = [
reference_ap_metrics = [
schemas.APMetric(
iou=0.5, value=1.0, label=schemas.Label(key="name", value="car")
iou=0.5,
value=1.0,
label=schemas.Label(key="name", value="car", score=None),
),
schemas.APMetric(
iou=0.5, value=0.0, label=schemas.Label(key="name", value="dog")
iou=0.75,
value=0.442244224422442,
label=schemas.Label(key="name", value="car", score=None),
),
schemas.APMetric(
iou=0.5,
value=0.25,
label=schemas.Label(key="name", value="person"),
iou=0.9,
value=0.11221122112211224,
label=schemas.Label(key="name", value="car", score=None),
),
schemas.APMetric(
iou=0.75, value=0.44, label=schemas.Label(key="name", value="car")
iou=0.5,
value=0.0,
label=schemas.Label(key="name", value="dog", score=None),
),
schemas.APMetric(
iou=0.75, value=0.0, label=schemas.Label(key="name", value="dog")
iou=0.75,
value=0.0,
label=schemas.Label(key="name", value="dog", score=None),
),
schemas.APMetric(
iou=0.75,
value=0.25,
label=schemas.Label(key="name", value="person"),
iou=0.9,
value=0.0,
label=schemas.Label(key="name", value="dog", score=None),
),
schemas.APMetric(
iou=0.9, value=0.11, label=schemas.Label(key="name", value="car")
iou=0.5,
value=0.25742574257425743,
label=schemas.Label(key="name", value="person", score=None),
),
schemas.APMetric(
iou=0.9, value=0.0, label=schemas.Label(key="name", value="dog")
iou=0.75,
value=0.25742574257425743,
label=schemas.Label(key="name", value="person", score=None),
),
schemas.APMetric(
iou=0.9,
value=0.25742574257425743,
label=schemas.Label(key="name", value="person", score=None),
),
]

reference_ar_metrics = [
schemas.ARMetric(
ious=[0.5, 0.75, 0.9],
value=0.6666666666666666, # average of [{'iou_threshold':.5, 'recall': 1}, {'iou_threshold':.75, 'recall':.66}, {'iou_threshold':.9, 'recall':.33}]
label=schemas.Label(key="name", value="car", score=None),
),
schemas.ARMetric(
ious=[0.5, 0.75, 0.9],
value=0.0,
label=schemas.Label(key="name", value="dog", score=None),
),
schemas.ARMetric(
ious=[0.5, 0.75, 0.9],
value=0.25,
label=schemas.Label(key="name", value="person"),
label=schemas.Label(key="name", value="person", score=None),
),
]

grouper_ids_associated_with_gts = set(["0", "1", "2"])

detection_metrics = _ap(
ap_metrics, ar_metrics = _calculate_ap_and_ar(
sorted_ranked_pairs=pairs,
number_of_groundtruths_per_grouper=number_of_groundtruths_per_grouper,
grouper_mappings=grouper_mappings,
iou_thresholds=iou_thresholds,
grouper_ids_associated_with_gts=grouper_ids_associated_with_gts,
recall_score_threshold=0.0,
)

assert len(reference_metrics) == len(detection_metrics)
for pd, gt in zip(detection_metrics, reference_metrics):
assert len(ap_metrics) == len(reference_ap_metrics)
assert len(ar_metrics) == len(reference_ar_metrics)
for pd, gt in zip(ap_metrics, reference_ap_metrics):
assert pd.iou == gt.iou
assert truncate_float(pd.value) == truncate_float(gt.value)
assert pd.label == gt.label
for pd, gt in zip(ar_metrics, reference_ar_metrics):
assert pd.ious == gt.ious
assert truncate_float(pd.value) == truncate_float(gt.value)
assert pd.label == gt.label

# Test iou threshold outside 0 < t <= 1
for illegal_thresh in [-1.1, -0.1, 0, 1.1]:
with pytest.raises(ValueError):
_ap(
_calculate_ap_and_ar(
sorted_ranked_pairs=pairs,
number_of_groundtruths_per_grouper=number_of_groundtruths_per_grouper,
grouper_mappings=grouper_mappings,
iou_thresholds=iou_thresholds + [0],
grouper_ids_associated_with_gts=grouper_ids_associated_with_gts,
recall_score_threshold=0.0,
)

# Test score threshold outside 0 <= t <= 1
for illegal_thresh in [-1.1, -0.1, 1.1]:
with pytest.raises(ValueError):
_ap(
_calculate_ap_and_ar(
sorted_ranked_pairs=pairs,
number_of_groundtruths_per_grouper=number_of_groundtruths_per_grouper,
grouper_mappings=grouper_mappings,
iou_thresholds=iou_thresholds,
grouper_ids_associated_with_gts=grouper_ids_associated_with_gts,
score_threshold=illegal_thresh,
recall_score_threshold=illegal_thresh,
)
2 changes: 2 additions & 0 deletions api/valor_api/backend/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def _compute_confusion_matrix_at_grouper_key(
The filter to be used to query predictions.
groundtruth_filter : schemas.Filter
The filter to be used to query groundtruths.
grouper_key: str
The key of the grouper used to calculate the confusion matrix.
grouper_mappings: dict[str, dict[str | int, any]]
A dictionary of mappings that connect groupers to their related labels.
Expand Down
Loading

0 comments on commit c08a923

Please sign in to comment.