Skip to content

Commit

Permalink
feat(#951): new uncovered_by_rules records filter (#991)
Browse files Browse the repository at this point in the history
* feat(api): new only_uncovered records filter

* feat: add uncovered_by_rules filter param

* feat(text-class): update query request model

* refactor: configure must_not filters

* test: add missing tests

* revert: list rules prior to generate query

* Apply suggestions from code review

Co-authored-by: David Fidalgo <david@recogn.ai>

* test: add functional test

Co-authored-by: David Fidalgo <david@recogn.ai>
(cherry picked from commit 164440b)
  • Loading branch information
frascuchon committed Feb 7, 2022
1 parent a1d6444 commit 162afa0
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 13 deletions.
9 changes: 7 additions & 2 deletions src/rubrix/client/sdk/text_classification/models.py
Expand Up @@ -19,12 +19,12 @@

from rubrix.client.models import (
TextClassificationRecord as ClientTextClassificationRecord,
TokenAttributions as ClientTokenAttributions,
)
from rubrix.client.models import TokenAttributions as ClientTokenAttributions
from rubrix.client.sdk.commons.models import (
MACHINE_NAME,
BaseAnnotation,
BaseRecord,
MACHINE_NAME,
PredictionStatus,
ScoreRange,
TaskStatus,
Expand Down Expand Up @@ -150,6 +150,11 @@ class TextClassificationQuery(BaseModel):
status: List[TaskStatus] = Field(default_factory=list)
predicted: Optional[PredictionStatus] = Field(default=None, nullable=True)

uncovered_by_rules: List[str] = Field(
default_factory=list,
description="List of rule queries that WILL NOT cover the resulting records",
)


class LabelingRule(BaseModel):
"""
Expand Down
33 changes: 22 additions & 11 deletions src/rubrix/server/tasks/text_classification/api/model.py
Expand Up @@ -473,6 +473,8 @@ class TextClassificationQuery(BaseModel):
List of task status
predicted: Optional[PredictionStatus]
The task prediction status
uncovered_by_rules:
Only return records that are NOT covered by these rules.
"""

Expand All @@ -489,6 +491,11 @@ class TextClassificationQuery(BaseModel):
status: List[TaskStatus] = Field(default_factory=list)
predicted: Optional[PredictionStatus] = Field(default=None, nullable=True)

uncovered_by_rules: List[str] = Field(
default_factory=list,
description="List of rule queries that WILL NOT cover the resulting records",
)

def as_elasticsearch(self) -> Dict[str, Any]:
"""Build an elasticsearch query part from search query"""

Expand All @@ -512,17 +519,21 @@ def as_elasticsearch(self) -> Dict[str, Any]:
query_text = filters.text_query(self.query_text)
all_filters.extend(query_filters)

return {
"bool": {
"must": query_text or {"match_all": {}},
"filter": {
"bool": {
"should": all_filters,
"minimum_should_match": len(all_filters),
}
},
}
}
return filters.boolean_filter(
must_query=query_text or {"match_all": {}},
must_not_query=filters.boolean_filter(
should_filters=[
filters.text_query(query) for query in self.uncovered_by_rules
]
)
if self.uncovered_by_rules
else None,
filter_query=filters.boolean_filter(
should_filters=all_filters, minimum_should_match=len(all_filters)
)
if all_filters
else None,
)


class TextClassificationSearchRequest(BaseModel):
Expand Down
16 changes: 16 additions & 0 deletions tests/server/text_classification/test_api_rules.py
Expand Up @@ -352,3 +352,19 @@ def test_rule_metric():
assert metrics.correct == 0
assert metrics.incorrect == 0
assert metrics.precision is None


def test_search_records_with_uncovered_rules():
dataset = "test_search_records_with_uncovered_rules"
log_some_records(dataset, annotation="OK")

response = client.post(
f"/api/datasets/{dataset}/TextClassification:search",
)
assert len(response.json()["records"]) == 1

response = client.post(
f"/api/datasets/{dataset}/TextClassification:search",
json={"query": {"uncovered_by_rules": ["texto"]}},
)
assert len(response.json()["records"]) == 0
64 changes: 64 additions & 0 deletions tests/server/text_classification/test_model.py
Expand Up @@ -16,8 +16,11 @@
from pydantic import ValidationError

from rubrix._constants import MAX_KEYWORD_LENGTH

from rubrix.server.commons.settings import settings
from rubrix.server.tasks.commons import TaskStatus
from rubrix.server.tasks.text_classification import TextClassificationQuery

from rubrix.server.tasks.text_classification.api import (
ClassPrediction,
PredictionStatus,
Expand Down Expand Up @@ -245,3 +248,64 @@ def test_validate_without_labels_for_single_label(annotation):
),
annotation=annotation,
)


def test_query_with_uncovered_by_rules():

query = TextClassificationQuery(uncovered_by_rules=["query", "other*"])
assert query.as_elasticsearch() == {
"bool": {
"must": {"match_all": {}},
"must_not": {
"bool": {
"should": [
{
"bool": {
"should": [
{
"query_string": {
"default_field": "words",
"default_operator": "AND",
"query": "query",
"boost": "2.0",
}
},
{
"query_string": {
"default_field": "words.extended",
"default_operator": "AND",
"query": "query",
}
},
],
"minimum_should_match": "50%",
}
},
{
"bool": {
"should": [
{
"query_string": {
"default_field": "words",
"default_operator": "AND",
"query": "other*",
"boost": "2.0",
}
},
{
"query_string": {
"default_field": "words.extended",
"default_operator": "AND",
"query": "other*",
}
},
],
"minimum_should_match": "50%",
}
},
],
"minimum_should_match": 1,
}
},
}
}

0 comments on commit 162afa0

Please sign in to comment.