Skip to content

Commit

Permalink
refactor: move all elasticsearch metrics to the elasticsearch layer
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Jul 29, 2022
1 parent 83a5965 commit 8bfcad7
Show file tree
Hide file tree
Showing 23 changed files with 445 additions and 1,091 deletions.
14 changes: 9 additions & 5 deletions src/rubrix/server/apis/v0/config/tasks_factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Any, Dict, List, Optional, Set, Type
from typing import Any, Dict, List, Optional, Set, Type, Union

from pydantic import BaseModel

from rubrix.server.apis.v0.models.commons.model import BaseRecord, TaskType
from rubrix.server.apis.v0.models.datasets import DatasetDB
from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics
from rubrix.server.apis.v0.models.metrics.base import (
BaseTaskMetrics,
Metric,
PythonMetric,
)
from rubrix.server.apis.v0.models.metrics.text_classification import (
TextClassificationMetrics,
)
Expand Down Expand Up @@ -106,16 +110,16 @@ def __get_task_config__(cls, task):
return config

@classmethod
def find_task_metric(cls, task: TaskType, metric_id: str) -> Optional[BaseMetric]:
def find_task_metric(cls, task: TaskType, metric_id: str) -> Optional[Metric]:
metrics = cls.find_task_metrics(task, {metric_id})
if metrics:
return metrics[0]
raise EntityNotFoundError(name=metric_id, type=BaseMetric)
raise EntityNotFoundError(name=metric_id, type=Metric)

@classmethod
def find_task_metrics(
cls, task: TaskType, metric_ids: Set[str]
) -> List[BaseMetric]:
) -> List[Union[Metric]]:

if not metric_ids:
return []
Expand Down
1 change: 0 additions & 1 deletion src/rubrix/server/apis/v0/handlers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def get_dataset_metrics(
teams_query: CommonTaskQueryParams = Depends(),
current_user: User = Security(auth.get_user, scopes=[]),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
metrics: MetricsService = Depends(MetricsService.get_instance),
) -> List[MetricInfo]:
"""
List available metrics info for a given dataset
Expand Down
1 change: 1 addition & 0 deletions src/rubrix/server/apis/v0/models/commons/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class PaginationParams:
)


# TODO(@frascuchon): Move this shit to the server.commons.models module
class BaseRecord(BaseRecordDB, GenericModel, Generic[Annotation]):
"""
Minimal dataset record information
Expand Down
255 changes: 13 additions & 242 deletions src/rubrix/server/apis/v0/models/metrics/base.py
Original file line number Diff line number Diff line change
@@ -1,114 +1,18 @@
from typing import (
Any,
ClassVar,
Dict,
Generic,
Iterable,
List,
Optional,
TypeVar,
Union,
)
from typing import Any, ClassVar, Dict, Generic, List, Optional, Union

from pydantic import BaseModel, root_validator
from pydantic import BaseModel

from rubrix.server._helpers import unflatten_dict
from rubrix.server.apis.v0.models.commons.model import BaseRecord
from rubrix.server.apis.v0.models.datasets import Dataset
from rubrix.server.daos.records import DatasetRecordsDAO
from rubrix.server.elasticseach.query_helpers import aggregations
from rubrix.server.services.metrics import BaseMetric as _BaseMetric
from rubrix.server.services.metrics import GenericRecord
from rubrix.server.services.metrics import PythonMetric as _PythonMetric

GenericRecord = TypeVar("GenericRecord", bound=BaseRecord)

class Metric(_BaseMetric):
pass

class BaseMetric(BaseModel):
"""
Base model for rubrix dataset metrics summaries
"""

id: str
name: str
description: str = None


class PythonMetric(BaseMetric, Generic[GenericRecord]):
"""
A metric definition which will be calculated using raw queried data
"""

def apply(self, records: Iterable[GenericRecord]) -> Dict[str, Any]:
"""
Metric calculation method.
Parameters
----------
records:
The matched records
Returns
-------
The metric result
"""
raise NotImplementedError()


class ElasticsearchMetric(BaseMetric):
"""
A metric summarized by using one or several elasticsearch aggregations
"""

