Skip to content

Commit

Permalink
Simplify PrecisionRecallCurve and add `DetailedPrecisionRecallCurve…
Browse files Browse the repository at this point in the history
…` for advanced debugging (#584)
  • Loading branch information
ntlind committed Jun 11, 2024
1 parent 79928c3 commit c90e1af
Show file tree
Hide file tree
Showing 25 changed files with 3,750 additions and 780 deletions.
2 changes: 1 addition & 1 deletion api/tests/functional-tests/backend/core/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def test_create_evaluation(
assert (
rows[0].parameters
== schemas.EvaluationParameters(
task_type=enums.TaskType.CLASSIFICATION
task_type=enums.TaskType.CLASSIFICATION,
).model_dump()
)

Expand Down
207 changes: 202 additions & 5 deletions api/tests/functional-tests/backend/metrics/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,9 +683,10 @@ def test_compute_classification(

confusion, metrics = _compute_clf_metrics(
db,
model_filter,
datum_filter,
prediction_filter=model_filter,
groundtruth_filter=datum_filter,
label_map=None,
pr_curve_max_examples=0,
metrics_to_return=[
"Precision",
"Recall",
Expand Down Expand Up @@ -928,8 +929,14 @@ def test__compute_curves(
grouper_key="animal",
grouper_mappings=grouper_mappings,
unique_datums=unique_datums,
pr_curve_max_examples=1,
metrics_to_return=[
"PrecisionRecallCurve",
"DetailedPrecisionRecallCurve",
],
)

# check PrecisionRecallCurve
pr_expected_answers = {
# bird
("bird", 0.05, "tp"): 3,
Expand Down Expand Up @@ -973,6 +980,196 @@ def test__compute_curves(
threshold,
metric,
), expected_length in pr_expected_answers.items():
list_of_datums = curves[value][threshold][metric]
assert isinstance(list_of_datums, list)
assert len(list_of_datums) == expected_length
classification = curves[0].value[value][threshold][metric]
assert classification == expected_length

# check DetailedPrecisionRecallCurve
detailed_pr_expected_answers = {
# bird
("bird", 0.05, "tp"): {"all": 3, "total": 3},
("bird", 0.05, "fp"): {
"hallucinations": 0,
"misclassifications": 1,
"total": 1,
},
("bird", 0.05, "tn"): {"all": 2, "total": 2},
("bird", 0.05, "fn"): {
"missed_detections": 0,
"misclassifications": 0,
"total": 0,
},
# dog
("dog", 0.05, "tp"): {"all": 2, "total": 2},
("dog", 0.05, "fp"): {
"hallucinations": 0,
"misclassifications": 3,
"total": 3,
},
("dog", 0.05, "tn"): {"all": 1, "total": 1},
("dog", 0.8, "fn"): {
"missed_detections": 1,
"misclassifications": 1,
"total": 2,
},
# cat
("cat", 0.05, "tp"): {"all": 1, "total": 1},
("cat", 0.05, "fp"): {
"hallucinations": 0,
"misclassifications": 5,
"total": 5,
},
("cat", 0.05, "tn"): {"all": 0, "total": 0},
("cat", 0.8, "fn"): {
"missed_detections": 0,
"misclassifications": 0,
"total": 0,
},
}

for (
value,
threshold,
metric,
), expected_output in detailed_pr_expected_answers.items():
model_output = curves[1].value[value][threshold][metric]
assert isinstance(model_output, dict)
assert model_output["total"] == expected_output["total"]
assert all(
[
model_output["observations"][key]["count"] # type: ignore - we know this element is a dict
== expected_output[key]
for key in [
key
for key in expected_output.keys()
if key not in ["total"]
]
]
)

# spot check number of examples
assert (
len(
curves[1].value["bird"][0.05]["tp"]["observations"]["all"][ # type: ignore - we know this element is a dict
"examples"
]
)
== 1
)
assert (
len(
curves[1].value["bird"][0.05]["tn"]["observations"]["all"][ # type: ignore - we know this element is a dict
"examples"
]
)
== 1
)

# repeat the above, but with a higher pr_max_curves_example
curves = _compute_curves(
db=db,
predictions=predictions,
groundtruths=groundtruths,
grouper_key="animal",
grouper_mappings=grouper_mappings,
unique_datums=unique_datums,
pr_curve_max_examples=3,
metrics_to_return=[
"PrecisionRecallCurve",
"DetailedPrecisionRecallCurve",
],
)

# these outputs shouldn't have changed
for (
value,
threshold,
metric,
), expected_output in detailed_pr_expected_answers.items():
model_output = curves[1].value[value][threshold][metric]
assert isinstance(model_output, dict)
assert model_output["total"] == expected_output["total"]
assert all(
[
model_output["observations"][key]["count"] # type: ignore - we know this element is a dict
== expected_output[key]
for key in [
key
for key in expected_output.keys()
if key not in ["total"]
]
]
)

assert (
len(
curves[1].value["bird"][0.05]["tp"]["observations"]["all"][ # type: ignore - we know this element is a dict
"examples"
]
)
== 3
)
assert (
len(
(
curves[1].value["bird"][0.05]["tn"]["observations"]["all"][ # type: ignore - we know this element is a dict
"examples"
]
)
)
== 2 # only two examples exist
)

# test behavior if pr_curve_max_examples == 0
curves = _compute_curves(
db=db,
predictions=predictions,
groundtruths=groundtruths,
grouper_key="animal",
grouper_mappings=grouper_mappings,
unique_datums=unique_datums,
pr_curve_max_examples=0,
metrics_to_return=[
"PrecisionRecallCurve",
"DetailedPrecisionRecallCurve",
],
)

# these outputs shouldn't have changed
for (
value,
threshold,
metric,
), expected_output in detailed_pr_expected_answers.items():
model_output = curves[1].value[value][threshold][metric]
assert isinstance(model_output, dict)
assert model_output["total"] == expected_output["total"]
assert all(
[
model_output["observations"][key]["count"] # type: ignore - we know this element is a dict
== expected_output[key]
for key in [
key
for key in expected_output.keys()
if key not in ["total"]
]
]
)

assert (
len(
curves[1].value["bird"][0.05]["tp"]["observations"]["all"][ # type: ignore - we know this element is a dict
"examples"
]
)
== 0
)
assert (
len(
(
curves[1].value["bird"][0.05]["tn"]["observations"]["all"][ # type: ignore - we know this element is a dict
"examples"
]
)
)
== 0
)
Loading

0 comments on commit c90e1af

Please sign in to comment.