-
Notifications
You must be signed in to change notification settings - Fork 316
/
metrics.py
133 lines (112 loc) · 4.42 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import Any, ClassVar, Dict, Iterable, List
from pydantic import Field
from sklearn.metrics import precision_recall_fscore_support
from sklearn.preprocessing import MultiLabelBinarizer
from rubrix.server.tasks.commons.metrics import CommonTasksMetrics, GenericRecord
from rubrix.server.tasks.commons.metrics.model.base import (
BaseMetric,
PythonMetric,
TermsAggregation,
)
from rubrix.server.tasks.text_classification.api.model import TextClassificationRecord
class F1Metric(PythonMetric):
"""
A basic f1 computation for text classification
Attributes:
-----------
multi_label:
If True, F1 will be calculated assuming multi class task. Default False
"""
multi_label: bool = False
def apply(self, records: Iterable[TextClassificationRecord]) -> Any:
filtered_records = list(filter(lambda r: r.predicted is not None, records))
# TODO: This must be precomputed with using a global dataset metric
ds_labels = {
label for record in filtered_records for label in record.annotated_as
}
if not len(ds_labels):
return {}
labels_mapping = {label: i for i, label in enumerate(ds_labels)}
y_true, y_pred = ([], [])
for record in filtered_records:
annotations = record.predicted_as
predictions = record.annotated_as
if not self.multi_label:
y_true.append(labels_mapping[annotations[0]])
y_pred.append(labels_mapping[predictions[0]])
else:
y_true.append([labels_mapping[label] for label in annotations])
y_pred.append([labels_mapping[label] for label in predictions])
if self.multi_label:
mlb = MultiLabelBinarizer(classes=list(labels_mapping.values()))
y_true = mlb.fit_transform(y_true)
y_pred = mlb.fit_transform(y_pred)
micro_p, micro_r, micro_f, _ = precision_recall_fscore_support(
y_true=y_true, y_pred=y_pred, average="micro"
)
macro_p, macro_r, macro_f, _ = precision_recall_fscore_support(
y_true=y_true, y_pred=y_pred, average="macro"
)
per_label = {}
for label, p, r, f, _ in zip(
labels_mapping.keys(),
*precision_recall_fscore_support(
y_true=y_true,
y_pred=y_pred,
labels=list(labels_mapping.values()),
average=None,
),
):
per_label.update(
{f"{label}_precision": p, f"{label}_recall": r, f"{label}_f1": f}
)
return {
"precision_macro": macro_p,
"recall_macro": macro_r,
"f1_macro": macro_f,
"precision_micro": micro_p,
"recall_micro": micro_r,
"f1_micro": micro_f,
**per_label,
}
class DatasetLabels(PythonMetric):
id: str = Field("dataset_labels", const=True)
name: str = Field("The dataset labels", const=True)
def apply(self, records: Iterable[TextClassificationRecord]) -> Dict[str, Any]:
ds_labels = set()
for record in records:
if record.annotation:
ds_labels.update(
[label.class_label for label in record.annotation.labels]
)
if record.prediction:
ds_labels.update(
[label.class_label for label in record.prediction.labels]
)
return {"labels": ds_labels or []}
class TextClassificationMetrics(CommonTasksMetrics[TextClassificationRecord]):
"""Configured metrics for text classification task"""
metrics: ClassVar[List[BaseMetric]] = CommonTasksMetrics.metrics + [
TermsAggregation(
id="predicted_as",
name="Predicted labels distribution",
field="predicted_as",
),
TermsAggregation(
id="annotated_as",
name="Annotated labels distribution",
field="annotated_as",
),
F1Metric(
id="F1",
name="F1 Metrics for single-label",
description="F1 Metrics for single-label (averaged and per label)",
),
F1Metric(
id="MultiLabelF1",
name="F1 Metrics for multi-label",
description="F1 Metrics for multi-label (averaged and per label)",
multi_label=True,
),
DatasetLabels(),
]