Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(#951): new uncovered_by_rules records filter #991

Merged
merged 9 commits into from Jan 19, 2022
16 changes: 15 additions & 1 deletion src/rubrix/server/tasks/text_classification/api/model.py
Expand Up @@ -489,7 +489,12 @@ class TextClassificationQuery(BaseModel):
status: List[TaskStatus] = Field(default_factory=list)
predicted: Optional[PredictionStatus] = Field(default=None, nullable=True)

def as_elasticsearch(self) -> Dict[str, Any]:
only_uncovered: bool = Field(
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
default=False,
description="If enabled, filter records that are not affected by defined rules",
)

def as_elasticsearch(self, rules: List[LabelingRule]) -> Dict[str, Any]:
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
"""Build an elasticsearch query part from search query"""

if self.ids:
Expand All @@ -506,6 +511,15 @@ def as_elasticsearch(self) -> Dict[str, Any]:
filters.status(self.status),
filters.predicted(self.predicted),
filters.score(self.score),
filters.boolean_filter(
must_not_query=filters.boolean_filter(
should_filters=[
filters.text_query(rule.query) for rule in rules
]
)
)
if self.only_uncovered and rules
else None,
]
if query_filter
]
Expand Down
Expand Up @@ -135,10 +135,11 @@ def search(
The matched records with aggregation info for specified task_meta.py

"""
rules = self.__labeling__.list_rules(dataset)
results = self.__dao__.search_records(
dataset,
search=RecordSearch(
query=query.as_elasticsearch(),
query=query.as_elasticsearch(rules),
sort=sort_by2elasticsearch(
sort_by,
valid_fields=[
Expand Down Expand Up @@ -187,8 +188,9 @@ def read_dataset(
the provided query filters. Optional

"""
rules = self.__labeling__.list_rules(dataset)
for db_record in self.__dao__.scan_dataset(
dataset, search=RecordSearch(query=query.as_elasticsearch())
dataset, search=RecordSearch(query=query.as_elasticsearch(rules))
):
yield TextClassificationRecord.parse_obj(db_record)

Expand Down