diff --git a/src/rubrix/server/commons/es_helpers.py b/src/rubrix/server/commons/es_helpers.py index 8b66671ee5..ba91537aee 100644 --- a/src/rubrix/server/commons/es_helpers.py +++ b/src/rubrix/server/commons/es_helpers.py @@ -25,22 +25,23 @@ TaskStatus, ) from rubrix.server.tasks.commons.api import EsRecordDataFieldNames +from rubrix.server.tasks.commons.dao.es_config import mappings def nested_mappings_from_base_model(model_class: Type[BaseModel]) -> Dict[str, Any]: - def resolve_type(info): + def resolve_mapping(info) -> Dict[str, Any]: the_type = info.get("type") if the_type == "number": - return "float" + return {"type": "float"} if the_type == "integer": - return "integer" - return "keyword" + return {"type": "integer"} + return mappings.keyword_field(enable_text_search=True) return { "type": "nested", "include_in_root": True, "properties": { - key: {"type": resolve_type(info)} + key: resolve_mapping(info) for key, info in model_class.schema()["properties"].items() }, } diff --git a/src/rubrix/server/tasks/commons/dao/es_config.py b/src/rubrix/server/tasks/commons/dao/es_config.py index 656284856b..f4885bed65 100644 --- a/src/rubrix/server/tasks/commons/dao/es_config.py +++ b/src/rubrix/server/tasks/commons/dao/es_config.py @@ -12,21 +12,28 @@ class mappings: @staticmethod - def keyword_field(): + def keyword_field(enable_text_search: bool = False): """Mappings config for keyword field""" - return { + mapping = { "type": "keyword", # TODO: Use environment var and align with fields validators "ignore_above": MAX_KEYWORD_LENGTH, } + if enable_text_search: + mapping["fields"] = {"text": mappings.text_field()} + return mapping @staticmethod - def path_match_keyword_template(path: str): + def path_match_keyword_template( + path: str, enable_text_search_in_keywords: bool = False + ): """Dynamic template mappings config for keyword field based on path match""" return { "path_match": path, "match_mapping_type": "string", - "mapping": mappings.keyword_field(), + "mapping": mappings.keyword_field( + enable_text_search=enable_text_search_in_keywords + ), } @staticmethod @@ -134,7 +141,11 @@ def dynamic_metrics_text(): def dynamic_metadata_text(): - return {"metadata.*": mappings.path_match_keyword_template(path="metadata.*")} + return { + "metadata.*": mappings.path_match_keyword_template( + path="metadata.*", enable_text_search_in_keywords=True + ) + } def tasks_common_mappings(): @@ -152,8 +163,8 @@ def tasks_common_mappings(): "status": mappings.keyword_field(), "event_timestamp": {"type": "date"}, "last_updated": {"type": "date"}, - "annotated_by": mappings.keyword_field(), - "predicted_by": mappings.keyword_field(), + "annotated_by": mappings.keyword_field(enable_text_search=True), + "predicted_by": mappings.keyword_field(enable_text_search=True), "metrics": {"dynamic": True, "type": "object"}, "metadata": {"dynamic": True, "type": "object"}, }, diff --git a/src/rubrix/server/tasks/text_classification/dao/es_config.py b/src/rubrix/server/tasks/text_classification/dao/es_config.py index 11785a724c..6500dc2c9e 100644 --- a/src/rubrix/server/tasks/text_classification/dao/es_config.py +++ b/src/rubrix/server/tasks/text_classification/dao/es_config.py @@ -27,8 +27,8 @@ def text_classification_mappings(): }, "predicted": mappings.keyword_field(), "multi_label": {"type": "boolean"}, - "annotated_as": mappings.keyword_field(), - "predicted_as": mappings.keyword_field(), + "annotated_as": mappings.keyword_field(enable_text_search=True), + "predicted_as": mappings.keyword_field(enable_text_search=True), "score": mappings.decimal_field(), }, "dynamic_templates": [ diff --git a/src/rubrix/server/tasks/token_classification/dao/es_config.py b/src/rubrix/server/tasks/token_classification/dao/es_config.py index 8bc96affd1..87322257b0 100644 --- a/src/rubrix/server/tasks/token_classification/dao/es_config.py +++ b/src/rubrix/server/tasks/token_classification/dao/es_config.py @@ -19,7 +19,7 @@ def mentions_mappings(): def token_classification_mappings(): metrics_mentions_mappings = nested_mappings_from_base_model(MentionMetrics) - _mentions_mappings = mentions_mappings() + _mentions_mappings = mentions_mappings() # TODO: remove return { "_source": mappings.source( excludes=[ @@ -36,11 +36,11 @@ def token_classification_mappings(): ), "properties": { "predicted": mappings.keyword_field(), - "annotated_as": mappings.keyword_field(), - "predicted_as": mappings.keyword_field(), + "annotated_as": mappings.keyword_field(enable_text_search=True), + "predicted_as": mappings.keyword_field(enable_text_search=True), "score": {"type": "float"}, - "predicted_mentions": _mentions_mappings, - "mentions": _mentions_mappings, + "predicted_mentions": _mentions_mappings, # TODO: remove + "mentions": _mentions_mappings, # TODO: remove "tokens": mappings.keyword_field(), "metrics.tokens": nested_mappings_from_base_model(TokenMetrics), "metrics.predicted.mentions": metrics_mentions_mappings,