Skip to content

Commit

Permalink
fix(#1527): check agents instead labels for predicted computation (#…
Browse files Browse the repository at this point in the history
…1528)

(cherry picked from commit 2d0612d)
  • Loading branch information
frascuchon committed Jun 7, 2022
1 parent 147d38a commit 2f2ee2e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/rubrix/server/tasks/text_classification/api/model.py
Expand Up @@ -299,7 +299,7 @@ def task(cls) -> TaskType:

@property
def predicted(self) -> Optional[PredictionStatus]:
if self.predicted_as and self.annotated_as:
if self.predicted_by and self.annotated_by:
return (
PredictionStatus.OK
if set(self.predicted_as) == set(self.annotated_as)
Expand Down
31 changes: 31 additions & 0 deletions tests/server/text_classification/test_model.py
Expand Up @@ -352,3 +352,34 @@ def test_query_with_uncovered_by_rules():
},
}
}


def test_empty_labels_for_no_multilabel():
with pytest.raises(
ValidationError,
match="Single label record must include only one annotation label",
):
TextClassificationRecord(
inputs={"text": "The input text"},
annotation=TextClassificationAnnotation(agent="ann.", labels=[]),
)

record = TextClassificationRecord(
inputs={"text": "The input text"},
prediction=TextClassificationAnnotation(agent="ann.", labels=[]),
annotation=TextClassificationAnnotation(
agent="ann.", labels=[ClassPrediction(class_label="B")]
),
)
assert record.predicted == PredictionStatus.KO


def test_annotated_without_labels_for_multilabel():
record = TextClassificationRecord(
inputs={"text": "The input text"},
multi_label=True,
prediction=TextClassificationAnnotation(agent="pred.", labels=[]),
annotation=TextClassificationAnnotation(agent="ann.", labels=[]),
)

assert record.predicted == PredictionStatus.OK
20 changes: 18 additions & 2 deletions tests/server/token_classification/test_model.py
Expand Up @@ -17,8 +17,8 @@
from pydantic import ValidationError

from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.tasks.search.query_builder import EsQueryBuilder
from rubrix.server.tasks.token_classification.api.model import (
from rubrix.server.apis.v0.models.commons.model import PredictionStatus
from rubrix.server.apis.v0.models.token_classification import (
EntitySpan,
TokenClassificationAnnotation,
TokenClassificationQuery,
Expand Down Expand Up @@ -220,3 +220,19 @@ def test_record_scores():
),
)
assert record.scores == [0.8, 0.1, 0.2]


def test_annotated_without_entities():
text = "The text that i wrote"
record = TokenClassificationRecord(
text=text,
tokens=text.split(),
prediction=TokenClassificationAnnotation(
agent="pred.test", entities=[EntitySpan(start=0, end=3, label="DET")]
),
annotation=TokenClassificationAnnotation(agent="test", entities=[]),
)

assert record.annotated_by == [record.annotation.agent]
assert record.predicted_by == [record.prediction.agent]
assert record.predicted == PredictionStatus.KO

0 comments on commit 2f2ee2e

Please sign in to comment.