From 8bfcad74198c1b5398a167343ccc43d65c2c4007 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 29 Jul 2022 11:45:50 +0200 Subject: [PATCH] refactor: move all elasticsearch metrics to the elasticsearch layer --- .../server/apis/v0/config/tasks_factory.py | 14 +- src/rubrix/server/apis/v0/handlers/metrics.py | 1 - .../server/apis/v0/models/commons/model.py | 1 + .../server/apis/v0/models/metrics/base.py | 255 +--------- .../server/apis/v0/models/metrics/commons.py | 39 +- .../v0/models/metrics/text_classification.py | 59 ++- .../v0/models/metrics/token_classification.py | 480 +++++------------- src/rubrix/server/daos/datasets.py | 13 +- src/rubrix/server/daos/records.py | 56 +- src/rubrix/server/elasticseach/backend.py | 72 ++- .../server/elasticseach/query_helpers.py | 31 +- .../server/elasticseach/search/model.py | 13 + .../elasticseach/search/query_builder.py | 1 - src/rubrix/server/services/metrics.py | 201 -------- src/rubrix/server/services/metrics/service.py | 92 ++++ src/rubrix/server/services/search/model.py | 12 +- src/rubrix/server/services/search/service.py | 3 +- src/rubrix/server/services/text2text.py | 3 +- .../server/services/text_classification.py | 3 +- .../text_classification_labelling_rules.py | 176 +------ .../server/services/token_classification.py | 3 +- .../search/test_search_service.py | 4 +- tests/server/metrics/test_api.py | 4 +- 23 files changed, 445 insertions(+), 1091 deletions(-) delete mode 100644 src/rubrix/server/services/metrics.py create mode 100644 src/rubrix/server/services/metrics/service.py diff --git a/src/rubrix/server/apis/v0/config/tasks_factory.py b/src/rubrix/server/apis/v0/config/tasks_factory.py index 5ad982bf4f..c575a251b7 100644 --- a/src/rubrix/server/apis/v0/config/tasks_factory.py +++ b/src/rubrix/server/apis/v0/config/tasks_factory.py @@ -1,10 +1,14 @@ -from typing import Any, Dict, List, Optional, Set, Type +from typing import Any, Dict, List, Optional, Set, Type, Union from pydantic import BaseModel from rubrix.server.apis.v0.models.commons.model import BaseRecord, TaskType from rubrix.server.apis.v0.models.datasets import DatasetDB -from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics +from rubrix.server.apis.v0.models.metrics.base import ( + BaseTaskMetrics, + Metric, + PythonMetric, +) from rubrix.server.apis.v0.models.metrics.text_classification import ( TextClassificationMetrics, ) @@ -106,16 +110,16 @@ def __get_task_config__(cls, task): return config @classmethod - def find_task_metric(cls, task: TaskType, metric_id: str) -> Optional[BaseMetric]: + def find_task_metric(cls, task: TaskType, metric_id: str) -> Optional[Metric]: metrics = cls.find_task_metrics(task, {metric_id}) if metrics: return metrics[0] - raise EntityNotFoundError(name=metric_id, type=BaseMetric) + raise EntityNotFoundError(name=metric_id, type=Metric) @classmethod def find_task_metrics( cls, task: TaskType, metric_ids: Set[str] - ) -> List[BaseMetric]: + ) -> List[Union[Metric]]: if not metric_ids: return [] diff --git a/src/rubrix/server/apis/v0/handlers/metrics.py b/src/rubrix/server/apis/v0/handlers/metrics.py index 8f4040add8..5064ae67f5 100644 --- a/src/rubrix/server/apis/v0/handlers/metrics.py +++ b/src/rubrix/server/apis/v0/handlers/metrics.py @@ -94,7 +94,6 @@ def get_dataset_metrics( teams_query: CommonTaskQueryParams = Depends(), current_user: User = Security(auth.get_user, scopes=[]), datasets: DatasetsService = Depends(DatasetsService.get_instance), - metrics: MetricsService = Depends(MetricsService.get_instance), ) -> List[MetricInfo]: """ List available metrics info for a given dataset diff --git a/src/rubrix/server/apis/v0/models/commons/model.py b/src/rubrix/server/apis/v0/models/commons/model.py index e8443fc347..3703f3a7e2 100644 --- a/src/rubrix/server/apis/v0/models/commons/model.py +++ b/src/rubrix/server/apis/v0/models/commons/model.py @@ -55,6 +55,7 @@ class PaginationParams: ) +# TODO(@frascuchon): Move this shit to the server.commons.models module class BaseRecord(BaseRecordDB, GenericModel, Generic[Annotation]): """ Minimal dataset record information diff --git a/src/rubrix/server/apis/v0/models/metrics/base.py b/src/rubrix/server/apis/v0/models/metrics/base.py index 8bae2ab238..0e42dda5fa 100644 --- a/src/rubrix/server/apis/v0/models/metrics/base.py +++ b/src/rubrix/server/apis/v0/models/metrics/base.py @@ -1,114 +1,18 @@ -from typing import ( - Any, - ClassVar, - Dict, - Generic, - Iterable, - List, - Optional, - TypeVar, - Union, -) +from typing import Any, ClassVar, Dict, Generic, List, Optional, Union -from pydantic import BaseModel, root_validator +from pydantic import BaseModel -from rubrix.server._helpers import unflatten_dict -from rubrix.server.apis.v0.models.commons.model import BaseRecord -from rubrix.server.apis.v0.models.datasets import Dataset -from rubrix.server.daos.records import DatasetRecordsDAO -from rubrix.server.elasticseach.query_helpers import aggregations +from rubrix.server.services.metrics import BaseMetric as _BaseMetric +from rubrix.server.services.metrics import GenericRecord +from rubrix.server.services.metrics import PythonMetric as _PythonMetric -GenericRecord = TypeVar("GenericRecord", bound=BaseRecord) +class Metric(_BaseMetric): + pass -class BaseMetric(BaseModel): - """ - Base model for rubrix dataset metrics summaries - """ - - id: str - name: str - description: str = None - - -class PythonMetric(BaseMetric, Generic[GenericRecord]): - """ - A metric definition which will be calculated using raw queried data - """ - - def apply(self, records: Iterable[GenericRecord]) -> Dict[str, Any]: - """ - Metric calculation method. - - Parameters - ---------- - records: - The matched records - - Returns - ------- - The metric result - """ - raise NotImplementedError() - - -class ElasticsearchMetric(BaseMetric): - """ - A metric summarized by using one or several elasticsearch aggregations - """ - - def aggregation_request( - self, *args, **kwargs - ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: - """ - Configures the summary es aggregation definition - """ - raise NotImplementedError() - - def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: - """ - Parse the es aggregation result. Override this method - for result customization - - Parameters - ---------- - aggregation_result: - Retrieved es aggregation result - - """ - return aggregation_result.get(self.id, aggregation_result) - - -class NestedPathElasticsearchMetric(ElasticsearchMetric): - """ - A ``ElasticsearchMetric`` which need nested fields for summary calculation. - - Aggregations for nested fields need some extra configuration and this class - encapsulate these common logic. - - Attributes: - ----------- - nested_path: - The nested - """ - - nested_path: str - - def inner_aggregation(self, *args, **kwargs) -> Dict[str, Any]: - """The specific aggregation definition""" - raise NotImplementedError() - def aggregation_request(self, *args, **kwargs) -> Dict[str, Any]: - """Implements the common mechanism to define aggregations with nested fields""" - return { - self.id: aggregations.nested_aggregation( - nested_path=self.nested_path, - inner_aggregation=self.inner_aggregation(*args, **kwargs), - ) - } - - def compound_nested_field(self, inner_field: str) -> str: - return f"{self.nested_path}.{inner_field}" +class PythonMetric(Metric, _PythonMetric, Generic[GenericRecord]): + pass class BaseTaskMetrics(BaseModel): @@ -122,19 +26,10 @@ class BaseTaskMetrics(BaseModel): A list of configured metrics for task """ - metrics: ClassVar[List[BaseMetric]] - - @classmethod - def configure_es_index(cls): - """ - If some metrics require specific es field mapping definitions, - include them here. - - """ - pass + metrics: ClassVar[List[Union[PythonMetric, str]]] @classmethod - def find_metric(cls, id: str) -> Optional[BaseMetric]: + def find_metric(cls, id: str) -> Optional[Union[PythonMetric, str]]: """ Finds a metric by id @@ -149,6 +44,8 @@ def find_metric(cls, id: str) -> Optional[BaseMetric]: """ for metric in cls.metrics: + if isinstance(metric, str) and metric == id: + return metric if metric.id == id: return metric @@ -176,129 +73,3 @@ def record_metrics(cls, record: GenericRecord) -> Dict[str, Any]: A dict with calculated metrics fields """ return {} - - -class HistogramAggregation(ElasticsearchMetric): - """ - Base elasticsearch histogram aggregation metric - - Attributes - ---------- - field: - The histogram field - script: - If provided, it will be used as scripted field - for aggregation - fixed_interval: - If provided, it will used ALWAYS as the histogram - aggregation interval - """ - - field: str - script: Optional[Union[str, Dict[str, Any]]] = None - fixed_interval: Optional[float] = None - - def aggregation_request(self, interval: Optional[float] = None) -> Dict[str, Any]: - if self.fixed_interval: - interval = self.fixed_interval - return { - self.id: aggregations.histogram_aggregation( - field_name=self.field, script=self.script, interval=interval - ) - } - - -class TermsAggregation(ElasticsearchMetric): - """ - The base elasticsearch terms aggregation metric - - Attributes - ---------- - - field: - The term field - script: - If provided, it will be used as scripted field - for aggregation - fixed_size: - If provided, the size will use for terms aggregation - missing: - If provided, will use the value for docs results with missing value for field - - """ - - field: str = None - script: Union[str, Dict[str, Any]] = None - fixed_size: Optional[int] = None - missing: Optional[str] = None - - def aggregation_request(self, size: int = None) -> Dict[str, Any]: - if self.fixed_size: - size = self.fixed_size - return { - self.id: aggregations.terms_aggregation( - self.field, script=self.script, size=size, missing=self.missing - ) - } - - -class NestedTermsAggregation(NestedPathElasticsearchMetric): - terms: TermsAggregation - - @root_validator - def normalize_terms_field(cls, values): - terms = values["terms"] - nested_path = values["nested_path"] - terms.field = f"{nested_path}.{terms.field}" - - return values - - def inner_aggregation(self, size: int) -> Dict[str, Any]: - return self.terms.aggregation_request(size) - - -class NestedHistogramAggregation(NestedPathElasticsearchMetric): - histogram: HistogramAggregation - - @root_validator - def normalize_terms_field(cls, values): - histogram = values["histogram"] - nested_path = values["nested_path"] - histogram.field = f"{nested_path}.{histogram.field}" - - return values - - def inner_aggregation(self, interval: float) -> Dict[str, Any]: - return self.histogram.aggregation_request(interval) - - -class WordCloudAggregation(ElasticsearchMetric): - default_field: str - - def aggregation_request( - self, text_field: str = None, size: int = None - ) -> Dict[str, Any]: - field = text_field or self.default_field - return TermsAggregation( - id=f"{self.id}_{field}" if text_field else self.id, - name=f"Words cloud for field {field}", - field=field, - ).aggregation_request(size=size) - - -class MetadataAggregations(ElasticsearchMetric): - def aggregation_request( - self, - dataset: Dataset, - dao: DatasetRecordsDAO, - size: int = None, - ) -> List[Dict[str, Any]]: - - metadata_aggs = aggregations.custom_fields( - fields_definitions=dao.get_metadata_schema(dataset), size=size - ) - return [{key: value} for key, value in metadata_aggs.items()] - - def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: - data = unflatten_dict(aggregation_result, stop_keys=["metadata"]) - return data.get("metadata", {}) diff --git a/src/rubrix/server/apis/v0/models/metrics/commons.py b/src/rubrix/server/apis/v0/models/metrics/commons.py index fbe3b4474f..827d51b75e 100644 --- a/src/rubrix/server/apis/v0/models/metrics/commons.py +++ b/src/rubrix/server/apis/v0/models/metrics/commons.py @@ -1,16 +1,10 @@ from typing import Any, ClassVar, Dict, Generic, List -from rubrix.server.apis.v0.models.commons.model import EsRecordDataFieldNames from rubrix.server.apis.v0.models.metrics.base import ( - BaseMetric, BaseTaskMetrics, GenericRecord, - HistogramAggregation, - MetadataAggregations, - TermsAggregation, - WordCloudAggregation, + Metric, ) -from rubrix.server.commons.models import TaskStatus class CommonTasksMetrics(BaseTaskMetrics, Generic[GenericRecord]): @@ -21,53 +15,40 @@ def record_metrics(cls, record: GenericRecord) -> Dict[str, Any]: """Record metrics will persist the text_length""" return {"text_length": len(record.all_text())} - metrics: ClassVar[List[BaseMetric]] = [ - HistogramAggregation( + metrics: ClassVar[List[Metric]] = [ + Metric( id="text_length", name="Text length distribution", description="Computes the input text length distribution", - field="metrics.text_length", - script="params._source.text.length()", - fixed_interval=1, ), - TermsAggregation( + Metric( id="error_distribution", name="Error distribution", description="Computes the dataset error distribution. It's mean, records " "with correct predictions vs records with incorrect prediction " "vs records with unknown prediction result", - field=EsRecordDataFieldNames.predicted, - missing="unknown", - fixed_size=3, ), - TermsAggregation( + Metric( id="status_distribution", name="Record status distribution", description="The dataset record status distribution", - field=EsRecordDataFieldNames.status, - fixed_size=len(TaskStatus), ), - WordCloudAggregation( + Metric( id="words_cloud", name="Inputs words cloud", description="The words cloud for dataset inputs", - default_field="text.wordcloud", ), - MetadataAggregations(id="metadata", name="Metadata fields stats"), - TermsAggregation( + Metric(id="metadata", name="Metadata fields stats"), + Metric( id="predicted_by", name="Predicted by distribution", - field="predicted_by", ), - TermsAggregation( + Metric( id="annotated_by", name="Annotated by distribution", - field="annotated_by", ), - HistogramAggregation( + Metric( id="score", name="Score record distribution", - field="score", - fixed_interval=0.001, ), ] diff --git a/src/rubrix/server/apis/v0/models/metrics/text_classification.py b/src/rubrix/server/apis/v0/models/metrics/text_classification.py index 266fae6ce5..b15a90fa2a 100644 --- a/src/rubrix/server/apis/v0/models/metrics/text_classification.py +++ b/src/rubrix/server/apis/v0/models/metrics/text_classification.py @@ -1,14 +1,10 @@ -from typing import Any, ClassVar, Dict, Iterable, List, Optional, Set +from typing import Any, ClassVar, Dict, Iterable, List, Union from pydantic import Field from sklearn.metrics import precision_recall_fscore_support from sklearn.preprocessing import MultiLabelBinarizer -from rubrix.server.apis.v0.models.metrics.base import ( - BaseMetric, - PythonMetric, - TermsAggregation, -) +from rubrix.server.apis.v0.models.metrics.base import Metric, PythonMetric from rubrix.server.apis.v0.models.metrics.commons import CommonTasksMetrics from rubrix.server.apis.v0.models.text_classification import TextClassificationRecord @@ -120,27 +116,30 @@ def apply(self, records: Iterable[TextClassificationRecord]) -> Dict[str, Any]: class TextClassificationMetrics(CommonTasksMetrics[TextClassificationRecord]): """Configured metrics for text classification task""" - metrics: ClassVar[List[BaseMetric]] = CommonTasksMetrics.metrics + [ - TermsAggregation( - id="predicted_as", - name="Predicted labels distribution", - field="predicted_as", - ), - TermsAggregation( - id="annotated_as", - name="Annotated labels distribution", - field="annotated_as", - ), - F1Metric( - id="F1", - name="F1 Metrics for single-label", - description="F1 Metrics for single-label (averaged and per label)", - ), - F1Metric( - id="MultiLabelF1", - name="F1 Metrics for multi-label", - description="F1 Metrics for multi-label (averaged and per label)", - multi_label=True, - ), - DatasetLabels(), - ] + metrics: ClassVar[List[Union[PythonMetric, str]]] = ( + CommonTasksMetrics.metrics + + [ + F1Metric( + id="F1", + name="F1 Metrics for single-label", + description="F1 Metrics for single-label (averaged and per label)", + ), + F1Metric( + id="MultiLabelF1", + name="F1 Metrics for multi-label", + description="F1 Metrics for multi-label (averaged and per label)", + multi_label=True, + ), + DatasetLabels(), + ] + + [ + Metric( + id="predicted_as", + name="Predicted labels distribution", + ), + Metric( + id="annotated_as", + name="Annotated labels distribution", + ), + ] + ) diff --git a/src/rubrix/server/apis/v0/models/metrics/token_classification.py b/src/rubrix/server/apis/v0/models/metrics/token_classification.py index 80c9abdee6..6a728ea7d9 100644 --- a/src/rubrix/server/apis/v0/models/metrics/token_classification.py +++ b/src/rubrix/server/apis/v0/models/metrics/token_classification.py @@ -2,178 +2,12 @@ from pydantic import BaseModel, Field -from rubrix.server.apis.v0.models.metrics.base import ( - BaseMetric, - ElasticsearchMetric, - HistogramAggregation, - NestedHistogramAggregation, - NestedPathElasticsearchMetric, - NestedTermsAggregation, - PythonMetric, - TermsAggregation, -) -from rubrix.server.apis.v0.models.metrics.commons import CommonTasksMetrics +from rubrix.server.apis.v0.models.metrics.base import PythonMetric +from rubrix.server.apis.v0.models.metrics.commons import CommonTasksMetrics, Metric from rubrix.server.apis.v0.models.token_classification import ( EntitySpan, TokenClassificationRecord, ) -from rubrix.server.elasticseach.query_helpers import aggregations - - -class TokensLength(ElasticsearchMetric): - """ - Summarizes the tokens length metric into an histogram - - Attributes: - ----------- - length_field: - The elasticsearch field where tokens length is stored - """ - - length_field: str - - def aggregation_request(self, interval: int) -> Dict[str, Any]: - return { - self.id: aggregations.histogram_aggregation( - self.length_field, interval=interval or 1 - ) - } - - -_DEFAULT_MAX_ENTITY_BUCKET = 1000 - - -class EntityLabels(NestedPathElasticsearchMetric): - """ - Computes the entity labels distribution - - Attributes: - ----------- - labels_field: - The elasticsearch field where tags are stored - """ - - labels_field: str - - def inner_aggregation(self, size: int) -> Dict[str, Any]: - return { - "labels": aggregations.terms_aggregation( - self.compound_nested_field(self.labels_field), - size=size or _DEFAULT_MAX_ENTITY_BUCKET, - ) - } - - -class EntityDensity(NestedPathElasticsearchMetric): - """Summarizes the entity density metric into an histogram""" - - density_field: str - - def inner_aggregation(self, interval: float) -> Dict[str, Any]: - return { - "density": aggregations.histogram_aggregation( - field_name=self.compound_nested_field(self.density_field), - interval=interval or 0.01, - ) - } - - -class MentionLength(NestedPathElasticsearchMetric): - """Summarizes the mention length into an histogram""" - - length_field: str - - def inner_aggregation(self, interval: int) -> Dict[str, Any]: - return { - "mention_length": aggregations.histogram_aggregation( - self.compound_nested_field(self.length_field), interval=interval or 1 - ) - } - - -class EntityCapitalness(NestedPathElasticsearchMetric): - """Computes the mention capitalness distribution""" - - capitalness_field: str - - def inner_aggregation(self) -> Dict[str, Any]: - return { - "capitalness": aggregations.terms_aggregation( - self.compound_nested_field(self.capitalness_field), - size=4, # The number of capitalness choices - ) - } - - -class MentionsByEntityDistribution(NestedPathElasticsearchMetric): - def inner_aggregation(self): - return { - self.id: aggregations.bidimentional_terms_aggregations( - field_name_x=f"{self.nested_path}.label", - field_name_y=f"{self.nested_path}.value", - ) - } - - -class EntityConsistency(NestedPathElasticsearchMetric): - """Computes the entity consistency distribution""" - - mention_field: str - labels_field: str - - def inner_aggregation( - self, - size: int, - interval: int = 2, - entity_size: int = _DEFAULT_MAX_ENTITY_BUCKET, - ) -> Dict[str, Any]: - size = size or 50 - interval = int(max(interval or 2, 2)) - return { - "consistency": { - **aggregations.terms_aggregation( - self.compound_nested_field(self.mention_field), size=size - ), - "aggs": { - "entities": aggregations.terms_aggregation( - self.compound_nested_field(self.labels_field), size=entity_size - ), - "count": { - "cardinality": { - "field": self.compound_nested_field(self.labels_field) - } - }, - "entities_variability_filter": { - "bucket_selector": { - "buckets_path": {"numLabels": "count"}, - "script": f"params.numLabels >= {interval}", - } - }, - "sortby_entities_count": { - "bucket_sort": { - "sort": [{"count": {"order": "desc"}}], - "size": size, - } - }, - }, - } - } - - def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: - """Simplifies the aggregation result sorting by worst mention consistency""" - result = [ - { - "mention": mention, - "entities": [ - {"label": entity, "count": count} - for entity, count in mention_aggs["entities"].items() - ], - } - for mention, mention_aggs in aggregation_result.items() - ] - # TODO: filter by entities threshold - result.sort(key=lambda m: len(m["entities"]), reverse=True) - return {"mentions": result} class F1Metric(PythonMetric[TokenClassificationRecord]): @@ -340,17 +174,6 @@ class TokenMetrics(BaseModel): class TokenClassificationMetrics(CommonTasksMetrics[TokenClassificationRecord]): """Configured metrics for token classification""" - _PREDICTED_NAMESPACE = "metrics.predicted" - _ANNOTATED_NAMESPACE = "metrics.annotated" - - _PREDICTED_MENTIONS_NAMESPACE = f"{_PREDICTED_NAMESPACE}.mentions" - _ANNOTATED_MENTIONS_NAMESPACE = f"{_ANNOTATED_NAMESPACE}.mentions" - - _PREDICTED_TAGS_NAMESPACE = f"{_PREDICTED_NAMESPACE}.tags" - _ANNOTATED_TAGS_NAMESPACE = f"{_ANNOTATED_NAMESPACE}.tags" - - _TOKENS_NAMESPACE = "metrics.tokens" - @staticmethod def density(value: int, sentence_length: int) -> float: """Compute the string density over a sentence""" @@ -457,187 +280,124 @@ def record_metrics(cls, record: TokenClassificationRecord) -> Dict[str, Any]: }, } - _TOKENS_METRICS = [ - TokensLength( - id="tokens_length", - name="Tokens length", - description="Computes the text length distribution measured in number of tokens", - length_field="metrics.tokens_length", - ), - NestedTermsAggregation( - id="token_frequency", - name="Tokens frequency distribution", - nested_path=_TOKENS_NAMESPACE, - terms=TermsAggregation( - id="frequency", - field="value", - name="", + metrics: ClassVar[List[Metric]] = ( + CommonTasksMetrics.metrics + + [ + DatasetLabels(), + F1Metric( + id="F1", + name="F1 Metric based on entity-level", + description="F1 metrics based on entity-level (averaged and per label), " + "where only exact matches count (CoNNL 2003).", ), - ), - NestedHistogramAggregation( - id="token_length", - name="Token length distribution", - nested_path=_TOKENS_NAMESPACE, - description="Computes token length distribution in number of characters", - histogram=HistogramAggregation( - id="length", - field="length", - name="", - fixed_interval=1, + ] + + [ + Metric( + id="predicted_as", + name="Predicted labels distribution", + ), + Metric( + id="annotated_as", + name="Annotated labels distribution", + ), + Metric( + id="tokens_length", + name="Tokens length", + description="Computes the text length distribution measured in number of tokens", + ), + Metric( + id="token_frequency", + name="Tokens frequency distribution", + ), + Metric( + id="token_length", + name="Token length distribution", + description="Computes token length distribution in number of characters", + ), + Metric( + id="token_capitalness", + name="Token capitalness distribution", + description="Computes capitalization information of tokens", + ), + Metric( + id="predicted_entity_density", + name="Mention entity density for predictions", + description="Computes the ratio between the number of all entity tokens and tokens in the text", + ), + Metric( + id="predicted_entity_labels", + name="Predicted entity labels", + description="Predicted entity labels distribution", + ), + Metric( + id="predicted_entity_capitalness", + name="Mention entity capitalness for predictions", + description="Computes capitalization information of predicted entity mentions", + ), + Metric( + id="predicted_mention_token_length", + name="Predicted mention tokens length", + description="Computes the length of the predicted entity mention measured in number of tokens", + ), + Metric( + id="predicted_mention_char_length", + name="Predicted mention characters length", + description="Computes the length of the predicted entity mention measured in number of tokens", ), - ), - NestedTermsAggregation( - id="token_capitalness", - name="Token capitalness distribution", - description="Computes capitalization information of tokens", - nested_path=_TOKENS_NAMESPACE, - terms=TermsAggregation( - id="capitalness", - field="capitalness", - name="", - # missing="OTHER", + Metric( + id="predicted_mentions_distribution", + name="Predicted mentions distribution by entity", + description="Computes predicted mentions distribution against its labels", ), - ), - ] - _PREDICTED_METRICS = [ - EntityDensity( - id="predicted_entity_density", - name="Mention entity density for predictions", - description="Computes the ratio between the number of all entity tokens and tokens in the text", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - density_field="density", - ), - EntityLabels( - id="predicted_entity_labels", - name="Predicted entity labels", - description="Predicted entity labels distribution", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - labels_field="label", - ), - EntityCapitalness( - id="predicted_entity_capitalness", - name="Mention entity capitalness for predictions", - description="Computes capitalization information of predicted entity mentions", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - capitalness_field="capitalness", - ), - MentionLength( - id="predicted_mention_token_length", - name="Predicted mention tokens length", - description="Computes the length of the predicted entity mention measured in number of tokens", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - length_field="tokens_length", - ), - MentionLength( - id="predicted_mention_char_length", - name="Predicted mention characters length", - description="Computes the length of the predicted entity mention measured in number of tokens", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - length_field="chars_length", - ), - MentionsByEntityDistribution( - id="predicted_mentions_distribution", - name="Predicted mentions distribution by entity", - description="Computes predicted mentions distribution against its labels", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - ), - EntityConsistency( - id="predicted_entity_consistency", - name="Entity label consistency for predictions", - description="Computes entity label variability for top-k predicted entity mentions", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - mention_field="value", - labels_field="label", - ), - EntityConsistency( - id="predicted_tag_consistency", - name="Token tag consistency for predictions", - description="Computes token tag variability for top-k predicted tags", - nested_path=_PREDICTED_TAGS_NAMESPACE, - mention_field="value", - labels_field="tag", - ), - ] - - _ANNOTATED_METRICS = [ - EntityDensity( - id="annotated_entity_density", - name="Mention entity density for annotations", - description="Computes the ratio between the number of all entity tokens and tokens in the text", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - density_field="density", - ), - EntityLabels( - id="annotated_entity_labels", - name="Annotated entity labels", - description="Annotated Entity labels distribution", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - labels_field="label", - ), - EntityCapitalness( - id="annotated_entity_capitalness", - name="Mention entity capitalness for annotations", - description="Compute capitalization information of annotated entity mentions", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - capitalness_field="capitalness", - ), - MentionLength( - id="annotated_mention_token_length", - name="Annotated mention tokens length", - description="Computes the length of the entity mention measured in number of tokens", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - length_field="tokens_length", - ), - MentionLength( - id="annotated_mention_char_length", - name="Annotated mention characters length", - description="Computes the length of the entity mention measured in number of tokens", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - length_field="chars_length", - ), - MentionsByEntityDistribution( - id="annotated_mentions_distribution", - name="Annotated mentions distribution by entity", - description="Computes annotated mentions distribution against its labels", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - ), - EntityConsistency( - id="annotated_entity_consistency", - name="Entity label consistency for annotations", - description="Computes entity label variability for top-k annotated entity mentions", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - mention_field="value", - labels_field="label", - ), - EntityConsistency( - id="annotated_tag_consistency", - name="Token tag consistency for annotations", - description="Computes token tag variability for top-k annotated tags", - nested_path=_ANNOTATED_TAGS_NAMESPACE, - mention_field="value", - labels_field="tag", - ), - ] - - metrics: ClassVar[List[BaseMetric]] = CommonTasksMetrics.metrics + [ - TermsAggregation( - id="predicted_as", - name="Predicted labels distribution", - field="predicted_as", - ), - TermsAggregation( - id="annotated_as", - name="Annotated labels distribution", - field="annotated_as", - ), - *_TOKENS_METRICS, - *_PREDICTED_METRICS, - *_ANNOTATED_METRICS, - DatasetLabels(), - F1Metric( - id="F1", - name="F1 Metric based on entity-level", - description="F1 metrics based on entity-level (averaged and per label), " - "where only exact matches count (CoNNL 2003).", - ), - ] + Metric( + id="predicted_entity_consistency", + name="Entity label consistency for predictions", + description="Computes entity label variability for top-k predicted entity mentions", + ), + Metric( + id="predicted_tag_consistency", + name="Token tag consistency for predictions", + description="Computes token tag variability for top-k predicted tags", + ), + Metric( + id="annotated_entity_density", + name="Mention entity density for annotations", + description="Computes the ratio between the number of all entity tokens and tokens in the text", + ), + Metric( + id="annotated_entity_labels", + name="Annotated entity labels", + description="Annotated Entity labels distribution", + ), + Metric( + id="annotated_entity_capitalness", + name="Mention entity capitalness for annotations", + description="Compute capitalization information of annotated entity mentions", + ), + Metric( + id="annotated_mention_token_length", + name="Annotated mention tokens length", + description="Computes the length of the entity mention measured in number of tokens", + ), + Metric( + id="annotated_mention_char_length", + name="Annotated mention characters length", + description="Computes the length of the entity mention measured in number of tokens", + ), + Metric( + id="annotated_mentions_distribution", + name="Annotated mentions distribution by entity", + description="Computes annotated mentions distribution against its labels", + ), + Metric( + id="annotated_entity_consistency", + name="Entity label consistency for annotations", + description="Computes entity label variability for top-k annotated entity mentions", + ), + Metric( + id="annotated_tag_consistency", + name="Token tag consistency for annotations", + description="Computes token tag variability for top-k annotated tags", + ), + ] + ) diff --git a/src/rubrix/server/daos/datasets.py b/src/rubrix/server/daos/datasets.py index 12eb57fc70..bf85f1a370 100644 --- a/src/rubrix/server/daos/datasets.py +++ b/src/rubrix/server/daos/datasets.py @@ -321,16 +321,11 @@ def open(self, dataset: DatasetDB): def get_all_workspaces(self) -> List[str]: """Get all datasets (Only for super users)""" - workspaces_dict = self._es.aggregate( - index=DATASETS_INDEX_NAME, - aggregation=query_helpers.aggregations.terms_aggregation( - "owner.keyword", - missing=NO_WORKSPACE, - size=500, # TODO: A max number of workspaces env var could be leveraged by this. - ), + metric_data = self._es.compute_metric( + DATASETS_INDEX_NAME, + metric_id="all_rubrix_workspaces", ) - - return [k for k in workspaces_dict] + return [k for k in metric_data] def save_settings(self, dataset: DatasetDB, settings: SettingsDB) -> SettingsDB: self._es.update_document( diff --git a/src/rubrix/server/daos/records.py b/src/rubrix/server/daos/records.py index 6e7d767d3a..36c04ff4e2 100644 --- a/src/rubrix/server/daos/records.py +++ b/src/rubrix/server/daos/records.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dataclasses import datetime import re from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar @@ -37,13 +36,14 @@ tasks_common_settings, ) from rubrix.server.elasticseach.query_helpers import parse_aggregations +from rubrix.server.elasticseach.search.query_builder import SearchQuery from rubrix.server.errors import ClosedDatasetError, MissingDatasetRecordsError from rubrix.server.errors.task_errors import MetadataLimitExceededError from rubrix.server.settings import settings DBRecord = TypeVar("DBRecord", bound=BaseRecord) -# TODO(@frascuchon): this should be defined in the dataset class, wat!? +# TODO(@frascuchon): Move to the backend and accept the dataset id as parameter def dataset_records_index(dataset_id: str) -> str: """ Returns dataset records index for a given dataset id @@ -164,6 +164,37 @@ def get_metadata_schema(self, dataset: BaseDatasetDB) -> Dict[str, str]: records_index = dataset_records_index(dataset.id) return self._es.get_field_mapping(index=records_index, field_name="metadata.*") + def compute_metric( + self, + dataset: BaseDatasetDB, + metric_id: str, + metric_params: Dict[str, Any] = None, + query: Optional[SearchQuery] = None, + ): + """ + Parameters + ---------- + metric_id: + The backend metric id + metric_params: + The summary params + dataset: + The records dataset + query: + The filter to apply to dataset + + Returns + ------- + The metric summary result + + """ + return self._es.compute_metric( + index=dataset_records_index(dataset.id), + metric_id=metric_id, + query=query, + params=metric_params, + ) + def search_records( self, dataset: BaseDatasetDB, @@ -193,6 +224,7 @@ def search_records( The search result """ + # TODO(@frascuchon): Move this logic to the backend class search = search or RecordSearch() records_index = dataset_records_index(dataset.id) compute_aggregations = record_from == 0 @@ -207,7 +239,7 @@ def search_records( "from": record_from, "query": self._es.query_builder( dataset=dataset, - schema=self.get_dataset_schema(dataset), + schema=self._es.get_index_mapping(records_index), query=search.query, ), "sort": sort_config, @@ -278,18 +310,19 @@ def scan_dataset( ------- An iterable over found dataset records """ + # TODO(@frascuchon): Move this logic inside the backend component + index = dataset_records_index(dataset.id) search = search or RecordSearch() es_query = { "query": self._es.query_builder( dataset=dataset, - schema=self.get_dataset_schema(dataset), + schema=self._es.get_index_mapping(index), query=search.query, ), "highlight": self.__configure_query_highlight__(task=dataset.task), } - docs = self._es.list_documents( - dataset_records_index(dataset.id), query=es_query - ) + + docs = self._es.list_documents(index, query=es_query) for doc in docs: yield self.__esdoc2record__(doc) @@ -394,13 +427,8 @@ def create_dataset_index( def get_dataset_schema(self, dataset: BaseDatasetDB) -> Dict[str, Any]: """Return inner elasticsearch index configuration""" - index_name = dataset_records_index(dataset.id) - response = self._es.__client__.indices.get_mapping(index=index_name) - - if index_name in response: - response = response.get(index_name) - - return response + schema = self._es.get_index_mapping(dataset_records_index(dataset.id)) + return schema @classmethod def __configure_query_highlight__(cls, task: TaskType): diff --git a/src/rubrix/server/elasticseach/backend.py b/src/rubrix/server/elasticseach/backend.py index 8f610a214b..f28b2dbc8b 100644 --- a/src/rubrix/server/elasticseach/backend.py +++ b/src/rubrix/server/elasticseach/backend.py @@ -22,8 +22,10 @@ from rubrix.logging import LoggingMixin from rubrix.server.elasticseach import query_helpers +from rubrix.server.elasticseach.metrics import ALL_METRICS +from rubrix.server.elasticseach.metrics.base import ElasticsearchMetric from rubrix.server.elasticseach.search.query_builder import EsQueryBuilder -from rubrix.server.errors import InvalidTextSearchError +from rubrix.server.errors import EntityNotFoundError, InvalidTextSearchError try: import ujson as json @@ -77,13 +79,21 @@ def get_instance(cls) -> "ElasticsearchBackend": retry_on_timeout=True, max_retries=5, ) - cls._INSTANCE = cls(es_client, query_builder=EsQueryBuilder()) + cls._INSTANCE = cls( + es_client, query_builder=EsQueryBuilder(), metrics={**ALL_METRICS} + ) return cls._INSTANCE - def __init__(self, es_client: OpenSearch, query_builder: EsQueryBuilder): + def __init__( + self, + es_client: OpenSearch, + query_builder: EsQueryBuilder, + metrics: Dict[str, ElasticsearchMetric] = None, + ): self.__client__ = es_client self.__query_builder__ = query_builder + self.__defined_metrics__ = metrics or {} @property def client(self): @@ -590,6 +600,62 @@ def aggregate(self, index: str, aggregation: Dict[str, Any]) -> Dict[str, Any]: aggregation_name ) + def find_metric_by_id(self, metric_id: str) -> Optional[ElasticsearchMetric]: + metric = self.__defined_metrics__.get(metric_id) + if not metric: + raise EntityNotFoundError(name=metric_id, type="Metric") + return metric + + def compute_metric( + self, + index: str, + metric_id: str, + query: Optional[Any] = None, + params: Optional[Dict[str, Any]] = None, + ): + metric = self.find_metric_by_id(metric_id) + # Only for metadata aggregation. In a future could be nice to provide the whole index schema + params.update( + {"schema": self.get_field_mapping(index=index, field_name="metadata.*")} + ) + + filtered_params = { + argument: params[argument] + for argument in metric.metric_arg_names + if argument in params + } + + aggs = metric.aggregation_request(**filtered_params) + if not aggs: + return {} + if not isinstance(aggs, list): + aggs = [aggs] + results = {} + for agg in aggs: + es_query = { + "query": self.query_builder( + schema=self.get_index_mapping(index), + query=query, + ), + "aggs": agg, + } + search_result = self.search(index=index, query=es_query, size=0) + search_aggregations = search_result.get("aggregations", {}) + + if search_aggregations: + parsed_aggregations = query_helpers.parse_aggregations( + search_aggregations + ) + results.update(parsed_aggregations) + + return metric.aggregation_result(results.get(metric_id, results)) + + def get_index_mapping(self, index: str) -> Dict[str, Any]: + response = self.__client__.indices.get_mapping(index=index) + if index in response: + response = response.get(index) + return response + _instance = None # The singleton instance diff --git a/src/rubrix/server/elasticseach/query_helpers.py b/src/rubrix/server/elasticseach/query_helpers.py index 33abca2529..242a4b959b 100644 --- a/src/rubrix/server/elasticseach/query_helpers.py +++ b/src/rubrix/server/elasticseach/query_helpers.py @@ -13,20 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Any, Dict, List, Optional, Type, Union from pydantic import BaseModel -from rubrix.server.apis.v0.models.commons.model import ( - EsRecordDataFieldNames, - SortableField, -) from rubrix.server.commons.models import TaskStatus from rubrix.server.elasticseach.mappings.helpers import mappings - # TODO(@frascuchon): this should be move to the ElasticsearchBackend context +from rubrix.server.elasticseach.search.model import SortableField + + def nested_mappings_from_base_model(model_class: Type[BaseModel]) -> Dict[str, Any]: def resolve_mapping(info) -> Dict[str, Any]: the_type = info.get("type") @@ -127,10 +124,6 @@ def parse_buckets(buckets: List[Dict[str, Any]]) -> Dict[str, Any]: return result -def decode_field_name(field: EsRecordDataFieldNames) -> str: - return field.value - - class filters: """Group of functions related to elasticsearch filters""" @@ -173,29 +166,21 @@ def predicted_by(predicted_by: List[str] = None) -> Optional[Dict[str, Any]]: if not predicted_by: return None - return { - "terms": { - decode_field_name(EsRecordDataFieldNames.predicted_by): predicted_by - } - } + return {"terms": {"predicted_by": predicted_by}} @staticmethod def annotated_by(annotated_by: List[str] = None) -> Optional[Dict[str, Any]]: """Filter records with given predicted by terms""" if not annotated_by: return None - return { - "terms": { - decode_field_name(EsRecordDataFieldNames.annotated_by): annotated_by - } - } + return {"terms": {"annotated_by": annotated_by}} @staticmethod def status(status: List[TaskStatus] = None) -> Optional[Dict[str, Any]]: """Filter records by status""" if not status: return None - return {"terms": {decode_field_name(EsRecordDataFieldNames.status): status}} + return {"terms": {"status": status}} @staticmethod def metadata(metadata: Dict[str, Union[str, List[str]]]) -> List[Dict[str, Any]]: @@ -270,7 +255,9 @@ class aggregations: MAX_AGGREGATION_SIZE = 5000 # TODO: improve by setting env var @staticmethod - def nested_aggregation(nested_path: str, inner_aggregation: Dict[str, Any]): + def nested_aggregation( + nested_path: str, inner_aggregation: Dict[str, Any] + ) -> Dict[str, Any]: inner_meta = list(inner_aggregation.values())[0].get("meta", {}) return { "meta": { diff --git a/src/rubrix/server/elasticseach/search/model.py b/src/rubrix/server/elasticseach/search/model.py index e44c37d383..cefec27464 100644 --- a/src/rubrix/server/elasticseach/search/model.py +++ b/src/rubrix/server/elasticseach/search/model.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Dict, List, Optional, Union from pydantic import BaseModel, Field @@ -5,6 +6,18 @@ from rubrix.server.commons.models import TaskStatus +class SortOrder(str, Enum): + asc = "asc" + desc = "desc" + + +class SortableField(BaseModel): + """Sortable field structure""" + + id: str + order: SortOrder = SortOrder.asc + + class BaseSearchQuery(BaseModel): query_text: Optional[str] = None diff --git a/src/rubrix/server/elasticseach/search/query_builder.py b/src/rubrix/server/elasticseach/search/query_builder.py index f06d759262..c29e59cb23 100644 --- a/src/rubrix/server/elasticseach/search/query_builder.py +++ b/src/rubrix/server/elasticseach/search/query_builder.py @@ -25,7 +25,6 @@ def get_instance(cls): def __call__( self, - dataset: BaseDatasetDB, schema: Dict[str, Any], query: Optional[SearchQuery] = None, ) -> Dict[str, Any]: diff --git a/src/rubrix/server/services/metrics.py b/src/rubrix/server/services/metrics.py deleted file mode 100644 index 65d8cdb7d1..0000000000 --- a/src/rubrix/server/services/metrics.py +++ /dev/null @@ -1,201 +0,0 @@ -from typing import Callable, Optional, Type, TypeVar, Union - -from fastapi import Depends - -from rubrix.server.apis.v0.models.metrics.base import ( - ElasticsearchMetric, - NestedPathElasticsearchMetric, - PythonMetric, -) -from rubrix.server.apis.v0.models.metrics.commons import * -from rubrix.server.daos.models.records import RecordSearch -from rubrix.server.daos.records import DatasetRecordsDAO, dataset_records_dao -from rubrix.server.elasticseach.search.query_builder import EsQueryBuilder -from rubrix.server.errors import WrongInputParamError -from rubrix.server.services.datasets import Dataset -from rubrix.server.services.tasks.commons.record import BaseRecordDB - -GenericQuery = TypeVar("GenericQuery") - - -class MetricsService: - """The dataset metrics service singleton""" - - _INSTANCE = None - - @classmethod - def get_instance( - cls, - dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance), - ) -> "MetricsService": - """ - Creates the service instance. - - Parameters - ---------- - dao: - The dataset records dao - - Returns - ------- - The metrics service instance - - """ - if not cls._INSTANCE: - cls._INSTANCE = cls(dao) - return cls._INSTANCE - - def __init__(self, dao: DatasetRecordsDAO): - """ - Creates a service instance - - Parameters - ---------- - dao: - The dataset records dao - """ - self.__dao__ = dao - - def summarize_metric( - self, - dataset: Dataset, - metric: BaseMetric, - record_class: Optional[Type[BaseRecordDB]] = None, - query: Optional[GenericQuery] = None, - **metric_params, - ) -> Dict[str, Any]: - """ - Applies a metric summarization. - - Parameters - ---------- - dataset: - The records dataset - metric: - The selected metric - query: - An optional query passed for records filtering - metric_params: - Related metrics parameters - - Returns - ------- - The metric summarization info - """ - - if isinstance(metric, ElasticsearchMetric): - return self._handle_elasticsearch_metric( - metric, metric_params, dataset=dataset, query=query - ) - elif isinstance(metric, PythonMetric): - records = self.__dao__.scan_dataset( - dataset, search=RecordSearch(query=query) - ) - return metric.apply(map(record_class.parse_obj, records)) - - raise WrongInputParamError(f"Cannot process {metric} of type {type(metric)}") - - def _handle_elasticsearch_metric( - self, - metric: ElasticsearchMetric, - metric_params: Dict[str, Any], - dataset: Dataset, - query: GenericQuery, - ) -> Dict[str, Any]: - """ - Parameters - ---------- - metric: - The elasticsearch metric summary - metric_params: - The summary params - dataset: - The records dataset - query: - The filter to apply to dataset - - Returns - ------- - The metric summary result - - """ - params = self.__compute_metric_params__( - dataset=dataset, metric=metric, query=query, provided_params=metric_params - ) - results = self.__metric_results__( - dataset=dataset, - query=query, - metric_aggregation=metric.aggregation_request(**params), - ) - return metric.aggregation_result( - aggregation_result=results.get(metric.id, results) - ) - - def __compute_metric_params__( - self, - dataset: Dataset, - metric: ElasticsearchMetric, - query: Optional[GenericQuery], - provided_params: Dict[str, Any], - ) -> Dict[str, Any]: - - return self._filter_metric_params( - metric=metric, - function=metric.aggregation_request, - metric_params={ - **provided_params, # in case of param conflict, provided metric params will be preserved - "dataset": dataset, - "dao": self.__dao__, - }, - ) - - def __metric_results__( - self, - dataset: Dataset, - query: Optional[GenericQuery], - metric_aggregation: Union[Dict[str, Any], List[Dict[str, Any]]], - ) -> Dict[str, Any]: - - if not metric_aggregation: - return {} - - if not isinstance(metric_aggregation, list): - metric_aggregation = [metric_aggregation] - - results = {} - for agg in metric_aggregation: - results_ = self.__dao__.search_records( - dataset, - size=0, # No records at all - search=RecordSearch( - query=query, - aggregations=agg, - ), - ) - results.update(results_.aggregations) - return results - - @staticmethod - def _filter_metric_params( - metric: ElasticsearchMetric, function: Callable, metric_params: Dict[str, Any] - ): - """ - Select from provided metric parameter those who can be applied to given metric - - Parameters - ---------- - metric: - The target metric - metric_params: - A dict of metric parameters - - """ - - if isinstance(metric, NestedPathElasticsearchMetric): - function = metric.inner_aggregation - - return { - argument: metric_params[argument] - for argument in function.__code__.co_varnames - if argument in metric_params - } diff --git a/src/rubrix/server/services/metrics/service.py b/src/rubrix/server/services/metrics/service.py new file mode 100644 index 0000000000..1b3655a62e --- /dev/null +++ b/src/rubrix/server/services/metrics/service.py @@ -0,0 +1,92 @@ +from typing import Any, Dict, Optional, Type, TypeVar + +from fastapi import Depends + +from rubrix.server.daos.models.records import RecordSearch +from rubrix.server.daos.records import DatasetRecordsDAO +from rubrix.server.services.datasets import Dataset +from rubrix.server.services.metrics.models import BaseMetric, PythonMetric +from rubrix.server.services.tasks.commons.record import BaseRecordDB + +GenericQuery = TypeVar("GenericQuery") + + +class MetricsService: + """The dataset metrics service singleton""" + + _INSTANCE = None + + @classmethod + def get_instance( + cls, + dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance), + ) -> "MetricsService": + """ + Creates the service instance. + + Parameters + ---------- + dao: + The dataset records dao + + Returns + ------- + The metrics service instance + + """ + if not cls._INSTANCE: + cls._INSTANCE = cls(dao) + return cls._INSTANCE + + def __init__(self, dao: DatasetRecordsDAO): + """ + Creates a service instance + + Parameters + ---------- + dao: + The dataset records dao + """ + self.__dao__ = dao + + def summarize_metric( + self, + dataset: Dataset, + metric: BaseMetric, + record_class: Optional[Type[BaseRecordDB]] = None, + query: Optional[GenericQuery] = None, + **metric_params, + ) -> Dict[str, Any]: + """ + Applies a metric summarization. + + Parameters + ---------- + dataset: + The records dataset + metric: + The selected metric + record_class: + The record class type for python metrics computation + query: + An optional query passed for records filtering + metric_params: + Related metrics parameters + + Returns + ------- + The metric summarization info + """ + + if isinstance(metric, PythonMetric): + records = self.__dao__.scan_dataset( + dataset, search=RecordSearch(query=query) + ) + return metric.apply(map(record_class.parse_obj, records)) + + return self.__dao__.compute_metric( + metric_id=metric.id, + metric_params=metric_params, + dataset=dataset, + query=query, + ) diff --git a/src/rubrix/server/services/search/model.py b/src/rubrix/server/services/search/model.py index 87a693b93b..d19c0e6be4 100644 --- a/src/rubrix/server/services/search/model.py +++ b/src/rubrix/server/services/search/model.py @@ -5,6 +5,8 @@ from pydantic.generics import GenericModel from rubrix.server.elasticseach.search.model import BaseSearchQuery as _BaseSearchQuery +from rubrix.server.elasticseach.search.model import SortableField as _SortableField +from rubrix.server.elasticseach.search.model import SortOrder from rubrix.server.services.tasks.commons.record import Record @@ -12,16 +14,10 @@ class BaseSVCSearchQuery(_BaseSearchQuery): pass -class SortOrder(str, Enum): - asc = "asc" - desc = "desc" - - -class SortableField(BaseModel): +class SortableField(_SortableField): """Sortable field structure""" - id: str - order: SortOrder = SortOrder.asc + pass class QueryRange(BaseModel): diff --git a/src/rubrix/server/services/search/service.py b/src/rubrix/server/services/search/service.py index d065aa153f..4d9c58d5e8 100644 --- a/src/rubrix/server/services/search/service.py +++ b/src/rubrix/server/services/search/service.py @@ -3,12 +3,11 @@ from fastapi import Depends -from rubrix.server.apis.v0.models.metrics.base import BaseMetric from rubrix.server.daos.models.records import RecordSearch from rubrix.server.daos.records import DatasetRecordsDAO from rubrix.server.elasticseach.query_helpers import sort_by2elasticsearch from rubrix.server.services.datasets import Dataset -from rubrix.server.services.metrics import MetricsService +from rubrix.server.services.metrics import BaseMetric, MetricsService from rubrix.server.services.search.model import ( BaseSVCSearchQuery, Record, diff --git a/src/rubrix/server/services/text2text.py b/src/rubrix/server/services/text2text.py index 03d5c8c5d4..a02ee402ff 100644 --- a/src/rubrix/server/services/text2text.py +++ b/src/rubrix/server/services/text2text.py @@ -22,7 +22,7 @@ EsRecordDataFieldNames, SortableField, ) -from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics +from rubrix.server.apis.v0.models.metrics.base import BaseTaskMetrics from rubrix.server.apis.v0.models.text2text import ( CreationText2TextRecord, Text2TextDatasetDB, @@ -32,6 +32,7 @@ Text2TextSearchAggregations, Text2TextSearchResults, ) +from rubrix.server.services.metrics import BaseMetric from rubrix.server.services.search.model import SortConfig from rubrix.server.services.search.service import SearchRecordsService from rubrix.server.services.storage.service import RecordsStorageService diff --git a/src/rubrix/server/services/text_classification.py b/src/rubrix/server/services/text_classification.py index 113562e9b5..d86e2590ab 100644 --- a/src/rubrix/server/services/text_classification.py +++ b/src/rubrix/server/services/text_classification.py @@ -22,7 +22,7 @@ EsRecordDataFieldNames, SortableField, ) -from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics +from rubrix.server.apis.v0.models.metrics.base import BaseTaskMetrics from rubrix.server.apis.v0.models.text_classification import ( CreationTextClassificationRecord, DatasetLabelingRulesMetricsSummary, @@ -36,6 +36,7 @@ TextClassificationSearchResults, ) from rubrix.server.errors.base_errors import MissingDatasetRecordsError +from rubrix.server.services.metrics import BaseMetric from rubrix.server.services.search.model import SortConfig from rubrix.server.services.search.service import SearchRecordsService from rubrix.server.services.storage.service import RecordsStorageService diff --git a/src/rubrix/server/services/text_classification_labelling_rules.py b/src/rubrix/server/services/text_classification_labelling_rules.py index 1204e53368..87b3c1e306 100644 --- a/src/rubrix/server/services/text_classification_labelling_rules.py +++ b/src/rubrix/server/services/text_classification_labelling_rules.py @@ -1,11 +1,9 @@ -from typing import Any, Dict, List, Optional, Tuple +import dataclasses +from typing import List, Optional, Tuple from fastapi import Depends from pydantic import BaseModel, Field -from rubrix.server._helpers import unflatten_dict -from rubrix.server.apis.v0.models.commons.model import EsRecordDataFieldNames -from rubrix.server.apis.v0.models.metrics.base import ElasticsearchMetric from rubrix.server.apis.v0.models.text_classification import ( LabelingRule, TextClassificationDatasetDB, @@ -13,131 +11,10 @@ from rubrix.server.daos.datasets import DatasetsDAO from rubrix.server.daos.models.records import BaseSearchQuery, RecordSearch from rubrix.server.daos.records import DatasetRecordsDAO -from rubrix.server.elasticseach.query_helpers import filters from rubrix.server.errors import EntityAlreadyExistsError, EntityNotFoundError -class DatasetLabelingRulesMetric(ElasticsearchMetric): - id: str = Field("dataset_labeling_rules", const=True) - name: str = Field( - "Computes overall metrics for defined rules in dataset", const=True - ) - - def aggregation_request(self, all_rules: List[LabelingRule]) -> Dict[str, Any]: - rules_filters = [filters.text_query(rule.query) for rule in all_rules] - return { - self.id: { - "filters": { - "filters": { - "covered_records": filters.boolean_filter( - should_filters=rules_filters, minimum_should_match=1 - ), - "annotated_covered_records": filters.boolean_filter( - filter_query=filters.exists_field( - EsRecordDataFieldNames.annotated_as - ), - should_filters=rules_filters, - minimum_should_match=1, - ), - } - } - } - } - - -class LabelingRulesMetric(ElasticsearchMetric): - id: str = Field("labeling_rule", const=True) - name: str = Field("Computes metrics for a labeling rule", const=True) - - def aggregation_request( - self, - rule_query: str, - labels: Optional[List[str]], - ) -> Dict[str, Any]: - - annotated_records_filter = filters.exists_field( - EsRecordDataFieldNames.annotated_as - ) - rule_query_filter = filters.text_query(rule_query) - aggr_filters = { - "covered_records": rule_query_filter, - "annotated_covered_records": filters.boolean_filter( - filter_query=annotated_records_filter, - should_filters=[rule_query_filter], - ), - } - - if labels is not None: - for label in labels: - rule_label_annotated_filter = filters.term_filter( - "annotated_as", value=label - ) - encoded_label = self._encode_label_name(label) - aggr_filters.update( - { - f"{encoded_label}.correct_records": filters.boolean_filter( - filter_query=annotated_records_filter, - should_filters=[ - rule_query_filter, - rule_label_annotated_filter, - ], - minimum_should_match=2, - ), - f"{encoded_label}.incorrect_records": filters.boolean_filter( - filter_query=annotated_records_filter, - must_query=rule_query_filter, - must_not_query=rule_label_annotated_filter, - ), - } - ) - - return {self.id: {"filters": {"filters": aggr_filters}}} - - @staticmethod - def _encode_label_name(label: str) -> str: - return label.replace(".", "@@@") - - @staticmethod - def _decode_label_name(label: str) -> str: - return label.replace("@@@", ".") - - def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: - if self.id in aggregation_result: - aggregation_result = aggregation_result[self.id] - - aggregation_result = unflatten_dict(aggregation_result) - results = { - "covered_records": aggregation_result.pop("covered_records"), - "annotated_covered_records": aggregation_result.pop( - "annotated_covered_records" - ), - } - - all_correct = [] - all_incorrect = [] - all_precision = [] - for label, metrics in aggregation_result.items(): - correct = metrics.get("correct_records", 0) - incorrect = metrics.get("incorrect_records", 0) - annotated = correct + incorrect - metrics["annotated"] = annotated - if annotated > 0: - precision = correct / annotated - metrics["precision"] = precision - all_precision.append(precision) - - all_correct.append(correct) - all_incorrect.append(incorrect) - results[self._decode_label_name(label)] = metrics - - results["correct_records"] = sum(all_correct) - results["incorrect_records"] = sum(all_incorrect) - if len(all_precision) > 0: - results["precision"] = sum(all_precision) / len(all_precision) - - return results - - +@dataclasses.dataclass class DatasetLabelingRulesSummary(BaseModel): covered_records: int annotated_covered_records: int @@ -155,9 +32,6 @@ class LabelingService: _INSTANCE = None - __rule_metrics__ = LabelingRulesMetric() - __dataset_rules_metrics__ = DatasetLabelingRulesMetric() - @classmethod def get_instance( cls, @@ -203,23 +77,19 @@ def compute_rule_metrics( """Computes metrics for given rule query and optional label against a set of rules""" annotated_records = self._count_annotated_records(dataset) - results = self.__records__.search_records( - dataset, - size=0, - search=RecordSearch( - aggregations=self.__rule_metrics__.aggregation_request( - rule_query=rule_query, labels=labels - ), - ), + dataset_records = self.__records__.search_records(dataset, size=0).total + metric_data = self.__records__.compute_metric( + dataset=dataset, + metric_id="labeling_rule", + metric_params=dict(rule_query=rule_query, labels=labels), ) - rule_metrics_summary = self.__rule_metrics__.aggregation_result( - results.aggregations + return ( + dataset_records, + annotated_records, + LabelingRuleSummary.parse_obj(metric_data), ) - metrics = LabelingRuleSummary.parse_obj(rule_metrics_summary) - return results.total, annotated_records, metrics - def _count_annotated_records(self, dataset: TextClassificationDatasetDB) -> int: results = self.__records__.search_records( dataset, @@ -232,25 +102,17 @@ def all_rules_metrics( self, dataset: TextClassificationDatasetDB ) -> Tuple[int, int, DatasetLabelingRulesSummary]: annotated_records = self._count_annotated_records(dataset) - results = self.__records__.search_records( - dataset, - size=0, - search=RecordSearch( - # TODO(@frascuchon): elasticsearch metrics should be managed by the backend component - aggregations=self.__dataset_rules_metrics__.aggregation_request( - all_rules=dataset.rules - ), - ), - ) - - rule_metrics_summary = self.__dataset_rules_metrics__.aggregation_result( - results.aggregations + dataset_records = self.__records__.search_records(dataset, size=0).total + metric_data = self.__records__.compute_metric( + dataset=dataset, + metric_id="dataset_labeling_rules", + metric_params=dict(queries=dataset.rules), ) return ( - results.total, + dataset_records, annotated_records, - DatasetLabelingRulesSummary.parse_obj(rule_metrics_summary), + DatasetLabelingRulesSummary.parse_obj(metric_data), ) def find_rule_by_query( diff --git a/src/rubrix/server/services/token_classification.py b/src/rubrix/server/services/token_classification.py index 1d66a617b2..b33cbadc6c 100644 --- a/src/rubrix/server/services/token_classification.py +++ b/src/rubrix/server/services/token_classification.py @@ -22,7 +22,7 @@ EsRecordDataFieldNames, SortableField, ) -from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics +from rubrix.server.apis.v0.models.metrics.base import BaseTaskMetrics from rubrix.server.apis.v0.models.token_classification import ( CreationTokenClassificationRecord, TokenClassificationAggregations, @@ -32,6 +32,7 @@ TokenClassificationRecordDB, TokenClassificationSearchResults, ) +from rubrix.server.services.metrics import BaseMetric from rubrix.server.services.search.model import SortConfig from rubrix.server.services.search.service import SearchRecordsService from rubrix.server.services.storage.service import RecordsStorageService diff --git a/tests/functional_tests/search/test_search_service.py b/tests/functional_tests/search/test_search_service.py index 9e29abb2d0..dd30883b3d 100644 --- a/tests/functional_tests/search/test_search_service.py +++ b/tests/functional_tests/search/test_search_service.py @@ -3,7 +3,7 @@ import rubrix from rubrix.server.apis.v0.models.commons.model import ScoreRange, TaskType from rubrix.server.apis.v0.models.datasets import Dataset -from rubrix.server.apis.v0.models.metrics.base import BaseMetric +from rubrix.server.apis.v0.models.metrics.base import PythonMetric from rubrix.server.apis.v0.models.text_classification import ( TextClassificationQuery, TextClassificationRecord, @@ -129,7 +129,7 @@ def test_failing_metrics(service, mocked_client): dataset=dataset, query=TextClassificationQuery(), sort_config=SortConfig(), - metrics=[BaseMetric(id="missing-metric", name="Missing metric")], + metrics=[PythonMetric(id="missing-metric", name="Missing metric")], size=0, record_type=TextClassificationRecord, ) diff --git a/tests/server/metrics/test_api.py b/tests/server/metrics/test_api.py index 14637defb5..2516c7f12c 100644 --- a/tests/server/metrics/test_api.py +++ b/tests/server/metrics/test_api.py @@ -146,7 +146,7 @@ def test_dataset_for_token_classification(mocked_client): json={}, ) - assert response.status_code == 200, response.json() + assert response.status_code == 200, f"{metric} :: {response.json()}" summary = response.json() if not ("predicted" in metric_id or "annotated" in metric_id): @@ -199,7 +199,7 @@ def test_dataset_metrics(mocked_client): assert response.json() == { "detail": { "code": "rubrix.api.errors::EntityNotFoundError", - "params": {"name": "missing_metric", "type": "BaseMetric"}, + "params": {"name": "missing_metric", "type": "Metric"}, } }