Skip to content

Commit

Permalink
Allow disabling per group f1 scores in customF1 (#719)
Browse files Browse the repository at this point in the history
Signed-off-by: Yoav Katz <katz@il.ibm.com>
Co-authored-by: Elron Bandel <elronbandel@gmail.com>
  • Loading branch information
yoavkatz and elronbandel committed Mar 28, 2024
1 parent e352dd4 commit 74899e4
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 15 deletions.
67 changes: 67 additions & 0 deletions prepare/metrics/custom_f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,4 +358,71 @@
global_target=global_target,
)


class NERWithoutClassReporting(NER):
report_per_group_scores = False


metric_without_class_reporting = NERWithoutClassReporting()
# 1.4 multi classes multi examples
predictions = [
[
("Dalia", "Person"),
("Amir", "Person"),
("Yaron", "Person"),
("Ramat-Gan", "Location"),
("Ramat-Gan", "Location"),
("IBM", "Org"),
("CIA", "Org"),
("FBI", "Org"),
]
]
references = [
[
[
("Amir", "Person"),
("Yaron", "Person"),
("Dalia", "Person"),
("Naftali", "Person"),
("Ramat-Gan", "Location"),
("Givataaim", "Location"),
]
]
]
# Person: Precision = 3/3, Recall = 3/4, F1 = 2 * 1 * 0.75 / (1 + 0.75) = 0.8571
# Location: Precision = 1/2, Recall = 1/2, F1 = 0.5
# Org (OOD): Precision = 0/3, Recall = 0/0 = 1(!), F1 = 0
instance_targets = [
{
"recall_micro": 0.67,
"recall_macro": 0.62,
"precision_micro": 0.5,
"precision_macro": 0.75, # Only on indomain classes
"f1_macro": 0.68,
"in_classes_support": 0.62,
"f1_micro": 0.57,
"score": 0.57,
"score_name": "f1_micro",
},
]
global_target = {
"recall_micro": 0.67,
"recall_macro": 0.62,
"precision_micro": 0.5,
"precision_macro": 0.75,
"f1_macro": 0.68,
"in_classes_support": 0.62,
"f1_micro": 0.57,
"score": 0.57,
"score_name": "f1_micro",
}

outputs = test_metric(
metric=metric_without_class_reporting,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

add_to_catalog(metric, "metrics.ner", overwrite=True)
46 changes: 31 additions & 15 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,7 +1606,8 @@ class CustomF1(GlobalMetric):
prediction_type = "Any"
single_reference_per_prediction = True
groups = None
zero_division = 0.0
zero_division: float = 0.0
report_per_group_scores: bool = True

@abstractmethod
def get_element_group(self, element, additional_input):
Expand Down Expand Up @@ -1737,6 +1738,35 @@ def compute(
num_of_unknown_class_predictions += pd

result = f1_result
self.add_macro_scores(f1_result, recall_result, precision_result, result)
self.add_in_class_support_scores(
num_of_unknown_class_predictions, pd_total, result
)
self.add_micro_scores(rd_total, rn_total, pd_total, pn_total, result)
if not self.report_per_group_scores:
for group in groups:
del result[f"f1_{group}"]
return result

def add_micro_scores(self, rd_total, rn_total, pd_total, pn_total, result):
result["f1_micro"] = self.f1(pn_total, pd_total, rn_total, rd_total)
result["recall_micro"] = self.recall(pn_total, pd_total, rn_total, rd_total)
result["precision_micro"] = self.precision(
pn_total, pd_total, rn_total, rd_total
)

def add_in_class_support_scores(
self, num_of_unknown_class_predictions, pd_total, result
):
amount_of_predictions = pd_total
if amount_of_predictions == 0:
result["in_classes_support"] = 1.0
else:
result["in_classes_support"] = (
1.0 - num_of_unknown_class_predictions / amount_of_predictions
)

def add_macro_scores(self, f1_result, recall_result, precision_result, result):
try:
result["f1_macro"] = sum(f1_result.values()) / len(result.keys())
result["recall_macro"] = sum(recall_result.values()) / len(
Expand All @@ -1750,20 +1780,6 @@ def compute(
result["recall_macro"] = self.zero_division
result["precision_macro"] = self.zero_division

amount_of_predictions = pd_total
if amount_of_predictions == 0:
result["in_classes_support"] = 1.0
else:
result["in_classes_support"] = (
1.0 - num_of_unknown_class_predictions / amount_of_predictions
)
result["f1_micro"] = self.f1(pn_total, pd_total, rn_total, rd_total)
result["recall_micro"] = self.recall(pn_total, pd_total, rn_total, rd_total)
result["precision_micro"] = self.precision(
pn_total, pd_total, rn_total, rd_total
)
return result


class NER(CustomF1):
prediction_type = "List[Tuple[str,str]]"
Expand Down
36 changes: 36 additions & 0 deletions tests/library/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from src.unitxt.logging_utils import get_logger
from src.unitxt.metrics import (
NER,
Accuracy,
BinaryAccuracy,
BinaryMaxAccuracy,
Expand Down Expand Up @@ -718,6 +719,41 @@ def test_normalized_sacrebleu(self):
global_target = 1.0
self.assertAlmostEqual(global_target, outputs[0]["score"]["global"]["score"])

def test_ner(self):
metric = NER()
predictions = [
[
("Dalia", "Person"),
("Ramat-Gan", "Location"),
("IBM", "Org"),
]
]
references = [
[
[
("Dalia", "Person"),
("Givataaim", "Location"),
]
]
]
outputs = apply_metric(
metric=metric, predictions=predictions, references=references
)
global_target = 1.0
self.assertAlmostEqual(
global_target, outputs[0]["score"]["global"]["f1_Person"]
)
global_target = 0.0
self.assertAlmostEqual(
global_target, outputs[0]["score"]["global"]["f1_Location"]
)
metric.report_per_group_scores = False
outputs = apply_metric(
metric=metric, predictions=predictions, references=references
)
self.assertTrue("f1_Person" not in outputs[0]["score"]["global"])
self.assertTrue("f1_Location" not in outputs[0]["score"]["global"])

def test_llama_index_correctness(self):
metric = LlamaIndexCorrectness(model_name="mock")
predictions = ["1976"]
Expand Down

0 comments on commit 74899e4

Please sign in to comment.