Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standard metrics #658

Closed
wants to merge 11 commits into from
55 changes: 55 additions & 0 deletions prepare/metrics/standard_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from src.unitxt.catalog import add_to_catalog
from src.unitxt.standard_metrics import (
StandardAccuracy,
StandardAccuracyMultiLabel,
StandardF1Macro,
StandardF1MacroMultiLabel,
StandardF1Micro,
StandardF1MicroMultiLabel,
StandardMatthewsCorrelation,
)

standard_accuracy = StandardAccuracy(metric_name="standard_metrics.accuracy")
add_to_catalog(standard_accuracy, "standard_metrics.accuracy", overwrite=True)

standard_accuracy_multi_label = StandardAccuracyMultiLabel(
metric_name="standard_metrics.accuracy_multi_label"
)
add_to_catalog(
standard_accuracy_multi_label,
"standard_metrics.accuracy_multi_label",
overwrite=True,
)

standard_f1_macro = StandardF1Macro(metric_name="standard_metrics.f1_macro")
add_to_catalog(standard_f1_macro, "standard_metrics.f1_macro", overwrite=True)

standard_f1_micro = StandardF1Micro(metric_name="standard_metrics.f1_micro")
add_to_catalog(standard_f1_micro, "standard_metrics.f1_micro", overwrite=True)

standard_f1_macro_multi_label = StandardF1MacroMultiLabel(
metric_name="standard_metrics.f1_macro_multi_label"
)
add_to_catalog(
standard_f1_macro_multi_label,
"standard_metrics.f1_macro_multi_label",
overwrite=True,
)

standard_f1_micro_multi_label = StandardF1MicroMultiLabel(
metric_name="standard_metrics.f1_micro_multi_label"
)
add_to_catalog(
standard_f1_micro_multi_label,
"standard_metrics.f1_micro_multi_label",
overwrite=True,
)

standard_matthews_correlation = StandardMatthewsCorrelation(
metric_name="standard_metrics.matthews_correlation"
)
add_to_catalog(
standard_f1_micro_multi_label,
"standard_metrics.matthews_correlation",
overwrite=True,
)
1 change: 1 addition & 0 deletions requirements/base.rqr
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ mecab-python3
absl-py
dpath
ipadic
llama-index-core
scipy
4 changes: 4 additions & 0 deletions src/unitxt/catalog/standard_metrics/accuracy.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "standard_accuracy",
"metric_name": "standard_metrics.accuracy"
}
4 changes: 4 additions & 0 deletions src/unitxt/catalog/standard_metrics/accuracy_multi_label.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "standard_accuracy_multi_label",
"metric_name": "standard_metrics.accuracy_multi_label"
}
4 changes: 4 additions & 0 deletions src/unitxt/catalog/standard_metrics/f1_macro.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "standard_f1_macro",
"metric_name": "standard_metrics.f1_macro"
}
4 changes: 4 additions & 0 deletions src/unitxt/catalog/standard_metrics/f1_macro_multi_label.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "standard_f1_macro_multi_label",
"metric_name": "standard_metrics.f1_macro_multi_label"
}
4 changes: 4 additions & 0 deletions src/unitxt/catalog/standard_metrics/f1_micro.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "standard_f1_micro",
"metric_name": "standard_metrics.f1_micro"
}
4 changes: 4 additions & 0 deletions src/unitxt/catalog/standard_metrics/f1_micro_multi_label.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "standard_f1_micro_multi_label",
"metric_name": "standard_metrics.f1_micro_multi_label"
}
4 changes: 4 additions & 0 deletions src/unitxt/catalog/standard_metrics/matthews_correlation.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "standard_f1_micro_multi_label",
"metric_name": "standard_metrics.f1_micro_multi_label"
}
1 change: 1 addition & 0 deletions src/unitxt/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .split_utils import __file__ as _
from .splitters import __file__ as _
from .standard import __file__ as _
from .standard_metrics import __file__ as _
from .stream import __file__ as _
from .struct_data_operators import __file__ as _
from .system_prompts import __file__ as _
Expand Down
1 change: 1 addition & 0 deletions src/unitxt/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .split_utils import __file__ as _
from .splitters import __file__ as _
from .standard import __file__ as _
from .standard_metrics import __file__ as _
from .stream import __file__ as _
from .struct_data_operators import __file__ as _
from .system_prompts import __file__ as _
Expand Down
35 changes: 27 additions & 8 deletions src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@
operator_names = [operator_names]
# otherwise , operator_names is already a list

# we now have a list of nanes of operators, each is equipped with process_instance method.
# we now have a list of names of operators, each is equipped with process_instance method.
operator = SequentialOperator(steps=operator_names)
return operator.process_instance(instance)

Expand Down Expand Up @@ -1596,6 +1596,17 @@
metric_field: str
calc_confidence_intervals: bool

def prepare(self):
super().prepare()
self.metrics_for_thin_evaluation = [
"metrics.accuracy",
"metrics.f1_macro",
"metrics.f1_micro",
"metrics.f1_micro_multi_label",
"metrics.f1_macro_multi_label",
"metrics.matthews_correlation",
]

def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
from .metrics import Metric

Expand All @@ -1621,28 +1632,36 @@
# Here we keep all the fields besides the score, and restore them after the metric finishes.
first_instance = stream.peek()
keys_to_restore = set(first_instance.keys()).difference({"score"})

multi_stream = MultiStream({"tmp": stream})
multi_stream = CopyFields(
field_to_field={k: f"{k}_orig" for k in keys_to_restore}
)(multi_stream)

for metric_name in metric_names:
metric = self.get_artifact(metric_name)
assert isinstance(
metric, Metric
), f"Operator {metric_name} must be a Metric"
if metric_name in self.metrics_for_thin_evaluation:
if "accuracy" in metric_name and isoftype(
first_instance["prediction"], list
):
metric, _ = fetch_artifact("standard_metrics.accuracy_multi_label")

Check warning on line 1646 in src/unitxt/operators.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/operators.py#L1646

Added line #L1646 was not covered by tests
else:
metric, _ = fetch_artifact(f"standard_{metric_name}")

else:
metric = self.get_artifact(metric_name)
assert isinstance(
metric, Metric
), f"Operator {metric_name} must be a Metric"

if not self.calc_confidence_intervals:
metric.disable_confidence_interval_calculation()

multi_stream = metric(multi_stream)
multi_stream = CopyFields(
field_to_field={f"{k}_orig": k for k in keys_to_restore}
)(multi_stream)

multi_stream = RemoveFields(fields=[f"{k}_orig" for k in keys_to_restore])(
multi_stream
)

stream = multi_stream["tmp"]
yield from stream

Expand Down
Loading
Loading