Skip to content

Commit

Permalink
refactor: remove words references in searches (#1571)
Browse files Browse the repository at this point in the history
Close #945
  • Loading branch information
frascuchon committed Jun 20, 2022
1 parent 0ceda28 commit 604e24e
Show file tree
Hide file tree
Showing 13 changed files with 44 additions and 248 deletions.
10 changes: 6 additions & 4 deletions docs/guides/queries.md
Expand Up @@ -13,12 +13,14 @@ An important concept when searching with Elasticsearch is the *field* concept.
Every search term in Rubrix is directed to a specific field of the record's underlying data model.
For example, writing `text:fox` in the search bar will search for records with the word `fox` in the field `text`.

If you do not provide any fields in your query string, by default Rubrix will search in the fields `word` and `word.extended`.
If you do not provide any fields in your query string, by default Rubrix will search in the `text` field.
For a complete list of available fields and their content, have a look at the field glossary below.

```{note}
The default behavior when not specifying any fields in the search, will likely change in the near future.
We recommend emulating the future behavior by using the `text` field for your default searches, that is change `brown fox` to `text:(brown fox)`, for example.
The default behavior when not specifying any fields in the query string changed in version `>=0.16.0`.
Before this version, Rubrix searched in a mixture of the the deprecated `word` and `word.extended` fields that allowed searches for special characters like `!` and `.`.
If you want to search for special characters now, you have to spcify the `text.exact` field.
For example, this is the query if you want to search for words with an exclamation mark in the end: `text.exact:*\!`
```

## `text` and `text.exact`
Expand All @@ -40,7 +42,7 @@ Now consider these queries:
- `text:dog.` or `text:fox`: matches both of the records.
- `text.exact:dog` or `text.exact:FOX`: matches none of the records.
- `text.exact:dog.` or `text.exact:fox`: matches only the first record.
- `text.exact:DOG` or `text.exact:FOX!`: matches only the second record.
- `text.exact:DOG` or `text.exact:FOX\!`: matches only the second record.

You can see how the `text.exact` field can be used to search in a more fine-grained manner.

Expand Down
8 changes: 2 additions & 6 deletions src/rubrix/server/apis/v0/models/metrics/commons.py
Expand Up @@ -29,9 +29,7 @@ def record_metrics(cls, record: GenericRecord) -> Dict[str, Any]:
name="Text length distribution",
description="Computes the input text length distribution",
field="metrics.text_length",
# TODO(@frascuchon): This won't work once words is excluded from _source
# TODO: Implement changes with backward compatibility
script="params._source.words.length()",
script="params._source.text.length()",
fixed_interval=1,
),
TermsAggregation(
Expand All @@ -55,9 +53,7 @@ def record_metrics(cls, record: GenericRecord) -> Dict[str, Any]:
id="words_cloud",
name="Inputs words cloud",
description="The words cloud for dataset inputs",
# TODO(@frascuchon): This won't work once words is excluded from _source
# TODO: Implement changes with backward compatibility
default_field=EsRecordDataFieldNames.words,
default_field="text.wordcloud",
),
MetadataAggregations(id="metadata", name="Metadata fields stats"),
TermsAggregation(
Expand Down
2 changes: 1 addition & 1 deletion src/rubrix/server/apis/v0/models/text2text.py
Expand Up @@ -148,7 +148,7 @@ def extended_fields(self) -> Dict[str, Any]:
EsRecordDataFieldNames.annotated_by: self.annotated_by,
EsRecordDataFieldNames.predicted_by: self.predicted_by,
EsRecordDataFieldNames.score: self.scores,
"words": self.all_text(),
EsRecordDataFieldNames.words: self.all_text(),
}


