Skip to content

Commit

Permalink
fix(search): prevent metrics computation breaks searches (#1175)
Browse files Browse the repository at this point in the history
* fix(search): handler metrics computation errors

* chore: printable rubrix server exceptions

* test: include missing test
  • Loading branch information
frascuchon committed Feb 17, 2022
1 parent e8b52ac commit 9f2adc9
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 7 deletions.
5 changes: 5 additions & 0 deletions src/rubrix/server/commons/errors/base_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def arguments(self):
else None
)

def __str__(self):
args = self.arguments or {}
printable_args = ",".join([f"{k}={v}" for k, v in args.items()])
return f"{self.code}({printable_args})"


class ValidationError(RubrixServerError):
"""Generic data validation error out of request"""
Expand Down
23 changes: 16 additions & 7 deletions src/rubrix/server/tasks/search/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, Set, Type
import logging
from typing import Iterable, List, Optional, Set, Type

from fastapi import Depends

Expand All @@ -18,6 +19,8 @@ def __init__(self, dao: DatasetRecordsDAO, metrics: MetricsService):
self.__dao__ = dao
self.__metrics__ = metrics

__LOGGER__ = logging.getLogger(__name__)

@classmethod
def get_instance(
cls,
Expand Down Expand Up @@ -68,12 +71,18 @@ def search(
record_from=record_from,
exclude_fields=exclude_fields,
)
metrics_results = {
metric: self.__metrics__.summarize_metric(
dataset=dataset, metric=metric, query=query
)
for metric in metrics or []
}
metrics_results = {}
for metric_id in metrics or []:
try:
metrics = self.__metrics__.summarize_metric(
dataset=dataset, metric=metric_id, query=query
)
metrics_results[metric_id] = metrics
except Exception as ex:
self.__LOGGER__.warning(
"Cannot compute metric [%s]. Error: %s", metric_id, ex
)
metrics_results[metric_id] = {}

return SearchResults(
total=results.total,
Expand Down
Empty file added tests/server/search/__init__.py
Empty file.
60 changes: 60 additions & 0 deletions tests/server/search/test_search_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pytest

import rubrix
from rubrix.server.commons.es_wrapper import ElasticsearchWrapper
from rubrix.server.datasets.model import Dataset
from rubrix.server.tasks.commons import TaskType
from rubrix.server.tasks.commons.dao.dao import DatasetRecordsDAO
from rubrix.server.tasks.commons.metrics.service import MetricsService
from rubrix.server.tasks.search.model import SortConfig
from rubrix.server.tasks.search.service import SearchRecordsService
from rubrix.server.tasks.text_classification import (
TextClassificationQuery,
TextClassificationRecord,
)
from tests.server.test_helpers import client, mocking_client


@pytest.fixture
def es_wrapper():
return ElasticsearchWrapper.get_instance()


@pytest.fixture
def dao(es_wrapper: ElasticsearchWrapper):
return DatasetRecordsDAO.get_instance(es=es_wrapper)


@pytest.fixture
def metrics(dao: DatasetRecordsDAO):
return MetricsService.get_instance(dao=dao)


@pytest.fixture
def service(dao: DatasetRecordsDAO, metrics: MetricsService):
return SearchRecordsService.get_instance(dao=dao, metrics=metrics)


def test_failing_metrics(service, monkeypatch):
dataset = Dataset(name="test_failing_metrics", task=TaskType.text_classification)
mocking_client(monkeypatch, client)

rubrix.delete(dataset.name)
rubrix.log(
rubrix.TextClassificationRecord(inputs="This is a text, yeah!"),
name=dataset.name,
)
results = service.search(
dataset=dataset,
query=TextClassificationQuery(),
sort_config=SortConfig(),
metrics=["missing-metric"],
size=0,
record_type=TextClassificationRecord,
)

assert results.dict() == {
"metrics": {"missing-metric": {}},
"records": [],
"total": 1,
}

0 comments on commit 9f2adc9

Please sign in to comment.