From 2dd5853152ff5051e16acc7f714eb4dc1552beb0 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 25 Feb 2022 09:52:36 +0100 Subject: [PATCH] feat(#950): include search keywords as part of record results (#1201) * chore: include search_keywords in client records * chore: signatures * feat: include search_records as part of client records * fix: add highlight on dataset scan * test: add missing tests * test: estabilize tests * Apply suggestions from code review Co-authored-by: David Fidalgo * test: try to fix push to hf hub Co-authored-by: David Fidalgo (cherry picked from commit 0678043e33344996d798c8cb5d80a176e57a7275) --- src/rubrix/client/models.py | 14 +++++- src/rubrix/client/sdk/commons/models.py | 1 + src/rubrix/client/sdk/text2text/models.py | 1 + .../client/sdk/text_classification/models.py | 1 + .../client/sdk/token_classification/models.py | 3 +- src/rubrix/server/tasks/commons/api/model.py | 7 +++ src/rubrix/server/tasks/commons/dao/dao.py | 50 ++++++++++++++++--- .../test_log_for_text_classification.py | 25 ++++++++++ .../test_log_for_token_classification.py | 25 ++++++++++ 9 files changed, 118 insertions(+), 9 deletions(-) diff --git a/src/rubrix/client/models.py b/src/rubrix/client/models.py index 8de0d9dacd..4527c8e29c 100644 --- a/src/rubrix/client/models.py +++ b/src/rubrix/client/models.py @@ -147,7 +147,9 @@ class TextClassificationRecord(_Validators): metrics: READ ONLY! Metrics at record level provided by the server when using `rb.load`. This attribute will be ignored when using `rb.log`. - + search_keywords: + READ ONLY! Relevant record keywords/terms for provided query when using `rb.load`. + This attribute will be ignored when using `rb.log`. Examples: >>> import rubrix as rb >>> record = rb.TextClassificationRecord( @@ -172,6 +174,7 @@ class TextClassificationRecord(_Validators): event_timestamp: Optional[datetime.datetime] = None metrics: Optional[Dict[str, Any]] = None + search_keywords: Optional[List[str]] = None @validator("inputs", pre=True) def input_as_dict(cls, inputs): @@ -213,7 +216,9 @@ class TokenClassificationRecord(_Validators): metrics: READ ONLY! Metrics at record level provided by the server when using `rb.load`. This attribute will be ignored when using `rb.log`. - + search_keywords: + READ ONLY! Relevant record keywords/terms for provided query when using `rb.load`. + This attribute will be ignored when using `rb.log`. Examples: >>> import rubrix as rb >>> record = rb.TokenClassificationRecord( @@ -239,6 +244,7 @@ class TokenClassificationRecord(_Validators): event_timestamp: Optional[datetime.datetime] = None metrics: Optional[Dict[str, Any]] = None + search_keywords: Optional[List[str]] = None @validator("prediction") def add_default_score( @@ -283,6 +289,9 @@ class Text2TextRecord(_Validators): metrics: READ ONLY! Metrics at record level provided by the server when using `rb.load`. This attribute will be ignored when using `rb.log`. + search_keywords: + READ ONLY! Relevant record keywords/terms for provided query when using `rb.load`. + This attribute will be ignored when using `rb.log`. Examples: >>> import rubrix as rb @@ -305,6 +314,7 @@ class Text2TextRecord(_Validators): event_timestamp: Optional[datetime.datetime] = None metrics: Optional[Dict[str, Any]] = None + search_keywords: Optional[List[str]] = None @validator("prediction") def prediction_as_tuples( diff --git a/src/rubrix/client/sdk/commons/models.py b/src/rubrix/client/sdk/commons/models.py index 918f5e901a..a5d3783537 100644 --- a/src/rubrix/client/sdk/commons/models.py +++ b/src/rubrix/client/sdk/commons/models.py @@ -46,6 +46,7 @@ class BaseRecord(GenericModel, Generic[T]): prediction: Optional[T] = None annotation: Optional[T] = None metrics: Dict[str, Any] = Field(default_factory=dict) + search_keywords: Optional[List[str]] = None # this is a small hack to get a json-compatible serialization on cls.dict(), which we use for the httpx calls. # they want to build this feature into pydantic, see https://github.com/samuelcolvin/pydantic/issues/1409 diff --git a/src/rubrix/client/sdk/text2text/models.py b/src/rubrix/client/sdk/text2text/models.py index 257a1d03b4..40b996edbf 100644 --- a/src/rubrix/client/sdk/text2text/models.py +++ b/src/rubrix/client/sdk/text2text/models.py @@ -94,6 +94,7 @@ def to_client(self) -> ClientText2TextRecord: id=self.id, event_timestamp=self.event_timestamp, metrics=self.metrics or None, + search_keywords=self.search_keywords or None, ) diff --git a/src/rubrix/client/sdk/text_classification/models.py b/src/rubrix/client/sdk/text_classification/models.py index eb551a67d5..8c8a25deb7 100644 --- a/src/rubrix/client/sdk/text_classification/models.py +++ b/src/rubrix/client/sdk/text_classification/models.py @@ -129,6 +129,7 @@ def to_client(self) -> ClientTextClassificationRecord: if self.explanation else None, metrics=self.metrics or None, + search_keywords=self.search_keywords or None, ) diff --git a/src/rubrix/client/sdk/token_classification/models.py b/src/rubrix/client/sdk/token_classification/models.py index 059da0ad54..c0564c8bea 100644 --- a/src/rubrix/client/sdk/token_classification/models.py +++ b/src/rubrix/client/sdk/token_classification/models.py @@ -117,7 +117,8 @@ def to_client(self) -> ClientTokenClassificationRecord: event_timestamp=self.event_timestamp, status=self.status, metadata=self.metadata or {}, - metrics=self.metrics or {}, + metrics=self.metrics or None, + search_keywords=self.search_keywords or None, ) diff --git a/src/rubrix/server/tasks/commons/api/model.py b/src/rubrix/server/tasks/commons/api/model.py index 573c074a30..3c0070448e 100644 --- a/src/rubrix/server/tasks/commons/api/model.py +++ b/src/rubrix/server/tasks/commons/api/model.py @@ -179,6 +179,13 @@ class BaseRecord(GenericModel, Generic[Annotation]): prediction: Optional[Annotation] = None annotation: Optional[Annotation] = None metrics: Dict[str, Any] = Field(default_factory=dict) + search_keywords: Optional[List[str]] = None + + @validator("search_keywords") + def remove_duplicated_keywords(cls, value) -> List[str]: + """Remove duplicated keywords""" + if value: + return list(set(value)) @validator("id", always=True) def default_id_if_none_provided(cls, id: Optional[str]) -> str: diff --git a/src/rubrix/server/tasks/commons/dao/dao.py b/src/rubrix/server/tasks/commons/dao/dao.py index a3b00fc1f1..d05af71eae 100644 --- a/src/rubrix/server/tasks/commons/dao/dao.py +++ b/src/rubrix/server/tasks/commons/dao/dao.py @@ -15,6 +15,7 @@ import dataclasses import datetime +import re from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar import deprecated @@ -129,6 +130,12 @@ class DatasetRecordsDAO: # This info must be provided by each task using dao.register_task_mappings method _MAPPINGS_BY_TASKS = {} + __HIGHLIGHT_PRE_TAG__ = "<@@-rb-key>" + __HIGHLIGHT_POST_TAG__ = "" + __HIGHLIGHT_VALUES_REGEX__ = re.compile( + rf"{__HIGHLIGHT_PRE_TAG__}(.+?){__HIGHLIGHT_POST_TAG__}" + ) + @classmethod def get_instance( cls, @@ -158,7 +165,7 @@ def init(self): def add_records( self, dataset: BaseDatasetDB, - records: List[BaseRecord], + records: List[DBRecord], record_class: Type[DBRecord], ) -> int: """ @@ -190,7 +197,9 @@ def add_records( db_record = record_class.parse_obj(r) if now: db_record.last_updated = now - documents.append(db_record.dict(exclude_none=False)) + documents.append( + db_record.dict(exclude_none=False, exclude={"search_keywords"}) + ) index_name = self.create_dataset_index(dataset) self._configure_metadata_fields(index_name, metadata_values) @@ -246,6 +255,7 @@ def search_records( "query": search.query or {"match_all": {}}, "sort": search.sort or [{"_id": {"order": "asc"}}], "aggs": aggregation_requests, + "highlight": self.__configure_query_highlight__(), } try: @@ -282,7 +292,7 @@ def search_records( result = RecordSearchResults( total=total, - records=list(map(self.esdoc2record, docs)), + records=list(map(self.__esdoc2record__, docs)), ) if search_aggregations: parsed_aggregations = parse_aggregations(search_aggregations) @@ -319,15 +329,34 @@ def scan_dataset( search = search or RecordSearch() es_query = { "query": search.query, + "highlight": self.__configure_query_highlight__(), } docs = self._es.list_documents( dataset_records_index(dataset.id), query=es_query ) for doc in docs: - yield self.esdoc2record(doc) + yield self.__esdoc2record__(doc) + + def __esdoc2record__(self, doc: Dict[str, Any]): + return { + **doc["_source"], + "id": doc["_id"], + "search_keywords": self.__parse_highlight_results__(doc), + } - def esdoc2record(self, doc): - return {**doc["_source"], "id": doc["_id"]} + @classmethod + def __parse_highlight_results__(cls, doc: Dict[str, Any]) -> Optional[List[str]]: + highlight_info = doc.get("highlight") + if not highlight_info: + return None + + search_keywords = [] + for content in highlight_info.values(): + if not isinstance(content, list): + content = [content] + for text in content: + search_keywords.extend(re.findall(cls.__HIGHLIGHT_VALUES_REGEX__, text)) + return list(set(search_keywords)) def _configure_metadata_fields(self, index: str, metadata_values: Dict[str, Any]): def check_metadata_length(metadata_length: int = 0): @@ -406,6 +435,15 @@ def get_dataset_schema(self, dataset: BaseDatasetDB) -> Dict[str, Any]: index_name = dataset_records_index(dataset.id) return self._es.__client__.indices.get_mapping(index=index_name) + @classmethod + def __configure_query_highlight__(cls): + return { + "pre_tags": [cls.__HIGHLIGHT_PRE_TAG__], + "post_tags": [cls.__HIGHLIGHT_POST_TAG__], + "require_field_match": False, + "fields": {"text": {}}, + } + _instance: Optional[DatasetRecordsDAO] = None diff --git a/tests/functional_tests/test_log_for_text_classification.py b/tests/functional_tests/test_log_for_text_classification.py index ddb42f5440..f09a57c40b 100644 --- a/tests/functional_tests/test_log_for_text_classification.py +++ b/tests/functional_tests/test_log_for_text_classification.py @@ -50,6 +50,31 @@ def test_delete_and_create_for_different_task(mocked_client): rubrix.load(dataset) +def test_search_keywords(mocked_client): + dataset = "test_search_keywords" + from datasets import load_dataset + + dataset_ds = load_dataset("Recognai/sentiment-banking", split="train") + dataset_rb = rubrix.read_datasets(dataset_ds, task="TextClassification") + + rubrix.delete(dataset) + rubrix.log(name=dataset, records=dataset_rb) + + df = rubrix.load(dataset, query="lim*") + assert not df.empty + assert "search_keywords" in df.columns + top_keywords = set( + [ + keyword + for keywords in df.search_keywords.value_counts(sort=True, ascending=False) + .index[:3] + .tolist() + for keyword in keywords + ] + ) + assert {"limit", "limits", "limited"} == top_keywords, top_keywords + + def test_log_records_with_empty_metadata_list(mocked_client): dataset = "test_log_records_with_empty_metadata_list" diff --git a/tests/functional_tests/test_log_for_token_classification.py b/tests/functional_tests/test_log_for_token_classification.py index c41db03f65..bf095df314 100644 --- a/tests/functional_tests/test_log_for_token_classification.py +++ b/tests/functional_tests/test_log_for_token_classification.py @@ -446,3 +446,28 @@ def test_log_record_that_makes_me_cry(mocked_client): }, "annotated": {"mentions": []}, } + + +def test_search_keywords(mocked_client): + dataset = "test_search_keywords" + from datasets import load_dataset + + dataset_ds = load_dataset("rubrix/gutenberg_spacy-ner", split="train") + dataset_rb = rubrix.read_datasets(dataset_ds, task="TokenClassification") + + rubrix.delete(dataset) + rubrix.log(name=dataset, records=dataset_rb) + + df = rubrix.load(dataset, query="lis*") + assert not df.empty + assert "search_keywords" in df.columns + top_keywords = set( + [ + keyword + for keywords in df.search_keywords.value_counts(sort=True, ascending=False) + .index[:3] + .tolist() + for keyword in keywords + ] + ) + assert {"listened", "listen"} == top_keywords, top_keywords