Expand Down
3 changes: 2 additions & 1 deletion src/rubrix/server/apis/v0/models/text_classification.py
Expand Up @@ -25,6 +25,7 @@
BaseRecord,
BaseSearchResults,
BaseSearchResultsAggregations,
EsRecordDataFieldNames,
PredictionStatus,
ScoreRange,
SortableField,
Expand Down Expand Up @@ -444,7 +445,7 @@ def extended_fields(self) -> Dict[str, Any]:
words = self.all_text()
return {
**super().extended_fields(),
"words": words,
EsRecordDataFieldNames.words: words,
# This allow query by text:.... or text.exact:....
# Once words is remove we can normalize at record level
"text": words,
Expand Down
3 changes: 2 additions & 1 deletion src/rubrix/server/apis/v0/models/token_classification.py
Expand Up @@ -24,6 +24,7 @@
BaseRecord,
BaseSearchResults,
BaseSearchResultsAggregations,
EsRecordDataFieldNames,
PredictionStatus,
ScoreRange,
SortableField,
Expand Down Expand Up @@ -342,7 +343,7 @@ def extended_fields(self) -> Dict[str, Any]:
{"mention": mention, "entity": entity.label}
for mention, entity in self.annotated_mentions()
],
"words": self.all_text(),
EsRecordDataFieldNames.words: self.all_text(),
}


Expand Down
1 change: 0 additions & 1 deletion src/rubrix/server/daos/models/records.py
Expand Up @@ -36,7 +36,6 @@ class RecordSearch(BaseModel):
query: Optional[Dict[str, Any]] = None
sort: List[Dict[str, Any]] = Field(default_factory=list)
aggregations: Optional[Dict[str, Any]]
include_default_aggregations: bool = True


class RecordSearchResults(BaseModel):
Expand Down
27 changes: 0 additions & 27 deletions src/rubrix/server/daos/records.py
Expand Up @@ -233,25 +233,6 @@ def search_records(
f"No records index found for dataset {dataset.name}"
)

if compute_aggregations and search.include_default_aggregations:
current_aggrs = results.get("aggregations", {})
for aggr in [
aggregations.predicted_by(),
aggregations.annotated_by(),
aggregations.status(),
aggregations.predicted(),
aggregations.words_cloud(),
aggregations.score(),
aggregations.custom_fields(self.get_metadata_schema(dataset)),
]:
if aggr:
aggr_results = self._es.search(
index=records_index,
query={"query": es_query["query"], "aggs": aggr},
)
current_aggrs.update(aggr_results["aggregations"])
results["aggregations"] = current_aggrs

hits = results["hits"]
total = hits["total"]
docs = hits["hits"]
Expand All @@ -263,13 +244,6 @@ def search_records(
)
if search_aggregations:
parsed_aggregations = parse_aggregations(search_aggregations)

if search.include_default_aggregations:
parsed_aggregations = unflatten_dict(
parsed_aggregations, stop_keys=["metadata"]
)
result.words = parsed_aggregations.pop("words", {})
result.metadata = parsed_aggregations.pop("metadata", {})
result.aggregations = parsed_aggregations

return result
Expand Down Expand Up @@ -441,7 +415,6 @@ def __configure_query_highlight__(cls, task: TaskType):
"text": {},
"text.*": {},
# TODO(@frascuchon): `words` will be removed in version 0.16.0
"words": {},
**({"inputs.*": {}} if task == TaskType.text_classification else {}),
},
}
Expand Down
159 changes: 7 additions & 152 deletions src/rubrix/server/elasticseach/query_helpers.py
Expand Up @@ -244,31 +244,13 @@ def text_query(text_query: Optional[str]) -> Dict[str, Any]:
if text_query is None:
return filters.match_all()
return filters.boolean_filter(
should_filters=[
{
"query_string": {
"default_field": EsRecordDataFieldNames.words,
"default_operator": "AND",
"query": text_query,
"boost": "2.0",
}
},
{
"query_string": {
"default_field": f"{EsRecordDataFieldNames.words}.extended",
"default_operator": "AND",
"query": text_query,
}
},
{
"query_string": {
"default_field": "text",
"default_operator": "AND",
"query": text_query,
}
},
],
minimum_should_match="30%",
must_query={
"query_string": {
"default_field": "text",
"default_operator": "AND",
"query": text_query,
}
},
)

@staticmethod
Expand Down Expand Up @@ -363,45 +345,6 @@ def histogram_aggregation(
},
}

@staticmethod
def predicted_by(size: int = DEFAULT_AGGREGATION_SIZE):
"""Predicted by aggregation"""
return {
"predicted_by": {
"terms": {
"field": decode_field_name(EsRecordDataFieldNames.predicted_by),
"size": size,
"order": {"_count": "desc"},
}
}
}

@staticmethod
def annotated_by(size: int = DEFAULT_AGGREGATION_SIZE):
"""Annotated by aggregation"""
return {
"annotated_by": {
"terms": {
"field": decode_field_name(EsRecordDataFieldNames.annotated_by),
"size": size,
"order": {"_count": "desc"},
}
}
}

