Skip to content

Commit

Permalink
fix(search): compute dataset schema properly for advanced query dsl (#…
Browse files Browse the repository at this point in the history
…1380)

* fix(search): compute dataset schema properly

* test: prevent loguru not installed error

* test: fix huggingface_hub version
  • Loading branch information
frascuchon committed Apr 6, 2022
1 parent 1b03ebb commit 670ab7d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 2 deletions.
2 changes: 2 additions & 0 deletions environment_dev.yml
Expand Up @@ -32,6 +32,7 @@ dependencies:
# extra test dependencies
- cleanlab
- datasets>1.17.0
- huggingface_hub==0.4.0 # resolve problems with 0.5.0
- https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.1.0/en_core_web_sm-3.1.0.tar.gz
- flair==0.10
- flyingsquid
Expand All @@ -40,5 +41,6 @@ dependencies:
- snorkel>=0.9.7
- spacy==3.1.0
- transformers[torch]
- loguru
# install Rubrix in editable mode
- -e .[server]
10 changes: 8 additions & 2 deletions src/rubrix/server/tasks/commons/dao/dao.py
Expand Up @@ -33,7 +33,8 @@
from rubrix.server.commons.helpers import unflatten_dict
from rubrix.server.commons.settings import settings
from rubrix.server.datasets.model import BaseDatasetDB
from rubrix.server.tasks.commons import BaseRecord, MetadataLimitExceededError, TaskType
from rubrix.server.tasks.commons import BaseRecord, TaskType
from rubrix.server.tasks.commons.api.errors import MetadataLimitExceededError
from rubrix.server.tasks.commons.dao.es_config import (
mappings,
tasks_common_mappings,
Expand Down Expand Up @@ -426,7 +427,12 @@ def create_dataset_index(
def get_dataset_schema(self, dataset: BaseDatasetDB) -> Dict[str, Any]:
"""Return inner elasticsearch index configuration"""
index_name = dataset_records_index(dataset.id)
return self._es.__client__.indices.get_mapping(index=index_name)
response = self._es.__client__.indices.get_mapping(index=index_name)

if index_name in response:
response = response.get(index_name)

return response

@classmethod
def __configure_query_highlight__(cls, task: TaskType):
Expand Down
57 changes: 57 additions & 0 deletions tests/functional_tests/search/test_search_service.py
Expand Up @@ -13,6 +13,7 @@
TextClassificationQuery,
TextClassificationRecord,
)
from rubrix.server.tasks.token_classification import TokenClassificationQuery


@pytest.fixture
Expand Down Expand Up @@ -61,6 +62,62 @@ def test_query_builder_with_query_range(query_builder):
}


def test_query_builder_with_nested(query_builder, mocked_client):
dataset = Dataset(
name="test_query_builder_with_nested",
owner=rubrix.get_workspace(),
task=TaskType.token_classification,
)
rubrix.delete(dataset.name)
rubrix.log(
name=dataset.name,
records=rubrix.TokenClassificationRecord(
text="Michael is a professor at Harvard",
tokens=["Michael", "is", "a", "professor", "at", "Harvard"],
prediction=[("NAME", 0, 7, 0.9), ("LOC", 26, 33, 0.12)],
),
)

es_query = query_builder(
dataset=dataset,
query=TokenClassificationQuery(
advanced_query_dsl=True,
query_text="metrics.predicted.mentions:(label:NAME AND score:[* TO 0.1])",
),
)

assert es_query == {
"bool": {
"filter": {"bool": {"must": {"match_all": {}}}},
"must": {
"nested": {
"path": "metrics.predicted.mentions",
"query": {
"bool": {
"must": [
{
"term": {
"metrics.predicted.mentions.label": {
"value": "NAME"
}
}
},
{
"range": {
"metrics.predicted.mentions.score": {
"lte": "0.1"
}
}
},
]
}
},
}
},
}
}


def test_failing_metrics(service, mocked_client):

dataset = Dataset(
Expand Down

0 comments on commit 670ab7d

Please sign in to comment.