diff --git a/src/rubrix/client/sdk/text_classification/models.py b/src/rubrix/client/sdk/text_classification/models.py index 634c992d41..019178431b 100644 --- a/src/rubrix/client/sdk/text_classification/models.py +++ b/src/rubrix/client/sdk/text_classification/models.py @@ -19,12 +19,12 @@ from rubrix.client.models import ( TextClassificationRecord as ClientTextClassificationRecord, - TokenAttributions as ClientTokenAttributions, ) +from rubrix.client.models import TokenAttributions as ClientTokenAttributions from rubrix.client.sdk.commons.models import ( + MACHINE_NAME, BaseAnnotation, BaseRecord, - MACHINE_NAME, PredictionStatus, ScoreRange, TaskStatus, @@ -150,6 +150,11 @@ class TextClassificationQuery(BaseModel): status: List[TaskStatus] = Field(default_factory=list) predicted: Optional[PredictionStatus] = Field(default=None, nullable=True) + uncovered_by_rules: List[str] = Field( + default_factory=list, + description="List of rule queries that WILL NOT cover the resulting records", + ) + class LabelingRule(BaseModel): """ diff --git a/src/rubrix/server/tasks/text_classification/api/model.py b/src/rubrix/server/tasks/text_classification/api/model.py index d45b8f3cc0..057a52b327 100644 --- a/src/rubrix/server/tasks/text_classification/api/model.py +++ b/src/rubrix/server/tasks/text_classification/api/model.py @@ -473,6 +473,8 @@ class TextClassificationQuery(BaseModel): List of task status predicted: Optional[PredictionStatus] The task prediction status + uncovered_by_rules: + Only return records that are NOT covered by these rules. """ @@ -489,6 +491,11 @@ class TextClassificationQuery(BaseModel): status: List[TaskStatus] = Field(default_factory=list) predicted: Optional[PredictionStatus] = Field(default=None, nullable=True) + uncovered_by_rules: List[str] = Field( + default_factory=list, + description="List of rule queries that WILL NOT cover the resulting records", + ) + def as_elasticsearch(self) -> Dict[str, Any]: """Build an elasticsearch query part from search query""" @@ -512,17 +519,21 @@ def as_elasticsearch(self) -> Dict[str, Any]: query_text = filters.text_query(self.query_text) all_filters.extend(query_filters) - return { - "bool": { - "must": query_text or {"match_all": {}}, - "filter": { - "bool": { - "should": all_filters, - "minimum_should_match": len(all_filters), - } - }, - } - } + return filters.boolean_filter( + must_query=query_text or {"match_all": {}}, + must_not_query=filters.boolean_filter( + should_filters=[ + filters.text_query(query) for query in self.uncovered_by_rules + ] + ) + if self.uncovered_by_rules + else None, + filter_query=filters.boolean_filter( + should_filters=all_filters, minimum_should_match=len(all_filters) + ) + if all_filters + else None, + ) class TextClassificationSearchRequest(BaseModel): diff --git a/tests/server/text_classification/test_api_rules.py b/tests/server/text_classification/test_api_rules.py index a9f77e59db..ae5a6a2e16 100644 --- a/tests/server/text_classification/test_api_rules.py +++ b/tests/server/text_classification/test_api_rules.py @@ -352,3 +352,19 @@ def test_rule_metric(): assert metrics.correct == 0 assert metrics.incorrect == 0 assert metrics.precision is None + + +def test_search_records_with_uncovered_rules(): + dataset = "test_search_records_with_uncovered_rules" + log_some_records(dataset, annotation="OK") + + response = client.post( + f"/api/datasets/{dataset}/TextClassification:search", + ) + assert len(response.json()["records"]) == 1 + + response = client.post( + f"/api/datasets/{dataset}/TextClassification:search", + json={"query": {"uncovered_by_rules": ["texto"]}}, + ) + assert len(response.json()["records"]) == 0 diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py index 3d5e755b4c..658e77fbc2 100644 --- a/tests/server/text_classification/test_model.py +++ b/tests/server/text_classification/test_model.py @@ -16,8 +16,11 @@ from pydantic import ValidationError from rubrix._constants import MAX_KEYWORD_LENGTH + from rubrix.server.commons.settings import settings from rubrix.server.tasks.commons import TaskStatus +from rubrix.server.tasks.text_classification import TextClassificationQuery + from rubrix.server.tasks.text_classification.api import ( ClassPrediction, PredictionStatus, @@ -245,3 +248,64 @@ def test_validate_without_labels_for_single_label(annotation): ), annotation=annotation, ) + + +def test_query_with_uncovered_by_rules(): + + query = TextClassificationQuery(uncovered_by_rules=["query", "other*"]) + assert query.as_elasticsearch() == { + "bool": { + "must": {"match_all": {}}, + "must_not": { + "bool": { + "should": [ + { + "bool": { + "should": [ + { + "query_string": { + "default_field": "words", + "default_operator": "AND", + "query": "query", + "boost": "2.0", + } + }, + { + "query_string": { + "default_field": "words.extended", + "default_operator": "AND", + "query": "query", + } + }, + ], + "minimum_should_match": "50%", + } + }, + { + "bool": { + "should": [ + { + "query_string": { + "default_field": "words", + "default_operator": "AND", + "query": "other*", + "boost": "2.0", + } + }, + { + "query_string": { + "default_field": "words.extended", + "default_operator": "AND", + "query": "other*", + } + }, + ], + "minimum_should_match": "50%", + } + }, + ], + "minimum_should_match": 1, + } + }, + } + }