@staticmethod
def status(size: int = DEFAULT_AGGREGATION_SIZE):
"""Status aggregation"""
return {
"status": {
"terms": {
"field": decode_field_name(EsRecordDataFieldNames.status),
"size": size,
"order": {"_count": "desc"},
}
}
}

@staticmethod
def custom_fields(
fields_definitions: Dict[str, str],
Expand Down Expand Up @@ -431,94 +374,6 @@ def __resolve_aggregation_for_field_type(
if aggregation
}

@staticmethod
def words_cloud(size: int = DEFAULT_AGGREGATION_SIZE):
"""Words cloud aggregation"""
return {
"words": {
"terms": {
"field": EsRecordDataFieldNames.words,
"size": size,
"order": {"_count": "desc"},
}
}
}

@staticmethod
def predicted_as(size: int = DEFAULT_AGGREGATION_SIZE):
"""Predicted as aggregation"""
return {
"predicted_as": {
"terms": {
"field": decode_field_name(EsRecordDataFieldNames.predicted_as),
"size": size,
"order": {"_count": "desc"},
}
}
}

@staticmethod
def annotated_as(size: int = DEFAULT_AGGREGATION_SIZE):
"""Annotated as aggregation"""

return {
"annotated_as": {
"terms": {
"field": decode_field_name(EsRecordDataFieldNames.annotated_as),
"size": size,
"order": {"_count": "desc"},
}
}
}

@staticmethod
def predicted(size: int = DEFAULT_AGGREGATION_SIZE):
"""Predicted aggregation"""
return {
"predicted": {
"terms": {
"field": decode_field_name(EsRecordDataFieldNames.predicted),
"size": size,
"order": {"_count": "desc"},
}
}
}

@staticmethod
def score(range_from: float = 0.0, range_to: float = 1.0, interval: float = 0.05):
decimals = 0
_interval = interval
while _interval < 1:
_interval *= 10
decimals += 1

ten_decimals = math.pow(10, decimals)

int_from = math.floor(range_from * ten_decimals)
int_to = math.floor(range_to * ten_decimals)
int_interval = math.floor(interval * ten_decimals)

return {
"score": {
"range": {
"field": EsRecordDataFieldNames.score,
"keyed": True,
"ranges": [
{"from": _from / ten_decimals, "to": _to / ten_decimals}
for _from, _to in zip(
range(int_from, int_to, int_interval),
range(
int_from + int_interval,
int_to + int_interval,
int_interval,
),
)
]
+ [{"from": range_to}],
}
}
}


def find_nested_field_path(
field_name: str, mapping_definition: Dict[str, Any]
Expand Down
1 change: 0 additions & 1 deletion src/rubrix/server/services/metrics.py
Expand Up @@ -173,7 +173,6 @@ def __metric_results__(
search=RecordSearch(
query=self.__query_builder__(dataset, query=query),
aggregations=agg,
include_default_aggregations=False,
),
)
results.update(results_.aggregations)
Expand Down
1 change: 0 additions & 1 deletion src/rubrix/server/services/search/service.py
Expand Up @@ -86,7 +86,6 @@ def search(
EsRecordDataFieldNames.event_timestamp,
],
),
include_default_aggregations=False,
),
size=size,
record_from=record_from,
Expand Down
Expand Up @@ -207,7 +207,6 @@ def compute_rule_metrics(
dataset,
size=0,
search=RecordSearch(
include_default_aggregations=False,
aggregations=self.__rule_metrics__.aggregation_request(
rule_query=rule_query, labels=labels
),
Expand All @@ -227,7 +226,6 @@ def _count_annotated_records(self, dataset: TextClassificationDatasetDB) -> int:
size=0,
search=RecordSearch(
query=filters.exists_field(EsRecordDataFieldNames.annotated_as),
include_default_aggregations=False,
),
)
return results.total
Expand All @@ -240,7 +238,6 @@ def all_rules_metrics(
dataset,
size=0,
search=RecordSearch(
include_default_aggregations=False,
aggregations=self.__dataset_rules_metrics__.aggregation_request(
all_rules=dataset.rules
),
Expand Down

0 comments on commit 604e24e

Please sign in to comment.