Skip to content

Commit

Permalink
fixed returned metric name
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed Mar 13, 2024
1 parent d90273e commit 0f69426
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/unitxt/standard_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def compute_f1_macro_multi_label_from_confusion_matrix(self) -> Any:
# e.g. from here: https://medium.com/synthesio-engineering/precision-accuracy-and-f1-score-for-multi-label-classification-34ac6bdfb404
# report only for the classes seen as references
if len(self.references_seen_thus_far) == 0:
return {"f1_macro_multi_label": np.nan}
return {self.metric_name[8:]: np.nan}
to_ret = {}
for c in self.references_seen_thus_far: # report only on them
num_as_pred = self.tp[c] + self.fp[c]
Expand All @@ -289,7 +289,7 @@ def compute_f1_macro_multi_label_from_confusion_matrix(self) -> Any:
avg_across_classes = (

Check warning on line 289 in src/unitxt/standard_metrics.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/standard_metrics.py#L288-L289

Added lines #L288 - L289 were not covered by tests
sum(val for val in to_ret.values() if not np.isnan(val))
) / len(self.references_seen_thus_far)
to_ret["f1_macro"] = round(avg_across_classes, 2)
to_ret[self.metric_name[8:]] = round(avg_across_classes, 2)
return to_ret

Check warning on line 293 in src/unitxt/standard_metrics.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/standard_metrics.py#L292-L293

Added lines #L292 - L293 were not covered by tests


Expand Down

0 comments on commit 0f69426

Please sign in to comment.