def aggregation_request(
self, *args, **kwargs
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""
Configures the summary es aggregation definition
"""
raise NotImplementedError()

def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
"""
Parse the es aggregation result. Override this method
for result customization
Parameters
----------
aggregation_result:
Retrieved es aggregation result
"""
return aggregation_result.get(self.id, aggregation_result)


class NestedPathElasticsearchMetric(ElasticsearchMetric):
"""
A ``ElasticsearchMetric`` which need nested fields for summary calculation.
Aggregations for nested fields need some extra configuration and this class
encapsulate these common logic.
Attributes:
-----------
nested_path:
The nested
"""

nested_path: str

def inner_aggregation(self, *args, **kwargs) -> Dict[str, Any]:
"""The specific aggregation definition"""
raise NotImplementedError()

def aggregation_request(self, *args, **kwargs) -> Dict[str, Any]:
"""Implements the common mechanism to define aggregations with nested fields"""
return {
self.id: aggregations.nested_aggregation(
nested_path=self.nested_path,
inner_aggregation=self.inner_aggregation(*args, **kwargs),
)
}

def compound_nested_field(self, inner_field: str) -> str:
return f"{self.nested_path}.{inner_field}"
class PythonMetric(Metric, _PythonMetric, Generic[GenericRecord]):
pass


class BaseTaskMetrics(BaseModel):
Expand All @@ -122,19 +26,10 @@ class BaseTaskMetrics(BaseModel):
A list of configured metrics for task
"""

metrics: ClassVar[List[BaseMetric]]

@classmethod
def configure_es_index(cls):
"""
If some metrics require specific es field mapping definitions,
include them here.
"""
pass
metrics: ClassVar[List[Union[PythonMetric, str]]]

@classmethod
def find_metric(cls, id: str) -> Optional[BaseMetric]:
def find_metric(cls, id: str) -> Optional[Union[PythonMetric, str]]:
"""
Finds a metric by id
Expand All @@ -149,6 +44,8 @@ def find_metric(cls, id: str) -> Optional[BaseMetric]:
"""
for metric in cls.metrics:
if isinstance(metric, str) and metric == id:
return metric
if metric.id == id:
return metric

Expand Down Expand Up @@ -176,129 +73,3 @@ def record_metrics(cls, record: GenericRecord) -> Dict[str, Any]:
A dict with calculated metrics fields
"""
return {}


class HistogramAggregation(ElasticsearchMetric):
"""
Base elasticsearch histogram aggregation metric
Attributes
----------
field:
The histogram field
script:
If provided, it will be used as scripted field
for aggregation
fixed_interval:
If provided, it will used ALWAYS as the histogram
aggregation interval
"""

field: str
script: Optional[Union[str, Dict[str, Any]]] = None
fixed_interval: Optional[float] = None

def aggregation_request(self, interval: Optional[float] = None) -> Dict[str, Any]:
if self.fixed_interval:
interval = self.fixed_interval
return {
self.id: aggregations.histogram_aggregation(
field_name=self.field, script=self.script, interval=interval
)
}


class TermsAggregation(ElasticsearchMetric):
"""
The base elasticsearch terms aggregation metric
Attributes
----------
field:
The term field
script:
If provided, it will be used as scripted field
for aggregation
fixed_size:
If provided, the size will use for terms aggregation
missing:
If provided, will use the value for docs results with missing value for field
"""

field: str = None
script: Union[str, Dict[str, Any]] = None
fixed_size: Optional[int] = None
missing: Optional[str] = None

def aggregation_request(self, size: int = None) -> Dict[str, Any]:
if self.fixed_size:
size = self.fixed_size
return {
self.id: aggregations.terms_aggregation(
self.field, script=self.script, size=size, missing=self.missing
)
}


class NestedTermsAggregation(NestedPathElasticsearchMetric):
terms: TermsAggregation

@root_validator
def normalize_terms_field(cls, values):
terms = values["terms"]
nested_path = values["nested_path"]
terms.field = f"{nested_path}.{terms.field}"

return values

def inner_aggregation(self, size: int) -> Dict[str, Any]:
return self.terms.aggregation_request(size)


class NestedHistogramAggregation(NestedPathElasticsearchMetric):
histogram: HistogramAggregation

@root_validator
def normalize_terms_field(cls, values):
histogram = values["histogram"]
nested_path = values["nested_path"]
histogram.field = f"{nested_path}.{histogram.field}"

return values

def inner_aggregation(self, interval: float) -> Dict[str, Any]:
return self.histogram.aggregation_request(interval)


class WordCloudAggregation(ElasticsearchMetric):
default_field: str

def aggregation_request(
self, text_field: str = None, size: int = None
) -> Dict[str, Any]:
field = text_field or self.default_field
return TermsAggregation(
id=f"{self.id}_{field}" if text_field else self.id,
name=f"Words cloud for field {field}",
field=field,
).aggregation_request(size=size)


class MetadataAggregations(ElasticsearchMetric):
def aggregation_request(
self,
dataset: Dataset,
dao: DatasetRecordsDAO,
size: int = None,
) -> List[Dict[str, Any]]:

metadata_aggs = aggregations.custom_fields(
fields_definitions=dao.get_metadata_schema(dataset), size=size
)
return [{key: value} for key, value in metadata_aggs.items()]

def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
data = unflatten_dict(aggregation_result, stop_keys=["metadata"])
return data.get("metadata", {})

0 comments on commit 8bfcad7

Please sign in to comment.