From 2f2ee2edfca6988c7ec9c48aa1a1a15b0793c39d Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 3 Jun 2022 17:39:41 +0200 Subject: [PATCH] fix(#1527): check agents instead labels for `predicted` computation (#1528) (cherry picked from commit 2d0612d4aa602c2962703349edc25386e53b5606) --- .../tasks/text_classification/api/model.py | 2 +- .../server/text_classification/test_model.py | 31 +++++++++++++++++++ .../server/token_classification/test_model.py | 20 ++++++++++-- 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/src/rubrix/server/tasks/text_classification/api/model.py b/src/rubrix/server/tasks/text_classification/api/model.py index 739b2cf244..e5a94930ab 100644 --- a/src/rubrix/server/tasks/text_classification/api/model.py +++ b/src/rubrix/server/tasks/text_classification/api/model.py @@ -299,7 +299,7 @@ def task(cls) -> TaskType: @property def predicted(self) -> Optional[PredictionStatus]: - if self.predicted_as and self.annotated_as: + if self.predicted_by and self.annotated_by: return ( PredictionStatus.OK if set(self.predicted_as) == set(self.annotated_as) diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py index c4a724879f..0fdb3085c1 100644 --- a/tests/server/text_classification/test_model.py +++ b/tests/server/text_classification/test_model.py @@ -352,3 +352,34 @@ def test_query_with_uncovered_by_rules(): }, } } + + +def test_empty_labels_for_no_multilabel(): + with pytest.raises( + ValidationError, + match="Single label record must include only one annotation label", + ): + TextClassificationRecord( + inputs={"text": "The input text"}, + annotation=TextClassificationAnnotation(agent="ann.", labels=[]), + ) + + record = TextClassificationRecord( + inputs={"text": "The input text"}, + prediction=TextClassificationAnnotation(agent="ann.", labels=[]), + annotation=TextClassificationAnnotation( + agent="ann.", labels=[ClassPrediction(class_label="B")] + ), + ) + assert record.predicted == PredictionStatus.KO + + +def test_annotated_without_labels_for_multilabel(): + record = TextClassificationRecord( + inputs={"text": "The input text"}, + multi_label=True, + prediction=TextClassificationAnnotation(agent="pred.", labels=[]), + annotation=TextClassificationAnnotation(agent="ann.", labels=[]), + ) + + assert record.predicted == PredictionStatus.OK diff --git a/tests/server/token_classification/test_model.py b/tests/server/token_classification/test_model.py index 19202cf865..a8db30bc16 100644 --- a/tests/server/token_classification/test_model.py +++ b/tests/server/token_classification/test_model.py @@ -17,8 +17,8 @@ from pydantic import ValidationError from rubrix._constants import MAX_KEYWORD_LENGTH -from rubrix.server.tasks.search.query_builder import EsQueryBuilder -from rubrix.server.tasks.token_classification.api.model import ( +from rubrix.server.apis.v0.models.commons.model import PredictionStatus +from rubrix.server.apis.v0.models.token_classification import ( EntitySpan, TokenClassificationAnnotation, TokenClassificationQuery, @@ -220,3 +220,19 @@ def test_record_scores(): ), ) assert record.scores == [0.8, 0.1, 0.2] + + +def test_annotated_without_entities(): + text = "The text that i wrote" + record = TokenClassificationRecord( + text=text, + tokens=text.split(), + prediction=TokenClassificationAnnotation( + agent="pred.test", entities=[EntitySpan(start=0, end=3, label="DET")] + ), + annotation=TokenClassificationAnnotation(agent="test", entities=[]), + ) + + assert record.annotated_by == [record.annotation.agent] + assert record.predicted_by == [record.prediction.agent] + assert record.predicted == PredictionStatus.KO