Skip to content

Commit

Permalink
chore: reuse label2index map
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Jun 2, 2022
1 parent fb592eb commit 7f78c28
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/rubrix/server/apis/v0/models/metrics/text_classification.py
Expand Up @@ -37,21 +37,22 @@ def apply(self, records: Iterable[TextClassificationRecord]) -> Any:
if not len(ds_labels):
return {}

labels_mapping = {label: i for i, label in enumerate(ds_labels)}
y_true, y_pred = ([], [])
for record in filtered_records:
annotations = record.annotated_as
predictions = record.predicted_as

if not self.multi_label:
y_true.append(annotations[0])
y_pred.append(predictions[0])
y_true.append(labels_mapping[annotations[0]])
y_pred.append(labels_mapping[predictions[0]])

else:
y_true.append(annotations)
y_pred.append(predictions)
y_true.append([labels_mapping[label] for label in annotations])
y_pred.append([labels_mapping[label] for label in predictions])

if self.multi_label:
mlb = MultiLabelBinarizer(classes=[label for label in ds_labels])
mlb = MultiLabelBinarizer(classes=list(labels_mapping.values()))
y_true = mlb.fit_transform(y_true)
y_pred = mlb.fit_transform(y_pred)

Expand All @@ -64,7 +65,7 @@ def apply(self, records: Iterable[TextClassificationRecord]) -> Any:

per_label = {}
for label, p, r, f, s in zip(
ds_labels,
labels_mapping.keys(),
*precision_recall_fscore_support(
y_true=y_true,
y_pred=y_pred,
Expand Down

0 comments on commit 7f78c28

Please sign in to comment.