Skip to content

Commit

Permalink
refactor(#945): unify query to elasticsearch generation (#1230)
Browse files Browse the repository at this point in the history
* refactor(#945): unify query to elasticsearch generation

* test: update tests

* chore: add TODOs
  • Loading branch information
frascuchon committed Mar 8, 2022
1 parent 6899272 commit 630091f
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 163 deletions.
69 changes: 23 additions & 46 deletions src/rubrix/server/commons/es_helpers.py
Expand Up @@ -18,12 +18,7 @@

from pydantic import BaseModel

from rubrix.server.tasks.commons import (
PredictionStatus,
ScoreRange,
SortableField,
TaskStatus,
)
from rubrix.server.tasks.commons import SortableField, TaskStatus
from rubrix.server.tasks.commons.api import EsRecordDataFieldNames
from rubrix.server.tasks.commons.dao.es_config import mappings

Expand Down Expand Up @@ -216,42 +211,35 @@ def metadata(metadata: Dict[str, Union[str, List[str]]]) -> List[Dict[str, Any]]
]

@staticmethod
def predicted_as(predicted_as: List[str] = None) -> Optional[Dict[str, Any]]:
"""Filter records with given predicted as terms"""
if not predicted_as:
def terms_filter(field: str, values: List[Any]) -> Optional[Dict[str, Any]]:
if not values:
return None
return {
"terms": {
decode_field_name(EsRecordDataFieldNames.predicted_as): predicted_as
}
}
return {"terms": {field: values}}

@staticmethod
def annotated_as(annotated_as: List[str] = None) -> Optional[Dict[str, Any]]:
"""Filter records with given predicted as terms"""

if not annotated_as:
def term_filter(field: str, value: Any) -> Optional[Dict[str, Any]]:
if value is None:
return None
return {
"terms": {
decode_field_name(EsRecordDataFieldNames.annotated_as): annotated_as
}
}
return {"term": {field: value}}

@staticmethod
def predicted(predicted: PredictionStatus = None) -> Optional[Dict[str, Any]]:
"""Filter records with given predicted status"""
if predicted is None:
def range_filter(
field: str, value_from: Optional[Any] = None, value_to: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
filter_data = {}
if value_from is not None:
filter_data["gte"] = value_from
if value_to is not None:
filter_data["lte"] = value_to
if not filter_data:
return None
return {
"term": {decode_field_name(EsRecordDataFieldNames.predicted): predicted}
}
return {"range": {field: filter_data}}

@staticmethod
def text_query(text_query: Optional[str]) -> Dict[str, Any]:
"""Filter records matching text query"""
if text_query is None:
return {"match_all": {}}
return filters.match_all()
return filters.boolean_filter(
should_filters=[
{
Expand All @@ -274,24 +262,13 @@ def text_query(text_query: Optional[str]) -> Dict[str, Any]:
)

@staticmethod
def score(
score: Optional[ScoreRange],
) -> Optional[Dict[str, Any]]:
if score is None:
return None

score_filter = {}
if score.range_from is not None:
score_filter["gte"] = score.range_from
if score.range_to is not None:
score_filter["lte"] = score.range_to

return {"range": {EsRecordDataFieldNames.score: score_filter}}

@classmethod
def match_all(cls):
def match_all():
return {"match_all": {}}

@staticmethod
def ids_filter(ids: List[str]):
return {"ids": {"values": ids}}


class aggregations:
"""Group of functions related to elasticsearch aggregations requests"""
Expand Down
6 changes: 5 additions & 1 deletion src/rubrix/server/tasks/commons/api/model.py
Expand Up @@ -348,11 +348,15 @@ class BaseSearchResults(GenericModel, Generic[Record, Aggregations]):
aggregations: Aggregations = None


class ScoreRange(BaseModel):
class QueryRange(BaseModel):
"""Score range filter"""

range_from: float = Field(default=0.0, alias="from")
range_to: float = Field(default=None, alias="to")

class Config:
allow_population_by_field_name = True


class ScoreRange(QueryRange):
pass
2 changes: 2 additions & 0 deletions src/rubrix/server/tasks/commons/metrics/commons.py
Expand Up @@ -27,6 +27,7 @@ def record_metrics(cls, record: GenericRecord) -> Dict[str, Any]:
description="Computes the input text length distribution",
field="metrics.text_length",
# TODO(@frascuchon): This won't work once words is excluded from _source
# TODO: Implement changes with backward compatibility
script="params._source.words.length()",
fixed_interval=1,
),
Expand All @@ -52,6 +53,7 @@ def record_metrics(cls, record: GenericRecord) -> Dict[str, Any]:
name="Inputs words cloud",
description="The words cloud for dataset inputs",
# TODO(@frascuchon): This won't work once words is excluded from _source
# TODO: Implement changes with backward compatibility
default_field=EsRecordDataFieldNames.words,
),
MetadataAggregations(id="metadata", name="Metadata fields stats"),
Expand Down
7 changes: 1 addition & 6 deletions src/rubrix/server/tasks/search/model.py
@@ -1,7 +1,6 @@
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from typing import Any, Dict, List, Optional, TypeVar, Union

from pydantic import BaseModel, Field
from pydantic.generics import GenericModel

from rubrix.server.tasks.commons import BaseRecord, SortableField, TaskStatus

Expand Down Expand Up @@ -45,10 +44,6 @@ class BaseSearchQuery(BaseModel):

metadata: Optional[Dict[str, Union[str, List[str]]]] = None

def as_elasticsearch(self) -> Dict[str, Any]:
# TODO: Hide transformations in DAO component
raise NotImplementedError()


class SortConfig(BaseModel):
shuffle: bool = False
Expand Down
49 changes: 47 additions & 2 deletions src/rubrix/server/tasks/search/query_builder.py
@@ -1,11 +1,14 @@
from enum import Enum
from typing import Any, Dict, Optional, TypeVar

from fastapi import Depends
from luqum.elasticsearch import ElasticsearchQueryBuilder, SchemaAnalyzer
from luqum.parser import parser

from rubrix.server.commons import es_helpers
from rubrix.server.commons.es_helpers import filters
from rubrix.server.datasets.model import BaseDatasetDB, Dataset
from rubrix.server.tasks.commons import QueryRange
from rubrix.server.tasks.commons.dao.dao import DatasetRecordsDAO
from rubrix.server.tasks.search.model import BaseSearchQuery

Expand Down Expand Up @@ -34,7 +37,7 @@ def __call__(
return es_helpers.filters.match_all()

if not query.advanced_query_dsl or not query.query_text:
return query.as_elasticsearch()
return self.to_es_query(query)

text_search = query.query_text
new_query = query.copy(update={"query_text": None})
Expand All @@ -52,5 +55,47 @@ def __call__(
query_text = es_query_builder(query_tree)

return es_helpers.filters.boolean_filter(
filter_query=new_query.as_elasticsearch(), must_query=query_text
filter_query=self.to_es_query(new_query), must_query=query_text
)

@staticmethod
def to_es_query(query: BaseSearchQuery) -> Dict[str, Any]:
if query.ids:
return filters.ids_filter(query.ids)

query_text = filters.text_query(query.query_text)
all_filters = filters.metadata(query.metadata)
query_data = query.dict(
exclude={"query_text", "metadata", "uncovered_by_rules"}
)
for key, value in query_data.items():
if value is None:
continue
key_filter = None
if isinstance(value, list):
key_filter = filters.terms_filter(key, value)
elif isinstance(value, (str, Enum)):
key_filter = filters.term_filter(key, value)
elif isinstance(value, QueryRange):
key_filter = filters.range_filter(
field=key, value_from=value.range_from, value_to=value.range_to
)
else:
print("Ups...", key, value)

if key_filter:
all_filters.append(key_filter)

return filters.boolean_filter(
must_query=query_text or filters.match_all(),
filter_query=filters.boolean_filter(
should_filters=all_filters, minimum_should_match=len(all_filters)
)
if all_filters
else None,
must_not_query=filters.boolean_filter(
should_filters=[filters.text_query(q) for q in query.uncovered_by_rules]
)
if hasattr(query, "uncovered_by_rules") and query.uncovered_by_rules
else None,
)
34 changes: 0 additions & 34 deletions src/rubrix/server/tasks/text2text/api/model.py
Expand Up @@ -19,7 +19,6 @@

from pydantic import BaseModel, Field, validator

from rubrix.server.commons.es_helpers import filters
from rubrix.server.datasets.model import DatasetDB, UpdateDatasetRequest
from rubrix.server.tasks.commons.api.model import (
BaseAnnotation,
Expand Down Expand Up @@ -198,39 +197,6 @@ class Text2TextQuery(BaseSearchQuery):
score: Optional[ScoreRange] = Field(default=None)
predicted: Optional[PredictionStatus] = Field(default=None, nullable=True)

def as_elasticsearch(self) -> Dict[str, Any]:
"""Build an elasticsearch query part from search query"""

if self.ids:
return {"ids": {"values": self.ids}}

all_filters = filters.metadata(self.metadata)
query_filters = [
query_filter
for query_filter in [
filters.predicted_by(self.predicted_by),
filters.annotated_by(self.annotated_by),
filters.status(self.status),
filters.predicted(self.predicted),
filters.score(self.score),
]
if query_filter
]
query_text = filters.text_query(self.query_text)
all_filters.extend(query_filters)

return {
"bool": {
"must": query_text or {"match_all": {}},
"filter": {
"bool": {
"should": all_filters,
"minimum_should_match": len(all_filters),
}
},
}
}


class Text2TextSearchRequest(BaseModel):
"""
Expand Down
40 changes: 1 addition & 39 deletions src/rubrix/server/tasks/text_classification/api/model.py
Expand Up @@ -28,6 +28,7 @@
BaseSearchResults,
BaseSearchResultsAggregations,
PredictionStatus,
QueryRange,
ScoreRange,
SortableField,
TaskStatus,
Expand Down Expand Up @@ -514,45 +515,6 @@ class TextClassificationQuery(BaseSearchQuery):
description="List of rule queries that WILL NOT cover the resulting records",
)

def as_elasticsearch(self) -> Dict[str, Any]:
"""Build an elasticsearch query part from search query"""

if self.ids:
return {"ids": {"values": self.ids}}

all_filters = filters.metadata(self.metadata)
query_filters = [
query_filter
for query_filter in [
filters.predicted_as(self.predicted_as),
filters.predicted_by(self.predicted_by),
filters.annotated_as(self.annotated_as),
filters.annotated_by(self.annotated_by),
filters.status(self.status),
filters.predicted(self.predicted),
filters.score(self.score),
]
if query_filter
]
query_text = filters.text_query(self.query_text)
all_filters.extend(query_filters)

return filters.boolean_filter(
must_query=query_text or {"match_all": {}},
must_not_query=filters.boolean_filter(
should_filters=[
filters.text_query(query) for query in self.uncovered_by_rules
]
)
if self.uncovered_by_rules
else None,
filter_query=filters.boolean_filter(
should_filters=all_filters, minimum_should_match=len(all_filters)
)
if all_filters
else None,
)


class TextClassificationSearchRequest(BaseModel):
"""
Expand Down
Expand Up @@ -69,7 +69,9 @@ def aggregation_request(

if labels is not None:
for label in labels:
rule_label_annotated_filter = filters.annotated_as([label])
rule_label_annotated_filter = filters.term_filter(
"annotated_as", value=label
)
encoded_label = self._encode_label_name(label)
aggr_filters.update(
{
Expand Down
30 changes: 0 additions & 30 deletions src/rubrix/server/tasks/token_classification/api/model.py
Expand Up @@ -19,7 +19,6 @@
from pydantic import BaseModel, Field, root_validator, validator

from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.commons.es_helpers import filters
from rubrix.server.datasets.model import DatasetDB, UpdateDatasetRequest
from rubrix.server.tasks.commons import (
BaseAnnotation,
Expand Down Expand Up @@ -388,35 +387,6 @@ class TokenClassificationQuery(BaseSearchQuery):
score: Optional[ScoreRange] = Field(default=None)
predicted: Optional[PredictionStatus] = Field(default=None, nullable=True)

def as_elasticsearch(self) -> Dict[str, Any]:
"""Build an elasticsearch query part from search query"""

if self.ids:
return {"ids": {"values": self.ids}}

all_filters = filters.metadata(self.metadata)
query_filters = [
query_filter
for query_filter in [
filters.predicted_as(self.predicted_as),
filters.predicted_by(self.predicted_by),
filters.annotated_as(self.annotated_as),
filters.annotated_by(self.annotated_by),
filters.status(self.status),
filters.predicted(self.predicted),
filters.score(self.score),
]
if query_filter
]
query_text = filters.text_query(self.query_text)
all_filters.extend(query_filters)

return filters.boolean_filter(
must_query=query_text,
should_filters=all_filters,
minimum_should_match=len(all_filters),
)


class TokenClassificationSearchRequest(BaseModel):

Expand Down
3 changes: 2 additions & 1 deletion tests/server/text2text/test_model.py
@@ -1,3 +1,4 @@
from rubrix.server.tasks.search.query_builder import EsQueryBuilder
from rubrix.server.tasks.text2text import (
Text2TextAnnotation,
Text2TextPrediction,
Expand Down Expand Up @@ -56,4 +57,4 @@ def test_model_dict():

def test_query_as_elasticsearch():
query = Text2TextQuery(ids=[1, 2, 3])
assert query.as_elasticsearch() == {"ids": {"values": query.ids}}
assert EsQueryBuilder.to_es_query(query) == {"ids": {"values": query.ids}}

0 comments on commit 630091f

Please sign in to comment.