From 27b04020b9d18e82799d74d1041844220031f112 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 11 Apr 2022 12:05:20 +0200 Subject: [PATCH] feat: build aggregation results splitting in multiple requests --- .../tasks/commons/metrics/model/base.py | 9 ++- .../server/tasks/commons/metrics/service.py | 81 ++++++++++++++----- 2 files changed, 67 insertions(+), 23 deletions(-) diff --git a/src/rubrix/server/tasks/commons/metrics/model/base.py b/src/rubrix/server/tasks/commons/metrics/model/base.py index ed92f61ed4..2a8c533166 100644 --- a/src/rubrix/server/tasks/commons/metrics/model/base.py +++ b/src/rubrix/server/tasks/commons/metrics/model/base.py @@ -58,7 +58,9 @@ class ElasticsearchMetric(BaseMetric): A metric summarized by using one or several elasticsearch aggregations """ - def aggregation_request(self, *args, **kwargs) -> Dict[str, Any]: + def aggregation_request( + self, *args, **kwargs + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: """ Configures the summary es aggregation definition """ @@ -291,11 +293,12 @@ def aggregation_request( dataset: Dataset, dao: DatasetRecordsDAO, size: int = None, - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: - return aggregations.custom_fields( + metadata_aggs = aggregations.custom_fields( fields_definitions=dao.get_metadata_schema(dataset), size=size ) + return [{key: value} for key, value in metadata_aggs.items()] def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: data = unflatten_dict(aggregation_result, stop_keys=["metadata"]) diff --git a/src/rubrix/server/tasks/commons/metrics/service.py b/src/rubrix/server/tasks/commons/metrics/service.py index 1cfaa87ed4..bb93a83a43 100644 --- a/src/rubrix/server/tasks/commons/metrics/service.py +++ b/src/rubrix/server/tasks/commons/metrics/service.py @@ -1,10 +1,10 @@ -from typing import Any, Dict, List, Optional, TypeVar +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union from fastapi import Depends from rubrix.server.commons.errors import EntityNotFoundError, WrongInputParamError from rubrix.server.datasets.model import BaseDatasetDB -from rubrix.server.tasks.commons import BaseRecord, TaskType +from rubrix.server.tasks.commons import TaskType from rubrix.server.tasks.commons.dao.dao import DatasetRecordsDAO, dataset_records_dao from rubrix.server.tasks.commons.dao.model import RecordSearch from rubrix.server.tasks.commons.metrics.model.base import ( @@ -136,7 +136,6 @@ def _handle_elasticsearch_metric( query: GenericQuery, ) -> Dict[str, Any]: """ - Parameters ---------- metric: @@ -153,23 +152,65 @@ def _handle_elasticsearch_metric( The metric summary result """ - metric_params = self._filter_metric_params( - metric, {**metric_params, "dataset": dataset, "dao": self.__dao__} + params = self.__compute_metric_params__( + dataset=dataset, metric=metric, query=query, provided_params=metric_params ) - metric_aggregation = metric.aggregation_request(**metric_params) - results = self.__dao__.search_records( - dataset, - size=0, # No records at all - search=RecordSearch( - query=self.__query_builder__(dataset, query=query) if query else None, - aggregations=metric_aggregation, - include_default_aggregations=False, - ), + results = self.__metric_results__( + dataset=dataset, + query=query, + metric_aggregation=metric.aggregation_request(**params), ) return metric.aggregation_result( - results.aggregations.get(metric.id, results.aggregations) + aggregation_result=results.get(metric.id, results) ) + def __compute_metric_params__( + self, + dataset: BaseDatasetDB, + metric: ElasticsearchMetric, + query: Optional[GenericQuery], + provided_params: Dict[str, Any], + ) -> Dict[str, Any]: + + return self._filter_metric_params( + metric=metric, + function=metric.aggregation_request, + metric_params={ + **provided_params, # in case of param conflict, provided metric params will be preserved + "dataset": dataset, + "dao": self.__dao__, + }, + ) + + def __metric_results__( + self, + dataset: BaseDatasetDB, + query: Optional[GenericQuery], + metric_aggregation: Union[Dict[str, Any], List[Dict[str, Any]]], + ) -> Dict[str, Any]: + + if not metric_aggregation: + return {} + + if not isinstance(metric_aggregation, list): + metric_aggregation = [metric_aggregation] + + results = {} + for agg in metric_aggregation: + results_ = self.__dao__.search_records( + dataset, + size=0, # No records at all + search=RecordSearch( + query=self.__query_builder__(dataset, query=query) + if query + else None, + aggregations=agg, + include_default_aggregations=False, + ), + ) + results.update(results_.aggregations) + return results + @staticmethod def get_dataset_metrics(dataset: BaseDatasetDB) -> List[BaseMetric]: """ @@ -185,22 +226,22 @@ def get_dataset_metrics(dataset: BaseDatasetDB) -> List[BaseMetric]: @staticmethod def _filter_metric_params( - _metric: ElasticsearchMetric, metric_params: Dict[str, Any] + metric: ElasticsearchMetric, function: Callable, metric_params: Dict[str, Any] ): """ Select from provided metric parameter those who can be applied to given metric Parameters ---------- - _metric: + metric: The target metric metric_params: A dict of metric parameters """ - function = _metric.aggregation_request - if isinstance(_metric, NestedPathElasticsearchMetric): - function = _metric.inner_aggregation + + if isinstance(metric, NestedPathElasticsearchMetric): + function = metric.inner_aggregation return { argument: metric_params[argument]