diff --git a/src/rubrix/server/tasks/text_classification/service/service.py b/src/rubrix/server/tasks/text_classification/service/service.py index 319ec7a78c..7d16e67779 100644 --- a/src/rubrix/server/tasks/text_classification/service/service.py +++ b/src/rubrix/server/tasks/text_classification/service/service.py @@ -250,6 +250,7 @@ def add_labeling_rule( rule: The rule """ + self.__normalized_rule__(rule) self.__labeling__.add_rule(dataset, rule) def update_labeling_rule( @@ -278,6 +279,8 @@ def update_labeling_rule( found_rule.label = labels[0] if len(labels) == 1 else None if description is not None: found_rule.description = description + + self.__normalized_rule__(found_rule) self.__labeling__.replace_rule(dataset, found_rule) return found_rule @@ -387,3 +390,12 @@ def compute_overall_rules_metrics(self, dataset: TextClassificationDatasetDB): total_records=total, annotated_records=annotated, ) + + @staticmethod + def __normalized_rule__(rule: LabelingRule) -> LabelingRule: + if rule.labels and len(rule.labels) == 1: + rule.label = rule.labels[0] + elif rule.label and not rule.labels: + rule.labels = [rule.label] + + return rule