From 0f6942612ae77f4b55237f1ae54148154e26de77 Mon Sep 17 00:00:00 2001 From: dafnapension Date: Thu, 14 Mar 2024 00:09:15 +0200 Subject: [PATCH] fixed returned metric name Signed-off-by: dafnapension --- src/unitxt/standard_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/unitxt/standard_metrics.py b/src/unitxt/standard_metrics.py index ea678e038..6dc0fcfca 100644 --- a/src/unitxt/standard_metrics.py +++ b/src/unitxt/standard_metrics.py @@ -270,7 +270,7 @@ def compute_f1_macro_multi_label_from_confusion_matrix(self) -> Any: # e.g. from here: https://medium.com/synthesio-engineering/precision-accuracy-and-f1-score-for-multi-label-classification-34ac6bdfb404 # report only for the classes seen as references if len(self.references_seen_thus_far) == 0: - return {"f1_macro_multi_label": np.nan} + return {self.metric_name[8:]: np.nan} to_ret = {} for c in self.references_seen_thus_far: # report only on them num_as_pred = self.tp[c] + self.fp[c] @@ -289,7 +289,7 @@ def compute_f1_macro_multi_label_from_confusion_matrix(self) -> Any: avg_across_classes = ( sum(val for val in to_ret.values() if not np.isnan(val)) ) / len(self.references_seen_thus_far) - to_ret["f1_macro"] = round(avg_across_classes, 2) + to_ret[self.metric_name[8:]] = round(avg_across_classes, 2) return to_ret