Skip to content

Commit

Permalink
feat: build aggregation results splitting in multiple requests
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Apr 11, 2022
1 parent 9a411a9 commit 27b0402
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 23 deletions.
9 changes: 6 additions & 3 deletions src/rubrix/server/tasks/commons/metrics/model/base.py
Expand Up @@ -58,7 +58,9 @@ class ElasticsearchMetric(BaseMetric):
A metric summarized by using one or several elasticsearch aggregations
"""

def aggregation_request(self, *args, **kwargs) -> Dict[str, Any]:
def aggregation_request(
self, *args, **kwargs
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""
Configures the summary es aggregation definition
"""
Expand Down Expand Up @@ -291,11 +293,12 @@ def aggregation_request(
dataset: Dataset,
dao: DatasetRecordsDAO,
size: int = None,
) -> Dict[str, Any]:
) -> List[Dict[str, Any]]:

return aggregations.custom_fields(
metadata_aggs = aggregations.custom_fields(
fields_definitions=dao.get_metadata_schema(dataset), size=size
)
return [{key: value} for key, value in metadata_aggs.items()]

def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
data = unflatten_dict(aggregation_result, stop_keys=["metadata"])
Expand Down
81 changes: 61 additions & 20 deletions src/rubrix/server/tasks/commons/metrics/service.py
@@ -1,10 +1,10 @@
from typing import Any, Dict, List, Optional, TypeVar
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union

from fastapi import Depends

from rubrix.server.commons.errors import EntityNotFoundError, WrongInputParamError
from rubrix.server.datasets.model import BaseDatasetDB
from rubrix.server.tasks.commons import BaseRecord, TaskType
from rubrix.server.tasks.commons import TaskType
from rubrix.server.tasks.commons.dao.dao import DatasetRecordsDAO, dataset_records_dao
from rubrix.server.tasks.commons.dao.model import RecordSearch
from rubrix.server.tasks.commons.metrics.model.base import (
Expand Down Expand Up @@ -136,7 +136,6 @@ def _handle_elasticsearch_metric(
query: GenericQuery,
) -> Dict[str, Any]:
"""
Parameters
----------
metric:
Expand All @@ -153,23 +152,65 @@ def _handle_elasticsearch_metric(
The metric summary result
"""
metric_params = self._filter_metric_params(
metric, {**metric_params, "dataset": dataset, "dao": self.__dao__}
params = self.__compute_metric_params__(
dataset=dataset, metric=metric, query=query, provided_params=metric_params
)
metric_aggregation = metric.aggregation_request(**metric_params)
results = self.__dao__.search_records(
dataset,
size=0, # No records at all
search=RecordSearch(
query=self.__query_builder__(dataset, query=query) if query else None,
aggregations=metric_aggregation,
include_default_aggregations=False,
),
results = self.__metric_results__(
dataset=dataset,
query=query,
metric_aggregation=metric.aggregation_request(**params),
)
return metric.aggregation_result(
results.aggregations.get(metric.id, results.aggregations)
aggregation_result=results.get(metric.id, results)
)

def __compute_metric_params__(
self,
dataset: BaseDatasetDB,
metric: ElasticsearchMetric,
query: Optional[GenericQuery],
provided_params: Dict[str, Any],
) -> Dict[str, Any]:

return self._filter_metric_params(
metric=metric,
function=metric.aggregation_request,
metric_params={
**provided_params, # in case of param conflict, provided metric params will be preserved
"dataset": dataset,
"dao": self.__dao__,
},
)

def __metric_results__(
self,
dataset: BaseDatasetDB,
query: Optional[GenericQuery],
metric_aggregation: Union[Dict[str, Any], List[Dict[str, Any]]],
) -> Dict[str, Any]:

if not metric_aggregation:
return {}

if not isinstance(metric_aggregation, list):
metric_aggregation = [metric_aggregation]

results = {}
for agg in metric_aggregation:
results_ = self.__dao__.search_records(
dataset,
size=0, # No records at all
search=RecordSearch(
query=self.__query_builder__(dataset, query=query)
if query
else None,
aggregations=agg,
include_default_aggregations=False,
),
)
results.update(results_.aggregations)
return results

@staticmethod
def get_dataset_metrics(dataset: BaseDatasetDB) -> List[BaseMetric]:
"""
Expand All @@ -185,22 +226,22 @@ def get_dataset_metrics(dataset: BaseDatasetDB) -> List[BaseMetric]:

@staticmethod
def _filter_metric_params(
_metric: ElasticsearchMetric, metric_params: Dict[str, Any]
metric: ElasticsearchMetric, function: Callable, metric_params: Dict[str, Any]
):
"""
Select from provided metric parameter those who can be applied to given metric
Parameters
----------
_metric:
metric:
The target metric
metric_params:
A dict of metric parameters
"""
function = _metric.aggregation_request
if isinstance(_metric, NestedPathElasticsearchMetric):
function = _metric.inner_aggregation

if isinstance(metric, NestedPathElasticsearchMetric):
function = metric.inner_aggregation

return {
argument: metric_params[argument]
Expand Down

0 comments on commit 27b0402

Please sign in to comment.