From c90e1af193ef9da2edb40e16d2c100036ff1f357 Mon Sep 17 00:00:00 2001 From: Nick L Date: Tue, 11 Jun 2024 10:09:46 -0600 Subject: [PATCH] Simplify `PrecisionRecallCurve` and add `DetailedPrecisionRecallCurve` for advanced debugging (#584) --- .../backend/core/test_evaluation.py | 2 +- .../backend/metrics/test_classification.py | 207 +- .../backend/metrics/test_detection.py | 1666 ++++++++++++++--- .../crud/test_create_delete.py | 8 +- .../crud/test_evaluation_crud.py | 6 +- .../backend/metrics/test_detection.py | 12 + .../unit-tests/schemas/test_evaluation.py | 28 +- api/tests/unit-tests/test_main.py | 10 +- .../backend/metrics/classification.py | 238 ++- api/valor_api/backend/metrics/detection.py | 1315 ++++++++++--- api/valor_api/backend/metrics/metric_utils.py | 7 + api/valor_api/backend/metrics/segmentation.py | 4 +- api/valor_api/schemas/__init__.py | 2 + api/valor_api/schemas/evaluation.py | 21 +- api/valor_api/schemas/metrics.py | 90 +- client/valor/coretypes.py | 8 +- client/valor/schemas/evaluation.py | 25 +- docs/metrics.md | 114 +- docs/technical_concepts.md | 6 +- .../client/datatype/test_data_generation.py | 1 + .../client/metrics/test_classification.py | 138 +- .../client/metrics/test_detection.py | 589 +++++- integration_tests/conftest.py | 12 + ts-client/src/ValorClient.ts | 19 +- ts-client/tests/ValorClient.test.ts | 2 - 25 files changed, 3750 insertions(+), 780 deletions(-) diff --git a/api/tests/functional-tests/backend/core/test_evaluation.py b/api/tests/functional-tests/backend/core/test_evaluation.py index 6b2d2f970..930c82575 100644 --- a/api/tests/functional-tests/backend/core/test_evaluation.py +++ b/api/tests/functional-tests/backend/core/test_evaluation.py @@ -282,7 +282,7 @@ def test_create_evaluation( assert ( rows[0].parameters == schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ).model_dump() ) diff --git a/api/tests/functional-tests/backend/metrics/test_classification.py b/api/tests/functional-tests/backend/metrics/test_classification.py index 4b051efe7..4b53f88ce 100644 --- a/api/tests/functional-tests/backend/metrics/test_classification.py +++ b/api/tests/functional-tests/backend/metrics/test_classification.py @@ -683,9 +683,10 @@ def test_compute_classification( confusion, metrics = _compute_clf_metrics( db, - model_filter, - datum_filter, + prediction_filter=model_filter, + groundtruth_filter=datum_filter, label_map=None, + pr_curve_max_examples=0, metrics_to_return=[ "Precision", "Recall", @@ -928,8 +929,14 @@ def test__compute_curves( grouper_key="animal", grouper_mappings=grouper_mappings, unique_datums=unique_datums, + pr_curve_max_examples=1, + metrics_to_return=[ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ], ) + # check PrecisionRecallCurve pr_expected_answers = { # bird ("bird", 0.05, "tp"): 3, @@ -973,6 +980,196 @@ def test__compute_curves( threshold, metric, ), expected_length in pr_expected_answers.items(): - list_of_datums = curves[value][threshold][metric] - assert isinstance(list_of_datums, list) - assert len(list_of_datums) == expected_length + classification = curves[0].value[value][threshold][metric] + assert classification == expected_length + + # check DetailedPrecisionRecallCurve + detailed_pr_expected_answers = { + # bird + ("bird", 0.05, "tp"): {"all": 3, "total": 3}, + ("bird", 0.05, "fp"): { + "hallucinations": 0, + "misclassifications": 1, + "total": 1, + }, + ("bird", 0.05, "tn"): {"all": 2, "total": 2}, + ("bird", 0.05, "fn"): { + "missed_detections": 0, + "misclassifications": 0, + "total": 0, + }, + # dog + ("dog", 0.05, "tp"): {"all": 2, "total": 2}, + ("dog", 0.05, "fp"): { + "hallucinations": 0, + "misclassifications": 3, + "total": 3, + }, + ("dog", 0.05, "tn"): {"all": 1, "total": 1}, + ("dog", 0.8, "fn"): { + "missed_detections": 1, + "misclassifications": 1, + "total": 2, + }, + # cat + ("cat", 0.05, "tp"): {"all": 1, "total": 1}, + ("cat", 0.05, "fp"): { + "hallucinations": 0, + "misclassifications": 5, + "total": 5, + }, + ("cat", 0.05, "tn"): {"all": 0, "total": 0}, + ("cat", 0.8, "fn"): { + "missed_detections": 0, + "misclassifications": 0, + "total": 0, + }, + } + + for ( + value, + threshold, + metric, + ), expected_output in detailed_pr_expected_answers.items(): + model_output = curves[1].value[value][threshold][metric] + assert isinstance(model_output, dict) + assert model_output["total"] == expected_output["total"] + assert all( + [ + model_output["observations"][key]["count"] # type: ignore - we know this element is a dict + == expected_output[key] + for key in [ + key + for key in expected_output.keys() + if key not in ["total"] + ] + ] + ) + + # spot check number of examples + assert ( + len( + curves[1].value["bird"][0.05]["tp"]["observations"]["all"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 1 + ) + assert ( + len( + curves[1].value["bird"][0.05]["tn"]["observations"]["all"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 1 + ) + + # repeat the above, but with a higher pr_max_curves_example + curves = _compute_curves( + db=db, + predictions=predictions, + groundtruths=groundtruths, + grouper_key="animal", + grouper_mappings=grouper_mappings, + unique_datums=unique_datums, + pr_curve_max_examples=3, + metrics_to_return=[ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ], + ) + + # these outputs shouldn't have changed + for ( + value, + threshold, + metric, + ), expected_output in detailed_pr_expected_answers.items(): + model_output = curves[1].value[value][threshold][metric] + assert isinstance(model_output, dict) + assert model_output["total"] == expected_output["total"] + assert all( + [ + model_output["observations"][key]["count"] # type: ignore - we know this element is a dict + == expected_output[key] + for key in [ + key + for key in expected_output.keys() + if key not in ["total"] + ] + ] + ) + + assert ( + len( + curves[1].value["bird"][0.05]["tp"]["observations"]["all"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 3 + ) + assert ( + len( + ( + curves[1].value["bird"][0.05]["tn"]["observations"]["all"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + ) + == 2 # only two examples exist + ) + + # test behavior if pr_curve_max_examples == 0 + curves = _compute_curves( + db=db, + predictions=predictions, + groundtruths=groundtruths, + grouper_key="animal", + grouper_mappings=grouper_mappings, + unique_datums=unique_datums, + pr_curve_max_examples=0, + metrics_to_return=[ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ], + ) + + # these outputs shouldn't have changed + for ( + value, + threshold, + metric, + ), expected_output in detailed_pr_expected_answers.items(): + model_output = curves[1].value[value][threshold][metric] + assert isinstance(model_output, dict) + assert model_output["total"] == expected_output["total"] + assert all( + [ + model_output["observations"][key]["count"] # type: ignore - we know this element is a dict + == expected_output[key] + for key in [ + key + for key in expected_output.keys() + if key not in ["total"] + ] + ] + ) + + assert ( + len( + curves[1].value["bird"][0.05]["tp"]["observations"]["all"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 0 + ) + assert ( + len( + ( + curves[1].value["bird"][0.05]["tn"]["observations"]["all"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + ) + == 0 + ) diff --git a/api/tests/functional-tests/backend/metrics/test_detection.py b/api/tests/functional-tests/backend/metrics/test_detection.py index 3ba639051..a31725aa5 100644 --- a/api/tests/functional-tests/backend/metrics/test_detection.py +++ b/api/tests/functional-tests/backend/metrics/test_detection.py @@ -5,8 +5,9 @@ from valor_api import crud, enums, schemas from valor_api.backend.metrics.detection import ( RankedPair, - _compute_curves, + _compute_detailed_curves, _compute_detection_metrics, + _compute_detection_metrics_with_detailed_precision_recall_curve, compute_detection_metrics, ) from valor_api.backend.models import ( @@ -29,373 +30,1486 @@ def _round_dict(d: dict, prec: int = 3) -> None: _round_dict(v, prec) -def test__compute_curves(db: Session): +def test__compute_detailed_curves(db: Session): # these inputs are taken directly from test__compute_detection_metrics (below) sorted_ranked_pairs = { - -1519138795911397979: [ + 3262893736873277849: [ RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid526", + pd_datum_uid="2", + gt_datum_uid="2", gt_geojson='{"type":"Polygon","coordinates":[[[277.11,103.84],[292.44,103.84],[292.44,150.72],[277.11,150.72],[277.11,103.84]]]}', - gt_id=1340, - pd_id=2389, + gt_id=404, + pd_id=397, score=0.953, iou=0.8775260257195348, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid526", + pd_datum_uid="2", + gt_datum_uid="2", + gt_geojson='{"type":"Polygon","coordinates":[[[277.11,103.84],[292.44,103.84],[292.44,150.72],[277.11,150.72],[277.11,103.84]]]}', + gt_id=404, + pd_id=397, + score=0.953, + iou=0.8775260257195348, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="2", + gt_datum_uid="2", + gt_geojson='{"type":"Polygon","coordinates":[[[462.08,105.09],[493.74,105.09],[493.74,146.99],[462.08,146.99],[462.08,105.09]]]}', + gt_id=403, + pd_id=396, + score=0.805, + iou=0.8811645870469409, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="2", + gt_datum_uid="2", gt_geojson='{"type":"Polygon","coordinates":[[[462.08,105.09],[493.74,105.09],[493.74,146.99],[462.08,146.99],[462.08,105.09]]]}', - gt_id=1339, - pd_id=2388, + gt_id=403, + pd_id=396, score=0.805, iou=0.8811645870469409, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="2", + gt_datum_uid="2", + gt_geojson='{"type":"Polygon","coordinates":[[[326.94,97.05],[340.49,97.05],[340.49,122.98],[326.94,122.98],[326.94,97.05]]]}', + gt_id=401, + pd_id=394, + score=0.611, + iou=0.742765273311898, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid526", + pd_datum_uid="2", + gt_datum_uid="2", gt_geojson='{"type":"Polygon","coordinates":[[[326.94,97.05],[340.49,97.05],[340.49,122.98],[326.94,122.98],[326.94,97.05]]]}', - gt_id=1337, - pd_id=2386, + gt_id=401, + pd_id=394, score=0.611, iou=0.742765273311898, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="2", + gt_datum_uid="2", + gt_geojson='{"type":"Polygon","coordinates":[[[295.55,93.96],[313.97,93.96],[313.97,152.79],[295.55,152.79],[295.55,93.96]]]}', + gt_id=400, + pd_id=393, + score=0.407, + iou=0.8970133882595271, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid526", + pd_datum_uid="2", + gt_datum_uid="2", gt_geojson='{"type":"Polygon","coordinates":[[[295.55,93.96],[313.97,93.96],[313.97,152.79],[295.55,152.79],[295.55,93.96]]]}', - gt_id=1336, - pd_id=2385, + gt_id=400, + pd_id=393, score=0.407, iou=0.8970133882595271, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid526", + pd_datum_uid="2", + gt_datum_uid="2", gt_geojson='{"type":"Polygon","coordinates":[[[356.62,95.47],[372.33,95.47],[372.33,147.55],[356.62,147.55],[356.62,95.47]]]}', - gt_id=1338, - pd_id=2387, + gt_id=402, + pd_id=395, score=0.335, iou=1.0000000000000002, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="2", + gt_datum_uid="2", + gt_geojson='{"type":"Polygon","coordinates":[[[356.62,95.47],[372.33,95.47],[372.33,147.55],[356.62,147.55],[356.62,95.47]]]}', + gt_id=402, + pd_id=395, + score=0.335, + iou=1.0000000000000002, + is_match=True, ), ], - 564624103770992353: [ + 8850376905924579852: [ + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="3", + gt_datum_uid="3", + gt_geojson='{"type":"Polygon","coordinates":[[[75.29,23.01],[91.85,23.01],[91.85,50.85],[75.29,50.85],[75.29,23.01]]]}', + gt_id=409, + pd_id=402, + score=0.883, + iou=0.9999999999999992, + is_match=True, + ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid527", + pd_datum_uid="3", + gt_datum_uid="3", gt_geojson='{"type":"Polygon","coordinates":[[[75.29,23.01],[91.85,23.01],[91.85,50.85],[75.29,50.85],[75.29,23.01]]]}', - gt_id=1345, - pd_id=2394, + gt_id=409, + pd_id=402, score=0.883, iou=0.9999999999999992, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="3", + gt_datum_uid="3", + gt_geojson='{"type":"Polygon","coordinates":[[[81.28,47.04],[98.66,47.04],[98.66,78.5],[81.28,78.5],[81.28,47.04]]]}', + gt_id=407, + pd_id=400, + score=0.782, + iou=0.8911860718171924, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid527", + pd_datum_uid="3", + gt_datum_uid="3", gt_geojson='{"type":"Polygon","coordinates":[[[81.28,47.04],[98.66,47.04],[98.66,78.5],[81.28,78.5],[81.28,47.04]]]}', - gt_id=1343, - pd_id=2392, + gt_id=407, + pd_id=400, score=0.782, iou=0.8911860718171924, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="3", + gt_datum_uid="3", + gt_geojson='{"type":"Polygon","coordinates":[[[62.34,55.23],[78.14,55.23],[78.14,79.57],[62.34,79.57],[62.34,55.23]]]}', + gt_id=412, + pd_id=404, + score=0.561, + iou=0.8809523809523806, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid527", + pd_datum_uid="3", + gt_datum_uid="3", gt_geojson='{"type":"Polygon","coordinates":[[[62.34,55.23],[78.14,55.23],[78.14,79.57],[62.34,79.57],[62.34,55.23]]]}', - gt_id=1348, - pd_id=2396, + gt_id=412, + pd_id=404, score=0.561, iou=0.8809523809523806, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid527", + pd_datum_uid="3", + gt_datum_uid="3", gt_geojson='{"type":"Polygon","coordinates":[[[72.92,45.96],[91.23,45.96],[91.23,80.57],[72.92,80.57],[72.92,45.96]]]}', - gt_id=1341, - pd_id=2390, + gt_id=405, + pd_id=398, score=0.532, iou=0.9999999999999998, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid527", + pd_datum_uid="3", + gt_datum_uid="3", + gt_geojson='{"type":"Polygon","coordinates":[[[72.92,45.96],[91.23,45.96],[91.23,80.57],[72.92,80.57],[72.92,45.96]]]}', + gt_id=405, + pd_id=398, + score=0.532, + iou=0.9999999999999998, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="3", + gt_datum_uid="3", + gt_geojson='{"type":"Polygon","coordinates":[[[58.18,44.8],[66.42,44.8],[66.42,56.25],[58.18,56.25],[58.18,44.8]]]}', + gt_id=414, + pd_id=406, + score=0.349, + iou=0.6093750000000003, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="3", + gt_datum_uid="3", gt_geojson='{"type":"Polygon","coordinates":[[[58.18,44.8],[66.42,44.8],[66.42,56.25],[58.18,56.25],[58.18,44.8]]]}', - gt_id=1350, - pd_id=2398, + gt_id=414, + pd_id=406, score=0.349, iou=0.6093750000000003, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="3", + gt_datum_uid="3", + gt_geojson='{"type":"Polygon","coordinates":[[[73.14,1.1],[98.96,1.1],[98.96,28.33],[73.14,28.33],[73.14,1.1]]]}', + gt_id=411, + pd_id=403, + score=0.271, + iou=0.8562185478073326, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid527", + pd_datum_uid="3", + gt_datum_uid="3", gt_geojson='{"type":"Polygon","coordinates":[[[73.14,1.1],[98.96,1.1],[98.96,28.33],[73.14,28.33],[73.14,1.1]]]}', - gt_id=1347, - pd_id=2395, + gt_id=411, + pd_id=403, score=0.271, iou=0.8562185478073326, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="3", + gt_datum_uid="3", + gt_geojson='{"type":"Polygon","coordinates":[[[44.17,45.78],[63.99,45.78],[63.99,78.48],[44.17,78.48],[44.17,45.78]]]}', + gt_id=413, + pd_id=399, + score=0.204, + iou=0.8089209038203885, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid527", + pd_datum_uid="3", + gt_datum_uid="3", gt_geojson='{"type":"Polygon","coordinates":[[[44.17,45.78],[63.99,45.78],[63.99,78.48],[44.17,78.48],[44.17,45.78]]]}', - gt_id=1349, - pd_id=2391, + gt_id=413, + pd_id=399, score=0.204, iou=0.8089209038203885, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid527", - gt_geojson='{"type":"Polygon","coordinates":[[[50.17,45.34],[71.28,45.34],[71.28,79.83],[50.17,79.83],[50.17,45.34]]]}', - gt_id=1342, - pd_id=2397, + pd_datum_uid="3", + gt_datum_uid="3", + gt_geojson='{"type":"Polygon","coordinates":[[[44.17,45.78],[63.99,45.78],[63.99,78.48],[44.17,78.48],[44.17,45.78]]]}', + gt_id=413, + pd_id=405, + score=0.204, + iou=0.7370727432077125, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="3", + gt_datum_uid="3", + gt_geojson='{"type":"Polygon","coordinates":[[[44.17,45.78],[63.99,45.78],[63.99,78.48],[44.17,78.48],[44.17,45.78]]]}', + gt_id=413, + pd_id=405, score=0.204, - iou=0.3460676561905953, + iou=0.7370727432077125, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="3", + gt_datum_uid="3", + gt_geojson='{"type":"Polygon","coordinates":[[[63.96,46.17],[84.35,46.17],[84.35,80.48],[63.96,80.48],[63.96,46.17]]]}', + gt_id=408, + pd_id=401, + score=0.202, + iou=0.6719967199671995, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid527", + pd_datum_uid="3", + gt_datum_uid="3", gt_geojson='{"type":"Polygon","coordinates":[[[63.96,46.17],[84.35,46.17],[84.35,80.48],[63.96,80.48],[63.96,46.17]]]}', - gt_id=1344, - pd_id=2393, + gt_id=408, + pd_id=401, score=0.202, iou=0.6719967199671995, + is_match=True, ), ], - 7641129594263252302: [ + 7683992730431173493: [ RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid525", + pd_datum_uid="1", + gt_datum_uid="1", gt_geojson='{"type":"Polygon","coordinates":[[[1.66,3.32],[270.26,3.32],[270.26,275.23],[1.66,275.23],[1.66,3.32]]]}', - gt_id=1333, - pd_id=2382, + gt_id=397, + pd_id=390, score=0.726, iou=0.9213161659513592, - ) + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="1", + gt_datum_uid="1", + gt_geojson='{"type":"Polygon","coordinates":[[[1.66,3.32],[270.26,3.32],[270.26,275.23],[1.66,275.23],[1.66,3.32]]]}', + gt_id=397, + pd_id=390, + score=0.726, + iou=0.9213161659513592, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="1", + gt_datum_uid="1", + gt_geojson='{"type":"Polygon","coordinates":[[[13,22.75],[548.98,22.75],[548.98,632.42],[13,632.42],[13,22.75]]]}', + gt_id=396, + pd_id=389, + score=0.318, + iou=0.8840217391304347, + is_match=False, + ), ], - 7594118964129415143: [ + 1591437737079826217: [ + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="2", + gt_datum_uid="2", + gt_geojson='{"type":"Polygon","coordinates":[[[61.87,276.25],[358.29,276.25],[358.29,379.43],[61.87,379.43],[61.87,276.25]]]}', + gt_id=398, + pd_id=391, + score=0.546, + iou=0.8387196824018363, + is_match=True, + ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid526", + pd_datum_uid="2", + gt_datum_uid="2", gt_geojson='{"type":"Polygon","coordinates":[[[61.87,276.25],[358.29,276.25],[358.29,379.43],[61.87,379.43],[61.87,276.25]]]}', - gt_id=1334, - pd_id=2383, + gt_id=398, + pd_id=391, score=0.546, iou=0.8387196824018363, + is_match=True, + ), + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="0", + gt_datum_uid="0", + gt_geojson='{"type":"Polygon","coordinates":[[[214.15,41.29],[562.41,41.29],[562.41,285.07],[214.15,285.07],[214.15,41.29]]]}', + gt_id=395, + pd_id=388, + score=0.236, + iou=0.7756590016825575, + is_match=True, ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid524", + pd_datum_uid="0", + gt_datum_uid="0", gt_geojson='{"type":"Polygon","coordinates":[[[214.15,41.29],[562.41,41.29],[562.41,285.07],[214.15,285.07],[214.15,41.29]]]}', - gt_id=1331, - pd_id=2380, + gt_id=395, + pd_id=388, score=0.236, iou=0.7756590016825575, + is_match=True, ), ], - 8707070029533313719: [ + -487256420494681688: [ + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="2", + gt_datum_uid="2", + gt_geojson='{"type":"Polygon","coordinates":[[[2.75,3.66],[162.15,3.66],[162.15,316.06],[2.75,316.06],[2.75,3.66]]]}', + gt_id=399, + pd_id=392, + score=0.3, + iou=0.8596978106691334, + is_match=True, + ), RankedPair( dataset_name="test_dataset", - gt_datum_uid="uid526", + pd_datum_uid="2", + gt_datum_uid="2", gt_geojson='{"type":"Polygon","coordinates":[[[2.75,3.66],[162.15,3.66],[162.15,316.06],[2.75,316.06],[2.75,3.66]]]}', - gt_id=1335, - pd_id=2384, + gt_id=399, + pd_id=392, score=0.3, iou=0.8596978106691334, + is_match=True, + ), + ], + -6111942735542320034: [ + RankedPair( + dataset_name="test_dataset", + pd_datum_uid="1", + gt_datum_uid="1", + gt_geojson='{"type":"Polygon","coordinates":[[[13,22.75],[548.98,22.75],[548.98,632.42],[13,632.42],[13,22.75]]]}', + gt_id=396, + pd_id=389, + score=0.318, + iou=0.8840217391304347, + is_match=False, ) ], } grouper_mappings = { "label_id_to_grouper_id_mapping": { - 752: 7594118964129415143, - 754: -1519138795911397979, - 753: 7641129594263252302, - 755: 8707070029533313719, - 757: 1005277842145977801, - 756: 564624103770992353, + 512: 1591437737079826217, + 513: 7683992730431173493, + 519: -6111942735542320034, + 515: -487256420494681688, + 514: 3262893736873277849, + 517: 8850376905924579852, + }, + "label_id_to_grouper_key_mapping": { + 512: "class", + 513: "class", + 519: "class", + 515: "class", + 514: "class", + 517: "class", }, "grouper_id_to_label_ids_mapping": { - 7594118964129415143: [752], - -1519138795911397979: [754], - 7641129594263252302: [753], - 8707070029533313719: [755], - 1005277842145977801: [757], - 564624103770992353: [756], + 1591437737079826217: [512], + 7683992730431173493: [513], + -6111942735542320034: [519], + -487256420494681688: [515], + 3262893736873277849: [514], + 8850376905924579852: [517], }, "grouper_id_to_grouper_label_mapping": { - 7594118964129415143: schemas.Label( + 1591437737079826217: schemas.Label( key="class", value="4", score=None ), - -1519138795911397979: schemas.Label( - key="class", value="0", score=None - ), - 7641129594263252302: schemas.Label( + 7683992730431173493: schemas.Label( key="class", value="2", score=None ), - 8707070029533313719: schemas.Label( + -6111942735542320034: schemas.Label( + key="class", value="3", score=None + ), + -487256420494681688: schemas.Label( key="class", value="1", score=None ), - 1005277842145977801: schemas.Label( - key="class", value="3", score=None + 3262893736873277849: schemas.Label( + key="class", value="0", score=None ), - 564624103770992353: schemas.Label( + 8850376905924579852: schemas.Label( key="class", value="49", score=None ), }, } groundtruths_per_grouper = { - 7594118964129415143: [ + 1591437737079826217: [ ( "test_dataset", - "uid524", - 1331, + "0", + 395, '{"type":"Polygon","coordinates":[[[214.15,41.29],[562.41,41.29],[562.41,285.07],[214.15,285.07],[214.15,41.29]]]}', ), ( "test_dataset", - "uid526", - 1334, + "2", + 398, '{"type":"Polygon","coordinates":[[[61.87,276.25],[358.29,276.25],[358.29,379.43],[61.87,379.43],[61.87,276.25]]]}', ), ], - 7641129594263252302: [ + 7683992730431173493: [ ( "test_dataset", - "uid525", - 1332, + "1", + 396, '{"type":"Polygon","coordinates":[[[13,22.75],[548.98,22.75],[548.98,632.42],[13,632.42],[13,22.75]]]}', ), ( "test_dataset", - "uid525", - 1333, + "1", + 397, '{"type":"Polygon","coordinates":[[[1.66,3.32],[270.26,3.32],[270.26,275.23],[1.66,275.23],[1.66,3.32]]]}', ), ], - 8707070029533313719: [ + -487256420494681688: [ ( "test_dataset", - "uid526", - 1335, + "2", + 399, '{"type":"Polygon","coordinates":[[[2.75,3.66],[162.15,3.66],[162.15,316.06],[2.75,316.06],[2.75,3.66]]]}', ) ], - -1519138795911397979: [ + 3262893736873277849: [ ( "test_dataset", - "uid526", - 1336, + "2", + 400, '{"type":"Polygon","coordinates":[[[295.55,93.96],[313.97,93.96],[313.97,152.79],[295.55,152.79],[295.55,93.96]]]}', ), ( "test_dataset", - "uid526", - 1337, + "2", + 401, '{"type":"Polygon","coordinates":[[[326.94,97.05],[340.49,97.05],[340.49,122.98],[326.94,122.98],[326.94,97.05]]]}', ), ( "test_dataset", - "uid526", - 1338, + "2", + 402, '{"type":"Polygon","coordinates":[[[356.62,95.47],[372.33,95.47],[372.33,147.55],[356.62,147.55],[356.62,95.47]]]}', ), ( "test_dataset", - "uid526", - 1339, + "2", + 403, '{"type":"Polygon","coordinates":[[[462.08,105.09],[493.74,105.09],[493.74,146.99],[462.08,146.99],[462.08,105.09]]]}', ), ( "test_dataset", - "uid526", - 1340, + "2", + 404, '{"type":"Polygon","coordinates":[[[277.11,103.84],[292.44,103.84],[292.44,150.72],[277.11,150.72],[277.11,103.84]]]}', ), ], - 564624103770992353: [ + 8850376905924579852: [ ( "test_dataset", - "uid527", - 1341, + "3", + 405, '{"type":"Polygon","coordinates":[[[72.92,45.96],[91.23,45.96],[91.23,80.57],[72.92,80.57],[72.92,45.96]]]}', ), ( "test_dataset", - "uid527", - 1342, + "3", + 406, '{"type":"Polygon","coordinates":[[[50.17,45.34],[71.28,45.34],[71.28,79.83],[50.17,79.83],[50.17,45.34]]]}', ), ( "test_dataset", - "uid527", - 1343, + "3", + 407, '{"type":"Polygon","coordinates":[[[81.28,47.04],[98.66,47.04],[98.66,78.5],[81.28,78.5],[81.28,47.04]]]}', ), ( "test_dataset", - "uid527", - 1344, + "3", + 408, '{"type":"Polygon","coordinates":[[[63.96,46.17],[84.35,46.17],[84.35,80.48],[63.96,80.48],[63.96,46.17]]]}', ), ( "test_dataset", - "uid527", - 1345, + "3", + 409, '{"type":"Polygon","coordinates":[[[75.29,23.01],[91.85,23.01],[91.85,50.85],[75.29,50.85],[75.29,23.01]]]}', ), ( "test_dataset", - "uid527", - 1346, + "3", + 410, '{"type":"Polygon","coordinates":[[[56.39,21.65],[75.66,21.65],[75.66,45.54],[56.39,45.54],[56.39,21.65]]]}', ), ( "test_dataset", - "uid527", - 1347, + "3", + 411, '{"type":"Polygon","coordinates":[[[73.14,1.1],[98.96,1.1],[98.96,28.33],[73.14,28.33],[73.14,1.1]]]}', ), ( "test_dataset", - "uid527", - 1348, + "3", + 412, '{"type":"Polygon","coordinates":[[[62.34,55.23],[78.14,55.23],[78.14,79.57],[62.34,79.57],[62.34,55.23]]]}', ), ( "test_dataset", - "uid527", - 1349, + "3", + 413, '{"type":"Polygon","coordinates":[[[44.17,45.78],[63.99,45.78],[63.99,78.48],[44.17,78.48],[44.17,45.78]]]}', ), ( "test_dataset", - "uid527", - 1350, + "3", + 414, '{"type":"Polygon","coordinates":[[[58.18,44.8],[66.42,44.8],[66.42,56.25],[58.18,56.25],[58.18,44.8]]]}', ), ], } + predictions_per_grouper = { + 1591437737079826217: [ + ( + "test_dataset", + "0", + 388, + '{"type":"Polygon","coordinates":[[[258.15,41.29],[606.41,41.29],[606.41,285.07],[258.15,285.07],[258.15,41.29]]]}', + ), + ( + "test_dataset", + "2", + 391, + '{"type":"Polygon","coordinates":[[[87.87,276.25],[384.29,276.25],[384.29,379.43],[87.87,379.43],[87.87,276.25]]]}', + ), + ], + -6111942735542320034: [ + ( + "test_dataset", + "1", + 389, + '{"type":"Polygon","coordinates":[[[61,22.75],[565,22.75],[565,632.42],[61,632.42],[61,22.75]]]}', + ) + ], + 7683992730431173493: [ + ( + "test_dataset", + "1", + 390, + '{"type":"Polygon","coordinates":[[[12.66,3.32],[281.26,3.32],[281.26,275.23],[12.66,275.23],[12.66,3.32]]]}', + ) + ], + -487256420494681688: [ + ( + "test_dataset", + "2", + 392, + '{"type":"Polygon","coordinates":[[[0,3.66],[142.15,3.66],[142.15,316.06],[0,316.06],[0,3.66]]]}', + ) + ], + 3262893736873277849: [ + ( + "test_dataset", + "2", + 393, + '{"type":"Polygon","coordinates":[[[296.55,93.96],[314.97,93.96],[314.97,152.79],[296.55,152.79],[296.55,93.96]]]}', + ), + ( + "test_dataset", + "2", + 394, + '{"type":"Polygon","coordinates":[[[328.94,97.05],[342.49,97.05],[342.49,122.98],[328.94,122.98],[328.94,97.05]]]}', + ), + ( + "test_dataset", + "2", + 395, + '{"type":"Polygon","coordinates":[[[356.62,95.47],[372.33,95.47],[372.33,147.55],[356.62,147.55],[356.62,95.47]]]}', + ), + ( + "test_dataset", + "2", + 396, + '{"type":"Polygon","coordinates":[[[464.08,105.09],[495.74,105.09],[495.74,146.99],[464.08,146.99],[464.08,105.09]]]}', + ), + ( + "test_dataset", + "2", + 397, + '{"type":"Polygon","coordinates":[[[276.11,103.84],[291.44,103.84],[291.44,150.72],[276.11,150.72],[276.11,103.84]]]}', + ), + ], + 8850376905924579852: [ + ( + "test_dataset", + "3", + 398, + '{"type":"Polygon","coordinates":[[[72.92,45.96],[91.23,45.96],[91.23,80.57],[72.92,80.57],[72.92,45.96]]]}', + ), + ( + "test_dataset", + "3", + 399, + '{"type":"Polygon","coordinates":[[[45.17,45.34],[66.28,45.34],[66.28,79.83],[45.17,79.83],[45.17,45.34]]]}', + ), + ( + "test_dataset", + "3", + 400, + '{"type":"Polygon","coordinates":[[[82.28,47.04],[99.66,47.04],[99.66,78.5],[82.28,78.5],[82.28,47.04]]]}', + ), + ( + "test_dataset", + "3", + 401, + '{"type":"Polygon","coordinates":[[[59.96,46.17],[80.35,46.17],[80.35,80.48],[59.96,80.48],[59.96,46.17]]]}', + ), + ( + "test_dataset", + "3", + 402, + '{"type":"Polygon","coordinates":[[[75.29,23.01],[91.85,23.01],[91.85,50.85],[75.29,50.85],[75.29,23.01]]]}', + ), + ( + "test_dataset", + "3", + 403, + '{"type":"Polygon","coordinates":[[[71.14,1.1],[96.96,1.1],[96.96,28.33],[71.14,28.33],[71.14,1.1]]]}', + ), + ( + "test_dataset", + "3", + 404, + '{"type":"Polygon","coordinates":[[[61.34,55.23],[77.14,55.23],[77.14,79.57],[61.34,79.57],[61.34,55.23]]]}', + ), + ( + "test_dataset", + "3", + 405, + '{"type":"Polygon","coordinates":[[[41.17,45.78],[60.99,45.78],[60.99,78.48],[41.17,78.48],[41.17,45.78]]]}', + ), + ( + "test_dataset", + "3", + 406, + '{"type":"Polygon","coordinates":[[[56.18,44.8],[64.42,44.8],[64.42,56.25],[56.18,56.25],[56.18,44.8]]]}', + ), + ], + } + + output = _compute_detailed_curves( + sorted_ranked_pairs=sorted_ranked_pairs, + grouper_mappings=grouper_mappings, + groundtruths_per_grouper=groundtruths_per_grouper, + predictions_per_grouper=predictions_per_grouper, + pr_curve_iou_threshold=0.5, + pr_curve_max_examples=1, + ) + + pr_expected_answers = { + # (class, 4) + ("class", "4", 0.05, "tp"): 2, + ("class", "4", 0.05, "fn"): 0, + ("class", "4", 0.25, "tp"): 1, + ("class", "4", 0.25, "fn"): 1, + ("class", "4", 0.55, "tp"): 0, + ("class", "4", 0.55, "fn"): 2, + # (class, 2) + ("class", "2", 0.05, "tp"): 1, + ("class", "2", 0.05, "fn"): 1, + ("class", "2", 0.75, "tp"): 0, + ("class", "2", 0.75, "fn"): 2, + # (class, 49) + ("class", "49", 0.05, "tp"): 8, + ("class", "49", 0.3, "tp"): 5, + ("class", "49", 0.5, "tp"): 4, + ("class", "49", 0.85, "tp"): 1, + # (class, 3) + ("class", "3", 0.05, "tp"): 0, + ("class", "3", 0.05, "fp"): 1, + # (class, 1) + ("class", "1", 0.05, "tp"): 1, + ("class", "1", 0.35, "tp"): 0, + # (class, 0) + ("class", "0", 0.05, "tp"): 5, + ("class", "0", 0.5, "tp"): 3, + ("class", "0", 0.95, "tp"): 1, + ("class", "0", 0.95, "fn"): 4, + } + + for ( + key, + value, + threshold, + metric, + ), expected_count in pr_expected_answers.items(): + actual_count = output[0].value[value][threshold][metric] + assert actual_count == expected_count + + # check DetailedPrecisionRecallCurve + detailed_pr_expected_answers = { + # (class, 4) + ("4", 0.05, "tp"): {"all": 2, "total": 2}, + ("4", 0.05, "fn"): { + "missed_detections": 0, + "misclassifications": 0, + "total": 0, + }, + # (class, 2) + ("2", 0.05, "tp"): {"all": 1, "total": 1}, + ("2", 0.05, "fn"): { + "missed_detections": 0, + "misclassifications": 1, + "total": 1, + }, + ("2", 0.75, "tp"): {"all": 0, "total": 0}, + ("2", 0.75, "fn"): { + "missed_detections": 2, + "misclassifications": 0, + "total": 2, + }, + # (class, 49) + ("49", 0.05, "tp"): {"all": 8, "total": 8}, + # (class, 3) + ("3", 0.05, "tp"): {"all": 0, "total": 0}, + ("3", 0.05, "fp"): { + "hallucinations": 0, + "misclassifications": 1, + "total": 1, + }, + # (class, 1) + ("1", 0.05, "tp"): {"all": 1, "total": 1}, + ("1", 0.8, "fn"): { + "missed_detections": 1, + "misclassifications": 0, + "total": 1, + }, + # (class, 0) + ("0", 0.05, "tp"): {"all": 5, "total": 5}, + ("0", 0.95, "fn"): { + "missed_detections": 4, + "misclassifications": 0, + "total": 4, + }, + } + + for ( + value, + threshold, + metric, + ), expected_output in detailed_pr_expected_answers.items(): + model_output = output[1].value[value][threshold][metric] + assert isinstance(model_output, dict) + assert model_output["total"] == expected_output["total"] + assert all( + [ + model_output["observations"][key]["count"] # type: ignore - we know this element is a dict + == expected_output[key] + for key in [ + key + for key in expected_output.keys() + if key not in ["total"] + ] + ] + ) + + # spot check number of examples + assert ( + len( + output[1].value["0"][0.95]["fn"]["observations"]["missed_detections"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 1 + ) + assert ( + len( + output[1].value["49"][0.05]["tp"]["observations"]["all"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 1 + ) + + # do a second test with a much higher iou_threshold + second_output = _compute_detailed_curves( + sorted_ranked_pairs=sorted_ranked_pairs, + grouper_mappings=grouper_mappings, + groundtruths_per_grouper=groundtruths_per_grouper, + predictions_per_grouper=predictions_per_grouper, + pr_curve_iou_threshold=0.9, + pr_curve_max_examples=1, + ) + + pr_expected_answers = { + # (class, 4) + ("class", "4", 0.05, "tp"): 0, + ("class", "4", 0.05, "fn"): 2, + # (class, 2) + ("class", "2", 0.05, "tp"): 1, + ("class", "2", 0.05, "fn"): 1, + ("class", "2", 0.75, "tp"): 0, + ("class", "2", 0.75, "fn"): 2, + # (class, 49) + ("class", "49", 0.05, "tp"): 2, + ("class", "49", 0.3, "tp"): 2, + ("class", "49", 0.5, "tp"): 2, + ("class", "49", 0.85, "tp"): 1, + # (class, 3) + ("class", "3", 0.05, "tp"): 0, + ("class", "3", 0.05, "fp"): 1, + # (class, 1) + ("class", "1", 0.05, "tp"): 0, + ("class", "1", 0.05, "fn"): 1, + # (class, 0) + ("class", "0", 0.05, "tp"): 1, + ("class", "0", 0.5, "tp"): 0, + ("class", "0", 0.95, "fn"): 5, + } + + for ( + key, + value, + threshold, + metric, + ), expected_count in pr_expected_answers.items(): + actual_count = second_output[0].value[value][threshold][metric] + assert actual_count == expected_count + + # check DetailedPrecisionRecallCurve + detailed_pr_expected_answers = { + # (class, 4) + ("4", 0.05, "tp"): {"all": 0, "total": 0}, + ("4", 0.05, "fn"): { + "missed_detections": 2, # below IOU threshold of .9 + "misclassifications": 0, + "total": 2, + }, + # (class, 2) + ("2", 0.05, "tp"): {"all": 1, "total": 1}, + ("2", 0.05, "fn"): { + "missed_detections": 1, + "misclassifications": 0, + "total": 1, + }, + ("2", 0.75, "tp"): {"all": 0, "total": 0}, + ("2", 0.75, "fn"): { + "missed_detections": 2, + "misclassifications": 0, + "total": 2, + }, + # (class, 49) + ("49", 0.05, "tp"): {"all": 2, "total": 2}, + # (class, 3) + ("3", 0.05, "tp"): {"all": 0, "total": 0}, + ("3", 0.05, "fp"): { + "hallucinations": 1, + "misclassifications": 0, + "total": 1, + }, + # (class, 1) + ("1", 0.05, "tp"): {"all": 0, "total": 0}, + ("1", 0.8, "fn"): { + "missed_detections": 1, + "misclassifications": 0, + "total": 1, + }, + # (class, 0) + ("0", 0.05, "tp"): {"all": 1, "total": 1}, + ("0", 0.95, "fn"): { + "missed_detections": 5, + "misclassifications": 0, + "total": 5, + }, + } + + for ( + value, + threshold, + metric, + ), expected_output in detailed_pr_expected_answers.items(): + model_output = second_output[1].value[value][threshold][metric] + assert isinstance(model_output, dict) + assert model_output["total"] == expected_output["total"] + assert all( + [ + model_output["observations"][key]["count"] # type: ignore - we know this element is a dict + == expected_output[key] + for key in [ + key + for key in expected_output.keys() + if key not in ["total"] + ] + ] + ) + + # spot check number of examples + assert ( + len( + second_output[1].value["0"][0.95]["fn"]["observations"]["missed_detections"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 1 + ) + assert ( + len( + second_output[1].value["49"][0.05]["tp"]["observations"]["all"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 1 + ) + + # repeat the above, but with a higher pr_max_curves_example + second_output = _compute_detailed_curves( + sorted_ranked_pairs=sorted_ranked_pairs, + grouper_mappings=grouper_mappings, + groundtruths_per_grouper=groundtruths_per_grouper, + predictions_per_grouper=predictions_per_grouper, + pr_curve_iou_threshold=0.9, + pr_curve_max_examples=3, + ) + + for ( + value, + threshold, + metric, + ), expected_output in detailed_pr_expected_answers.items(): + model_output = second_output[1].value[value][threshold][metric] + assert isinstance(model_output, dict) + assert model_output["total"] == expected_output["total"] + assert all( + [ + model_output["observations"][key]["count"] # type: ignore - we know this element is a dict + == expected_output[key] + for key in [ + key + for key in expected_output.keys() + if key not in ["total"] + ] + ] + ) + + # spot check number of examples + assert ( + len( + second_output[1].value["0"][0.95]["fn"]["observations"]["missed_detections"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 3 + ) + assert ( + len( + second_output[1].value["49"][0.05]["tp"]["observations"]["all"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 2 + ) + + # test behavior if pr_curve_max_examples == 0 + second_output = _compute_detailed_curves( + sorted_ranked_pairs=sorted_ranked_pairs, + grouper_mappings=grouper_mappings, + groundtruths_per_grouper=groundtruths_per_grouper, + predictions_per_grouper=predictions_per_grouper, + pr_curve_iou_threshold=0.9, + pr_curve_max_examples=0, + ) + + for ( + value, + threshold, + metric, + ), expected_output in detailed_pr_expected_answers.items(): + model_output = second_output[1].value[value][threshold][metric] + assert isinstance(model_output, dict) + assert model_output["total"] == expected_output["total"] + assert all( + [ + model_output["observations"][key]["count"] # type: ignore - we know this element is a dict + == expected_output[key] + for key in [ + key + for key in expected_output.keys() + if key not in ["total"] + ] + ] + ) + + # spot check number of examples + assert ( + len( + second_output[1].value["0"][0.95]["fn"]["observations"]["missed_detections"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 0 + ) + assert ( + len( + second_output[1].value["49"][0.05]["tp"]["observations"]["all"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 0 + ) + + +def test__compute_detection( + db: Session, + groundtruths: list[list[GroundTruth]], + predictions: list[list[Prediction]], +): + iou_thresholds = set([round(0.5 + 0.05 * i, 2) for i in range(10)]) + + def _metric_to_dict(m) -> dict: + m = m.model_dump(exclude_none=True) + _round_dict(m, 3) + return m + + metrics = _compute_detection_metrics( + db=db, + parameters=schemas.EvaluationParameters( + task_type=enums.TaskType.OBJECT_DETECTION, + convert_annotations_to_type=enums.AnnotationType.BOX, + iou_thresholds_to_compute=list(iou_thresholds), + iou_thresholds_to_return=[0.5, 0.75], + metrics_to_return=[ + "AP", + "AR", + "mAP", + "APAveragedOverIOUs", + "mAR", + "mAPAveragedOverIOUs", + "PrecisionRecallCurve", + ], + ), + prediction_filter=schemas.Filter( + model_names=["test_model"], + label_keys=["class"], + ), + groundtruth_filter=schemas.Filter( + dataset_names=["test_dataset"], + label_keys=["class"], + ), + target_type=enums.AnnotationType.BOX, + ) + + ap_metrics = [ + _metric_to_dict(m) for m in metrics if isinstance(m, schemas.APMetric) + ] + map_metrics = [ + _metric_to_dict(m) for m in metrics if isinstance(m, schemas.mAPMetric) + ] + ap_metrics_ave_over_ious = [ + _metric_to_dict(m) + for m in metrics + if isinstance(m, schemas.APMetricAveragedOverIOUs) + ] + map_metrics_ave_over_ious = [ + _metric_to_dict(m) + for m in metrics + if isinstance(m, schemas.mAPMetricAveragedOverIOUs) + ] + ar_metrics = [ + _metric_to_dict(m) for m in metrics if isinstance(m, schemas.ARMetric) + ] + mar_metrics = [ + _metric_to_dict(m) for m in metrics if isinstance(m, schemas.mARMetric) + ] + + # cf with torch metrics/pycocotools results listed here: + # https://github.com/Lightning-AI/metrics/blob/107dbfd5fb158b7ae6d76281df44bd94c836bfce/tests/unittests/detection/test_map.py#L231 + expected_ap_metrics = [ + {"iou": 0.5, "value": 0.505, "label": {"key": "class", "value": "2"}}, + {"iou": 0.75, "value": 0.505, "label": {"key": "class", "value": "2"}}, + {"iou": 0.5, "value": 0.79, "label": {"key": "class", "value": "49"}}, + { + "iou": 0.75, + "value": 0.576, + "label": {"key": "class", "value": "49"}, + }, + {"iou": 0.5, "value": 1.0, "label": {"key": "class", "value": "0"}}, + {"iou": 0.75, "value": 0.723, "label": {"key": "class", "value": "0"}}, + {"iou": 0.5, "value": 1.0, "label": {"key": "class", "value": "1"}}, + {"iou": 0.75, "value": 1.0, "label": {"key": "class", "value": "1"}}, + {"iou": 0.5, "value": 1.0, "label": {"key": "class", "value": "4"}}, + {"iou": 0.75, "value": 1.0, "label": {"key": "class", "value": "4"}}, + ] + expected_map_metrics = [ + {"iou": 0.5, "value": 0.859, "label_key": "class"}, + {"iou": 0.75, "value": 0.761, "label_key": "class"}, + ] + expected_ap_metrics_ave_over_ious = [ + { + "ious": iou_thresholds, + "value": 0.454, + "label": {"key": "class", "value": "2"}, + }, + { + "ious": iou_thresholds, + "value": 0.555, # note COCO had 0.556 + "label": {"key": "class", "value": "49"}, + }, + { + "ious": iou_thresholds, + "value": 0.725, + "label": {"key": "class", "value": "0"}, + }, + { + "ious": iou_thresholds, + "value": 0.8, + "label": {"key": "class", "value": "1"}, + }, + { + "ious": iou_thresholds, + "value": 0.650, + "label": {"key": "class", "value": "4"}, + }, + ] + expected_map_metrics_ave_over_ious = [ + {"ious": iou_thresholds, "value": 0.637, "label_key": "class"} + ] + expected_ar_metrics = [ + { + "ious": iou_thresholds, + "value": 0.45, + "label": {"key": "class", "value": "2"}, + }, + { + "ious": iou_thresholds, + "value": -1, + "label": {"key": "class", "value": "3"}, + }, + { + "ious": iou_thresholds, + "value": 0.58, + "label": {"key": "class", "value": "49"}, + }, + { + "ious": iou_thresholds, + "value": 0.78, + "label": {"key": "class", "value": "0"}, + }, + { + "ious": iou_thresholds, + "value": 0.8, + "label": {"key": "class", "value": "1"}, + }, + { + "ious": iou_thresholds, + "value": 0.65, + "label": {"key": "class", "value": "4"}, + }, + ] + expected_mar_metrics = [ + {"ious": iou_thresholds, "value": 0.652, "label_key": "class"}, + ] + + for metric_type, actual_metrics, expected_metrics in [ + ("AP", ap_metrics, expected_ap_metrics), + ("mAP", map_metrics, expected_map_metrics), + ( + "APAveOverIOUs", + ap_metrics_ave_over_ious, + expected_ap_metrics_ave_over_ious, + ), + ( + "mAPAveOverIOUs", + map_metrics_ave_over_ious, + expected_map_metrics_ave_over_ious, + ), + ("AR", ar_metrics, expected_ar_metrics), + ("mAR", mar_metrics, expected_mar_metrics), + ]: + + for m in actual_metrics: + assert m in expected_metrics, f"{metric_type} {m} not in expected" + for m in expected_metrics: + assert m in actual_metrics, f"{metric_type} {m} not in actual" + + pr_metrics = metrics[-1].model_dump(exclude_none=True) + + pr_expected_answers = { + # (class, 4) + ("class", "4", 0.05, "tp"): 2, + ("class", "4", 0.05, "fn"): 0, + ("class", "4", 0.25, "tp"): 1, + ("class", "4", 0.25, "fn"): 1, + ("class", "4", 0.55, "tp"): 0, + ("class", "4", 0.55, "fn"): 2, + # (class, 2) + ("class", "2", 0.05, "tp"): 1, + ("class", "2", 0.05, "fn"): 1, + ("class", "2", 0.75, "tp"): 0, + ("class", "2", 0.75, "fn"): 2, + # (class, 49) + ("class", "49", 0.05, "tp"): 8, + ("class", "49", 0.3, "tp"): 5, + ("class", "49", 0.5, "tp"): 4, + ("class", "49", 0.85, "tp"): 1, + # (class, 3) + ("class", "3", 0.05, "tp"): 0, + ("class", "3", 0.05, "fp"): 1, + # (class, 1) + ("class", "1", 0.05, "tp"): 1, + ("class", "1", 0.35, "tp"): 0, + # (class, 0) + ("class", "0", 0.05, "tp"): 5, + ("class", "0", 0.5, "tp"): 3, + ("class", "0", 0.95, "tp"): 1, + ("class", "0", 0.95, "fn"): 4, + } + + for ( + _, + value, + threshold, + metric, + ), expected_value in pr_expected_answers.items(): + assert pr_metrics["value"][value][threshold][metric] == expected_value + + # now add PrecisionRecallCurve + metrics = _compute_detection_metrics( + db=db, + parameters=schemas.EvaluationParameters( + task_type=enums.TaskType.OBJECT_DETECTION, + convert_annotations_to_type=enums.AnnotationType.BOX, + iou_thresholds_to_compute=list(iou_thresholds), + iou_thresholds_to_return=[0.5, 0.75], + metrics_to_return=[ + "AP", + "AR", + "mAP", + "APAveragedOverIOUs", + "mAR", + "mAPAveragedOverIOUs", + "PrecisionRecallCurve", + ], + ), + prediction_filter=schemas.Filter( + model_names=["test_model"], + label_keys=["class"], + ), + groundtruth_filter=schemas.Filter( + dataset_names=["test_dataset"], + label_keys=["class"], + ), + target_type=enums.AnnotationType.BOX, + ) + + ap_metrics = [ + _metric_to_dict(m) for m in metrics if isinstance(m, schemas.APMetric) + ] + map_metrics = [ + _metric_to_dict(m) for m in metrics if isinstance(m, schemas.mAPMetric) + ] + ap_metrics_ave_over_ious = [ + _metric_to_dict(m) + for m in metrics + if isinstance(m, schemas.APMetricAveragedOverIOUs) + ] + map_metrics_ave_over_ious = [ + _metric_to_dict(m) + for m in metrics + if isinstance(m, schemas.mAPMetricAveragedOverIOUs) + ] + ar_metrics = [ + _metric_to_dict(m) for m in metrics if isinstance(m, schemas.ARMetric) + ] + mar_metrics = [ + _metric_to_dict(m) for m in metrics if isinstance(m, schemas.mARMetric) + ] + + # cf with torch metrics/pycocotools results listed here: + # https://github.com/Lightning-AI/metrics/blob/107dbfd5fb158b7ae6d76281df44bd94c836bfce/tests/unittests/detection/test_map.py#L231 + expected_ap_metrics = [ + {"iou": 0.5, "value": 0.505, "label": {"key": "class", "value": "2"}}, + {"iou": 0.75, "value": 0.505, "label": {"key": "class", "value": "2"}}, + {"iou": 0.5, "value": 0.79, "label": {"key": "class", "value": "49"}}, + { + "iou": 0.75, + "value": 0.576, + "label": {"key": "class", "value": "49"}, + }, + {"iou": 0.5, "value": 1.0, "label": {"key": "class", "value": "0"}}, + {"iou": 0.75, "value": 0.723, "label": {"key": "class", "value": "0"}}, + {"iou": 0.5, "value": 1.0, "label": {"key": "class", "value": "1"}}, + {"iou": 0.75, "value": 1.0, "label": {"key": "class", "value": "1"}}, + {"iou": 0.5, "value": 1.0, "label": {"key": "class", "value": "4"}}, + {"iou": 0.75, "value": 1.0, "label": {"key": "class", "value": "4"}}, + ] + expected_map_metrics = [ + {"iou": 0.5, "value": 0.859, "label_key": "class"}, + {"iou": 0.75, "value": 0.761, "label_key": "class"}, + ] + expected_ap_metrics_ave_over_ious = [ + { + "ious": iou_thresholds, + "value": 0.454, + "label": {"key": "class", "value": "2"}, + }, + { + "ious": iou_thresholds, + "value": 0.555, # note COCO had 0.556 + "label": {"key": "class", "value": "49"}, + }, + { + "ious": iou_thresholds, + "value": 0.725, + "label": {"key": "class", "value": "0"}, + }, + { + "ious": iou_thresholds, + "value": 0.8, + "label": {"key": "class", "value": "1"}, + }, + { + "ious": iou_thresholds, + "value": 0.650, + "label": {"key": "class", "value": "4"}, + }, + ] + expected_map_metrics_ave_over_ious = [ + {"ious": iou_thresholds, "value": 0.637, "label_key": "class"} + ] + expected_ar_metrics = [ + { + "ious": iou_thresholds, + "value": 0.45, + "label": {"key": "class", "value": "2"}, + }, + { + "ious": iou_thresholds, + "value": -1, + "label": {"key": "class", "value": "3"}, + }, + { + "ious": iou_thresholds, + "value": 0.58, + "label": {"key": "class", "value": "49"}, + }, + { + "ious": iou_thresholds, + "value": 0.78, + "label": {"key": "class", "value": "0"}, + }, + { + "ious": iou_thresholds, + "value": 0.8, + "label": {"key": "class", "value": "1"}, + }, + { + "ious": iou_thresholds, + "value": 0.65, + "label": {"key": "class", "value": "4"}, + }, + ] + expected_mar_metrics = [ + {"ious": iou_thresholds, "value": 0.652, "label_key": "class"}, + ] - false_positive_entries = [ + for metric_type, actual_metrics, expected_metrics in [ + ("AP", ap_metrics, expected_ap_metrics), + ("mAP", map_metrics, expected_map_metrics), ( - "test_dataset", - None, - "uid525", - None, - 1005277842145977801, - 0.318, - '{"type":"Polygon","coordinates":[[[61,22.75],[565,22.75],[565,632.42],[61,632.42],[61,22.75]]]}', - ) - ] + "APAveOverIOUs", + ap_metrics_ave_over_ious, + expected_ap_metrics_ave_over_ious, + ), + ( + "mAPAveOverIOUs", + map_metrics_ave_over_ious, + expected_map_metrics_ave_over_ious, + ), + ("AR", ar_metrics, expected_ar_metrics), + ("mAR", mar_metrics, expected_mar_metrics), + ]: - output = _compute_curves( - sorted_ranked_pairs=sorted_ranked_pairs, - grouper_mappings=grouper_mappings, - groundtruths_per_grouper=groundtruths_per_grouper, - false_positive_entries=false_positive_entries, - iou_threshold=0.5, - ) + for m in actual_metrics: + assert m in expected_metrics, f"{metric_type} {m} not in expected" + for m in expected_metrics: + assert m in actual_metrics, f"{metric_type} {m} not in actual" + + pr_metrics = metrics[-1].model_dump(exclude_none=True) pr_expected_answers = { # (class, 4) @@ -429,82 +1543,15 @@ def test__compute_curves(db: Session): } for ( - key, - value, - threshold, - metric, - ), expected_length in pr_expected_answers.items(): - datum_geojson_tuples = output[0].value[value][threshold][metric] - assert isinstance(datum_geojson_tuples, list) - assert len(datum_geojson_tuples) == expected_length - - # spot check a few geojson results - assert ( - output[0].value["4"][0.05]["tp"][0][2] # type: ignore - == '{"type":"Polygon","coordinates":[[[61.87,276.25],[358.29,276.25],[358.29,379.43],[61.87,379.43],[61.87,276.25]]]}' - ) - assert ( - output[0].value["49"][0.85]["tp"][0][2] # type: ignore - == '{"type":"Polygon","coordinates":[[[75.29,23.01],[91.85,23.01],[91.85,50.85],[75.29,50.85],[75.29,23.01]]]}' - ) - assert ( - output[0].value["3"][0.05]["fp"][0][2] # type: ignore - == '{"type":"Polygon","coordinates":[[[61,22.75],[565,22.75],[565,632.42],[61,632.42],[61,22.75]]]}' - ) - - # do a second test with a much higher iou_threshold - second_output = _compute_curves( - sorted_ranked_pairs=sorted_ranked_pairs, - grouper_mappings=grouper_mappings, - groundtruths_per_grouper=groundtruths_per_grouper, - false_positive_entries=false_positive_entries, - iou_threshold=0.9, - ) - - pr_expected_answers = { - # (class, 4) - ("class", "4", 0.05, "tp"): 0, - ("class", "4", 0.05, "fn"): 2, - # (class, 2) - ("class", "2", 0.05, "tp"): 1, - ("class", "2", 0.05, "fn"): 1, - ("class", "2", 0.75, "tp"): 0, - ("class", "2", 0.75, "fn"): 2, - # (class, 49) - ("class", "49", 0.05, "tp"): 2, - ("class", "49", 0.3, "tp"): 2, - ("class", "49", 0.5, "tp"): 2, - ("class", "49", 0.85, "tp"): 1, - # (class, 3) - ("class", "3", 0.05, "tp"): 0, - ("class", "3", 0.05, "fp"): 1, - # (class, 1) - ("class", "1", 0.05, "tp"): 0, - ("class", "1", 0.05, "fn"): 1, - # (class, 0) - ("class", "0", 0.05, "tp"): 1, - ("class", "0", 0.5, "tp"): 0, - ("class", "0", 0.95, "fn"): 5, - } - - for ( - key, + _, value, threshold, metric, - ), expected_length in pr_expected_answers.items(): - datum_geojson_tuples = second_output[0].value[value][threshold][metric] - assert isinstance(datum_geojson_tuples, list) - assert len(datum_geojson_tuples) == expected_length + ), expected_value in pr_expected_answers.items(): + assert pr_metrics["value"][value][threshold][metric] == expected_value - -def test__compute_detection_metrics( - db: Session, - groundtruths: list[list[GroundTruth]], - predictions: list[list[Prediction]], -): - iou_thresholds = set([round(0.5 + 0.05 * i, 2) for i in range(10)]) - metrics = _compute_detection_metrics( + # finally, test the DetailedPrecisionRecallCurve version + metrics = _compute_detection_metrics_with_detailed_precision_recall_curve( db=db, parameters=schemas.EvaluationParameters( task_type=enums.TaskType.OBJECT_DETECTION, @@ -532,11 +1579,6 @@ def test__compute_detection_metrics( target_type=enums.AnnotationType.BOX, ) - def _metric_to_dict(m) -> dict: - m = m.model_dump(exclude_none=True) - _round_dict(m, 3) - return m - ap_metrics = [ _metric_to_dict(m) for m in metrics if isinstance(m, schemas.APMetric) ] @@ -670,7 +1712,7 @@ def _metric_to_dict(m) -> dict: for m in expected_metrics: assert m in actual_metrics, f"{metric_type} {m} not in actual" - pr_metrics = metrics[-1].model_dump(exclude_none=True) + pr_metrics = metrics[-2].model_dump(exclude_none=True) pr_expected_answers = { # (class, 4) @@ -708,25 +1750,8 @@ def _metric_to_dict(m) -> dict: value, threshold, metric, - ), expected_length in pr_expected_answers.items(): - assert ( - len(pr_metrics["value"][value][threshold][metric]) - == expected_length - ) - - # spot check a few geojson results - assert ( - pr_metrics["value"]["4"][0.05]["tp"][0][2] - == '{"type":"Polygon","coordinates":[[[61.87,276.25],[358.29,276.25],[358.29,379.43],[61.87,379.43],[61.87,276.25]]]}' - ) - assert ( - pr_metrics["value"]["49"][0.85]["tp"][0][2] - == '{"type":"Polygon","coordinates":[[[75.29,23.01],[91.85,23.01],[91.85,50.85],[75.29,50.85],[75.29,23.01]]]}' - ) - assert ( - pr_metrics["value"]["3"][0.05]["fp"][0][2] - == '{"type":"Polygon","coordinates":[[[61,22.75],[565,22.75],[565,632.42],[61,632.42],[61,22.75]]]}' - ) + ), expected_value in pr_expected_answers.items(): + assert pr_metrics["value"][value][threshold][metric] == expected_value def test__compute_detection_metrics_with_rasters( @@ -871,23 +1896,148 @@ def test__compute_detection_metrics_with_rasters( value, threshold, metric, - ), expected_length in pr_expected_answers.items(): - assert ( - len(pr_metrics["value"][value][threshold][metric]) - == expected_length - ) + ), expected_value in pr_expected_answers.items(): + assert pr_metrics["value"][value][threshold][metric] == expected_value - # spot check a few geojson results - assert ( - pr_metrics["value"]["label1"][0.05]["tp"][0][2] - == '{"type":"Polygon","coordinates":[[[0,0],[0,80],[32,80],[32,0],[0,0]]]}' - ) - assert ( - pr_metrics["value"]["label2"][0.85]["tp"][0][2] - == '{"type":"Polygon","coordinates":[[[0,0],[0,80],[32,80],[32,0],[0,0]]]}' + # test DetailedPrecisionRecallCurve version + metrics = _compute_detection_metrics_with_detailed_precision_recall_curve( + db=db, + parameters=schemas.EvaluationParameters( + task_type=enums.TaskType.OBJECT_DETECTION, + convert_annotations_to_type=enums.AnnotationType.RASTER, + iou_thresholds_to_compute=list(iou_thresholds), + iou_thresholds_to_return=[0.5, 0.75], + metrics_to_return=[ + "AP", + "AR", + "mAP", + "APAveragedOverIOUs", + "mAR", + "mAPAveragedOverIOUs", + "PrecisionRecallCurve", + ], + ), + prediction_filter=schemas.Filter( + model_names=["test_model"], + label_keys=["class"], + ), + groundtruth_filter=schemas.Filter( + dataset_names=["test_dataset"], + label_keys=["class"], + ), + target_type=enums.AnnotationType.RASTER, ) - assert pr_metrics["value"]["label3"][0.85]["tp"] == [] + metrics = [m.model_dump(exclude_none=True) for m in metrics] + + for m in metrics: + _round_dict(m, 3) + + expected = [ + # AP METRICS + { + "iou": 0.5, + "value": 1.0, + "label": {"key": "class", "value": "label2"}, + }, + { + "iou": 0.75, + "value": 1.0, + "label": {"key": "class", "value": "label2"}, + }, + { + "iou": 0.5, + "value": 1.0, + "label": {"key": "class", "value": "label1"}, + }, + { + "iou": 0.75, + "value": 1.0, + "label": {"key": "class", "value": "label1"}, + }, + { + "iou": 0.5, + "value": 0.0, + "label": {"key": "class", "value": "label3"}, + }, + { + "iou": 0.75, + "value": 0.0, + "label": {"key": "class", "value": "label3"}, + }, + # AP METRICS AVERAGED OVER IOUS + { + "ious": iou_thresholds, + "value": 1.0, + "label": {"key": "class", "value": "label2"}, + }, + { + "ious": iou_thresholds, + "value": -1.0, + "label": {"key": "class", "value": "label4"}, + }, + { + "ious": iou_thresholds, + "value": 1.0, + "label": {"key": "class", "value": "label1"}, + }, + { + "ious": iou_thresholds, + "value": 0.0, + "label": {"key": "class", "value": "label3"}, + }, + # mAP METRICS + {"iou": 0.5, "value": 0.667, "label_key": "class"}, + {"iou": 0.75, "value": 0.667, "label_key": "class"}, + # mAP METRICS AVERAGED OVER IOUS + {"ious": iou_thresholds, "value": 0.667, "label_key": "class"}, + # AR METRICS + { + "ious": iou_thresholds, + "value": 1.0, + "label": {"key": "class", "value": "label2"}, + }, + { + "ious": iou_thresholds, + "value": 1.0, + "label": {"key": "class", "value": "label1"}, + }, + { + "ious": iou_thresholds, + "value": 0.0, + "label": {"key": "class", "value": "label3"}, + }, + # mAR METRICS + {"ious": iou_thresholds, "value": 0.667, "label_key": "class"}, + ] + + non_pr_metrics = metrics[:-2] + pr_metrics = metrics[-2] + for m in non_pr_metrics: + assert m in expected + + for m in expected: + assert m in non_pr_metrics + + pr_expected_answers = { + ("class", "label1", 0.05, "tp"): 1, + ("class", "label1", 0.35, "tp"): 0, + ("class", "label2", 0.05, "tp"): 1, + ("class", "label2", 0.05, "fp"): 0, + ("class", "label2", 0.95, "fp"): 0, + ("class", "label3", 0.05, "tp"): 0, + ("class", "label3", 0.05, "fn"): 1, + ("class", "label4", 0.05, "tp"): 0, + ("class", "label4", 0.05, "fp"): 1, + } + + for ( + _, + value, + threshold, + metric, + ), expected_value in pr_expected_answers.items(): + assert pr_metrics["value"][value][threshold][metric] == expected_value def test_detection_exceptions(db: Session): diff --git a/api/tests/functional-tests/crud/test_create_delete.py b/api/tests/functional-tests/crud/test_create_delete.py index 60bd5d553..35dc2624b 100644 --- a/api/tests/functional-tests/crud/test_create_delete.py +++ b/api/tests/functional-tests/crud/test_create_delete.py @@ -1236,6 +1236,8 @@ def method_to_test( "mAR", "mAP", "mAPAveragedOverIOUs", + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", } # test when min area and max area are specified @@ -1261,6 +1263,8 @@ def method_to_test( "mAR", "mAP", "mAPAveragedOverIOUs", + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", } # check we have the right evaluations @@ -1360,7 +1364,7 @@ def test_create_clf_metrics( model_names=[model_name], datum_filter=schemas.Filter(dataset_names=[dataset_name]), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), ) @@ -1464,7 +1468,7 @@ def test_create_clf_metrics( ) assert query metrics = query.metrics - assert len(metrics) == 2 + 2 + 6 + 6 + 6 + assert len(metrics) == 22 confusion_matrices = db.scalars( select(models.ConfusionMatrix).where( models.ConfusionMatrix.evaluation_id == evaluation_id diff --git a/api/tests/functional-tests/crud/test_evaluation_crud.py b/api/tests/functional-tests/crud/test_evaluation_crud.py index 89d4597a1..e5ecf72e0 100644 --- a/api/tests/functional-tests/crud/test_evaluation_crud.py +++ b/api/tests/functional-tests/crud/test_evaluation_crud.py @@ -162,7 +162,7 @@ def test_restart_failed_evaluation(db: Session): model_names=["model"], datum_filter=schemas.Filter(dataset_names=["dataset"]), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), ), allow_retries=False, @@ -186,7 +186,7 @@ def test_restart_failed_evaluation(db: Session): model_names=["model"], datum_filter=schemas.Filter(dataset_names=["dataset"]), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), ), allow_retries=False, @@ -202,7 +202,7 @@ def test_restart_failed_evaluation(db: Session): model_names=["model"], datum_filter=schemas.Filter(dataset_names=["dataset"]), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), ), allow_retries=True, diff --git a/api/tests/unit-tests/backend/metrics/test_detection.py b/api/tests/unit-tests/backend/metrics/test_detection.py index 0faf7d908..99e06ab48 100644 --- a/api/tests/unit-tests/backend/metrics/test_detection.py +++ b/api/tests/unit-tests/backend/metrics/test_detection.py @@ -30,60 +30,72 @@ def test__calculate_ap_and_ar(): RankedPair( dataset_name="test_dataset", gt_datum_uid="1", + pd_datum_uid="1", gt_id=1, pd_id=1, score=0.8, iou=0.6, gt_geojson="", + is_match=True, ), RankedPair( dataset_name="test_dataset", gt_datum_uid="1", + pd_datum_uid="1", gt_id=2, pd_id=2, score=0.6, iou=0.8, gt_geojson="", + is_match=True, ), RankedPair( dataset_name="test_dataset", gt_datum_uid="1", + pd_datum_uid="1", gt_id=3, pd_id=3, score=0.4, iou=1.0, gt_geojson="", + is_match=True, ), ], "1": [ RankedPair( dataset_name="test_dataset", gt_datum_uid="1", + pd_datum_uid="1", gt_id=0, pd_id=0, score=0.0, iou=1.0, gt_geojson="", + is_match=True, ), RankedPair( dataset_name="test_dataset", gt_datum_uid="1", + pd_datum_uid="1", gt_id=2, pd_id=2, score=0.0, iou=1.0, gt_geojson="", + is_match=True, ), ], "2": [ RankedPair( dataset_name="test_dataset", gt_datum_uid="1", + pd_datum_uid="1", gt_id=0, pd_id=0, score=1.0, iou=1.0, gt_geojson="", + is_match=True, ), ], } diff --git a/api/tests/unit-tests/schemas/test_evaluation.py b/api/tests/unit-tests/schemas/test_evaluation.py index 9dd0f9132..d6efb7dda 100644 --- a/api/tests/unit-tests/schemas/test_evaluation.py +++ b/api/tests/unit-tests/schemas/test_evaluation.py @@ -7,7 +7,9 @@ def test_EvaluationParameters(): - schemas.EvaluationParameters(task_type=enums.TaskType.CLASSIFICATION) + schemas.EvaluationParameters( + task_type=enums.TaskType.CLASSIFICATION, + ) schemas.EvaluationParameters( task_type=enums.TaskType.OBJECT_DETECTION, @@ -57,7 +59,7 @@ def test_EvaluationParameters(): schemas.EvaluationParameters( task_type=enums.TaskType.OBJECT_DETECTION, iou_thresholds_to_compute=None, - iou_thresholds_to_return=0.2, # type: ignore - purposefully throwing error + iou_thresholds_to_return=0.2, # type: ignore - purposefully throwing error, ) with pytest.raises(ValidationError): @@ -81,21 +83,21 @@ def test_EvaluationRequest(): model_names=["name"], datum_filter=schemas.Filter(), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), ) schemas.EvaluationRequest( model_names=["name"], datum_filter=schemas.Filter(), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), ) schemas.EvaluationRequest( model_names=["name", "other"], datum_filter=schemas.Filter(), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), ) @@ -105,7 +107,7 @@ def test_EvaluationRequest(): model_filter=None, # type: ignore - purposefully throwing error datum_filter=schemas.Filter(), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), ) with pytest.raises(ValidationError): @@ -113,7 +115,7 @@ def test_EvaluationRequest(): model_names=["name"], datum_filter=None, # type: ignore - purposefully throwing error parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), ) with pytest.raises(ValidationError): @@ -129,7 +131,7 @@ def test_EvaluationRequest(): model_names=[], datum_filter=schemas.Filter(), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), ) @@ -139,7 +141,7 @@ def test_EvaluationRequest(): model_filter=schemas.Filter(), # type: ignore - purposefully throwing error datum_filter=schemas.Filter(), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), ) @@ -150,7 +152,7 @@ def test_EvaluationResponse(): model_name="test", datum_filter=schemas.Filter(), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), status=enums.EvaluationStatus.DONE, metrics=[], @@ -166,7 +168,7 @@ def test_EvaluationResponse(): model_name="test", datum_filter=schemas.Filter(), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), status=enums.EvaluationStatus.DONE, metrics=[], @@ -182,7 +184,7 @@ def test_EvaluationResponse(): model_name=None, # type: ignore - purposefully throwing error datum_filter=schemas.Filter(), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), status=enums.EvaluationStatus.DONE, metrics=[], @@ -212,7 +214,7 @@ def test_EvaluationResponse(): model_name="name", datum_filter=schemas.Filter(), parameters=schemas.EvaluationParameters( - task_type=enums.TaskType.CLASSIFICATION + task_type=enums.TaskType.CLASSIFICATION, ), status=None, # type: ignore - purposefully throwing error metrics=[], diff --git a/api/tests/unit-tests/test_main.py b/api/tests/unit-tests/test_main.py index ece658772..18af72c9a 100644 --- a/api/tests/unit-tests/test_main.py +++ b/api/tests/unit-tests/test_main.py @@ -903,7 +903,7 @@ def test_post_detection_metrics(client: TestClient): model_name="modelname", datum_filter=schemas.Filter(dataset_names=["dsetname"]), parameters=schemas.EvaluationParameters( - task_type=TaskType.OBJECT_DETECTION + task_type=TaskType.OBJECT_DETECTION, ), status=EvaluationStatus.PENDING, metrics=[], @@ -920,7 +920,7 @@ def test_post_detection_metrics(client: TestClient): dataset_names=["dsetname"], ), parameters=schemas.EvaluationParameters( - task_type=TaskType.OBJECT_DETECTION + task_type=TaskType.OBJECT_DETECTION, ), ).model_dump() @@ -952,7 +952,7 @@ def test_post_clf_metrics(client: TestClient): model_names=["modelname"], datum_filter=schemas.Filter(dataset_names=["dsetname"]), parameters=schemas.EvaluationParameters( - task_type=TaskType.CLASSIFICATION + task_type=TaskType.CLASSIFICATION, ), ).model_dump() @@ -971,7 +971,7 @@ def test_post_semenatic_segmentation_metrics(client: TestClient): model_name="modelname", datum_filter=schemas.Filter(dataset_names=["dsetname"]), parameters=schemas.EvaluationParameters( - task_type=TaskType.SEMANTIC_SEGMENTATION + task_type=TaskType.SEMANTIC_SEGMENTATION, ), status=EvaluationStatus.PENDING, metrics=[], @@ -986,7 +986,7 @@ def test_post_semenatic_segmentation_metrics(client: TestClient): model_names=["modelname"], datum_filter=schemas.Filter(dataset_names=["dsetname"]), parameters=schemas.EvaluationParameters( - task_type=TaskType.SEMANTIC_SEGMENTATION + task_type=TaskType.SEMANTIC_SEGMENTATION, ), ).model_dump() diff --git a/api/valor_api/backend/metrics/classification.py b/api/valor_api/backend/metrics/classification.py index 8b56b62a6..0a35e2183 100644 --- a/api/valor_api/backend/metrics/classification.py +++ b/api/valor_api/backend/metrics/classification.py @@ -1,3 +1,4 @@ +import random from collections import defaultdict from typing import Sequence @@ -29,20 +30,9 @@ def _compute_curves( grouper_key: str, grouper_mappings: dict[str, dict[str, dict]], unique_datums: set[tuple[str, str]], -) -> dict[ - str, - dict[ - float, - dict[ - str, - int - | float - | list[tuple[str, str]] - | list[tuple[str, str, str]] - | None, - ], - ], -]: + pr_curve_max_examples: int, + metrics_to_return: list[str], +) -> list[schemas.PrecisionRecallCurve | schemas.DetailedPrecisionRecallCurve]: """ Calculates precision-recall curves for each class. @@ -60,14 +50,19 @@ def _compute_curves( A dictionary of mappings that connect groupers to their related labels. unique_datums: list[tuple[str, str]] All of the unique datums associated with the ground truth and prediction filters. + pr_curve_max_examples: int + The maximum number of datum examples to store per true positive, false negative, etc. + metrics_to_return: list[str] + The list of metrics requested by the user. Returns ------- - dict[str,dict[float, dict[str, int | float | list[tuple[str, str]] | list[tuple[str, str, str]] | None]] - A nested dictionary where the first key is the class label, the second key is the confidence threshold (e.g., 0.05), the third key is the metric name (e.g., "precision"), and the final key is either the value itself (for precision, recall, etc.) or a list of tuples containing the (dataset_name, datum_id) for each observation. + list[schemas.PrecisionRecallCurve | schemas.DetailedPrecisionRecallCurve] + The PrecisionRecallCurve and/or DetailedPrecisionRecallCurve metrics. """ - output = defaultdict(lambda: defaultdict(dict)) + pr_output = defaultdict(lambda: defaultdict(dict)) + detailed_pr_output = defaultdict(lambda: defaultdict(dict)) for threshold in [x / 100 for x in range(5, 100, 5)]: # get predictions that are above the confidence threshold @@ -150,10 +145,25 @@ def _compute_curves( key=lambda x: ((x[1] is None, x[0][0] != x[0][1], x[1], x[2])) ) + # create sets of all datums for which there is a prediction / groundtruth + # used when separating hallucinations/misclassifications/missed_detections + gt_datums = set() + pd_datums = set() + + for row in res: + (pd_datum_uid, pd_dataset_name, gt_datum_uid, gt_dataset_name,) = ( + row[2], + row[3], + row[5], + row[6], + ) + gt_datums.add((gt_dataset_name, gt_datum_uid)) + pd_datums.add((pd_dataset_name, pd_datum_uid)) + for grouper_value in grouper_mappings["grouper_key_to_labels_mapping"][ grouper_key ].keys(): - tp, tn, fp, fn = [], [], [], [] + tp, tn, fp, fn = [], [], defaultdict(list), defaultdict(list) seen_datums = set() for row in res: @@ -177,26 +187,47 @@ def _compute_curves( tp += [(pd_dataset_name, pd_datum_uid)] seen_datums.add(gt_datum_uid) elif predicted_label == grouper_value: - fp += [(pd_dataset_name, pd_datum_uid)] + # if there was a groundtruth for a given datum, then it was a misclassification + if (pd_dataset_name, pd_datum_uid) in gt_datums: + fp["misclassifications"].append( + (pd_dataset_name, pd_datum_uid) + ) + else: + fp["hallucinations"].append( + (pd_dataset_name, pd_datum_uid) + ) seen_datums.add(gt_datum_uid) elif ( actual_label == grouper_value and gt_datum_uid not in seen_datums ): - fn += [(gt_dataset_name, gt_datum_uid)] + # if there was a prediction for a given datum, then it was a misclassification + if (gt_dataset_name, gt_datum_uid) in pd_datums: + fn["misclassifications"].append( + (gt_dataset_name, gt_datum_uid) + ) + else: + fn["missed_detections"].append( + (gt_dataset_name, gt_datum_uid) + ) seen_datums.add(gt_datum_uid) # calculate metrics tn = [ datum_uid_pair for datum_uid_pair in unique_datums - if datum_uid_pair not in tp + fp + fn + if datum_uid_pair + not in tp + + fp["hallucinations"] + + fp["misclassifications"] + + fn["misclassifications"] + + fn["missed_detections"] and None not in datum_uid_pair ] tp_cnt, fp_cnt, fn_cnt, tn_cnt = ( len(tp), - len(fp), - len(fn), + len(fp["hallucinations"]) + len(fp["misclassifications"]), + len(fn["missed_detections"]) + len(fn["misclassifications"]), len(tn), ) @@ -217,18 +248,122 @@ def _compute_curves( else -1 ) - output[grouper_value][threshold] = { - "tp": tp, - "fp": fp, - "fn": fn, - "tn": tn, + pr_output[grouper_value][threshold] = { + "tp": tp_cnt, + "fp": fp_cnt, + "fn": fn_cnt, + "tn": tn_cnt, "accuracy": accuracy, "precision": precision, "recall": recall, "f1_score": f1_score, } - return dict(output) + if "DetailedPrecisionRecallCurve" in metrics_to_return: + + detailed_pr_output[grouper_value][threshold] = { + "tp": { + "total": tp_cnt, + "observations": { + "all": { + "count": tp_cnt, + "examples": ( + random.sample(tp, pr_curve_max_examples) + if len(tp) >= pr_curve_max_examples + else tp + ), + } + }, + }, + "tn": { + "total": tn_cnt, + "observations": { + "all": { + "count": tn_cnt, + "examples": ( + random.sample(tn, pr_curve_max_examples) + if len(tn) >= pr_curve_max_examples + else tn + ), + } + }, + }, + "fn": { + "total": fn_cnt, + "observations": { + "misclassifications": { + "count": len(fn["misclassifications"]), + "examples": ( + random.sample( + fn["misclassifications"], + pr_curve_max_examples, + ) + if len(fn["misclassifications"]) + >= pr_curve_max_examples + else fn["misclassifications"] + ), + }, + "missed_detections": { + "count": len(fn["missed_detections"]), + "examples": ( + random.sample( + fn["missed_detections"], + pr_curve_max_examples, + ) + if len(fn["missed_detections"]) + >= pr_curve_max_examples + else fn["missed_detections"] + ), + }, + }, + }, + "fp": { + "total": fp_cnt, + "observations": { + "misclassifications": { + "count": len(fp["misclassifications"]), + "examples": ( + random.sample( + fp["misclassifications"], + pr_curve_max_examples, + ) + if len(fp["misclassifications"]) + >= pr_curve_max_examples + else fp["misclassifications"] + ), + }, + "hallucinations": { + "count": len(fp["hallucinations"]), + "examples": ( + random.sample( + fp["hallucinations"], + pr_curve_max_examples, + ) + if len(fp["hallucinations"]) + >= pr_curve_max_examples + else fp["hallucinations"] + ), + }, + }, + }, + } + + output = [] + + output.append( + schemas.PrecisionRecallCurve( + label_key=grouper_key, value=dict(pr_output) + ), + ) + + if "DetailedPrecisionRecallCurve" in metrics_to_return: + output += [ + schemas.DetailedPrecisionRecallCurve( + label_key=grouper_key, value=dict(detailed_pr_output) + ) + ] + + return output def _compute_binary_roc_auc( @@ -625,6 +760,7 @@ def _compute_confusion_matrix_and_metrics_at_grouper_key( groundtruth_filter: schemas.Filter, grouper_key: str, grouper_mappings: dict[str, dict[str, dict]], + pr_curve_max_examples: int, metrics_to_return: list[str], ) -> ( tuple[ @@ -652,6 +788,8 @@ def _compute_confusion_matrix_and_metrics_at_grouper_key( The filter to be used to query groundtruths. grouper_mappings: dict[str, dict[str, dict]] A dictionary of mappings that connect groupers to their related labels. + pr_curve_max_examples: int + The maximum number of datum examples to store per true positive, false negative, etc. metrics: list[str] The list of metrics to compute, store, and return to the user. @@ -703,7 +841,7 @@ def _compute_confusion_matrix_and_metrics_at_grouper_key( return None # aggregate metrics (over all label values) - metrics = [ + output = [ schemas.AccuracyMetric( label_key=grouper_key, value=_compute_accuracy_from_cm(confusion_matrix), @@ -746,11 +884,10 @@ def _compute_confusion_matrix_and_metrics_at_grouper_key( grouper_key=grouper_key, grouper_mappings=grouper_mappings, unique_datums=unique_datums, + pr_curve_max_examples=pr_curve_max_examples, + metrics_to_return=metrics_to_return, ) - output = schemas.PrecisionRecallCurve( - label_key=grouper_key, value=pr_curves - ) - metrics.append(output) + output += pr_curves # metrics that are per label for grouper_value in grouper_mappings["grouper_key_to_labels_mapping"][ @@ -765,32 +902,30 @@ def _compute_confusion_matrix_and_metrics_at_grouper_key( ) pydantic_label = schemas.Label(key=grouper_key, value=grouper_value) - metrics.append( + + output += [ schemas.PrecisionMetric( label=pydantic_label, value=precision, - ) - ) - metrics.append( + ), schemas.RecallMetric( label=pydantic_label, value=recall, - ) - ) - metrics.append( + ), schemas.F1Metric( label=pydantic_label, value=f1, - ) - ) + ), + ] - return confusion_matrix, metrics + return confusion_matrix, output def _compute_clf_metrics( db: Session, prediction_filter: schemas.Filter, groundtruth_filter: schemas.Filter, + pr_curve_max_examples: int, metrics_to_return: list[str], label_map: LabelMapType | None = None, ) -> tuple[ @@ -819,6 +954,9 @@ def _compute_clf_metrics( The list of metrics to compute, store, and return to the user. label_map: LabelMapType, optional Optional mapping of individual labels to a grouper label. Useful when you need to evaluate performance using labels that differ across datasets and models. + pr_curve_max_examples: int + The maximum number of datum examples to store per true positive, false negative, etc. + Returns ---------- @@ -839,7 +977,7 @@ def _compute_clf_metrics( ) # compute metrics and confusion matrix for each grouper id - confusion_matrices, metrics = [], [] + confusion_matrices, metrics_to_output = [], [] for grouper_key in grouper_mappings[ "grouper_key_to_labels_mapping" ].keys(): @@ -849,13 +987,14 @@ def _compute_clf_metrics( groundtruth_filter=groundtruth_filter, grouper_key=grouper_key, grouper_mappings=grouper_mappings, + pr_curve_max_examples=pr_curve_max_examples, metrics_to_return=metrics_to_return, ) if cm_and_metrics is not None: confusion_matrices.append(cm_and_metrics[0]) - metrics.extend(cm_and_metrics[1]) + metrics_to_output.extend(cm_and_metrics[1]) - return confusion_matrices, metrics + return confusion_matrices, metrics_to_output @validate_computation @@ -908,6 +1047,11 @@ def compute_clf_metrics( prediction_filter=prediction_filter, groundtruth_filter=groundtruth_filter, label_map=parameters.label_map, + pr_curve_max_examples=( + parameters.pr_curve_max_examples + if parameters.pr_curve_max_examples + else 0 + ), metrics_to_return=parameters.metrics_to_return, ) diff --git a/api/valor_api/backend/metrics/detection.py b/api/valor_api/backend/metrics/detection.py index 8f1d96846..8c917f42d 100644 --- a/api/valor_api/backend/metrics/detection.py +++ b/api/valor_api/backend/metrics/detection.py @@ -1,6 +1,7 @@ import bisect import heapq import math +import random from collections import defaultdict from dataclasses import dataclass from typing import Sequence, Tuple @@ -19,19 +20,21 @@ log_evaluation_item_counts, validate_computation, ) -from valor_api.backend.query import generate_query, generate_select +from valor_api.backend.query import generate_query from valor_api.enums import AnnotationType @dataclass class RankedPair: dataset_name: str + pd_datum_uid: str | None gt_datum_uid: str | None gt_geojson: str | None gt_id: int | None pd_id: int score: float iou: float + is_match: bool def _calculate_101_pt_interp(precisions, recalls) -> float: @@ -61,6 +64,148 @@ def _calculate_101_pt_interp(precisions, recalls) -> float: return ret / 101 +def _calculate_ap_and_ar( + sorted_ranked_pairs: dict[str, list[RankedPair]], + number_of_groundtruths_per_grouper: dict[str, int], + grouper_mappings: dict[str, dict[str, schemas.Label]], + iou_thresholds: list[float], + recall_score_threshold: float, +) -> Tuple[list[schemas.APMetric], list[schemas.ARMetric]]: + """ + Computes the average precision and average recall metrics. Returns a dict with keys + `f"IoU={iou_thres}"` for each `iou_thres` in `iou_thresholds` as well as + `f"IoU={min(iou_thresholds)}:{max(iou_thresholds)}", which is the average + of the scores across all of the IoU thresholds. + """ + if recall_score_threshold < 0 or recall_score_threshold > 1.0: + raise ValueError( + "recall_score_threshold should exist in the range 0 <= threshold <= 1." + ) + if min(iou_thresholds) <= 0 or max(iou_thresholds) > 1.0: + raise ValueError( + "IOU thresholds should exist in the range 0 < threshold <= 1." + ) + + ap_metrics = [] + ar_metrics = [] + + for grouper_id, grouper_label in grouper_mappings[ + "grouper_id_to_grouper_label_mapping" + ].items(): + recalls_across_thresholds = [] + + for iou_threshold in iou_thresholds: + if grouper_id not in number_of_groundtruths_per_grouper.keys(): + continue + + precisions = [] + recalls = [] + # recall true positives require a confidence score above recall_score_threshold, while precision + # true positives only require a confidence score above 0 + recall_cnt_tp = 0 + recall_cnt_fp = 0 + recall_cnt_fn = 0 + precision_cnt_tp = 0 + precision_cnt_fp = 0 + + if grouper_id in sorted_ranked_pairs: + matched_gts_for_precision = set() + matched_gts_for_recall = set() + for row in sorted_ranked_pairs[grouper_id]: + + precision_score_conditional = row.score > 0 + + recall_score_conditional = ( + row.score > recall_score_threshold + or ( + math.isclose(row.score, recall_score_threshold) + and recall_score_threshold > 0 + ) + ) + + iou_conditional = ( + row.iou >= iou_threshold and iou_threshold > 0 + ) + + if ( + recall_score_conditional + and iou_conditional + and row.gt_id not in matched_gts_for_recall + ): + recall_cnt_tp += 1 + matched_gts_for_recall.add(row.gt_id) + else: + recall_cnt_fp += 1 + + if ( + precision_score_conditional + and iou_conditional + and row.gt_id not in matched_gts_for_precision + ): + matched_gts_for_precision.add(row.gt_id) + precision_cnt_tp += 1 + else: + precision_cnt_fp += 1 + + recall_cnt_fn = ( + number_of_groundtruths_per_grouper[grouper_id] + - recall_cnt_tp + ) + + precision_cnt_fn = ( + number_of_groundtruths_per_grouper[grouper_id] + - precision_cnt_tp + ) + + precisions.append( + precision_cnt_tp + / (precision_cnt_tp + precision_cnt_fp) + if (precision_cnt_tp + precision_cnt_fp) + else 0 + ) + recalls.append( + precision_cnt_tp + / (precision_cnt_tp + precision_cnt_fn) + if (precision_cnt_tp + precision_cnt_fn) + else 0 + ) + + recalls_across_thresholds.append( + recall_cnt_tp / (recall_cnt_tp + recall_cnt_fn) + if (recall_cnt_tp + recall_cnt_fn) + else 0 + ) + else: + precisions = [0] + recalls = [0] + recalls_across_thresholds.append(0) + + ap_metrics.append( + schemas.APMetric( + iou=iou_threshold, + value=_calculate_101_pt_interp( + precisions=precisions, recalls=recalls + ), + label=grouper_label, + ) + ) + + ar_metrics.append( + schemas.ARMetric( + ious=set(iou_thresholds), + value=( + sum(recalls_across_thresholds) + / len(recalls_across_thresholds) + if recalls_across_thresholds + else -1 + ), + label=grouper_label, + ) + ) + + return ap_metrics, ar_metrics + + def _compute_curves( sorted_ranked_pairs: dict[int, list[RankedPair]], grouper_mappings: dict[str, dict[str, schemas.Label]], @@ -86,8 +231,8 @@ def _compute_curves( Returns ------- - dict - A nested dictionary where the first key is the class label, the second key is the confidence threshold (e.g., 0.05), the third key is the metric name (e.g., "precision"), and the final key is either the value itself (for precision, recall, etc.) or a list of tuples containing the (dataset_name, datum_id, bounding boxes) for each observation. + list[schemas.PrecisionRecallCurve] + A list of PrecisionRecallCurve metrics. """ output = defaultdict(dict) @@ -103,20 +248,14 @@ def _compute_curves( for confidence_threshold in [x / 100 for x in range(5, 100, 5)]: if grouper_id not in sorted_ranked_pairs: - tp = [] - fn = ( - [ - (dataset_name, datum_uid, gt_geojson) - for dataset_name, datum_uid, _, gt_geojson in groundtruths_per_grouper[ - grouper_id - ] - ] - if grouper_id in groundtruths_per_grouper - else [] - ) + tp_cnt = 0 + if grouper_id in groundtruths_per_grouper: + fn_cnt = len(groundtruths_per_grouper[grouper_id]) + else: + fn_cnt = 0 else: - tp, fp, fn = [], [], [] + tp_cnt, fn_cnt = 0, 0 seen_gts = set() for row in sorted_ranked_pairs[grouper_id]: @@ -125,33 +264,34 @@ def _compute_curves( and row.iou >= iou_threshold and row.gt_id not in seen_gts ): - tp += [ - ( - row.dataset_name, - row.gt_datum_uid, - row.gt_geojson, - ) - ] + tp_cnt += 1 seen_gts.add(row.gt_id) - fn = [ - (dataset_name, datum_uid, gt_geojson) - for dataset_name, datum_uid, gt_id, gt_geojson in groundtruths_per_grouper[ - grouper_id - ] - if gt_id not in seen_gts - ] - - fp = [ - (dset_name, pd_datum_uid, pd_geojson) - for dset_name, _, pd_datum_uid, gt_label_id, pd_label_id, pd_score, pd_geojson in false_positive_entries - if pd_score >= confidence_threshold - and pd_label_id == grouper_id - and gt_label_id is None - ] + for ( + _, + _, + gt_id, + ) in groundtruths_per_grouper[grouper_id]: + if gt_id not in seen_gts: + fn_cnt += 1 + + fp_cnt = 0 + for ( + _, + _, + _, + gt_label_id_grouper, + pd_label_id_grouper, + pd_score, + ) in false_positive_entries: + if ( + pd_score >= confidence_threshold + and pd_label_id_grouper == grouper_id + and gt_label_id_grouper is None + ): + fp_cnt += 1 # calculate metrics - tp_cnt, fp_cnt, fn_cnt = len(tp), len(fp), len(fn) precision = ( tp_cnt / (tp_cnt + fp_cnt) if (tp_cnt + fp_cnt) > 0 else -1 ) @@ -165,9 +305,9 @@ def _compute_curves( ) curves[label_value][confidence_threshold] = { - "tp": tp, - "fp": fp, - "fn": fn, + "tp": tp_cnt, + "fp": fp_cnt, + "fn": fn_cnt, "tn": None, # tn and accuracy aren't applicable to detection tasks because there's an infinite number of true negatives "precision": precision, "recall": recall, @@ -187,149 +327,692 @@ def _compute_curves( ] -def _calculate_ap_and_ar( - sorted_ranked_pairs: dict[str, list[RankedPair]], - number_of_groundtruths_per_grouper: dict[str, int], +def _compute_detailed_curves( + sorted_ranked_pairs: dict[int, list[RankedPair]], grouper_mappings: dict[str, dict[str, schemas.Label]], - iou_thresholds: list[float], - recall_score_threshold: float, -) -> Tuple[list[schemas.APMetric], list[schemas.ARMetric]]: + groundtruths_per_grouper: dict[int, list], + predictions_per_grouper: dict[int, list], + pr_curve_iou_threshold: float, + pr_curve_max_examples: int, +) -> list[schemas.PrecisionRecallCurve | schemas.DetailedPrecisionRecallCurve]: """ - Computes the average precision. Return is a dict with keys - `f"IoU={iou_thres}"` for each `iou_thres` in `iou_thresholds` as well as - `f"IoU={min(iou_thresholds)}:{max(iou_thresholds)}"` which is the average - of the scores across all of the IoU thresholds. + Calculates precision-recall curves and detailed precision recall curves for each class. + + Parameters + ---------- + sorted_ranked_pairs: dict[int, list[RankedPair]] + The ground truth-prediction matches from psql, grouped by grouper_id. + grouper_mappings: dict[str, dict[str, schemas.Label]] + A dictionary of mappings that connect groupers to their related labels. + groundtruths_per_grouper: dict[int, int] + A dictionary containing the (dataset_name, datum_id, gt_id) for all groundtruths associated with a grouper. + predictions_per_grouper: dict[int, int] + A dictionary containing the (dataset_name, datum_id, gt_id) for all predictions associated with a grouper. + pr_curve_iou_threshold: float + The IOU threshold to use as a cut-off for our predictions. + pr_curve_max_examples: int + The maximum number of datum examples to store per true positive, false negative, etc. + + Returns + ------- + list[schemas.PrecisionRecallCurve | schemas.DetailedPrecisionRecallCurve] + A list of PrecisionRecallCurve and DetailedPrecisionRecallCurve metrics. """ - if recall_score_threshold < 0 or recall_score_threshold > 1.0: + pr_output = defaultdict(dict) + detailed_pr_output = defaultdict(dict) + + # transform sorted_ranked_pairs into two sets (groundtruths and predictions) + # we'll use these dictionaries to look up the IOU overlap between specific groundtruths and predictions + # to separate misclassifications from hallucinations/missed_detections + pd_datums = defaultdict(lambda: defaultdict(list)) + gt_datums = defaultdict(lambda: defaultdict(list)) + + for grouper_id, ranked_pairs in sorted_ranked_pairs.items(): + for ranked_pair in ranked_pairs: + grouper_id_key = hash( + ( + ranked_pair.dataset_name, + ranked_pair.pd_datum_uid, + grouper_mappings["grouper_id_to_grouper_label_mapping"][ + grouper_id + ].key, # type: ignore + ) + ) + gt_key = hash( + ( + ranked_pair.dataset_name, + ranked_pair.gt_datum_uid, + ranked_pair.gt_id, + ) + ) + pd_key = hash( + ( + ranked_pair.dataset_name, + ranked_pair.pd_datum_uid, + ranked_pair.pd_id, + ) + ) + pd_datums[grouper_id_key][gt_key].append( + (ranked_pair.iou, ranked_pair.score) + ) + gt_datums[grouper_id_key][pd_key].append( + (ranked_pair.iou, ranked_pair.score) + ) + + for grouper_id, grouper_label in grouper_mappings[ + "grouper_id_to_grouper_label_mapping" + ].items(): + + pr_curves = defaultdict(lambda: defaultdict(dict)) + detailed_pr_curves = defaultdict(lambda: defaultdict(dict)) + + label_key = grouper_label.key + label_value = grouper_label.value + + for confidence_threshold in [x / 100 for x in range(5, 100, 5)]: + seen_pds = set() + seen_gts = set() + + tp, fp, fn = [], defaultdict(list), defaultdict(list) + + for row in sorted_ranked_pairs[int(grouper_id)]: + if ( + row.score >= confidence_threshold + and row.iou >= pr_curve_iou_threshold + and row.gt_id not in seen_gts + and row.is_match is True + ): + tp += [ + ( + row.dataset_name, + row.gt_datum_uid, + row.gt_geojson, + ) + ] + seen_gts.add(row.gt_id) + seen_pds.add(row.pd_id) + + if grouper_id in groundtruths_per_grouper: + for ( + dataset_name, + datum_uid, + gt_id, + gt_geojson, + ) in groundtruths_per_grouper[int(grouper_id)]: + if gt_id not in seen_gts: + grouper_id_key = hash( + ( + dataset_name, + datum_uid, + grouper_mappings[ + "grouper_id_to_grouper_label_mapping" + ][grouper_id].key, + ) + ) + gt_key = hash((dataset_name, datum_uid, gt_id)) + misclassification_detected = any( + [ + score >= confidence_threshold + and iou >= pr_curve_iou_threshold + for (iou, score) in pd_datums[grouper_id_key][ + gt_key + ] + ] + ) + # if there is at least one prediction overlapping the groundtruth with a sufficient score and iou threshold, then it's a misclassification + if misclassification_detected: + fn["misclassifications"].append( + (dataset_name, datum_uid, gt_geojson) + ) + else: + fn["missed_detections"].append( + (dataset_name, datum_uid, gt_geojson) + ) + + if grouper_id in predictions_per_grouper: + for ( + dataset_name, + datum_uid, + pd_id, + pd_geojson, + ) in predictions_per_grouper[int(grouper_id)]: + if pd_id not in seen_pds: + grouper_id_key = hash( + ( + dataset_name, + datum_uid, + grouper_mappings[ + "grouper_id_to_grouper_label_mapping" + ][ + grouper_id + ].key, # type: ignore + ) + ) + pd_key = hash((dataset_name, datum_uid, pd_id)) + misclassification_detected = any( + [ + iou >= pr_curve_iou_threshold + and score >= confidence_threshold + for (iou, score) in gt_datums[grouper_id_key][ + pd_key + ] + ] + ) + hallucination_detected = any( + [ + score >= confidence_threshold + for (_, score) in gt_datums[grouper_id_key][ + pd_key + ] + ] + ) + # if there is at least one groundtruth overlapping the prediction with a sufficient score and iou threshold, then it's a misclassification + if misclassification_detected: + fp["misclassifications"].append( + (dataset_name, datum_uid, pd_geojson) + ) + elif hallucination_detected: + fp["hallucinations"].append( + (dataset_name, datum_uid, pd_geojson) + ) + + # calculate metrics + tp_cnt, fp_cnt, fn_cnt = ( + len(tp), + len(fp["hallucinations"]) + len(fp["misclassifications"]), + len(fn["missed_detections"]) + len(fn["misclassifications"]), + ) + precision = ( + tp_cnt / (tp_cnt + fp_cnt) if (tp_cnt + fp_cnt) > 0 else -1 + ) + recall = ( + tp_cnt / (tp_cnt + fn_cnt) if (tp_cnt + fn_cnt) > 0 else -1 + ) + f1_score = ( + (2 * precision * recall) / (precision + recall) + if precision and recall + else -1 + ) + + pr_curves[label_value][confidence_threshold] = { + "tp": tp_cnt, + "fp": fp_cnt, + "fn": fn_cnt, + "tn": None, # tn and accuracy aren't applicable to detection tasks because there's an infinite number of true negatives + "precision": precision, + "recall": recall, + "accuracy": None, + "f1_score": f1_score, + } + + detailed_pr_curves[label_value][confidence_threshold] = { + "tp": { + "total": tp_cnt, + "observations": { + "all": { + "count": tp_cnt, + "examples": ( + random.sample(tp, pr_curve_max_examples) + if len(tp) >= pr_curve_max_examples + else tp + ), + } + }, + }, + "fn": { + "total": fn_cnt, + "observations": { + "misclassifications": { + "count": len(fn["misclassifications"]), + "examples": ( + random.sample( + fn["misclassifications"], + pr_curve_max_examples, + ) + if len(fn["misclassifications"]) + >= pr_curve_max_examples + else fn["misclassifications"] + ), + }, + "missed_detections": { + "count": len(fn["missed_detections"]), + "examples": ( + random.sample( + fn["missed_detections"], + pr_curve_max_examples, + ) + if len(fn["missed_detections"]) + >= pr_curve_max_examples + else fn["missed_detections"] + ), + }, + }, + }, + "fp": { + "total": fp_cnt, + "observations": { + "misclassifications": { + "count": len(fp["misclassifications"]), + "examples": ( + random.sample( + fp["misclassifications"], + pr_curve_max_examples, + ) + if len(fp["misclassifications"]) + >= pr_curve_max_examples + else fp["misclassifications"] + ), + }, + "hallucinations": { + "count": len(fp["hallucinations"]), + "examples": ( + random.sample( + fp["hallucinations"], + pr_curve_max_examples, + ) + if len(fp["hallucinations"]) + >= pr_curve_max_examples + else fp["hallucinations"] + ), + }, + }, + }, + } + + pr_output[label_key].update(dict(pr_curves)) + detailed_pr_output[label_key].update(dict(detailed_pr_curves)) + + output = [] + + output += [ + schemas.PrecisionRecallCurve( + label_key=key, + value=dict(value), + pr_curve_iou_threshold=pr_curve_iou_threshold, + ) + for key, value in pr_output.items() + ] + + output += [ + schemas.DetailedPrecisionRecallCurve( + label_key=key, + value=dict(value), + pr_curve_iou_threshold=pr_curve_iou_threshold, + ) + for key, value in detailed_pr_output.items() + ] + + return output + + +def _compute_detection_metrics( + db: Session, + parameters: schemas.EvaluationParameters, + prediction_filter: schemas.Filter, + groundtruth_filter: schemas.Filter, + target_type: enums.AnnotationType, +) -> Sequence[ + schemas.APMetric + | schemas.ARMetric + | schemas.APMetricAveragedOverIOUs + | schemas.mAPMetric + | schemas.mARMetric + | schemas.mAPMetricAveragedOverIOUs + | schemas.PrecisionRecallCurve +]: + """ + Compute detection metrics. This version of _compute_detection_metrics only does IOU calculations for every groundtruth-prediction pair that shares a common grouper id. It also runs _compute_curves to calculate the PrecisionRecallCurve. + + Parameters + ---------- + db : Session + The database Session to query against. + parameters : schemas.EvaluationParameters + Any user-defined parameters. + prediction_filter : schemas.Filter + The filter to be used to query predictions. + groundtruth_filter : schemas.Filter + The filter to be used to query groundtruths. + target_type: enums.AnnotationType + The annotation type to compute metrics for. + + + Returns + ---------- + List[schemas.APMetric | schemas.ARMetric | schemas.APMetricAveragedOverIOUs | schemas.mAPMetric | schemas.mARMetric | schemas.mAPMetricAveragedOverIOUs | schemas.PrecisionRecallCurve] + A list of metrics to return to the user. + + """ + + def _annotation_type_to_column( + annotation_type: AnnotationType, + table, + ): + match annotation_type: + case AnnotationType.BOX: + return table.box + case AnnotationType.POLYGON: + return table.polygon + case AnnotationType.RASTER: + return table.raster + case _: + raise RuntimeError + + def _annotation_type_to_geojson( + annotation_type: AnnotationType, + table, + ): + match annotation_type: + case AnnotationType.BOX: + box = table.box + case AnnotationType.POLYGON: + box = gfunc.ST_Envelope(table.polygon) + case AnnotationType.RASTER: + box = gfunc.ST_Envelope(gfunc.ST_MinConvexHull(table.raster)) + case _: + raise RuntimeError + return gfunc.ST_AsGeoJSON(box) + + if ( + parameters.iou_thresholds_to_return is None + or parameters.iou_thresholds_to_compute is None + or parameters.recall_score_threshold is None + or parameters.pr_curve_iou_threshold is None + ): raise ValueError( - "recall_score_threshold should exist in the range 0 <= threshold <= 1." + "iou_thresholds_to_return, iou_thresholds_to_compute, recall_score_threshold, and pr_curve_iou_threshold are required attributes of EvaluationParameters when evaluating detections." ) - if min(iou_thresholds) <= 0 or max(iou_thresholds) > 1.0: + + if ( + parameters.recall_score_threshold > 1 + or parameters.recall_score_threshold < 0 + ): raise ValueError( - "IOU thresholds should exist in the range 0 < threshold <= 1." + "recall_score_threshold should exist in the range 0 <= threshold <= 1." ) - ap_metrics = [] - ar_metrics = [] + labels = core.fetch_union_of_labels( + db=db, + rhs=prediction_filter, + lhs=groundtruth_filter, + ) - for grouper_id, grouper_label in grouper_mappings[ - "grouper_id_to_grouper_label_mapping" - ].items(): - recalls_across_thresholds = [] + grouper_mappings = create_grouper_mappings( + labels=labels, + label_map=parameters.label_map, + evaluation_type=enums.TaskType.OBJECT_DETECTION, + ) - for iou_threshold in iou_thresholds: - if grouper_id not in number_of_groundtruths_per_grouper.keys(): - continue + gt = generate_query( + models.Dataset.name.label("dataset_name"), + models.GroundTruth.id.label("id"), + models.GroundTruth.annotation_id.label("annotation_id"), + models.Annotation.datum_id.label("datum_id"), + models.Datum.uid.label("datum_uid"), + case( + grouper_mappings["label_id_to_grouper_id_mapping"], + value=models.GroundTruth.label_id, + ).label("label_id_grouper"), + _annotation_type_to_geojson(target_type, models.Annotation).label( + "geojson" + ), + db=db, + filter_=groundtruth_filter, + label_source=models.GroundTruth, + ).subquery("groundtruths") - precisions = [] - recalls = [] - # recall true positives require a confidence score above recall_score_threshold, while precision - # true positives only require a confidence score above 0 - recall_cnt_tp = 0 - recall_cnt_fp = 0 - recall_cnt_fn = 0 - precision_cnt_tp = 0 - precision_cnt_fp = 0 + pd = generate_query( + models.Dataset.name.label("dataset_name"), + models.Prediction.id.label("id"), + models.Prediction.annotation_id.label("annotation_id"), + models.Prediction.score.label("score"), + models.Annotation.datum_id.label("datum_id"), + models.Datum.uid.label("datum_uid"), + case( + grouper_mappings["label_id_to_grouper_id_mapping"], + value=models.Prediction.label_id, + ).label("label_id_grouper"), + _annotation_type_to_geojson(target_type, models.Annotation).label( + "geojson" + ), + db=db, + filter_=prediction_filter, + label_source=models.Prediction, + ).subquery("predictions") + + joint = ( + select( + func.coalesce(pd.c.dataset_name, gt.c.dataset_name).label( + "dataset_name" + ), + gt.c.datum_uid.label("gt_datum_uid"), + pd.c.datum_uid.label("pd_datum_uid"), + gt.c.geojson.label("gt_geojson"), + gt.c.id.label("gt_id"), + pd.c.id.label("pd_id"), + gt.c.label_id_grouper.label("gt_label_id_grouper"), + pd.c.label_id_grouper.label("pd_label_id_grouper"), + gt.c.annotation_id.label("gt_ann_id"), + pd.c.annotation_id.label("pd_ann_id"), + pd.c.score.label("score"), + ) + .select_from(pd) + .outerjoin( + gt, + and_( + pd.c.datum_id == gt.c.datum_id, + pd.c.label_id_grouper == gt.c.label_id_grouper, + ), + ) + .subquery() + ) + + # Alias the annotation table (required for joining twice) + gt_annotation = aliased(models.Annotation) + pd_annotation = aliased(models.Annotation) + + # IOU Computation Block + if target_type == AnnotationType.RASTER: + gintersection = gfunc.ST_Count( + gfunc.ST_Intersection(gt_annotation.raster, pd_annotation.raster) + ) + gunion_gt = gfunc.ST_Count(gt_annotation.raster) + gunion_pd = gfunc.ST_Count(pd_annotation.raster) + gunion = gunion_gt + gunion_pd - gintersection + iou_computation = gintersection / gunion + else: + gt_geom = _annotation_type_to_column(target_type, gt_annotation) + pd_geom = _annotation_type_to_column(target_type, pd_annotation) + gintersection = gfunc.ST_Intersection(gt_geom, pd_geom) + gunion = gfunc.ST_Union(gt_geom, pd_geom) + iou_computation = gfunc.ST_Area(gintersection) / gfunc.ST_Area(gunion) + + ious = ( + select( + joint.c.dataset_name, + joint.c.pd_datum_uid, + joint.c.gt_datum_uid, + joint.c.gt_id.label("gt_id"), + joint.c.pd_id.label("pd_id"), + joint.c.gt_label_id_grouper, + joint.c.score, + func.coalesce(iou_computation, 0).label("iou"), + joint.c.gt_geojson, + ) + .select_from(joint) + .join(gt_annotation, gt_annotation.id == joint.c.gt_ann_id) + .join(pd_annotation, pd_annotation.id == joint.c.pd_ann_id) + .subquery() + ) + + ordered_ious = ( + db.query(ious).order_by(-ious.c.score, -ious.c.iou, ious.c.gt_id).all() + ) + + matched_pd_set = set() + matched_sorted_ranked_pairs = defaultdict(list) + + for row in ordered_ious: + ( + dataset_name, + pd_datum_uid, + gt_datum_uid, + gt_id, + pd_id, + gt_label_id_grouper, + score, + iou, + gt_geojson, + ) = row + + if pd_id not in matched_pd_set: + matched_pd_set.add(pd_id) + matched_sorted_ranked_pairs[gt_label_id_grouper].append( + RankedPair( + dataset_name=dataset_name, + pd_datum_uid=pd_datum_uid, + gt_datum_uid=gt_datum_uid, + gt_geojson=gt_geojson, + gt_id=gt_id, + pd_id=pd_id, + score=score, + iou=iou, + is_match=True, # we're joining on grouper IDs, so only matches are included in matched_sorted_ranked_pairs + ) + ) + + # get predictions that didn't make it into matched_sorted_ranked_pairs + # because they didn't have a corresponding groundtruth to pair with + predictions_not_in_sorted_ranked_pairs = ( + db.query( + pd.c.id, + pd.c.score, + pd.c.dataset_name, + pd.c.datum_uid, + pd.c.label_id_grouper, + ) + .filter(pd.c.id.notin_(matched_pd_set)) + .all() + ) + + for ( + pd_id, + score, + dataset_name, + pd_datum_uid, + grouper_id, + ) in predictions_not_in_sorted_ranked_pairs: + if ( + grouper_id in matched_sorted_ranked_pairs + and pd_id not in matched_pd_set + ): + # add to sorted_ranked_pairs in sorted order + bisect.insort( # type: ignore - bisect type issue + matched_sorted_ranked_pairs[grouper_id], + RankedPair( + dataset_name=dataset_name, + pd_datum_uid=pd_datum_uid, + gt_datum_uid=None, + gt_geojson=None, + gt_id=None, + pd_id=pd_id, + score=score, + iou=0, + is_match=False, + ), + key=lambda rp: -rp.score, # bisect assumes decreasing order + ) - if grouper_id in sorted_ranked_pairs: - matched_gts_for_precision = set() - matched_gts_for_recall = set() - for row in sorted_ranked_pairs[grouper_id]: + groundtruths_per_grouper = defaultdict(list) + number_of_groundtruths_per_grouper = defaultdict(int) - precision_score_conditional = row.score > 0 + groundtruths = db.query( + gt.c.id, + gt.c.label_id_grouper, + gt.c.datum_uid, + gt.c.dataset_name, + ) - recall_score_conditional = ( - row.score > recall_score_threshold - or ( - math.isclose(row.score, recall_score_threshold) - and recall_score_threshold > 0 - ) - ) + for gt_id, grouper_id, datum_uid, dset_name in groundtruths: + groundtruths_per_grouper[grouper_id].append( + (dset_name, datum_uid, gt_id) + ) + number_of_groundtruths_per_grouper[grouper_id] += 1 - iou_conditional = ( - row.iou >= iou_threshold and iou_threshold > 0 - ) + if ( + parameters.metrics_to_return + and "PrecisionRecallCurve" in parameters.metrics_to_return + ): + false_positive_entries = db.query( + select( + joint.c.dataset_name, + joint.c.gt_datum_uid, + joint.c.pd_datum_uid, + joint.c.gt_label_id_grouper, + joint.c.pd_label_id_grouper, + joint.c.score.label("score"), + ) + .select_from(joint) + .where( + or_( + joint.c.gt_id.is_(None), + joint.c.pd_id.is_(None), + ) + ) + .subquery() + ).all() - if ( - recall_score_conditional - and iou_conditional - and row.gt_id not in matched_gts_for_recall - ): - recall_cnt_tp += 1 - matched_gts_for_recall.add(row.gt_id) - else: - recall_cnt_fp += 1 + pr_curves = _compute_curves( + sorted_ranked_pairs=matched_sorted_ranked_pairs, + grouper_mappings=grouper_mappings, + groundtruths_per_grouper=groundtruths_per_grouper, + false_positive_entries=false_positive_entries, + iou_threshold=parameters.pr_curve_iou_threshold, + ) + else: + pr_curves = [] - if ( - precision_score_conditional - and iou_conditional - and row.gt_id not in matched_gts_for_precision - ): - matched_gts_for_precision.add(row.gt_id) - precision_cnt_tp += 1 - else: - precision_cnt_fp += 1 + ap_ar_output = [] - recall_cnt_fn = ( - number_of_groundtruths_per_grouper[grouper_id] - - recall_cnt_tp - ) + ap_metrics, ar_metrics = _calculate_ap_and_ar( + sorted_ranked_pairs=matched_sorted_ranked_pairs, + number_of_groundtruths_per_grouper=number_of_groundtruths_per_grouper, + iou_thresholds=parameters.iou_thresholds_to_compute, + grouper_mappings=grouper_mappings, + recall_score_threshold=parameters.recall_score_threshold, + ) - precision_cnt_fn = ( - number_of_groundtruths_per_grouper[grouper_id] - - precision_cnt_tp - ) + ap_ar_output += [ + m for m in ap_metrics if m.iou in parameters.iou_thresholds_to_return + ] + ap_ar_output += ar_metrics - precisions.append( - precision_cnt_tp - / (precision_cnt_tp + precision_cnt_fp) - if (precision_cnt_tp + precision_cnt_fp) - else 0 - ) - recalls.append( - precision_cnt_tp - / (precision_cnt_tp + precision_cnt_fn) - if (precision_cnt_tp + precision_cnt_fn) - else 0 - ) + # calculate averaged metrics + mean_ap_metrics = _compute_mean_detection_metrics_from_aps(ap_metrics) + mean_ar_metrics = _compute_mean_ar_metrics(ar_metrics) - recalls_across_thresholds.append( - recall_cnt_tp / (recall_cnt_tp + recall_cnt_fn) - if (recall_cnt_tp + recall_cnt_fn) - else 0 - ) - else: - precisions = [0] - recalls = [0] - recalls_across_thresholds.append(0) + ap_metrics_ave_over_ious = list( + _compute_detection_metrics_averaged_over_ious_from_aps(ap_metrics) + ) - ap_metrics.append( - schemas.APMetric( - iou=iou_threshold, - value=_calculate_101_pt_interp( - precisions=precisions, recalls=recalls - ), - label=grouper_label, - ) - ) + ap_ar_output += [ + m + for m in mean_ap_metrics + if isinstance(m, schemas.mAPMetric) + and m.iou in parameters.iou_thresholds_to_return + ] + ap_ar_output += mean_ar_metrics + ap_ar_output += ap_metrics_ave_over_ious - ar_metrics.append( - schemas.ARMetric( - ious=set(iou_thresholds), - value=( - sum(recalls_across_thresholds) - / len(recalls_across_thresholds) - if recalls_across_thresholds - else -1 - ), - label=grouper_label, - ) - ) + mean_ap_metrics_ave_over_ious = list( + _compute_mean_detection_metrics_from_aps(ap_metrics_ave_over_ious) + ) + ap_ar_output += mean_ap_metrics_ave_over_ious - return ap_metrics, ar_metrics + return ap_ar_output + pr_curves -def _compute_detection_metrics( +def _compute_detection_metrics_with_detailed_precision_recall_curve( db: Session, parameters: schemas.EvaluationParameters, prediction_filter: schemas.Filter, @@ -343,9 +1026,10 @@ def _compute_detection_metrics( | schemas.mARMetric | schemas.mAPMetricAveragedOverIOUs | schemas.PrecisionRecallCurve + | schemas.DetailedPrecisionRecallCurve ]: """ - Compute detection metrics. + Compute detection metrics via the heaviest possible calculation set. This version of _compute_detection_metrics does IOU calculations for every groundtruth-prediction pair that shares a common grouper key, which is necessary for calculating the DetailedPrecisionRecallCurve metric. Parameters ---------- @@ -363,8 +1047,8 @@ def _compute_detection_metrics( Returns ---------- - List[schemas.APMetric | schemas.ARMetric | schemas.APMetricAveragedOverIOUs | schemas.mAPMetric | schemas.mARMetric | schemas.mAPMetricAveragedOverIOUs | schemas.PrecisionRecallCurve] - A list of average precision metrics. + List[schemas.APMetric | schemas.ARMetric | schemas.APMetricAveragedOverIOUs | schemas.mAPMetric | schemas.mARMetric | schemas.mAPMetricAveragedOverIOUs | schemas.PrecisionRecallCurve | schemas.DetailedPrecisionRecallCurve] + A list of metrics to return to the user. """ @@ -427,31 +1111,32 @@ def _annotation_type_to_geojson( evaluation_type=enums.TaskType.OBJECT_DETECTION, ) - # Join gt, datum, annotation, label. Map grouper_ids to each label_id. - gt = generate_select( + gt = generate_query( models.Dataset.name.label("dataset_name"), models.GroundTruth.id.label("id"), models.GroundTruth.annotation_id.label("annotation_id"), - models.GroundTruth.label_id.label("label_id"), models.Annotation.datum_id.label("datum_id"), models.Datum.uid.label("datum_uid"), case( grouper_mappings["label_id_to_grouper_id_mapping"], value=models.GroundTruth.label_id, ).label("label_id_grouper"), + case( + grouper_mappings["label_id_to_grouper_key_mapping"], + value=models.GroundTruth.label_id, + ).label("label_key_grouper"), _annotation_type_to_geojson(target_type, models.Annotation).label( "geojson" ), + db=db, filter_=groundtruth_filter, label_source=models.GroundTruth, ).subquery("groundtruths") - # Join pd, datum, annotation, label - pd = generate_select( + pd = generate_query( models.Dataset.name.label("dataset_name"), models.Prediction.id.label("id"), models.Prediction.annotation_id.label("annotation_id"), - models.Prediction.label_id.label("label_id"), models.Prediction.score.label("score"), models.Annotation.datum_id.label("datum_id"), models.Datum.uid.label("datum_uid"), @@ -459,9 +1144,14 @@ def _annotation_type_to_geojson( grouper_mappings["label_id_to_grouper_id_mapping"], value=models.Prediction.label_id, ).label("label_id_grouper"), + case( + grouper_mappings["label_id_to_grouper_key_mapping"], + value=models.Prediction.label_id, + ).label("label_key_grouper"), _annotation_type_to_geojson(target_type, models.Annotation).label( "geojson" ), + db=db, filter_=prediction_filter, label_source=models.Prediction, ).subquery("predictions") @@ -471,18 +1161,15 @@ def _annotation_type_to_geojson( func.coalesce(pd.c.dataset_name, gt.c.dataset_name).label( "dataset_name" ), - gt.c.datum_id.label("gt_datum_id"), - pd.c.datum_id.label("pd_datum_id"), gt.c.datum_uid.label("gt_datum_uid"), pd.c.datum_uid.label("pd_datum_uid"), gt.c.geojson.label("gt_geojson"), - pd.c.geojson.label("pd_geojson"), gt.c.id.label("gt_id"), pd.c.id.label("pd_id"), - gt.c.label_id.label("gt_label_id"), - pd.c.label_id.label("pd_label_id"), gt.c.label_id_grouper.label("gt_label_id_grouper"), pd.c.label_id_grouper.label("pd_label_id_grouper"), + gt.c.label_key_grouper.label("gt_label_key_grouper"), + pd.c.label_key_grouper.label("pd_label_key_grouper"), gt.c.annotation_id.label("gt_ann_id"), pd.c.annotation_id.label("pd_ann_id"), pd.c.score.label("score"), @@ -492,13 +1179,12 @@ def _annotation_type_to_geojson( gt, and_( pd.c.datum_id == gt.c.datum_id, - pd.c.label_id_grouper == gt.c.label_id_grouper, + pd.c.label_key_grouper == gt.c.label_key_grouper, ), ) .subquery() ) - # Alias the annotation table (required for joining twice) gt_annotation = aliased(models.Annotation) pd_annotation = aliased(models.Annotation) @@ -518,190 +1204,220 @@ def _annotation_type_to_geojson( gunion = gfunc.ST_Union(gt_geom, pd_geom) iou_computation = gfunc.ST_Area(gintersection) / gfunc.ST_Area(gunion) - # Compute IOUs ious = ( select( joint.c.dataset_name, - joint.c.pd_datum_id, - joint.c.gt_datum_id, joint.c.pd_datum_uid, joint.c.gt_datum_uid, joint.c.gt_id.label("gt_id"), joint.c.pd_id.label("pd_id"), - joint.c.gt_label_id.label("gt_label_id"), - joint.c.pd_label_id.label("pd_label_id"), - joint.c.gt_label_id_grouper.label("gt_label_id_grouper"), - joint.c.pd_label_id_grouper.label("pd_label_id_grouper"), + joint.c.gt_label_id_grouper, + joint.c.pd_label_id_grouper, joint.c.score.label("score"), func.coalesce(iou_computation, 0).label("iou"), joint.c.gt_geojson, - joint.c.pd_geojson, + (joint.c.gt_label_id_grouper == joint.c.pd_label_id_grouper).label( + "is_match" + ), ) .select_from(joint) .join(gt_annotation, gt_annotation.id == joint.c.gt_ann_id) .join(pd_annotation, pd_annotation.id == joint.c.pd_ann_id) - .where( - and_( - joint.c.gt_id.isnot(None), - joint.c.pd_id.isnot(None), - ) - ) .subquery() ) - # Order by score, iou ordered_ious = ( - db.query(ious).order_by(-ious.c.score, -ious.c.iou, ious.c.gt_id).all() + db.query(ious) + .order_by( + ious.c.is_match.desc(), -ious.c.score, -ious.c.iou, ious.c.gt_id + ) + .all() ) - # Filter out repeated predictions pd_set = set() - ranking = {} + matched_pd_set = set() + sorted_ranked_pairs = defaultdict(list) + matched_sorted_ranked_pairs = defaultdict(list) + for row in ordered_ious: ( dataset_name, - _, - _, - _, + pd_datum_uid, gt_datum_uid, gt_id, pd_id, - _, - _, gt_label_id_grouper, - _, + pd_label_id_grouper, score, iou, gt_geojson, - _, + is_match, ) = row - # there should only be one rankedpair per prediction but - # there can be multiple rankedpairs per groundtruth at this point (i.e. before - # an iou threshold is specified) if pd_id not in pd_set: + # sorted_ranked_pairs will include all groundtruth-prediction pairs that meet filter criteria pd_set.add(pd_id) - if gt_label_id_grouper not in ranking: - ranking[gt_label_id_grouper] = [] + sorted_ranked_pairs[gt_label_id_grouper].append( + RankedPair( + dataset_name=dataset_name, + pd_datum_uid=pd_datum_uid, + gt_datum_uid=gt_datum_uid, + gt_geojson=gt_geojson, + gt_id=gt_id, + pd_id=pd_id, + score=score, + iou=iou, + is_match=is_match, + ) + ) + sorted_ranked_pairs[pd_label_id_grouper].append( + RankedPair( + dataset_name=dataset_name, + pd_datum_uid=pd_datum_uid, + gt_datum_uid=gt_datum_uid, + gt_geojson=gt_geojson, + gt_id=gt_id, + pd_id=pd_id, + score=score, + iou=iou, + is_match=is_match, + ) + ) - ranking[gt_label_id_grouper].append( + if pd_id not in matched_pd_set and is_match: + # matched_sorted_ranked_pairs only contains matched groundtruth-prediction pairs + matched_pd_set.add(pd_id) + matched_sorted_ranked_pairs[gt_label_id_grouper].append( RankedPair( dataset_name=dataset_name, + pd_datum_uid=pd_datum_uid, gt_datum_uid=gt_datum_uid, gt_geojson=gt_geojson, gt_id=gt_id, pd_id=pd_id, score=score, iou=iou, + is_match=True, ) ) - # get pds not appearing - predictions = ( - generate_query( - models.Prediction.id, - models.Prediction.score, - models.Dataset.name, - case( - grouper_mappings["label_id_to_grouper_id_mapping"], - value=models.Prediction.label_id, - ).label("label_id_grouper"), - db=db, - filter_=prediction_filter, - label_source=models.Prediction, + # get predictions that didn't make it into matched_sorted_ranked_pairs + # because they didn't have a corresponding groundtruth to pair with + predictions_not_in_sorted_ranked_pairs = ( + db.query( + pd.c.id, + pd.c.score, + pd.c.dataset_name, + pd.c.datum_uid, + pd.c.label_id_grouper, ) - .where(models.Prediction.id.notin_(pd_set)) + .filter(pd.c.id.notin_(matched_pd_set)) .all() ) - for pd_id, score, dataset_name, grouper_id in predictions: - if grouper_id in ranking and pd_id not in pd_set: - # add to ranking in sorted order + for ( + pd_id, + score, + dataset_name, + pd_datum_uid, + grouper_id, + ) in predictions_not_in_sorted_ranked_pairs: + if pd_id not in pd_set: + # add to sorted_ranked_pairs in sorted order bisect.insort( # type: ignore - bisect type issue - ranking[grouper_id], + sorted_ranked_pairs[grouper_id], RankedPair( dataset_name=dataset_name, + pd_datum_uid=pd_datum_uid, gt_datum_uid=None, gt_geojson=None, gt_id=None, pd_id=pd_id, score=score, iou=0, + is_match=False, ), key=lambda rp: -rp.score, # bisect assumes decreasing order ) + bisect.insort( + matched_sorted_ranked_pairs[grouper_id], + RankedPair( + dataset_name=dataset_name, + pd_datum_uid=pd_datum_uid, + gt_datum_uid=None, + gt_geojson=None, + gt_id=None, + pd_id=pd_id, + score=score, + iou=0, + is_match=False, + ), + key=lambda rp: -rp.score, # bisect assumes decreasing order + ) - # Get the groundtruths per grouper_id + # Get all groundtruths per grouper_id groundtruths_per_grouper = defaultdict(list) + predictions_per_grouper = defaultdict(list) number_of_groundtruths_per_grouper = defaultdict(int) - groundtruths = generate_query( - models.GroundTruth.id, - case( - grouper_mappings["label_id_to_grouper_id_mapping"], - value=models.GroundTruth.label_id, - ).label("label_id_grouper"), - models.Datum.uid.label("datum_uid"), - models.Dataset.name.label("dataset_name"), - _annotation_type_to_geojson(target_type, models.Annotation).label( - "gt_geojson" - ), - db=db, - filter_=groundtruth_filter, - label_source=models.GroundTruth, - ).all() + groundtruths = db.query( + gt.c.id, + gt.c.label_id_grouper, + gt.c.datum_uid, + gt.c.dataset_name, + gt.c.geojson, + ) + + predictions = db.query( + pd.c.id, + pd.c.label_id_grouper, + pd.c.datum_uid, + pd.c.dataset_name, + pd.c.geojson, + ) for gt_id, grouper_id, datum_uid, dset_name, gt_geojson in groundtruths: - # we're ok with duplicates since they indicate multiple groundtruths for a given dataset/datum_id + # we're ok with adding duplicates here since they indicate multiple groundtruths for a given dataset/datum_id groundtruths_per_grouper[grouper_id].append( (dset_name, datum_uid, gt_id, gt_geojson) ) number_of_groundtruths_per_grouper[grouper_id] += 1 + for pd_id, grouper_id, datum_uid, dset_name, pd_geojson in predictions: + predictions_per_grouper[grouper_id].append( + (dset_name, datum_uid, pd_id, pd_geojson) + ) if parameters.metrics_to_return is None: raise RuntimeError("Metrics to return should always contains values.") - # Optionally compute precision-recall curves - if "PrecisionRecallCurve" in parameters.metrics_to_return: - false_positive_entries = db.query( - select( - joint.c.dataset_name, - joint.c.gt_datum_uid, - joint.c.pd_datum_uid, - joint.c.gt_label_id_grouper.label("gt_label_id_grouper"), - joint.c.pd_label_id_grouper.label("pd_label_id_grouper"), - joint.c.score.label("score"), - joint.c.pd_geojson, - ) - .select_from(joint) - .where( - or_( - joint.c.gt_id.is_(None), - joint.c.pd_id.is_(None), - ) - ) - .subquery() - ).all() + pr_curves = _compute_detailed_curves( + sorted_ranked_pairs=sorted_ranked_pairs, + grouper_mappings=grouper_mappings, + groundtruths_per_grouper=groundtruths_per_grouper, + predictions_per_grouper=predictions_per_grouper, + pr_curve_iou_threshold=parameters.pr_curve_iou_threshold, + pr_curve_max_examples=( + parameters.pr_curve_max_examples + if parameters.pr_curve_max_examples + else 1 + ), + ) - pr_curves = _compute_curves( - sorted_ranked_pairs=ranking, - grouper_mappings=grouper_mappings, - groundtruths_per_grouper=groundtruths_per_grouper, - false_positive_entries=false_positive_entries, - iou_threshold=parameters.pr_curve_iou_threshold, - ) - else: - pr_curves = [] + ap_ar_output = [] - # Compute AP ap_metrics, ar_metrics = _calculate_ap_and_ar( - sorted_ranked_pairs=ranking, + sorted_ranked_pairs=matched_sorted_ranked_pairs, number_of_groundtruths_per_grouper=number_of_groundtruths_per_grouper, iou_thresholds=parameters.iou_thresholds_to_compute, grouper_mappings=grouper_mappings, recall_score_threshold=parameters.recall_score_threshold, ) + ap_ar_output += [ + m for m in ap_metrics if m.iou in parameters.iou_thresholds_to_return + ] + ap_ar_output += ar_metrics + # calculate averaged metrics mean_ap_metrics = _compute_mean_detection_metrics_from_aps(ap_metrics) mean_ar_metrics = _compute_mean_ar_metrics(ar_metrics) @@ -710,30 +1426,21 @@ def _annotation_type_to_geojson( _compute_detection_metrics_averaged_over_ious_from_aps(ap_metrics) ) - mean_ap_metrics_ave_over_ious = list( - _compute_mean_detection_metrics_from_aps(ap_metrics_ave_over_ious) - ) - - # filter out only specified ious - ap_metrics = [ - m for m in ap_metrics if m.iou in parameters.iou_thresholds_to_return - ] - mean_ap_metrics = [ + ap_ar_output += [ m for m in mean_ap_metrics if isinstance(m, schemas.mAPMetric) and m.iou in parameters.iou_thresholds_to_return ] + ap_ar_output += mean_ar_metrics + ap_ar_output += ap_metrics_ave_over_ious - return ( - ap_metrics - + ar_metrics - + mean_ap_metrics - + mean_ar_metrics - + ap_metrics_ave_over_ious - + mean_ap_metrics_ave_over_ious - + pr_curves + mean_ap_metrics_ave_over_ious = list( + _compute_mean_detection_metrics_from_aps(ap_metrics_ave_over_ious) ) + ap_ar_output += mean_ap_metrics_ave_over_ious + + return ap_ar_output + pr_curves def _compute_detection_metrics_averaged_over_ious_from_aps( @@ -996,13 +1703,29 @@ def compute_detection_metrics(*_, db: Session, evaluation_id: int): groundtruth_filter.require_raster = True prediction_filter.require_raster = True - metrics = _compute_detection_metrics( - db=db, - parameters=parameters, - prediction_filter=prediction_filter, - groundtruth_filter=groundtruth_filter, - target_type=target_type, - ) + if ( + parameters.metrics_to_return + and "DetailedPrecisionRecallCurve" in parameters.metrics_to_return + ): + # this function is more computationally expensive since it calculates IOUs for every groundtruth-prediction pair that shares a label key + metrics = ( + _compute_detection_metrics_with_detailed_precision_recall_curve( + db=db, + parameters=parameters, + prediction_filter=prediction_filter, + groundtruth_filter=groundtruth_filter, + target_type=target_type, + ) + ) + else: + # this function is much faster since it only calculates IOUs for every groundtruth-prediction pair that shares a label id + metrics = _compute_detection_metrics( + db=db, + parameters=parameters, + prediction_filter=prediction_filter, + groundtruth_filter=groundtruth_filter, + target_type=target_type, + ) metric_mappings = create_metric_mappings( db=db, diff --git a/api/valor_api/backend/metrics/metric_utils.py b/api/valor_api/backend/metrics/metric_utils.py index 6e78b2037..e07d6bde9 100644 --- a/api/valor_api/backend/metrics/metric_utils.py +++ b/api/valor_api/backend/metrics/metric_utils.py @@ -20,6 +20,7 @@ def _create_detection_grouper_mappings( """Create grouper mappings for use when evaluating detections.""" label_id_to_grouper_id_mapping = {} + label_id_to_grouper_key_mapping = {} grouper_id_to_grouper_label_mapping = {} grouper_id_to_label_ids_mapping = defaultdict(list) @@ -29,8 +30,12 @@ def _create_detection_grouper_mappings( ) # create an integer to track each group by grouper_id = hash((mapped_key, mapped_value)) + # create a separate grouper_key_id which is used to cross-join labels that share a given key + # when computing IOUs for PrecisionRecallCurve + grouper_key_id = mapped_key label_id_to_grouper_id_mapping[label.id] = grouper_id + label_id_to_grouper_key_mapping[label.id] = grouper_key_id grouper_id_to_grouper_label_mapping[grouper_id] = schemas.Label( key=mapped_key, value=mapped_value ) @@ -38,6 +43,7 @@ def _create_detection_grouper_mappings( return { "label_id_to_grouper_id_mapping": label_id_to_grouper_id_mapping, + "label_id_to_grouper_key_mapping": label_id_to_grouper_key_mapping, "grouper_id_to_label_ids_mapping": grouper_id_to_label_ids_mapping, "grouper_id_to_grouper_label_mapping": grouper_id_to_grouper_label_mapping, } @@ -208,6 +214,7 @@ def create_metric_mappings( | schemas.IOUMetric | schemas.mIOUMetric | schemas.PrecisionRecallCurve + | schemas.DetailedPrecisionRecallCurve ], evaluation_id: int, ) -> list[dict]: diff --git a/api/valor_api/backend/metrics/segmentation.py b/api/valor_api/backend/metrics/segmentation.py index a39c74c89..5d7728a58 100644 --- a/api/valor_api/backend/metrics/segmentation.py +++ b/api/valor_api/backend/metrics/segmentation.py @@ -205,12 +205,12 @@ def _compute_segmentation_metrics( "grouper_id_to_grouper_label_mapping" ][grouper_id] - ret.append( + ret += [ IOUMetric( label=grouper_label, value=computed_iou_score, ) - ) + ] ious_per_grouper_key[grouper_label.key].append(computed_iou_score) diff --git a/api/valor_api/schemas/__init__.py b/api/valor_api/schemas/__init__.py index 28501a48c..91b7ac33a 100644 --- a/api/valor_api/schemas/__init__.py +++ b/api/valor_api/schemas/__init__.py @@ -33,6 +33,7 @@ ConfusionMatrix, ConfusionMatrixEntry, ConfusionMatrixResponse, + DetailedPrecisionRecallCurve, F1Metric, IOUMetric, Metric, @@ -86,6 +87,7 @@ "RecallMetric", "ROCAUCMetric", "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", "ConfusionMatrixResponse", "APMetric", "ARMetric", diff --git a/api/valor_api/schemas/evaluation.py b/api/valor_api/schemas/evaluation.py index fa402fa5d..65512a3b1 100644 --- a/api/valor_api/schemas/evaluation.py +++ b/api/valor_api/schemas/evaluation.py @@ -16,7 +16,11 @@ class EvaluationParameters(BaseModel): Attributes ---------- - metrics: list[str], optional + task_type: TaskType + The task type of a given evaluation. + label_map: Optional[List[List[List[str]]]] + Optional mapping of individual labels to a grouper label. Useful when you need to evaluate performance using labels that differ across datasets and models. + metrics: List[str], optional The list of metrics to compute, store, and return to the user. convert_annotations_to_type: AnnotationType | None = None The type to convert all annotations to. @@ -24,31 +28,32 @@ class EvaluationParameters(BaseModel): A list of floats describing which Intersection over Unions (IoUs) to use when calculating metrics (i.e., mAP). iou_thresholds_to_return: List[float], optional A list of floats describing which Intersection over Union (IoUs) thresholds to calculate a metric for. Must be a subset of `iou_thresholds_to_compute`. - label_map: LabelMapType, optional - Optional mapping of individual labels to a grouper label. Useful when you need to evaluate performance using labels that differ across datasets and models. recall_score_threshold: float, default=0 The confidence score threshold for use when determining whether to count a prediction as a true positive or not while calculating Average Recall. pr_curve_iou_threshold: float, optional The IOU threshold to use when calculating precision-recall curves for object detection tasks. Defaults to 0.5. + pr_curve_max_examples: int + The maximum number of datum examples to store when calculating PR curves. """ task_type: TaskType - metrics_to_return: list[str] | None = None + label_map: LabelMapType | None = None + convert_annotations_to_type: AnnotationType | None = None iou_thresholds_to_compute: list[float] | None = None iou_thresholds_to_return: list[float] | None = None - label_map: LabelMapType | None = None recall_score_threshold: float | None = 0 - pr_curve_iou_threshold: float | None = 0.5 + pr_curve_iou_threshold: float = 0.5 + pr_curve_max_examples: int = 1 # pydantic setting model_config = ConfigDict(extra="forbid") @model_validator(mode="after") @classmethod - def _validate_by_task_type(cls, values): - """Validate the IOU thresholds.""" + def _validate_parameters(cls, values): + """Validate EvaluationParameters via type-specific checks.""" # set default metrics for each task type if values.metrics_to_return is None: diff --git a/api/valor_api/schemas/metrics.py b/api/valor_api/schemas/metrics.py index d81d9a3e0..2300ca56c 100644 --- a/api/valor_api/schemas/metrics.py +++ b/api/valor_api/schemas/metrics.py @@ -418,7 +418,23 @@ def db_mapping(self, evaluation_id: int) -> dict: } -class PrecisionRecallCurve(BaseModel): +class _BasePrecisionRecallCurve(BaseModel): + """ + Describes the parent class of our precision-recall curve metrics. + + Attributes + ---------- + label_key: str + The label key associated with the metric. + pr_curve_iou_threshold: float, optional + The IOU threshold to use when calculating precision-recall curves. Defaults to 0.5. + """ + + label_key: str + pr_curve_iou_threshold: float | None = None + + +class PrecisionRecallCurve(_BasePrecisionRecallCurve): """ Describes a precision-recall curve. @@ -432,22 +448,76 @@ class PrecisionRecallCurve(BaseModel): The IOU threshold to use when calculating precision-recall curves. Defaults to 0.5. """ - label_key: str value: dict[ str, dict[ float, + dict[str, int | float | None], + ], + ] + + def db_mapping(self, evaluation_id: int) -> dict: + """ + Creates a mapping for use when uploading the curves to the database. + + Parameters + ---------- + evaluation_id : int + The evaluation id. + + Returns + ---------- + A mapping dictionary. + """ + + return { + "value": self.value, + "type": "PrecisionRecallCurve", + "evaluation_id": evaluation_id, + "parameters": { + "label_key": self.label_key, + "pr_curve_iou_threshold": self.pr_curve_iou_threshold, + }, + } + + +class DetailedPrecisionRecallCurve(_BasePrecisionRecallCurve): + """ + Describes a detailed precision-recall curve, which includes datum examples for each classification (e.g., true positive, false negative, etc.). + + Attributes + ---------- + label_key: str + The label key associated with the metric. + value: dict + A nested dictionary where the first key is the class label, the second key is the confidence threshold (e.g., 0.05), the third key is the metric name (e.g., "precision"), and the final key is either the value itself (for precision, recall, etc.) or a list of tuples containing data for each observation. + pr_curve_iou_threshold: float, optional + The IOU threshold to use when calculating precision-recall curves. Defaults to 0.5. + """ + + value: dict[ + str, # the label value + dict[ + float, # the IOU threshold dict[ - str, - int - | float - | list[tuple[str, str]] # for classification tasks - | list[tuple[str, str, str]] # for object detection tasks - | None, + str, # the metric (e.g., "tp" for true positive) + dict[ + str, # the label for the next level of the dictionary (e.g., "observations" or "total") + int # the count of classifications + | dict[ + str, # the subclassification for the label (e.g., "misclassifications") + dict[ + str, # the label for the next level of the dictionary (e.g., "count" or "examples") + int # the count of subclassifications + | list[ + tuple[str, str] | tuple[str, str, str] + ], # a list containing examples + ], + ], + ], ], ], ] - pr_curve_iou_threshold: float | None = None def db_mapping(self, evaluation_id: int) -> dict: """ @@ -465,7 +535,7 @@ def db_mapping(self, evaluation_id: int) -> dict: return { "value": self.value, - "type": "PrecisionRecallCurve", + "type": "DetailedPrecisionRecallCurve", "evaluation_id": evaluation_id, "parameters": { "label_key": self.label_key, diff --git a/client/valor/coretypes.py b/client/valor/coretypes.py index a29fe1940..9d70ffa53 100644 --- a/client/valor/coretypes.py +++ b/client/valor/coretypes.py @@ -871,6 +871,7 @@ def evaluate_classification( datasets: Optional[Union[Dataset, List[Dataset]]] = None, filter_by: Optional[FilterType] = None, label_map: Optional[Dict[Label, Label]] = None, + pr_curve_max_examples: int = 1, metrics_to_return: Optional[List[str]] = None, allow_retries: bool = False, ) -> Evaluation: @@ -908,6 +909,7 @@ def evaluate_classification( parameters=EvaluationParameters( task_type=TaskType.CLASSIFICATION, label_map=self._create_label_map(label_map=label_map), + pr_curve_max_examples=pr_curve_max_examples, metrics_to_return=metrics_to_return, ), ) @@ -931,6 +933,7 @@ def evaluate_detection( recall_score_threshold: float = 0, metrics_to_return: Optional[List[str]] = None, pr_curve_iou_threshold: float = 0.5, + pr_curve_max_examples: int = 1, allow_retries: bool = False, ) -> Evaluation: """ @@ -952,10 +955,10 @@ def evaluate_detection( Optional mapping of individual labels to a grouper label. Useful when you need to evaluate performance using labels that differ across datasets and models. recall_score_threshold: float, default=0 The confidence score threshold for use when determining whether to count a prediction as a true positive or not while calculating Average Recall. - metrics: List[str], optional - The list of metrics to compute, store, and return to the user. pr_curve_iou_threshold: float, optional The IOU threshold to use when calculating precision-recall curves. Defaults to 0.5. + pr_curve_max_examples: int, optional + The maximum number of datum examples to store when calculating PR curves. allow_retries : bool, default = False Option to retry previously failed evaluations. @@ -982,6 +985,7 @@ def evaluate_detection( recall_score_threshold=recall_score_threshold, metrics_to_return=metrics_to_return, pr_curve_iou_threshold=pr_curve_iou_threshold, + pr_curve_max_examples=pr_curve_max_examples, ) datum_filter = self._format_constraints(datasets, filter_by) request = EvaluationRequest( diff --git a/client/valor/schemas/evaluation.py b/client/valor/schemas/evaluation.py index fe189f88c..54f83f657 100644 --- a/client/valor/schemas/evaluation.py +++ b/client/valor/schemas/evaluation.py @@ -12,31 +12,36 @@ class EvaluationParameters: Attributes ---------- - iou_thresholds_to_compute : Optional[List[float]] - A list of floats describing which Intersection over Unions (IoUs) to use when calculating metrics (i.e., mAP). - iou_thresholds_to_return: Optional[List[float]] - A list of floats describing which Intersection over Union (IoUs) thresholds to calculate a metric for. Must be a subset of `iou_thresholds_to_compute`. + task_type: TaskType + The task type of a given evaluation. label_map: Optional[List[List[List[str]]]] Optional mapping of individual labels to a grouper label. Useful when you need to evaluate performance using labels that differ across datasets and models. - recall_score_threshold: float, default=0 - The confidence score threshold for use when determining whether to count a prediction as a true positive or not while calculating Average Recall. metrics: List[str], optional The list of metrics to compute, store, and return to the user. + convert_annotations_to_type: AnnotationType | None = None + The type to convert all annotations to. + iou_thresholds_to_compute: List[float], optional + A list of floats describing which Intersection over Unions (IoUs) to use when calculating metrics (i.e., mAP). + iou_thresholds_to_return: List[float], optional + A list of floats describing which Intersection over Union (IoUs) thresholds to calculate a metric for. Must be a subset of `iou_thresholds_to_compute`. + recall_score_threshold: float, default=0 + The confidence score threshold for use when determining whether to count a prediction as a true positive or not while calculating Average Recall. pr_curve_iou_threshold: float, optional The IOU threshold to use when calculating precision-recall curves for object detection tasks. Defaults to 0.5. - + pr_curve_max_examples: int + The maximum number of datum examples to store when calculating PR curves. """ task_type: TaskType + label_map: Optional[List[List[List[str]]]] = None + metrics_to_return: Optional[List[str]] = None - # object detection convert_annotations_to_type: Optional[AnnotationType] = None iou_thresholds_to_compute: Optional[List[float]] = None iou_thresholds_to_return: Optional[List[float]] = None - label_map: Optional[List[List[List[str]]]] = None recall_score_threshold: float = 0 - metrics_to_return: Optional[List[str]] = None pr_curve_iou_threshold: float = 0.5 + pr_curve_max_examples: int = 1 @dataclass diff --git a/docs/metrics.md b/docs/metrics.md index a3ff16898..c2aeff8e3 100644 --- a/docs/metrics.md +++ b/docs/metrics.md @@ -11,7 +11,8 @@ If we're missing an important metric for your particular use case, please [write | F1 | A weighted average of precision and recall. | $\frac{2 * Precision * Recall}{Precision + Recall}$ | | Accuracy | The number of true predictions divided by the total number of predictions. | $\dfrac{\|TP\|+\|TN\|}{\|TP\|+\|TN\|+\|FP\|+\|FN\|}$ | | ROC AUC | The area under the Receiver Operating Characteristic (ROC) curve for the predictions generated by a given model. | See [ROCAUC methods](#binary-roc-auc). | -| Precision-Recall Curves | Outputs a nested dictionary containing the true positives, false positives, true negatives, false negatives, precision, recall, and F1 score for each (label key, label value, confidence threshold) combination. Computing this metric requires passing `PrecisionRecallCurve` into the list of `metrics_to_return` at evaluation time. | See [precision-recall curve methods](#precision-recall-curves)| +| Precision-Recall Curves | Outputs a nested dictionary containing the true positives, false positives, true negatives, false negatives, precision, recall, and F1 score for each (label key, label value, confidence threshold) combination. | See [precision-recall curve methods](#precision-recall-curves)| +| Detailed Precision-Recall Curves | Similar to `PrecisionRecallCurve`, except this metric a) classifies false positives as `hallucinations` or `misclassifications`, b) classifies false negatives as `misclassifications` or `missed_detections`, and c) gives example datums for each observation, up to a maximum of `pr_curve_max_examples`. | See [detailed precision-recall curve methods](#detailedprecisionrecallcurve)| ## Object Detection and Instance Segmentation Metrics** @@ -23,7 +24,8 @@ If we're missing an important metric for your particular use case, please [write | mAP Averaged Over IOUs | The average of several mAP metrics grouped by label keys. | $\dfrac{1}{\text{number of thresholds}} \sum\limits_{iou \in thresholds} mAP_{iou}$ | | Average Recall (AR) | The average of several recall metrics across IOU thresholds, grouped by class labels. | See [AR methods](#average-recall-ar). | | Mean Average Recall (mAR) | The average of several AR metrics, grouped by label keys. | $\dfrac{1}{\text{number of labels}} \sum\limits_{label \in labels} AR_{class}$ | -| Precision-Recall Curves | Outputs a nested dictionary containing the true positives, false positives, true negatives, false negatives, precision, recall, and F1 score for each (label key, label value, confidence threshold) combination. Computing this metric requires passing `PrecisionRecallCurve` into the list of `metrics_to_return` at evaluation time. These curves are calculated using a default IOU threshold of 0.5; you can set your own threshold by passing a float between 0 and 1 to the `pr_curve_iou_threshold` parameter at evaluation time. | See [precision-recall curve methods](#precision-recall-curves)| +| Precision-Recall Curves | Outputs a nested dictionary containing the true positives, false positives, true negatives, false negatives, precision, recall, and F1 score for each (label key, label value, confidence threshold) combination. These curves are calculated using a default IOU threshold of 0.5; you can set your own threshold by passing a float between 0 and 1 to the `pr_curve_iou_threshold` parameter at evaluation time. | See [precision-recall curve methods](#precision-recall-curves)| +| Detailed Precision-Recall Curves | Similar to `PrecisionRecallCurve`, except this metric a) classifies false positives as `hallucinations` or `misclassifications`, b) classifies false negatives as `misclassifications` or `missed_detections`, and c) gives example datums and bounding boxes for each observation, up to a maximum of `pr_curve_max_examples`. | See [detailed precision-recall curve methods](#detailedprecisionrecallcurve)| **When calculating IOUs for object detection metrics, Valor handles the necessary conversion between different types of geometric annotations. For example, if your model prediction is a polygon and your groundtruth is a raster, then the raster will be converted to a polygon prior to calculating the IOU. @@ -161,13 +163,19 @@ To calculate Average Recall (AR), we: Note that this metric differs from COCO's calculation in two ways: - COCO averages across classes while calculating AR, while we calculate AR separately for each class. Our AR calculations matches the original FAIR definition of AR, while our mAR calculations match what COCO calls AR. -- COCO calculates three different AR metrics (AR@1, AR@5, AR@100)) by considering only the top 1/5/100 most confident predictions during the matching process. Valor, on the other hand, allows users to input a `recall_score_threshold` value that will prevent low-confidence predictions from being counted as true positives when calculating AR. +- COCO calculates three different AR metrics (AR@1, AR@5, AR@100) by considering only the top 1/5/100 most confident predictions during the matching process. Valor, on the other hand, allows users to input a `recall_score_threshold` value that will prevent low-confidence predictions from being counted as true positives when calculating AR. ## Precision-Recall Curves - -Precision-recall curves offer insight into which confidence threshold you should pick for your production pipeline. To compute these curves for your classification or object detection workflow, simply pass `PrecisionRecallCurve` into the list of `metrics_to_return` when initiating your evaluation. Valor will then tabulate the true positives, false positives, true negatives, false negatives, precision, recall, and F1 score for each (label key, label value, confidence threshold) combination, and store them in a nested dictionary for your use. When using the Valor Python client, the output will be formatted as follows: +Precision-recall curves offer insight into which confidence threshold you should pick for your production pipeline. The `PrecisionRecallCurve` metric includes the true positives, false positives, true negatives, false negatives, precision, recall, and F1 score for each (label key, label value, confidence threshold) combination. When using the Valor Python client, the output will be formatted as follows: ```python + +pr_evaluation = evaluate_detection( + data=dataset, +) +print(pr_evaluation) + +[..., { "type": "PrecisionRecallCurve", "parameters": { @@ -177,20 +185,17 @@ Precision-recall curves offer insight into which confidence threshold you should "value": { "cat": { # The value of the label. "0.05": { # The confidence score threshold, ranging from 0.05 to 0.95 in increments of 0.05. - "fn": [], - "fp": [ - ( - 'test_dataset', - 1, - '{"type":"Polygon","coordinates":[[[464.08,105.09],[495.74,105.09],[495.74,146.99],[464.08,146.99],[464.08,105.09]]]}' - ) # There's one false positive for this (key, value, confidence threshold) combination as indicated by the one tuple shown here. This tuple contains that observation's dataset name, datum ID, and coordinates in the form of a GeoJSON string. For classification tasks, this tuple will only contain the given observation's dataset name and datum ID. - ], - "tp": [], - "recall": -1, - "f1_score": -1, - "precision": 0.0, + "fn": 0, + "fp": 1, + "tp": 3, + "recall": 1, + "precision": 0.75, + "f1_score": .857, }, ... + }, + } +}] ``` It's important to note that these curves are computed slightly differently from our other aggregate metrics above: @@ -203,7 +208,78 @@ We think the approach above makes sense when calculating aggregate precision and ### Detection Tasks -The `PrecisionRecallCurve` values differ from the precision-recall curves used to calculate [Average Precsion](#average-precision-ap) in two subtle ways: +The `PrecisionRecallCurve` values differ from the precision-recall curves used to calculate [Average Precision](#average-precision-ap) in two subtle ways: - The `PrecisionRecallCurve` values visualize how precision and recall change as confidence thresholds vary from 0.05 to 0.95 in increments of 0.05. In contrast, the precision-recall curves used to calculate Average Precision are non-uniform; they vary over the actual confidence scores for each ground truth-prediction match. -- If your pipeline predicts a label on an image, but that label doesn't exist on any ground truths in that particular image, then the `PrecisionRecallCurve` values will consider that prediction to be a false positive, whereas the other detection metrics will ignore that particular prediction. \ No newline at end of file +- If your pipeline predicts a label on an image, but that label doesn't exist on any ground truths in that particular image, then the `PrecisionRecallCurve` values will consider that prediction to be a false positive, whereas the other detection metrics will ignore that particular prediction. + +### DetailedPrecisionRecallCurve + +Valor also includes a more detailed version of `PrecisionRecallCurve` which can be useful for debugging your model's false positives and false negatives. When calculating `DetailedPrecisionCurve`, Valor will classify false positives as either `hallucinations` or `misclassifications` and your false negatives as either `missed_detections` or `misclassifications` using the following logic: + +#### Classification Tasks + - A **false positive** is a `misclassification` if there is a qualified prediction (with `score >= score_threshold`) with the same `Label.key` as the groundtruth on the datum, but the `Label.value` is incorrect. For example: if there's a photo with one groundtruth label on it (e.g., `Label(key='animal', value='dog')`), and we predicted another label value (e.g., `Label(key='animal', value='cat')`) on that datum, we'd say it's a `misclassification` since the key was correct but the value was not. Any false positives that do not meet this criteria are considered to be `hallucinations`. + - Similarly, a **false negative** is a `misclassification` if there is a prediction with the same `Label.key` as the groundtruth on the datum, but the `Label.value` is incorrect. Any false negatives that do not meet this criteria are considered to be `missed_detections`. + +#### Object Detection Tasks + - A **false positive** is a `misclassification` if a) there is a qualified prediction with the same `Label.key` as the groundtruth on the datum, but the `Label.value` is incorrect, and b) the qualified prediction and groundtruth have an IOU >= `pr_curve_iou_threshold`. For example: if there's a photo with one groundtruth label on it (e.g., `Label(key='animal', value='dog')`), and we predicted another bounding box directly over that same object (e.g., `Label(key='animal', value='cat')`), we'd say it's a `misclassification`. Any false positives that do not meet this criteria are considered to be `hallucinations`. + - A **false negative** is determined to be a `misclassification` if the two criteria above are met: a) there is a qualified prediction with the same `Label.key` as the groundtruth on the datum, but the `Label.value` is incorrect, and b) the qualified prediction and groundtruth have an IOU >= `pr_curve_iou_threshold`. Any false negatives that do not meet this criteria are considered to be `missed_detections`. + +The `DetailedPrecisionRecallOutput` also includes up to `n` examples of each type of error, where `n` is set using `pr_curve_max_examples`. An example output is as follows: + + +```python +# To retrieve more detailed examples for each `fn`, `fp`, and `tp`, look at the `DetailedPrecisionRecallCurve` metric +detailed_evaluation = evaluate_detection( + data=dataset, + pr_curve_max_examples=1 # The maximum number of examples to return for each obseration type (e.g., hallucinations, misclassifications, etc.) + metrics_to_return=[..., 'DetailedPrecisionRecallCurve'] # DetailedPrecisionRecallCurve isn't returned by default; the user must ask for it explicitely +) +print(detailed_evaluation) + +[..., +{ + "type": "PrecisionRecallCurve", + "parameters": { + "label_key": "class", # The key of the label. + "pr_curve_iou_threshold": 0.5, + }, + "value": { + "cat": { # The value of the label. + "0.05": { # The confidence score threshold, ranging from 0.05 to 0.95 in increments of 0.05. + "fp": { + "total": 1, + "observations": { + 'hallucinations': { + "count": 1, + "examples": [ + ( + 'test_dataset', + 1, + '{"type":"Polygon","coordinates":[[[464.08,105.09],[495.74,105.09],[495.74,146.99],[464.08,146.99],[464.08,105.91]]]}' + ) # There's one false positive for this (key, value, confidence threshold) combination as indicated by the one tuple shown here. This tuple contains that observation's dataset name, datum ID, and coordinates in the form of a GeoJSON string. For classification tasks, this tuple will only contain the given observation's dataset name and datum ID. + ], + } + }, + }, + "tp": { + "total": 3, + "observations": { + 'all': { + "count": 3, + "examples": [ + ( + 'test_dataset', + 2, + '{"type":"Polygon","coordinates":[[[464.08,105.09],[495.74,105.09],[495.74,146.99],[464.08,146.99],[464.08,105.91]]]}' + ) # We only return one example since `pr_curve_max_examples` is set to 1 by default; update this argument at evaluation time to store and retrieve an arbitrary number of examples. + ], + }, + } + }, + "fn": {...}, + }, + }, + } +}] +``` \ No newline at end of file diff --git a/docs/technical_concepts.md b/docs/technical_concepts.md index a974fd20c..19bffb4e1 100644 --- a/docs/technical_concepts.md +++ b/docs/technical_concepts.md @@ -16,7 +16,7 @@ Note that Valor does _not_ store raw data (such as underlying images) or facilit ## Supported Task Types -As of January 2024, Valor supports the following types of supervised learning tasks and associated metrics: +As of May 2024, Valor supports the following types of supervised learning tasks and associated metrics: - Classification (including multi-label classification) - F1 @@ -24,11 +24,15 @@ As of January 2024, Valor supports the following types of supervised learning ta - Accuracy - Precision - Recall + - Precision Recall Curve + - Detailed Precision Recall Curve - Object detection - AP - mAP - AP Averaged Over IOUs - mAP Averaged Over IOUs + - Precision Recall Curve + - Detailed Precision Recall Curve - Segmentation (including both instance and semantic segmentation) - IOU - mIOU diff --git a/integration_tests/client/datatype/test_data_generation.py b/integration_tests/client/datatype/test_data_generation.py index a8178e718..cb7b9ff63 100644 --- a/integration_tests/client/datatype/test_data_generation.py +++ b/integration_tests/client/datatype/test_data_generation.py @@ -413,6 +413,7 @@ def test_generate_prediction_data(client: Client): "mAPAveragedOverIOUs", ], "pr_curve_iou_threshold": 0.5, + "pr_curve_max_examples": 1, }, "meta": {}, } diff --git a/integration_tests/client/metrics/test_classification.py b/integration_tests/client/metrics/test_classification.py index f86a77d9d..d99abd15b 100644 --- a/integration_tests/client/metrics/test_classification.py +++ b/integration_tests/client/metrics/test_classification.py @@ -166,7 +166,11 @@ def test_evaluate_image_clf( ] for m in metrics: - assert m in expected_metrics + if m["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert m in expected_metrics for m in expected_metrics: assert m in metrics @@ -351,7 +355,11 @@ def test_evaluate_tabular_clf( }, ] for m in metrics: - assert m in expected_metrics + if m["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert m in expected_metrics for m in expected_metrics: assert m in metrics @@ -373,7 +381,11 @@ def test_evaluate_tabular_clf( assert len(bulk_evals) == 1 for metric in bulk_evals[0].metrics: - assert metric in expected_metrics + if metric["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert metric in expected_metrics assert len(bulk_evals[0].confusion_matrices[0]) == len( expected_confusion_matrix ) @@ -416,7 +428,11 @@ def test_evaluate_tabular_clf( metrics_from_eval_settings_id = results[0].metrics assert len(metrics_from_eval_settings_id) == len(expected_metrics) for m in metrics_from_eval_settings_id: - assert m in expected_metrics + if m["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert m in expected_metrics for m in expected_metrics: assert m in metrics_from_eval_settings_id @@ -584,7 +600,11 @@ def test_stratify_clf_metrics( for metrics in [val2_metrics, val_bool_metrics]: assert len(metrics) == len(expected_metrics) for m in metrics: - assert m in expected_metrics + if m["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert m in expected_metrics for m in expected_metrics: assert m in metrics @@ -1003,6 +1023,7 @@ def test_evaluate_classification_with_label_maps( eval_job = model.evaluate_classification( dataset, label_map=label_mapping, + pr_curve_max_examples=3, metrics_to_return=[ "Precision", "Recall", @@ -1010,17 +1031,22 @@ def test_evaluate_classification_with_label_maps( "Accuracy", "ROCAUC", "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", ], ) assert eval_job.id assert eval_job.wait_for_completion(timeout=30) == EvaluationStatus.DONE - pr_expected_lengths = { + pr_expected_values = { # k3 (0, "k3", "v1", "0.1", "fp"): 1, (0, "k3", "v1", "0.1", "tn"): 2, (0, "k3", "v3", "0.1", "fn"): 1, (0, "k3", "v3", "0.1", "tn"): 2, + (0, "k3", "v3", "0.1", "accuracy"): 2 / 3, + (0, "k3", "v3", "0.1", "precision"): -1, + (0, "k3", "v3", "0.1", "recall"): 0, + (0, "k3", "v3", "0.1", "f1_score"): -1, # k4 (1, "k4", "v1", "0.1", "fp"): 1, (1, "k4", "v1", "0.1", "tn"): 2, @@ -1048,42 +1074,33 @@ def test_evaluate_classification_with_label_maps( "0.1", "tn", ): 2, - (3, "special_class", "cat_type1", "0.1", "tp"): 3, - (3, "special_class", "cat_type1", "0.1", "tn"): 0, - (3, "special_class", "cat_type1", "0.95", "tp"): 3, - } - - pr_expected_metrics = { - # k3, v3 - (0, "k3", "v3", "0.1", "accuracy"): 2 / 3, - (0, "k3", "v3", "0.1", "precision"): -1, - (0, "k3", "v3", "0.1", "recall"): 0, - (0, "k3", "v3", "0.1", "f1_score"): -1, - # k5, v1 (2, "k5", "v1", "0.1", "accuracy"): 2 / 3, (2, "k5", "v1", "0.1", "precision"): 0, (2, "k5", "v1", "0.1", "recall"): -1, (2, "k5", "v1", "0.1", "f1_score"): -1, - # special_class, cat_type1 - (3, "special_class", "cat_type1", "0.1", "accuracy"): 1, - (3, "special_class", "cat_type1", "0.1", "precision"): 1, - (3, "special_class", "cat_type1", "0.1", "recall"): 1, - (3, "special_class", "cat_type1", "0.1", "f1_score"): 1, + # special_class + (3, "special_class", "cat_type1", "0.1", "tp"): 3, + (3, "special_class", "cat_type1", "0.1", "tn"): 0, + (3, "special_class", "cat_type1", "0.95", "tp"): 3, } metrics = eval_job.metrics pr_metrics = [] + detailed_pr_metrics = [] for m in metrics: - if m["type"] != "PrecisionRecallCurve": - assert m in cat_expected_metrics - else: + if m["type"] == "PrecisionRecallCurve": pr_metrics.append(m) + elif m["type"] == "DetailedPrecisionRecallCurve": + detailed_pr_metrics.append(m) + else: + assert m in cat_expected_metrics for m in cat_expected_metrics: assert m in metrics pr_metrics.sort(key=lambda x: x["parameters"]["label_key"]) + detailed_pr_metrics.sort(key=lambda x: x["parameters"]["label_key"]) for ( index, @@ -1091,24 +1108,70 @@ def test_evaluate_classification_with_label_maps( value, threshold, metric, - ), expected_length in pr_expected_lengths.items(): + ), expected_value in pr_expected_values.items(): assert ( - len(pr_metrics[index]["value"][value][threshold][metric]) - == expected_length + pr_metrics[index]["value"][value][threshold][metric] + == expected_value ) + # check DetailedPrecisionRecallCurve + detailed_pr_expected_answers = { + # k3 + (0, "v1", "0.1", "tp"): {"all": 0, "total": 0}, + (0, "v1", "0.1", "fp"): { + "hallucinations": 0, + "misclassifications": 1, + "total": 1, + }, + (0, "v1", "0.1", "tn"): {"all": 2, "total": 2}, + (0, "v1", "0.1", "fn"): { + "missed_detections": 0, + "misclassifications": 0, + "total": 0, + }, + # k4 + (1, "v1", "0.1", "tp"): {"all": 0, "total": 0}, + (1, "v1", "0.1", "fp"): { + "hallucinations": 0, + "misclassifications": 1, + "total": 1, + }, + (1, "v1", "0.1", "tn"): {"all": 2, "total": 2}, + (1, "v1", "0.1", "fn"): { + "missed_detections": 0, + "misclassifications": 0, + "total": 0, + }, + (1, "v4", "0.1", "fn"): { + "missed_detections": 0, + "misclassifications": 1, + "total": 1, + }, + (1, "v8", "0.1", "tn"): {"all": 2, "total": 2}, + } + for ( index, - key, value, threshold, metric, - ), expected_length in pr_expected_metrics.items(): - assert ( - pr_metrics[index]["value"][value][threshold][metric] - ) == expected_length - - confusion_matrix = eval_job.confusion_matrices + ), expected_output in detailed_pr_expected_answers.items(): + model_output = detailed_pr_metrics[index]["value"][value][threshold][ + metric + ] + assert isinstance(model_output, dict) + assert model_output["total"] == expected_output["total"] + assert all( + [ + model_output["observations"][key]["count"] # type: ignore - we know this element is a dict + == expected_output[key] + for key in [ + key + for key in expected_output.keys() + if key not in ["total"] + ] + ] + ) # check metadata assert eval_job.meta["datums"] == 3 @@ -1116,6 +1179,9 @@ def test_evaluate_classification_with_label_maps( assert eval_job.meta["annotations"] == 6 assert eval_job.meta["duration"] <= 10 # usually 2 + # check confusion matrix + confusion_matrix = eval_job.confusion_matrices + for row in confusion_matrix: if row["label_key"] == "special_class": for entry in cat_expected_cm[0]["entries"]: diff --git a/integration_tests/client/metrics/test_detection.py b/integration_tests/client/metrics/test_detection.py index 84ea5092b..ce11c85f7 100644 --- a/integration_tests/client/metrics/test_detection.py +++ b/integration_tests/client/metrics/test_detection.py @@ -161,6 +161,7 @@ def test_evaluate_detection( "mAPAveragedOverIOUs", ], "pr_curve_iou_threshold": 0.5, + "pr_curve_max_examples": 1, }, "status": EvaluationStatus.DONE.value, "confusion_matrices": [], @@ -168,7 +169,11 @@ def test_evaluate_detection( "ignored_pred_labels": [], } for m in actual_metrics: - assert m in expected_metrics + if m["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert m in expected_metrics for m in expected_metrics: assert m in actual_metrics @@ -290,6 +295,7 @@ def test_evaluate_detection( "mAPAveragedOverIOUs", ], "pr_curve_iou_threshold": 0.5, + "pr_curve_max_examples": 1, }, "status": EvaluationStatus.DONE.value, "confusion_matrices": [], @@ -298,7 +304,11 @@ def test_evaluate_detection( } for m in actual_metrics: - assert m in expected_metrics + if m["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert m in expected_metrics for m in expected_metrics: assert m in actual_metrics @@ -351,6 +361,7 @@ def test_evaluate_detection( "mAPAveragedOverIOUs", ], "pr_curve_iou_threshold": 0.5, + "pr_curve_max_examples": 1, }, # check metrics below "status": EvaluationStatus.DONE.value, @@ -429,6 +440,7 @@ def test_evaluate_detection( "mAPAveragedOverIOUs", ], "pr_curve_iou_threshold": 0.5, + "pr_curve_max_examples": 1, }, # check metrics below "status": EvaluationStatus.DONE.value, @@ -634,6 +646,7 @@ def test_evaluate_detection_with_json_filters( "mAPAveragedOverIOUs", ], "pr_curve_iou_threshold": 0.5, + "pr_curve_max_examples": 1, }, # check metrics below "status": EvaluationStatus.DONE.value, @@ -785,7 +798,11 @@ def test_get_evaluations( assert len(evaluations) == 1 assert len(evaluations[0].metrics) for m in evaluations[0].metrics: - assert m in expected_metrics + if m["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert m in expected_metrics for m in expected_metrics: assert m in evaluations[0].metrics @@ -822,7 +839,11 @@ def test_get_evaluations( assert len(second_model_evaluations) == 1 for m in second_model_evaluations[0].metrics: - assert m in second_model_expected_metrics + if m["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert m in second_model_expected_metrics for m in second_model_expected_metrics: assert m in second_model_evaluations[0].metrics @@ -837,12 +858,20 @@ def test_get_evaluations( ] if evaluation.model_name == model_name: for m in evaluation.metrics: - assert m in expected_metrics + if m["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert m in expected_metrics for m in expected_metrics: assert m in evaluation.metrics elif evaluation.model_name == "second_model": for m in evaluation.metrics: - assert m in second_model_expected_metrics + if m["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert m in second_model_expected_metrics for m in second_model_expected_metrics: assert m in evaluation.metrics @@ -877,7 +906,7 @@ def test_get_evaluations( metrics_to_sort_by={"mAPAveragedOverIOUs": "k1"}, ) - assert both_evaluations_from_evaluation_ids[0].metrics[-1]["value"] == 0 + assert both_evaluations_from_evaluation_ids[0].metrics[-2]["value"] == 0 # with sorting, the evaluation with the higher mAPAveragedOverIOUs is returned first assert ( @@ -1137,6 +1166,7 @@ def test_evaluate_detection_with_label_maps( dataset, iou_thresholds_to_compute=[0.1, 0.6], iou_thresholds_to_return=[0.1, 0.6], + pr_curve_max_examples=1, metrics_to_return=[ "AP", "AR", @@ -1145,6 +1175,7 @@ def test_evaluate_detection_with_label_maps( "mAR", "mAPAveragedOverIOUs", "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", ], ) @@ -1164,16 +1195,18 @@ def test_evaluate_detection_with_label_maps( metrics = eval_job.metrics pr_metrics = [] + pr_metrics = [] + detailed_pr_metrics = [] for m in metrics: - if m["type"] != "PrecisionRecallCurve": - assert m in baseline_expected_metrics - else: + if m["type"] == "PrecisionRecallCurve": pr_metrics.append(m) - - for m in baseline_expected_metrics: - assert m in metrics + elif m["type"] == "DetailedPrecisionRecallCurve": + detailed_pr_metrics.append(m) + else: + assert m in baseline_expected_metrics pr_metrics.sort(key=lambda x: x["parameters"]["label_key"]) + detailed_pr_metrics.sort(key=lambda x: x["parameters"]["label_key"]) pr_expected_answers = { # class @@ -1205,24 +1238,105 @@ def test_evaluate_detection_with_label_maps( value, threshold, metric, - ), expected_length in pr_expected_answers.items(): + ), expected_value in pr_expected_answers.items(): assert ( - len(pr_metrics[index]["value"][value][threshold][metric]) - == expected_length + pr_metrics[index]["value"][value][threshold][metric] + == expected_value ) - # spot check a few geojson results - assert ( - pr_metrics[0]["value"]["cat"]["0.1"]["fp"][0][2] - == '{"type":"Polygon","coordinates":[[[10,10],[60,10],[60,40],[10,40],[10,10]]]}' - ) + # check DetailedPrecisionRecallCurve + detailed_pr_expected_answers = { + # class + (0, "cat", "0.1", "fp"): { + "hallucinations": 1, + "misclassifications": 0, + "total": 1, + }, + (0, "cat", "0.4", "fp"): { + "hallucinations": 0, + "misclassifications": 0, + "total": 0, + }, + (0, "british shorthair", "0.1", "fn"): { + "missed_detections": 1, + "misclassifications": 0, + "total": 1, + }, + # class_name + (1, "cat", "0.4", "fp"): { + "hallucinations": 1, + "misclassifications": 0, + "total": 1, + }, + (1, "maine coon cat", "0.1", "fn"): { + "missed_detections": 1, + "misclassifications": 0, + "total": 1, + }, + # k1 + (2, "v1", "0.1", "fn"): { + "missed_detections": 1, + "misclassifications": 0, + "total": 1, + }, + (2, "v1", "0.4", "fn"): { + "missed_detections": 2, + "misclassifications": 0, + "total": 2, + }, + (2, "v1", "0.1", "tp"): {"all": 1, "total": 1}, + # k2 + (3, "v2", "0.1", "fn"): { + "missed_detections": 1, + "misclassifications": 0, + "total": 1, + }, + (3, "v2", "0.1", "fp"): { + "hallucinations": 1, + "misclassifications": 0, + "total": 1, + }, + } + + for ( + index, + value, + threshold, + metric, + ), expected_output in detailed_pr_expected_answers.items(): + model_output = detailed_pr_metrics[index]["value"][value][threshold][ + metric + ] + assert isinstance(model_output, dict) + assert model_output["total"] == expected_output["total"] + assert all( + [ + model_output["observations"][key]["count"] # type: ignore - we know this element is a dict + == expected_output[key] + for key in [ + key + for key in expected_output.keys() + if key not in ["total"] + ] + ] + ) + + # check that we get at most 1 example assert ( - pr_metrics[1]["value"]["maine coon cat"]["0.1"]["fn"][0][2] - == '{"type":"Polygon","coordinates":[[[10,10],[60,10],[60,40],[10,40],[10,10]]]}' + len( + detailed_pr_metrics[0]["value"]["cat"]["0.4"]["fp"]["observations"]["hallucinations"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 0 ) assert ( - pr_metrics[3]["value"]["v2"]["0.1"]["fp"][0][2] - == '{"type":"Polygon","coordinates":[[[15,0],[70,0],[70,20],[15,20],[15,0]]]}' + len( + detailed_pr_metrics[2]["value"]["v1"]["0.4"]["fn"]["observations"]["missed_detections"][ # type: ignore - we know this element is a dict + "examples" + ] + ) + == 1 ) # now, we correct most of the mismatched labels with a label map @@ -1432,7 +1546,11 @@ def test_evaluate_detection_with_label_maps( metrics = eval_job.metrics for m in metrics: - assert m in cat_expected_metrics + if m["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert m in cat_expected_metrics for m in cat_expected_metrics: assert m in metrics @@ -1609,7 +1727,11 @@ def test_evaluate_detection_with_label_maps( metrics = eval_job.metrics for m in metrics: - assert m in foo_expected_metrics + if m["type"] not in [ + "PrecisionRecallCurve", + "DetailedPrecisionRecallCurve", + ]: + assert m in foo_expected_metrics for m in foo_expected_metrics: assert m in metrics @@ -1775,6 +1897,7 @@ def test_evaluate_detection_with_label_maps( "PrecisionRecallCurve", ], ) + assert ( eval_job.ignored_pred_labels is not None and eval_job.missing_pred_labels is not None @@ -1806,16 +1929,19 @@ def test_evaluate_detection_with_label_maps( "PrecisionRecallCurve", ], "pr_curve_iou_threshold": 0.5, + "pr_curve_max_examples": 1, } metrics = eval_job.metrics pr_metrics = [] for m in metrics: - if m["type"] != "PrecisionRecallCurve": - assert m in foo_expected_metrics_with_higher_score_threshold - else: + if m["type"] == "PrecisionRecallCurve": pr_metrics.append(m) + elif m["type"] == "DetailedPrecisionRecallCurve": + continue + else: + assert m in foo_expected_metrics_with_higher_score_threshold for m in foo_expected_metrics_with_higher_score_threshold: assert m in metrics @@ -1843,33 +1969,12 @@ def test_evaluate_detection_with_label_maps( value, threshold, metric, - ), expected_length in pr_expected_answers.items(): + ), expected_value in pr_expected_answers.items(): assert ( - len(pr_metrics[index]["value"][value][threshold][metric]) - == expected_length + pr_metrics[index]["value"][value][threshold][metric] + == expected_value ) - # spot check a few geojson results - pr_metric = [ - m for m in pr_metrics if m["parameters"]["label_key"] == "foo" - ][0] - assert ( - pr_metric["value"]["bar"]["0.4"]["fn"][0][2] - == '{"type":"Polygon","coordinates":[[[10,10],[60,10],[60,40],[10,40],[10,10]]]}' - ) - assert ( - pr_metric["value"]["bar"]["0.4"]["tp"][0][2] - == '{"type":"Polygon","coordinates":[[[15,0],[70,0],[70,20],[15,20],[15,0]]]}' - ) - - pr_metric = [ - m for m in pr_metrics if m["parameters"]["label_key"] == "k2" - ][0] - assert ( - pr_metric["value"]["v2"]["0.1"]["fp"][0][2] - == '{"type":"Polygon","coordinates":[[[15,0],[70,0],[70,20],[15,20],[15,0]]]}' - ) - assert eval_job.parameters.label_map == [ [["class_name", "maine coon cat"], ["foo", "bar"]], [["class", "siamese cat"], ["foo", "bar"]], @@ -2363,3 +2468,379 @@ def test_evaluate_detection_false_negatives_two_images_one_only_with_different_c "value": 0, "label": {"key": "key", "value": "other value"}, } + + +def test_detailed_precision_recall_curve( + db: Session, + model_name, + dataset_name, + img1, + img2, + rect1, + rect2, + rect3, + rect4, + rect5, +): + gts = [ + GroundTruth( + datum=img1, + annotations=[ + Annotation( + is_instance=True, + labels=[Label(key="k1", value="v1")], + bounding_box=Box([rect1]), + ), + Annotation( + is_instance=True, + labels=[Label(key="k1", value="missed_detection")], + bounding_box=Box([rect2]), + ), + Annotation( + is_instance=True, + labels=[Label(key="k1", value="v2")], + bounding_box=Box([rect3]), + ), + ], + ), + GroundTruth( + datum=img2, + annotations=[ + Annotation( + is_instance=True, + labels=[Label(key="k1", value="low_iou")], + bounding_box=Box([rect1]), + ), + ], + ), + ] + + pds = [ + Prediction( + datum=img1, + annotations=[ + Annotation( + is_instance=True, + labels=[Label(key="k1", value="v1", score=0.5)], + bounding_box=Box([rect1]), + ), + Annotation( + is_instance=True, + labels=[Label(key="k1", value="not_v2", score=0.3)], + bounding_box=Box([rect5]), + ), + Annotation( + is_instance=True, + labels=[Label(key="k1", value="hallucination", score=0.1)], + bounding_box=Box([rect4]), + ), + ], + ), + # prediction for img2 has the wrong bounding box, so it should count as a hallucination + Prediction( + datum=img2, + annotations=[ + Annotation( + is_instance=True, + labels=[Label(key="k1", value="low_iou", score=0.5)], + bounding_box=Box([rect2]), + ), + ], + ), + ] + + dataset = Dataset.create(dataset_name) + + for gt in gts: + dataset.add_groundtruth(gt) + + dataset.finalize() + + model = Model.create(model_name) + + for pd in pds: + model.add_prediction(dataset, pd) + + model.finalize_inferences(dataset) + + eval_job = model.evaluate_detection( + dataset, + pr_curve_max_examples=1, + metrics_to_return=[ + "DetailedPrecisionRecallCurve", + ], + ) + eval_job.wait_for_completion(timeout=30) + + # one true positive that becomes a false negative when score > .5 + assert eval_job.metrics[0]["value"]["v1"]["0.3"]["tp"]["total"] == 1 + assert eval_job.metrics[0]["value"]["v1"]["0.55"]["tp"]["total"] == 0 + assert eval_job.metrics[0]["value"]["v1"]["0.55"]["fn"]["total"] == 1 + assert ( + eval_job.metrics[0]["value"]["v1"]["0.55"]["fn"]["observations"][ + "missed_detections" + ]["count"] + == 1 + ) + assert eval_job.metrics[0]["value"]["v1"]["0.05"]["fn"]["total"] == 0 + assert eval_job.metrics[0]["value"]["v1"]["0.05"]["fp"]["total"] == 0 + + # one missed detection that never changes + assert ( + eval_job.metrics[0]["value"]["missed_detection"]["0.05"]["fn"][ + "observations" + ]["missed_detections"]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["missed_detection"]["0.95"]["fn"][ + "observations" + ]["missed_detections"]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["missed_detection"]["0.05"]["tp"]["total"] + == 0 + ) + assert ( + eval_job.metrics[0]["value"]["missed_detection"]["0.05"]["fp"]["total"] + == 0 + ) + + # one fn missed_dection that becomes a misclassification when pr_curve_iou_threshold <= .48 and score threshold <= .3 + assert ( + eval_job.metrics[0]["value"]["v2"]["0.3"]["fn"]["observations"][ + "missed_detections" + ]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["v2"]["0.35"]["fn"]["observations"][ + "missed_detections" + ]["count"] + == 1 + ) + assert eval_job.metrics[0]["value"]["v2"]["0.05"]["tp"]["total"] == 0 + assert eval_job.metrics[0]["value"]["v2"]["0.05"]["fp"]["total"] == 0 + + # one fp hallucination that becomes a misclassification when pr_curve_iou_threshold <= .48 and score threshold <= .3 + assert ( + eval_job.metrics[0]["value"]["not_v2"]["0.05"]["fp"]["observations"][ + "hallucinations" + ]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["not_v2"]["0.05"]["fp"]["observations"][ + "misclassifications" + ]["count"] + == 0 + ) + assert eval_job.metrics[0]["value"]["not_v2"]["0.05"]["tp"]["total"] == 0 + assert eval_job.metrics[0]["value"]["not_v2"]["0.05"]["fn"]["total"] == 0 + + # one fp hallucination that disappears when score threshold >.15 + assert ( + eval_job.metrics[0]["value"]["hallucination"]["0.05"]["fp"][ + "observations" + ]["hallucinations"]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["hallucination"]["0.35"]["fp"][ + "observations" + ]["hallucinations"]["count"] + == 0 + ) + assert ( + eval_job.metrics[0]["value"]["hallucination"]["0.05"]["tp"]["total"] + == 0 + ) + assert ( + eval_job.metrics[0]["value"]["hallucination"]["0.05"]["fn"]["total"] + == 0 + ) + + # one missed detection and one hallucination due to low iou overlap + assert ( + eval_job.metrics[0]["value"]["low_iou"]["0.3"]["fn"]["observations"][ + "missed_detections" + ]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["low_iou"]["0.95"]["fn"]["observations"][ + "missed_detections" + ]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["low_iou"]["0.3"]["fp"]["observations"][ + "hallucinations" + ]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["low_iou"]["0.55"]["fp"]["observations"][ + "hallucinations" + ]["count"] + == 0 + ) + + # repeat tests using a lower IOU threshold + eval_job_low_iou_threshold = model.evaluate_detection( + dataset, + pr_curve_max_examples=1, + metrics_to_return=[ + "DetailedPrecisionRecallCurve", + ], + pr_curve_iou_threshold=0.45, # actual IOU is .481 + ) + eval_job_low_iou_threshold.wait_for_completion(timeout=30) + + # one true positive that becomes a false negative when score > .5 + assert eval_job.metrics[0]["value"]["v1"]["0.3"]["tp"]["total"] == 1 + assert eval_job.metrics[0]["value"]["v1"]["0.55"]["tp"]["total"] == 0 + assert eval_job.metrics[0]["value"]["v1"]["0.55"]["fn"]["total"] == 1 + assert ( + eval_job.metrics[0]["value"]["v1"]["0.55"]["fn"]["observations"][ + "missed_detections" + ]["count"] + == 1 + ) + assert eval_job.metrics[0]["value"]["v1"]["0.05"]["fn"]["total"] == 0 + assert eval_job.metrics[0]["value"]["v1"]["0.05"]["fp"]["total"] == 0 + + # one missed detection that never changes + assert ( + eval_job.metrics[0]["value"]["missed_detection"]["0.05"]["fn"][ + "observations" + ]["missed_detections"]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["missed_detection"]["0.95"]["fn"][ + "observations" + ]["missed_detections"]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["missed_detection"]["0.05"]["tp"]["total"] + == 0 + ) + assert ( + eval_job.metrics[0]["value"]["missed_detection"]["0.05"]["fp"]["total"] + == 0 + ) + + # one fn missed_dection that becomes a misclassification when pr_curve_iou_threshold <= .48 and score threshold <= .3 + assert ( + eval_job_low_iou_threshold.metrics[0]["value"]["v2"]["0.3"]["fn"][ + "observations" + ]["misclassifications"]["count"] + == 1 + ) + assert ( + eval_job_low_iou_threshold.metrics[0]["value"]["v2"]["0.3"]["fn"][ + "observations" + ]["missed_detections"]["count"] + == 0 + ) + assert ( + eval_job_low_iou_threshold.metrics[0]["value"]["v2"]["0.35"]["fn"][ + "observations" + ]["misclassifications"]["count"] + == 0 + ) + assert ( + eval_job_low_iou_threshold.metrics[0]["value"]["v2"]["0.35"]["fn"][ + "observations" + ]["missed_detections"]["count"] + == 1 + ) + assert ( + eval_job_low_iou_threshold.metrics[0]["value"]["v2"]["0.05"]["tp"][ + "total" + ] + == 0 + ) + assert ( + eval_job_low_iou_threshold.metrics[0]["value"]["v2"]["0.05"]["fp"][ + "total" + ] + == 0 + ) + + # one fp hallucination that becomes a misclassification when pr_curve_iou_threshold <= .48 and score threshold <= .3 + assert ( + eval_job_low_iou_threshold.metrics[0]["value"]["not_v2"]["0.05"]["fp"][ + "observations" + ]["hallucinations"]["count"] + == 0 + ) + assert ( + eval_job_low_iou_threshold.metrics[0]["value"]["not_v2"]["0.05"]["fp"][ + "observations" + ]["misclassifications"]["count"] + == 1 + ) + assert ( + eval_job_low_iou_threshold.metrics[0]["value"]["not_v2"]["0.05"]["tp"][ + "total" + ] + == 0 + ) + assert ( + eval_job_low_iou_threshold.metrics[0]["value"]["not_v2"]["0.05"]["fn"][ + "total" + ] + == 0 + ) + + # one fp hallucination that disappears when score threshold >.15 + assert ( + eval_job.metrics[0]["value"]["hallucination"]["0.05"]["fp"][ + "observations" + ]["hallucinations"]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["hallucination"]["0.35"]["fp"][ + "observations" + ]["hallucinations"]["count"] + == 0 + ) + assert ( + eval_job.metrics[0]["value"]["hallucination"]["0.05"]["tp"]["total"] + == 0 + ) + assert ( + eval_job.metrics[0]["value"]["hallucination"]["0.05"]["fn"]["total"] + == 0 + ) + + # one missed detection and one hallucination due to low iou overlap + assert ( + eval_job.metrics[0]["value"]["low_iou"]["0.3"]["fn"]["observations"][ + "missed_detections" + ]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["low_iou"]["0.95"]["fn"]["observations"][ + "missed_detections" + ]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["low_iou"]["0.3"]["fp"]["observations"][ + "hallucinations" + ]["count"] + == 1 + ) + assert ( + eval_job.metrics[0]["value"]["low_iou"]["0.55"]["fp"]["observations"][ + "hallucinations" + ]["count"] + == 0 + ) diff --git a/integration_tests/conftest.py b/integration_tests/conftest.py index 4e53fc47e..0824febe3 100644 --- a/integration_tests/conftest.py +++ b/integration_tests/conftest.py @@ -291,6 +291,18 @@ def rect4() -> list[tuple[float, float]]: ] +@pytest.fixture +def rect5() -> list[tuple[float, float]]: + """Box with partial overlap to rect3.""" + return [ + (87, 10), + (158, 10), + (158, 400), + (87, 400), + (87, 10), + ] + + """GroundTruths""" diff --git a/ts-client/src/ValorClient.ts b/ts-client/src/ValorClient.ts index f80bf8cf7..e9a127ab4 100644 --- a/ts-client/src/ValorClient.ts +++ b/ts-client/src/ValorClient.ts @@ -385,7 +385,8 @@ export class ValorClient { * @param [iouThresholdsToReturn] list of floats describing which Intersection over Union (IoUs) thresholds to calculate a metric for. Must be a subset of `iou_thresholds_to_compute` * @param [labelMap] mapping of individual labels to a grouper label. Useful when you need to evaluate performance using labels that differ across datasets and models * @param [recallScoreThreshold] confidence score threshold for use when determining whether to count a prediction as a true positive or not while calculating Average Recall - * @param [prCurveIouThreshold] the IOU threshold to use when calculating precision-recall curves for object detection tasks. + * @param [prCurveIouThreshold] the IOU threshold to use when calculating precision-recall curves for object detection tasks. Defaults to 0.5. + * @param [prCurveMaxExamples] the maximum number of datum examples to store for each error type when calculating PR curves. * * @returns {Promise} */ @@ -398,7 +399,8 @@ export class ValorClient { iouThresholdsToReturn?: number[], labelMap?: number[][][], recallScoreThreshold?: number, - prCurveIouThreshold?: number + prCurveIouThreshold?: number, + prCurveMaxExamples?: number ): Promise { const response = await this.client.post('/evaluations', { model_names: [model], @@ -410,7 +412,8 @@ export class ValorClient { label_map: labelMap, recall_score_threshold: recallScoreThreshold, metrics_to_return: metrics_to_return, - pr_curve_iou_threshold: prCurveIouThreshold + pr_curve_iou_threshold: prCurveIouThreshold, + pr_curve_max_examples: prCurveMaxExamples }, }); return this.unmarshalEvaluation(response.data[0]); @@ -428,7 +431,9 @@ export class ValorClient { * @param [iouThresholdsToReturn] list of floats describing which Intersection over Union (IoUs) thresholds to calculate a metric for. Must be a subset of `iou_thresholds_to_compute` * @param [labelMap] mapping of individual labels to a grouper label. Useful when you need to evaluate performance using labels that differ across datasets and models * @param [recallScoreThreshold] confidence score threshold for use when determining whether to count a prediction as a true positive or not while calculating Average Recall - * @param [prCurveIouThreshold] the IOU threshold to use when calculating precision-recall curves for object detection tasks. Defaults to 0.5. + * @param [prCurveIouThreshold] the IOU threshold to use when calculating precision-recall curves for object detection tasks. Defaults to 0.5 + * @param [prCurveMaxExamples] the maximum number of datum examples to store for each error type when calculating PR curves. + * * @returns {Promise} */ @@ -441,7 +446,8 @@ export class ValorClient { iouThresholdsToReturn?: number[], labelMap?: any[][][], recallScoreThreshold?: number, - prCurveIouThreshold?: number + prCurveIouThreshold?: number, + prCurveMaxExamples?: number ): Promise { const response = await this.client.post('/evaluations', { model_names: models, @@ -453,7 +459,8 @@ export class ValorClient { iou_thresholds_to_return: iouThresholdsToReturn, label_map: labelMap, recall_score_threshold: recallScoreThreshold, - pr_curve_iou_threshold: prCurveIouThreshold + pr_curve_iou_threshold: prCurveIouThreshold, + pr_curve_max_examples: prCurveMaxExamples }, }); return response.data.map(this.unmarshalEvaluation); diff --git a/ts-client/tests/ValorClient.test.ts b/ts-client/tests/ValorClient.test.ts index db330640c..160a47951 100644 --- a/ts-client/tests/ValorClient.test.ts +++ b/ts-client/tests/ValorClient.test.ts @@ -355,7 +355,6 @@ test('bulk create or get evaluations', async () => { // bulk create evaluations for each dataset for (const datasetName of datasetNames) { await client.finalizeDataset(datasetName); - let evaluations = await client.bulkCreateOrGetEvaluations( modelNames, datasetName, @@ -363,7 +362,6 @@ test('bulk create or get evaluations', async () => { ); expect(evaluations.length).toBe(2); // check all evaluations are pending - while (evaluations.every((evaluation) => evaluation.status !== 'done')) { await new Promise((resolve) => setTimeout(resolve, 1000)); evaluations = await client.getEvaluationsByIds(