From 6aa109e1b76195d2385386201f91800a4e885dd9 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 31 May 2022 13:31:00 +0200 Subject: [PATCH 1/3] fix: check agents instead labels --- src/rubrix/server/apis/v0/models/text_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rubrix/server/apis/v0/models/text_classification.py b/src/rubrix/server/apis/v0/models/text_classification.py index 1030e90d42..2d65bde003 100644 --- a/src/rubrix/server/apis/v0/models/text_classification.py +++ b/src/rubrix/server/apis/v0/models/text_classification.py @@ -297,7 +297,7 @@ def task(cls) -> TaskType: @property def predicted(self) -> Optional[PredictionStatus]: - if self.predicted_as and self.annotated_as: + if self.predicted_by and self.predicted_by: return ( PredictionStatus.OK if set(self.predicted_as) == set(self.annotated_as) From f15f9bb4a83be29e08dcb1cf0814924c1901c78f Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 31 May 2022 13:31:11 +0200 Subject: [PATCH 2/3] test: add missing tests --- .../server/text_classification/test_model.py | 31 +++++++++++++++++++ .../server/token_classification/test_model.py | 17 ++++++++++ 2 files changed, 48 insertions(+) diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py index 9a74df959f..2d727dac98 100644 --- a/tests/server/text_classification/test_model.py +++ b/tests/server/text_classification/test_model.py @@ -337,3 +337,34 @@ def test_query_with_uncovered_by_rules(): }, } } + + +def test_empty_labels_for_no_multilabel(): + with pytest.raises( + ValidationError, + match="Single label record must include only one annotation label", + ): + TextClassificationRecord( + inputs={"text": "The input text"}, + annotation=TextClassificationAnnotation(agent="ann.", labels=[]), + ) + + record = TextClassificationRecord( + inputs={"text": "The input text"}, + prediction=TextClassificationAnnotation(agent="ann.", labels=[]), + annotation=TextClassificationAnnotation( + agent="ann.", labels=[ClassPrediction(class_label="B")] + ), + ) + assert record.predicted == PredictionStatus.KO + + +def test_annotated_without_labels_for_multilabel(): + record = TextClassificationRecord( + inputs={"text": "The input text"}, + multi_label=True, + prediction=TextClassificationAnnotation(agent="pred.", labels=[]), + annotation=TextClassificationAnnotation(agent="ann.", labels=[]), + ) + + assert record.predicted == PredictionStatus.OK diff --git a/tests/server/token_classification/test_model.py b/tests/server/token_classification/test_model.py index 9987f10b0d..8b02d40a15 100644 --- a/tests/server/token_classification/test_model.py +++ b/tests/server/token_classification/test_model.py @@ -17,6 +17,7 @@ from pydantic import ValidationError from rubrix._constants import MAX_KEYWORD_LENGTH +from rubrix.server.apis.v0.models.commons.model import PredictionStatus from rubrix.server.apis.v0.models.token_classification import ( EntitySpan, TokenClassificationAnnotation, @@ -220,3 +221,19 @@ def test_record_scores(): ), ) assert record.scores == [0.8, 0.1, 0.2] + + +def test_annotated_without_entities(): + text = "The text that i wrote" + record = TokenClassificationRecord( + text=text, + tokens=text.split(), + prediction=TokenClassificationAnnotation( + agent="pred.test", entities=[EntitySpan(start=0, end=3, label="DET")] + ), + annotation=TokenClassificationAnnotation(agent="test", entities=[]), + ) + + assert record.annotated_by == [record.annotation.agent] + assert record.predicted_by == [record.prediction.agent] + assert record.predicted == PredictionStatus.KO From ff387a1f3badb989e1398a0203a912d9c3ae7017 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 31 May 2022 14:27:37 +0200 Subject: [PATCH 3/3] fix: wrong predicted condition --- src/rubrix/server/apis/v0/models/text_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rubrix/server/apis/v0/models/text_classification.py b/src/rubrix/server/apis/v0/models/text_classification.py index 2d65bde003..006d4efde5 100644 --- a/src/rubrix/server/apis/v0/models/text_classification.py +++ b/src/rubrix/server/apis/v0/models/text_classification.py @@ -297,7 +297,7 @@ def task(cls) -> TaskType: @property def predicted(self) -> Optional[PredictionStatus]: - if self.predicted_by and self.predicted_by: + if self.predicted_by and self.annotated_by: return ( PredictionStatus.OK if set(self.predicted_as) == set(self.annotated_as)