From 4273388034e9b4cadfd4614c33e8c6c3cf88893d Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 25 Aug 2022 13:15:14 +0200 Subject: [PATCH] refactor: encapsulate server component layers (#1647) * chore: rename class to `ElasticsearchBackend` * chore: moving the query builder to the elasticsearch module * refactor: query builder accesible through the backend instance * fix: moving and rename models * revert: configs * refactor: move TaskStatus class to commons module * refactor: move TaskStatus class to commons module * refactor: allow filter records with annotations/predictions * test: update tests * refactor: move all elasticsearch metrics to the elasticsearch layer (#1651) * perf: handle all backend errors in calls * refactor: remove all filters references outside the elasticsearch module * refactor: move query generation to backend module (#1652) * refactor: move es-mappings to backend module (#1653) * refactor: search and scan to the backend component * chore: cleaning source * chore: renaming module elasticsearch -> backend * fix: sort by id and default sorting * Refactor/normalize dao layer (#1654) * refactor: normalize daos layer * chore: remove refs to EsRecordDataFieldNames * refactor: keep dao models and extends in service layer * revert: rollback wrong documentation * refactor: keep service models and extends in api layer (#1655) * chore: moving helpers and class renaming * revert: renaming * refactor: clean and normalize api layer (#1657) * refactor: more cleaning models * refactor: fetch metrics for search on services * refactor: resolve TODOs * chore: add TODOs * refactor: keep elasticsearch details inside the backend module * chore: cleaning code * fix: create index properly * test: fix tests * refactor: clean API models * fix: `id_from` changes (cherry picked from commit b9e756297a7229d6a56703bd49266b1195519a52) --- pyproject.toml | 1 - src/rubrix/server/_helpers.py | 54 -- .../server/apis/v0/config/tasks_factory.py | 155 ----- .../server/apis/v0/handlers/datasets.py | 172 +---- src/rubrix/server/apis/v0/handlers/info.py | 26 +- src/rubrix/server/apis/v0/handlers/metrics.py | 101 +-- .../server/apis/v0/handlers/text2text.py | 146 ++-- .../apis/v0/handlers/text_classification.py | 178 +++-- .../text_classification_dataset_settings.py | 20 +- .../apis/v0/handlers/token_classification.py | 172 ++--- .../token_classification_dataset_settings.py | 20 +- .../server/apis/v0/models/commons/model.py | 101 +-- .../server/apis/v0/models/commons/params.py | 43 +- .../apis/v0/models/commons/workspace.py | 30 - src/rubrix/server/apis/v0/models/datasets.py | 23 +- src/rubrix/server/apis/v0/models/info.py | 31 - .../server/apis/v0/models/metrics/base.py | 304 -------- .../server/apis/v0/models/metrics/commons.py | 75 -- .../v0/models/metrics/token_classification.py | 643 ----------------- src/rubrix/server/apis/v0/models/text2text.py | 206 +----- .../apis/v0/models/text_classification.py | 470 ++----------- .../apis/v0/models/token_classification.py | 412 +---------- .../apis/v0/validators/text_classification.py | 22 +- .../v0/validators/token_classification.py | 26 +- .../{apis/v0/config => commons}/__init__.py | 0 src/rubrix/server/commons/config.py | 96 +++ src/rubrix/server/commons/models.py | 21 + .../metrics => daos/backend}/__init__.py | 0 .../backend/mappings}/__init__.py | 0 .../backend}/mappings/datasets.py | 4 +- .../backend}/mappings/helpers.py | 2 + .../backend}/mappings/text2text.py | 6 +- .../backend}/mappings/text_classification.py | 2 +- .../backend}/mappings/token_classification.py | 43 +- .../server/daos/backend/metrics/__init__.py | 5 + .../server/daos/backend/metrics/base.py | 213 ++++++ .../server/daos/backend/metrics/commons.py | 47 ++ .../server/daos/backend/metrics/datasets.py | 8 + .../backend/metrics/text_classification.py | 130 ++++ .../backend/metrics/token_classification.py | 214 ++++++ .../backend}/query_helpers.py | 50 +- .../backend/search}/__init__.py | 0 .../server/daos/backend/search/model.py | 67 ++ .../daos/backend/search/query_builder.py | 227 ++++++ src/rubrix/server/daos/datasets.py | 259 ++----- src/rubrix/server/daos/models/datasets.py | 14 +- src/rubrix/server/daos/models/records.py | 183 +++-- src/rubrix/server/daos/records.py | 391 ++--------- .../server/elasticseach/client_wrapper.py | 650 ------------------ src/rubrix/server/{apis/v0 => }/helpers.py | 67 +- src/rubrix/server/server.py | 4 +- src/rubrix/server/services/datasets.py | 107 ++- src/rubrix/server/services/info.py | 78 ++- src/rubrix/server/services/metrics.py | 204 ------ .../server/services/metrics/__init__.py | 2 + src/rubrix/server/services/metrics/models.py | 156 +++++ src/rubrix/server/services/metrics/service.py | 91 +++ src/rubrix/server/services/search/model.py | 87 +-- .../server/services/search/query_builder.py | 108 --- src/rubrix/server/services/search/service.py | 71 +- src/rubrix/server/services/storage/service.py | 18 +- .../server/services/tasks/commons/__init__.py | 3 +- .../server/services/tasks/commons/logging.py | 21 - .../server/services/tasks/commons/models.py | 35 + .../server/services/tasks/commons/record.py | 152 ---- .../services/tasks/text2text/__init__.py | 1 + .../server/services/tasks/text2text/models.py | 81 +++ .../text2text/service.py} | 114 ++- .../tasks/text_classification/__init__.py | 2 + .../labeling_rules_service.py | 138 ++++ .../tasks/text_classification/metrics.py} | 75 +- .../tasks/text_classification/model.py | 344 +++++++++ .../text_classification/service.py} | 184 ++--- .../tasks/token_classification/__init__.py | 1 + .../tasks/token_classification/metrics.py | 407 +++++++++++ .../tasks/token_classification/model.py | 327 +++++++++ .../token_classification/service.py} | 140 ++-- .../text_classification_labelling_rules.py | 271 -------- tests/client/sdk/conftest.py | 53 ++ tests/client/sdk/text2text/test_models.py | 11 +- .../sdk/text_classification/test_models.py | 10 +- .../sdk/token_classification/test_models.py | 10 +- .../search/test_search_service.py | 123 ++-- .../labeling/text_classification/test_rule.py | 2 +- tests/metrics/test_text_classification.py | 6 +- tests/server/backend/__init__.py | 0 tests/server/backend/test_query_builder.py | 66 ++ tests/server/commons/test_records_dao.py | 12 +- tests/server/datasets/test_api.py | 16 +- tests/server/datasets/test_dao.py | 22 +- tests/server/datasets/test_model.py | 7 +- tests/server/info/test_api.py | 2 +- tests/server/metrics/test_api.py | 28 +- tests/server/test_api.py | 7 +- tests/server/text2text/test_api.py | 4 +- tests/server/text2text/test_model.py | 4 +- tests/server/text_classification/test_api.py | 35 +- .../text_classification/test_api_rules.py | 6 +- .../text_classification/test_api_settings.py | 2 +- .../server/text_classification/test_model.py | 46 +- tests/server/token_classification/test_api.py | 23 +- .../token_classification/test_api_settings.py | 2 +- .../server/token_classification/test_model.py | 40 +- 103 files changed, 4147 insertions(+), 5742 deletions(-) delete mode 100644 src/rubrix/server/_helpers.py delete mode 100644 src/rubrix/server/apis/v0/config/tasks_factory.py delete mode 100644 src/rubrix/server/apis/v0/models/commons/workspace.py delete mode 100644 src/rubrix/server/apis/v0/models/info.py delete mode 100644 src/rubrix/server/apis/v0/models/metrics/base.py delete mode 100644 src/rubrix/server/apis/v0/models/metrics/commons.py delete mode 100644 src/rubrix/server/apis/v0/models/metrics/token_classification.py rename src/rubrix/server/{apis/v0/config => commons}/__init__.py (100%) create mode 100644 src/rubrix/server/commons/config.py create mode 100644 src/rubrix/server/commons/models.py rename src/rubrix/server/{apis/v0/models/metrics => daos/backend}/__init__.py (100%) rename src/rubrix/server/{elasticseach => daos/backend/mappings}/__init__.py (100%) rename src/rubrix/server/{elasticseach => daos/backend}/mappings/datasets.py (87%) rename src/rubrix/server/{elasticseach => daos/backend}/mappings/helpers.py (97%) rename src/rubrix/server/{elasticseach => daos/backend}/mappings/text2text.py (69%) rename src/rubrix/server/{elasticseach => daos/backend}/mappings/text_classification.py (95%) rename src/rubrix/server/{elasticseach => daos/backend}/mappings/token_classification.py (68%) create mode 100644 src/rubrix/server/daos/backend/metrics/__init__.py create mode 100644 src/rubrix/server/daos/backend/metrics/base.py create mode 100644 src/rubrix/server/daos/backend/metrics/commons.py create mode 100644 src/rubrix/server/daos/backend/metrics/datasets.py create mode 100644 src/rubrix/server/daos/backend/metrics/text_classification.py create mode 100644 src/rubrix/server/daos/backend/metrics/token_classification.py rename src/rubrix/server/{elasticseach => daos/backend}/query_helpers.py (90%) rename src/rubrix/server/{elasticseach/mappings => daos/backend/search}/__init__.py (100%) create mode 100644 src/rubrix/server/daos/backend/search/model.py create mode 100644 src/rubrix/server/daos/backend/search/query_builder.py delete mode 100644 src/rubrix/server/elasticseach/client_wrapper.py rename src/rubrix/server/{apis/v0 => }/helpers.py (68%) delete mode 100644 src/rubrix/server/services/metrics.py create mode 100644 src/rubrix/server/services/metrics/__init__.py create mode 100644 src/rubrix/server/services/metrics/models.py create mode 100644 src/rubrix/server/services/metrics/service.py delete mode 100644 src/rubrix/server/services/search/query_builder.py delete mode 100644 src/rubrix/server/services/tasks/commons/logging.py create mode 100644 src/rubrix/server/services/tasks/commons/models.py delete mode 100644 src/rubrix/server/services/tasks/commons/record.py create mode 100644 src/rubrix/server/services/tasks/text2text/__init__.py create mode 100644 src/rubrix/server/services/tasks/text2text/models.py rename src/rubrix/server/services/{text2text.py => tasks/text2text/service.py} (52%) create mode 100644 src/rubrix/server/services/tasks/text_classification/__init__.py create mode 100644 src/rubrix/server/services/tasks/text_classification/labeling_rules_service.py rename src/rubrix/server/{apis/v0/models/metrics/text_classification.py => services/tasks/text_classification/metrics.py} (69%) create mode 100644 src/rubrix/server/services/tasks/text_classification/model.py rename src/rubrix/server/services/{text_classification.py => tasks/text_classification/service.py} (67%) create mode 100644 src/rubrix/server/services/tasks/token_classification/__init__.py create mode 100644 src/rubrix/server/services/tasks/token_classification/metrics.py create mode 100644 src/rubrix/server/services/tasks/token_classification/model.py rename src/rubrix/server/services/{token_classification.py => tasks/token_classification/service.py} (53%) delete mode 100644 src/rubrix/server/services/text_classification_labelling_rules.py create mode 100644 tests/server/backend/__init__.py create mode 100644 tests/server/backend/test_query_builder.py diff --git a/pyproject.toml b/pyproject.toml index da1f7409a8..3e57161904 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,6 @@ server = [ "python-jose[cryptography]~=3.2.0", "passlib[bcrypt]~=1.7.4", # Info status - "hurry.filesize", # TODO: remove "psutil ~= 5.8.0", ] listeners = [ diff --git a/src/rubrix/server/_helpers.py b/src/rubrix/server/_helpers.py deleted file mode 100644 index c088a1566f..0000000000 --- a/src/rubrix/server/_helpers.py +++ /dev/null @@ -1,54 +0,0 @@ -# coding=utf-8 -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Common helper functions -""" -from typing import Any, Dict, List, Optional - - -def unflatten_dict( - data: Dict[str, Any], sep: str = ".", stop_keys: Optional[List[str]] = None -) -> Dict[str, Any]: - """ - Given a flat dictionary keys, build a hierarchical version by grouping keys - - Parameters - ---------- - data: - The data dictionary - sep: - The key separator. Default "." - stop_keys - List of dictionary first level keys where hierarchy will stop - - Returns - ------- - - """ - resultDict = {} - stop_keys = stop_keys or [] - for key, value in data.items(): - if key is not None: - parts = key.split(sep) - if parts[0] in stop_keys: - parts = [parts[0], sep.join(parts[1:])] - d = resultDict - for part in parts[:-1]: - if part not in d: - d[part] = {} - d = d[part] - d[parts[-1]] = value - return resultDict diff --git a/src/rubrix/server/apis/v0/config/tasks_factory.py b/src/rubrix/server/apis/v0/config/tasks_factory.py deleted file mode 100644 index 5ad982bf4f..0000000000 --- a/src/rubrix/server/apis/v0/config/tasks_factory.py +++ /dev/null @@ -1,155 +0,0 @@ -from typing import Any, Dict, List, Optional, Set, Type - -from pydantic import BaseModel - -from rubrix.server.apis.v0.models.commons.model import BaseRecord, TaskType -from rubrix.server.apis.v0.models.datasets import DatasetDB -from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics -from rubrix.server.apis.v0.models.metrics.text_classification import ( - TextClassificationMetrics, -) -from rubrix.server.apis.v0.models.metrics.token_classification import ( - TokenClassificationMetrics, -) -from rubrix.server.apis.v0.models.text2text import ( - Text2TextDatasetDB, - Text2TextMetrics, - Text2TextQuery, - Text2TextRecord, -) -from rubrix.server.apis.v0.models.text_classification import ( - TextClassificationDatasetDB, - TextClassificationQuery, - TextClassificationRecord, -) -from rubrix.server.apis.v0.models.token_classification import ( - TokenClassificationDatasetDB, - TokenClassificationQuery, - TokenClassificationRecord, -) -from rubrix.server.elasticseach.mappings.text2text import text2text_mappings -from rubrix.server.elasticseach.mappings.text_classification import ( - text_classification_mappings, -) -from rubrix.server.elasticseach.mappings.token_classification import ( - token_classification_mappings, -) -from rubrix.server.errors import EntityNotFoundError, WrongTaskError - - -class TaskConfig(BaseModel): - task: TaskType - query: Any - dataset: Type[DatasetDB] - record: Type[BaseRecord] - metrics: Optional[Type[BaseTaskMetrics]] - es_mappings: Dict[str, Any] - - -class TaskFactory: - - _REGISTERED_TASKS = dict() - - @classmethod - def register_task( - cls, - task_type: TaskType, - dataset_class: Type[DatasetDB], - query_request: Type[Any], - es_mappings: Dict[str, Any], - record_class: Type[BaseRecord], - metrics: Optional[Type[BaseTaskMetrics]] = None, - ): - cls._REGISTERED_TASKS[task_type] = TaskConfig( - task=task_type, - dataset=dataset_class, - es_mappings=es_mappings, - query=query_request, - record=record_class, - metrics=metrics, - ) - - @classmethod - def get_all_configs(cls) -> List[TaskConfig]: - return [cfg for cfg in cls._REGISTERED_TASKS.values()] - - @classmethod - def get_task_by_task_type(cls, task_type: TaskType) -> Optional[TaskConfig]: - return cls._REGISTERED_TASKS.get(task_type) - - @classmethod - def get_task_metrics(cls, task: TaskType) -> Optional[Type[BaseTaskMetrics]]: - config = cls.get_task_by_task_type(task) - if config: - return config.metrics - - @classmethod - def get_task_dataset(cls, task: TaskType) -> Type[DatasetDB]: - config = cls.__get_task_config__(task) - return config.dataset - - @classmethod - def get_task_record(cls, task: TaskType) -> Type[BaseRecord]: - config = cls.__get_task_config__(task) - return config.record - - @classmethod - def get_task_mappings(cls, task: TaskType) -> Dict[str, Any]: - config = cls.__get_task_config__(task) - return config.es_mappings - - @classmethod - def __get_task_config__(cls, task): - config = cls.get_task_by_task_type(task) - if not config: - raise WrongTaskError(f"No configuration found for task {task}") - return config - - @classmethod - def find_task_metric(cls, task: TaskType, metric_id: str) -> Optional[BaseMetric]: - metrics = cls.find_task_metrics(task, {metric_id}) - if metrics: - return metrics[0] - raise EntityNotFoundError(name=metric_id, type=BaseMetric) - - @classmethod - def find_task_metrics( - cls, task: TaskType, metric_ids: Set[str] - ) -> List[BaseMetric]: - - if not metric_ids: - return [] - - metrics = [] - for metric in cls.get_task_metrics(task).metrics: - if metric.id in metric_ids: - metrics.append(metric) - return metrics - - -TaskFactory.register_task( - task_type=TaskType.token_classification, - dataset_class=TokenClassificationDatasetDB, - query_request=TokenClassificationQuery, - record_class=TokenClassificationRecord, - metrics=TokenClassificationMetrics, - es_mappings=token_classification_mappings(), -) - -TaskFactory.register_task( - task_type=TaskType.text_classification, - dataset_class=TextClassificationDatasetDB, - query_request=TextClassificationQuery, - record_class=TextClassificationRecord, - metrics=TextClassificationMetrics, - es_mappings=text_classification_mappings(), -) - -TaskFactory.register_task( - task_type=TaskType.text2text, - dataset_class=Text2TextDatasetDB, - query_request=Text2TextQuery, - record_class=Text2TextRecord, - metrics=Text2TextMetrics, - es_mappings=text2text_mappings(), -) diff --git a/src/rubrix/server/apis/v0/handlers/datasets.py b/src/rubrix/server/apis/v0/handlers/datasets.py index 753dc9775b..301a3a3cdb 100644 --- a/src/rubrix/server/apis/v0/handlers/datasets.py +++ b/src/rubrix/server/apis/v0/handlers/datasets.py @@ -17,14 +17,14 @@ from fastapi import APIRouter, Body, Depends, Security -from rubrix.server.apis.v0.config.tasks_factory import TaskFactory -from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams +from rubrix.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies from rubrix.server.apis.v0.models.datasets import ( CopyDatasetRequest, + CreateDatasetRequest, Dataset, - DatasetCreate, UpdateDatasetRequest, ) +from rubrix.server.commons.config import TasksFactory from rubrix.server.errors import EntityNotFoundError from rubrix.server.security import auth from rubrix.server.security.model import User @@ -39,30 +39,16 @@ response_model_exclude_none=True, operation_id="list_datasets", ) -def list_datasets( - ds_params: CommonTaskQueryParams = Depends(), +async def list_datasets( + request_deps: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> List[Dataset]: - """ - List accessible user datasets - - Parameters - ---------- - ds_params: - Common task query params - service: - The datasets service - current_user: - The request user - - Returns - ------- - A list of datasets visible by current user - """ return service.list( user=current_user, - workspaces=[ds_params.workspace] if ds_params.workspace is not None else None, + workspaces=[request_deps.workspace] + if request_deps.workspace is not None + else None, ) @@ -75,24 +61,19 @@ def list_datasets( description="Create a new dataset", ) async def create_dataset( - request: DatasetCreate = Body(..., description=f"The request dataset info"), - ws_params: CommonTaskQueryParams = Depends(), + request: CreateDatasetRequest = Body(..., description=f"The request dataset info"), + ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), user: User = Security(auth.get_user, scopes=["create:datasets"]), ) -> Dataset: owner = user.check_workspace(ws_params.workspace) - dataset_class = TaskFactory.get_task_dataset(request.task) - task_mappings = TaskFactory.get_task_mappings(request.task) - + dataset_class = TasksFactory.get_task_dataset(request.task) dataset = dataset_class.parse_obj({**request.dict()}) dataset.owner = owner - response = datasets.create_dataset( - user=user, dataset=dataset, mappings=task_mappings - ) - + response = datasets.create_dataset(user=user, dataset=dataset) return Dataset.parse_obj(response) @@ -104,34 +85,15 @@ async def create_dataset( ) def get_dataset( name: str, - ds_params: CommonTaskQueryParams = Depends(), + ds_params: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> Dataset: - """ - Find a dataset by name visible for current user - - Parameters - ---------- - name: - The dataset name - ds_params: - Common dataset query params - service: - Datasets service - current_user: - The current user - - Returns - ------- - - The found dataset if accessible or exists. - - EntityNotFoundError if not found. - - NotAuthorizedError if user cannot access the found dataset - - """ return Dataset.parse_obj( service.find_by_name( - user=current_user, name=name, workspace=ds_params.workspace + user=current_user, + name=name, + workspace=ds_params.workspace, ) ) @@ -144,35 +106,11 @@ def get_dataset( ) def update_dataset( name: str, - update_request: UpdateDatasetRequest, - ds_params: CommonTaskQueryParams = Depends(), + request: UpdateDatasetRequest, + ds_params: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> Dataset: - """ - Update a set of parameters for a dataset - - Parameters - ---------- - name: - The dataset name - update_request: - The fields to update - ds_params: - Common dataset query params - service: - The datasets service - current_user: - The current user - - Returns - ------- - - - The updated dataset if exists and user has access. - - EntityNotFoundError if not found. - - NotAuthorizedError if user cannot access the found dataset - - """ found_ds = service.find_by_name( user=current_user, name=name, workspace=ds_params.workspace @@ -181,8 +119,8 @@ def update_dataset( return service.update( user=current_user, dataset=found_ds, - tags=update_request.tags, - metadata=update_request.metadata, + tags=request.tags, + metadata=request.metadata, ) @@ -192,25 +130,10 @@ def update_dataset( ) def delete_dataset( name: str, - ds_params: CommonTaskQueryParams = Depends(), + ds_params: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ): - """ - Deletes a dataset - - Parameters - ---------- - name: - The dataset name - ds_params: - Common dataset query params - service: - The datasets service - current_user: - The current user - - """ try: found_ds = service.find_by_name( user=current_user, name=name, workspace=ds_params.workspace @@ -226,25 +149,10 @@ def delete_dataset( ) def close_dataset( name: str, - ds_params: CommonTaskQueryParams = Depends(), + ds_params: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ): - """ - Closes a dataset. This operation will releases backend resources - - Parameters - ---------- - name: - The dataset name - ds_params: - Common dataset query params - service: - The datasets service - current_user: - The current user - - """ found_ds = service.find_by_name( user=current_user, name=name, workspace=ds_params.workspace ) @@ -257,25 +165,10 @@ def close_dataset( ) def open_dataset( name: str, - ds_params: CommonTaskQueryParams = Depends(), + ds_params: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ): - """ - Closes a dataset. This operation will releases backend resources - - Parameters - ---------- - name: - The dataset name - ds_params: - Common dataset query params - service: - The datasets service - current_user: - The current user - - """ found_ds = service.find_by_name( user=current_user, name=name, workspace=ds_params.workspace ) @@ -291,27 +184,10 @@ def open_dataset( def copy_dataset( name: str, copy_request: CopyDatasetRequest, - ds_params: CommonTaskQueryParams = Depends(), + ds_params: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> Dataset: - """ - Creates a dataset copy and its tags/metadata info - - Parameters - ---------- - name: - The dataset name - copy_request: - The copy request data - ds_params: - Common dataset query params - service: - The datasets service - current_user: - The current user - - """ found = service.find_by_name( user=current_user, name=name, workspace=ds_params.workspace ) diff --git a/src/rubrix/server/apis/v0/handlers/info.py b/src/rubrix/server/apis/v0/handlers/info.py index 30fcfd8c28..69a7ddde75 100644 --- a/src/rubrix/server/apis/v0/handlers/info.py +++ b/src/rubrix/server/apis/v0/handlers/info.py @@ -13,11 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from fastapi import APIRouter, Depends, Security +from fastapi import APIRouter, Depends -from rubrix.server.apis.v0.models.info import ApiInfo, ApiStatus -from rubrix.server.security import auth -from rubrix.server.services.info import ApiInfoService, create_info_service +from rubrix.server.services.info import ApiInfo, ApiInfoService, ApiStatus router = APIRouter(tags=["status"]) @@ -26,29 +24,13 @@ "/_status", operation_id="api_status", response_model=ApiStatus, - dependencies=[Security(auth.get_user, scopes=[])], ) def api_status( - service: ApiInfoService = Depends(create_info_service), + service: ApiInfoService = Depends(ApiInfoService.get_instance), ) -> ApiStatus: - """ - - Parameters - ---------- - service: - The Api info service - - Returns - ------- - - The detailed api status - - """ return service.api_status() @router.get("/_info", operation_id="api_info", response_model=ApiInfo) -def api_info( - service: ApiInfoService = Depends(create_info_service), -) -> ApiInfo: +def api_info(service: ApiInfoService = Depends(ApiInfoService.get_instance)) -> ApiInfo: return ApiInfo.parse_obj(service.api_status()) diff --git a/src/rubrix/server/apis/v0/handlers/metrics.py b/src/rubrix/server/apis/v0/handlers/metrics.py index 8f4040add8..3de73bafdc 100644 --- a/src/rubrix/server/apis/v0/handlers/metrics.py +++ b/src/rubrix/server/apis/v0/handlers/metrics.py @@ -19,13 +19,13 @@ from fastapi import APIRouter, Depends, Query, Security from pydantic import BaseModel, Field -from rubrix.server.apis.v0.config.tasks_factory import TaskConfig, TaskFactory from rubrix.server.apis.v0.handlers import ( text2text, text_classification, token_classification, ) -from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams +from rubrix.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies +from rubrix.server.commons.config import TaskConfig, TasksFactory from rubrix.server.security import auth from rubrix.server.security.model import User from rubrix.server.services.datasets import DatasetsService @@ -33,7 +33,6 @@ class MetricInfo(BaseModel): - """Metric info data model for retrieve dataset metrics information""" id: str = Field(description="The metric id") name: str = Field(description="The metric name") @@ -44,19 +43,6 @@ class MetricInfo(BaseModel): @dataclass class MetricSummaryParams: - """ - For metrics summary calculation, common summary parameters. - - Attributes: - ----------- - - interval: - For histogram summaries, the bucket interval - - size: - For terminological metrics, the number of terms to retrieve - - """ interval: Optional[float] = Query( default=None, @@ -71,17 +57,8 @@ class MetricSummaryParams: def configure_metrics_endpoints(router: APIRouter, cfg: TaskConfig): - """ - Configures an api router with the dataset task metrics endpoints. - Parameters - ---------- - router: - The api router - cfg: - The task configuration model - - """ + # TODO(@frascuchon): Use new api endpoint (/datasets/{name}/{task}/... base_metrics_endpoint = f"/{cfg.task}/{{name}}/metrics" @router.get( @@ -91,41 +68,19 @@ def configure_metrics_endpoints(router: APIRouter, cfg: TaskConfig): ) def get_dataset_metrics( name: str, - teams_query: CommonTaskQueryParams = Depends(), + request_deps: CommonTaskHandlerDependencies = Depends(), current_user: User = Security(auth.get_user, scopes=[]), datasets: DatasetsService = Depends(DatasetsService.get_instance), - metrics: MetricsService = Depends(MetricsService.get_instance), ) -> List[MetricInfo]: - """ - List available metrics info for a given dataset - - Parameters - ---------- - name: - The dataset name - teams_query: - Team query param where dataset belongs to. Optional - current_user: - The current user - datasets: - The datasets service - metrics: - The metrics service - - Returns - ------- - A list of metric info availables for given dataset - - """ dataset = datasets.find_by_name( user=current_user, name=name, task=cfg.task, - workspace=teams_query.workspace, - as_dataset_class=TaskFactory.get_task_dataset(cfg.task), + workspace=request_deps.workspace, + as_dataset_class=TasksFactory.get_task_dataset(cfg.task), ) - metrics = TaskFactory.get_task_metrics(dataset.task) + metrics = TasksFactory.get_task_metrics(dataset.task) metrics = metrics.metrics if metrics else [] return [MetricInfo.parse_obj(metric) for metric in metrics] @@ -140,52 +95,25 @@ def metric_summary( metric: str, query: cfg.query, metric_params: MetricSummaryParams = Depends(), - teams_query: CommonTaskQueryParams = Depends(), + request_deps: CommonTaskHandlerDependencies = Depends(), current_user: User = Security(auth.get_user, scopes=[]), datasets: DatasetsService = Depends(DatasetsService.get_instance), metrics: MetricsService = Depends(MetricsService.get_instance), ): - """ - Summarizes a given metric for a given dataset. - - Parameters - ---------- - name: - The dataset name - metric: - The metric id - query: - A query for records filtering. Optional - metric_params: - Metric parameters for result calculation - teams_query: - Team query param where dataset belongs to. Optional - current_user: - The current user - datasets: - The datasets service - metrics: - The metrics service - - Returns - ------- - The metric summary for a given dataset - - """ dataset = datasets.find_by_name( user=current_user, name=name, task=cfg.task, - workspace=teams_query.workspace, - as_dataset_class=TaskFactory.get_task_dataset(cfg.task), + workspace=request_deps.workspace, + as_dataset_class=TasksFactory.get_task_dataset(cfg.task), ) - metric_ = TaskFactory.find_task_metric(task=cfg.task, metric_id=metric) - record_class = TaskFactory.get_task_record(cfg.task) + metric_ = TasksFactory.find_task_metric(task=cfg.task, metric_id=metric) + record_class = TasksFactory.get_task_record(cfg.task) return metrics.summarize_metric( dataset=dataset, - owner=current_user.check_workspace(teams_query.workspace), + owner=current_user.check_workspace(request_deps.workspace), metric=metric_, record_class=record_class, query=query, @@ -196,8 +124,7 @@ def metric_summary( router = APIRouter() for task_api in [text_classification, token_classification, text2text]: - cfg = TaskFactory.get_task_by_task_type(task_api.TASK_TYPE) + cfg = TasksFactory.get_task_by_task_type(task_api.TASK_TYPE) if cfg: configure_metrics_endpoints(task_api.router, cfg) - router.include_router(task_api.router) diff --git a/src/rubrix/server/apis/v0/handlers/text2text.py b/src/rubrix/server/apis/v0/handlers/text2text.py index a2edd5aaa7..29daec8e12 100644 --- a/src/rubrix/server/apis/v0/handlers/text2text.py +++ b/src/rubrix/server/apis/v0/handlers/text2text.py @@ -19,31 +19,47 @@ from fastapi import APIRouter, Depends, Query, Security from fastapi.responses import StreamingResponse -from rubrix.server.apis.v0.config.tasks_factory import TaskFactory -from rubrix.server.apis.v0.helpers import takeuntil -from rubrix.server.apis.v0.models.commons.model import ( - BulkResponse, - PaginationParams, - TaskType, +from rubrix.server.apis.v0.models.commons.model import BulkResponse +from rubrix.server.apis.v0.models.commons.params import ( + CommonTaskHandlerDependencies, + RequestPagination, ) -from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams from rubrix.server.apis.v0.models.text2text import ( - Text2TextBulkData, + Text2TextBulkRequest, + Text2TextDataset, + Text2TextMetrics, Text2TextQuery, Text2TextRecord, + Text2TextSearchAggregations, Text2TextSearchRequest, Text2TextSearchResults, ) +from rubrix.server.commons.config import TasksFactory +from rubrix.server.commons.models import TaskType from rubrix.server.errors import EntityNotFoundError +from rubrix.server.helpers import takeuntil from rubrix.server.responses import StreamingResponseWithErrorHandling from rubrix.server.security import auth from rubrix.server.security.model import User from rubrix.server.services.datasets import DatasetsService -from rubrix.server.services.text2text import Text2TextService, text2text_service +from rubrix.server.services.tasks.text2text import Text2TextService +from rubrix.server.services.tasks.text2text.models import ( + ServiceText2TextQuery, + ServiceText2TextRecord, +) TASK_TYPE = TaskType.text2text BASE_ENDPOINT = "/{name}/" + TASK_TYPE +TasksFactory.register_task( + task_type=TaskType.text2text, + dataset_class=Text2TextDataset, + query_request=Text2TextQuery, + record_class=ServiceText2TextRecord, + metrics=Text2TextMetrics, +) + + router = APIRouter(tags=[TASK_TYPE], prefix="/datasets") @@ -55,37 +71,14 @@ ) def bulk_records( name: str, - bulk: Text2TextBulkData, - common_params: CommonTaskQueryParams = Depends(), - service: Text2TextService = Depends(text2text_service), + bulk: Text2TextBulkRequest, + common_params: CommonTaskHandlerDependencies = Depends(), + service: Text2TextService = Depends(Text2TextService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> BulkResponse: - """ - Includes a chunk of record data with provided dataset bulk information - - Parameters - ---------- - name: - The dataset name - bulk: - The bulk data - common_params: - Common task query params - service: - the Service - datasets: - The dataset service - current_user: - Current request user - - Returns - ------- - Bulk response data - """ task = TASK_TYPE - task_mappings = TaskFactory.get_task_mappings(TASK_TYPE) owner = current_user.check_workspace(common_params.workspace) try: dataset = datasets.find_by_name( @@ -93,7 +86,7 @@ def bulk_records( name=name, task=task, workspace=owner, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) datasets.update( user=current_user, @@ -102,19 +95,14 @@ def bulk_records( metadata=bulk.metadata, ) except EntityNotFoundError: - dataset_class = TaskFactory.get_task_dataset(task) + dataset_class = TasksFactory.get_task_dataset(task) dataset = dataset_class.parse_obj({**bulk.dict(), "name": name}) dataset.owner = owner - - datasets.create_dataset( - user=current_user, dataset=dataset, mappings=task_mappings - ) + datasets.create_dataset(user=current_user, dataset=dataset) result = service.add_records( dataset=dataset, - mappings=task_mappings, - records=bulk.records, - metrics=TaskFactory.get_task_metrics(TASK_TYPE), + records=[ServiceText2TextRecord.parse_obj(r) for r in bulk.records], ) return BulkResponse( dataset=name, @@ -132,42 +120,15 @@ def bulk_records( def search_records( name: str, search: Text2TextSearchRequest = None, - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), include_metrics: bool = Query( False, description="If enabled, return related record metrics" ), - pagination: PaginationParams = Depends(), - service: Text2TextService = Depends(text2text_service), + pagination: RequestPagination = Depends(), + service: Text2TextService = Depends(Text2TextService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> Text2TextSearchResults: - """ - Searches data from dataset - - Parameters - ---------- - name: - The dataset name - common_params: - The task common query params - include_metrics: - Flag to include metrics in results - search: - THe search query request - pagination: - The pagination params - service: - The dataset records service - datasets: - The dataset service - current_user: - The current request user - - Returns - ------- - The search results data - - """ search = search or Text2TextSearchRequest() query = search.query or Text2TextQuery() @@ -176,29 +137,24 @@ def search_records( name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) result = service.search( dataset=dataset, - query=query, + query=ServiceText2TextQuery.parse_obj(query), sort_by=search.sort, record_from=pagination.from_, size=pagination.limit, exclude_metrics=not include_metrics, - metrics=TaskFactory.find_task_metrics( - TASK_TYPE, - metric_ids={ - "words_cloud", - "predicted_by", - "annotated_by", - "status_distribution", - "metadata", - "score", - }, - ), ) - return result + return Text2TextSearchResults( + total=result.total, + records=[Text2TextRecord.parse_obj(r) for r in result.records], + aggregations=Text2TextSearchAggregations.parse_obj(result.metrics) + if result.metrics + else None, + ) def scan_data_response( @@ -241,15 +197,15 @@ def grouper(n, iterable, fillvalue=None): async def stream_data( name: str, query: Optional[Text2TextQuery] = None, - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), limit: Optional[int] = Query(None, description="Limit loaded records", gt=0), - service: Text2TextService = Depends(text2text_service), + service: Text2TextService = Depends(Text2TextService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), id_from: Optional[str] = None ) -> StreamingResponse: """ - Creates a data stream over dataset records + Creates a data stream over dataset records Parameters ---------- @@ -277,10 +233,12 @@ async def stream_data( name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), + ) + data_stream = map( + Text2TextRecord.parse_obj, + service.read_dataset(dataset, query=ServiceText2TextQuery.parse_obj(query), id_from=id_from, limit=limit), ) - data_stream = service.read_dataset(dataset, query=query, id_from=id_from, limit=limit) - return scan_data_response( data_stream=data_stream, limit=limit, diff --git a/src/rubrix/server/apis/v0/handlers/text_classification.py b/src/rubrix/server/apis/v0/handlers/text_classification.py index 976646c1a3..35b0a134ab 100644 --- a/src/rubrix/server/apis/v0/handlers/text_classification.py +++ b/src/rubrix/server/apis/v0/handlers/text_classification.py @@ -19,39 +19,64 @@ from fastapi import APIRouter, Depends, Query, Security from fastapi.responses import StreamingResponse -from rubrix.server.apis.v0.config.tasks_factory import TaskFactory from rubrix.server.apis.v0.handlers import text_classification_dataset_settings -from rubrix.server.apis.v0.helpers import takeuntil -from rubrix.server.apis.v0.models.commons.model import ( - BulkResponse, - PaginationParams, - TaskType, +from rubrix.server.apis.v0.models.commons.model import BulkResponse +from rubrix.server.apis.v0.models.commons.params import ( + CommonTaskHandlerDependencies, + RequestPagination, ) -from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams from rubrix.server.apis.v0.models.text_classification import ( CreateLabelingRule, DatasetLabelingRulesMetricsSummary, LabelingRule, LabelingRuleMetricsSummary, - TextClassificationBulkData, + TextClassificationBulkRequest, TextClassificationQuery, TextClassificationRecord, + TextClassificationSearchAggregations, TextClassificationSearchRequest, TextClassificationSearchResults, UpdateLabelingRule, ) +from rubrix.server.apis.v0.models.token_classification import ( + TokenClassificationDataset, + TokenClassificationQuery, +) from rubrix.server.apis.v0.validators.text_classification import DatasetValidator +from rubrix.server.commons.config import TasksFactory +from rubrix.server.commons.models import TaskType from rubrix.server.errors import EntityNotFoundError +from rubrix.server.helpers import takeuntil from rubrix.server.responses import StreamingResponseWithErrorHandling from rubrix.server.security import auth from rubrix.server.security.model import User from rubrix.server.services.datasets import DatasetsService -from rubrix.server.services.text_classification import TextClassificationService +from rubrix.server.services.tasks.text_classification import TextClassificationService +from rubrix.server.services.tasks.text_classification.model import ( + ServiceLabelingRule, + ServiceTextClassificationQuery, + ServiceTextClassificationRecord, +) +from rubrix.server.services.tasks.token_classification.metrics import ( + TokenClassificationMetrics, +) +from rubrix.server.services.tasks.token_classification.model import ( + ServiceTokenClassificationRecord, +) TASK_TYPE = TaskType.text_classification BASE_ENDPOINT = "/{name}/" + TASK_TYPE NEW_BASE_ENDPOINT = f"/{TASK_TYPE}/{{name}}" +TasksFactory.register_task( + task_type=TaskType.token_classification, + dataset_class=TokenClassificationDataset, + query_request=TokenClassificationQuery, + record_class=ServiceTokenClassificationRecord, + metrics=TokenClassificationMetrics, +) + + router = APIRouter(tags=[TASK_TYPE], prefix="/datasets") @@ -63,8 +88,8 @@ ) async def bulk_records( name: str, - bulk: TextClassificationBulkData, - common_params: CommonTaskQueryParams = Depends(), + bulk: TextClassificationBulkRequest, + common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends( TextClassificationService.get_instance ), @@ -72,33 +97,8 @@ async def bulk_records( validator: DatasetValidator = Depends(DatasetValidator.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> BulkResponse: - """ - Includes a chunk of record data with provided dataset bulk information - - Parameters - ---------- - name: - The dataset name - bulk: - The bulk data - common_params: - Common query params - service: - the Service - datasets: - The dataset service - validator: - The dataset validator component - current_user: - Current request user - - Returns - ------- - Bulk response data - """ task = TASK_TYPE - task_mappings = TaskFactory.get_task_mappings(TASK_TYPE) owner = current_user.check_workspace(common_params.workspace) try: dataset = datasets.find_by_name( @@ -106,7 +106,7 @@ async def bulk_records( name=name, task=task, workspace=owner, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) datasets.update( user=current_user, @@ -115,22 +115,20 @@ async def bulk_records( metadata=bulk.metadata, ) except EntityNotFoundError: - dataset_class = TaskFactory.get_task_dataset(task) + dataset_class = TasksFactory.get_task_dataset(task) dataset = dataset_class.parse_obj({**bulk.dict(), "name": name}) dataset.owner = owner - datasets.create_dataset( - user=current_user, dataset=dataset, mappings=task_mappings - ) + datasets.create_dataset(user=current_user, dataset=dataset) + # TODO(@frascuchon): Validator should be applied in the service layer + records = [ServiceTextClassificationRecord.parse_obj(r) for r in bulk.records] await validator.validate_dataset_records( - user=current_user, dataset=dataset, records=bulk.records + user=current_user, dataset=dataset, records=records ) result = service.add_records( dataset=dataset, - mappings=task_mappings, - records=bulk.records, - metrics=TaskFactory.get_task_metrics(TASK_TYPE), + records=records, ) return BulkResponse( dataset=name, @@ -148,11 +146,11 @@ async def bulk_records( def search_records( name: str, search: TextClassificationSearchRequest = None, - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), include_metrics: bool = Query( False, description="If enabled, return related record metrics" ), - pagination: PaginationParams = Depends(), + pagination: RequestPagination = Depends(), service: TextClassificationService = Depends( TextClassificationService.get_instance ), @@ -195,32 +193,24 @@ def search_records( name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) result = service.search( dataset=dataset, - query=query, + query=ServiceTextClassificationQuery.parse_obj(query), sort_by=search.sort, record_from=pagination.from_, size=pagination.limit, exclude_metrics=not include_metrics, - metrics=TaskFactory.find_task_metrics( - TASK_TYPE, - metric_ids={ - "words_cloud", - "predicted_by", - "predicted_as", - "annotated_by", - "annotated_as", - "error_distribution", - "status_distribution", - "metadata", - "score", - }, - ), ) - return result + return TextClassificationSearchResults( + total=result.total, + records=result.records, + aggregations=TextClassificationSearchAggregations.parse_obj(result.metrics) + if result.metrics + else None, + ) def scan_data_response( @@ -263,7 +253,7 @@ def grouper(n, iterable, fillvalue=None): async def stream_data( name: str, query: Optional[TextClassificationQuery] = None, - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), id_from: Optional[str] = None, limit: Optional[int] = Query(None, description="Limit loaded records", gt=0), service: TextClassificationService = Depends( @@ -273,7 +263,7 @@ async def stream_data( current_user: User = Security(auth.get_user, scopes=[]), ) -> StreamingResponse: """ - Creates a data stream over dataset records + Creates a data stream over dataset records Parameters ---------- @@ -302,10 +292,15 @@ async def stream_data( name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) - data_stream = service.read_dataset(dataset, query=query, id_from=id_from, limit=limit) + data_stream = map( + TextClassificationRecord.parse_obj, + service.read_dataset( + dataset, query=ServiceTextClassificationQuery.parse_obj(query), id_from=id_from, limit=limit + ), + ) return scan_data_response( data_stream=data_stream, limit=limit, @@ -321,7 +316,7 @@ async def stream_data( ) async def list_labeling_rules( name: str, - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), service: TextClassificationService = Depends( TextClassificationService.get_instance @@ -334,10 +329,12 @@ async def list_labeling_rules( name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) - return list(service.get_labeling_rules(dataset)) + return [ + LabelingRule.parse_obj(rule) for rule in service.get_labeling_rules(dataset) + ] @router.post( @@ -350,7 +347,7 @@ async def list_labeling_rules( async def create_rule( name: str, rule: CreateLabelingRule, - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends( TextClassificationService.get_instance ), @@ -363,10 +360,10 @@ async def create_rule( name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) - rule = LabelingRule( + rule = ServiceLabelingRule( **rule.dict(), author=current_user.username, ) @@ -374,8 +371,7 @@ async def create_rule( dataset, rule=rule, ) - - return rule + return LabelingRule.parse_obj(rule) @router.get( @@ -391,19 +387,20 @@ async def compute_rule_metrics( labels: Optional[List[str]] = Query( None, description="Label related to query rule", alias="label" ), - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends( TextClassificationService.get_instance ), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> LabelingRuleMetricsSummary: + dataset = datasets.find_by_name( user=current_user, name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) return service.compute_rule_metrics(dataset, rule_query=query, labels=labels) @@ -418,7 +415,7 @@ async def compute_rule_metrics( ) async def compute_dataset_rules_metrics( name: str, - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends( TextClassificationService.get_instance ), @@ -430,10 +427,10 @@ async def compute_dataset_rules_metrics( name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) - - return service.compute_overall_rules_metrics(dataset) + metrics = service.compute_overall_rules_metrics(dataset) + return DatasetLabelingRulesMetricsSummary.parse_obj(metrics) @router.delete( @@ -444,7 +441,7 @@ async def compute_dataset_rules_metrics( async def delete_labeling_rule( name: str, query: str, - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends( TextClassificationService.get_instance ), @@ -457,7 +454,7 @@ async def delete_labeling_rule( name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) service.delete_labeling_rule(dataset, rule_query=query) @@ -473,7 +470,7 @@ async def delete_labeling_rule( async def get_rule( name: str, query: str, - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends( TextClassificationService.get_instance ), @@ -486,14 +483,13 @@ async def get_rule( name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) - rule = service.find_labeling_rule( dataset, rule_query=query, ) - return rule + return LabelingRule.parse_obj(rule) @router.patch( @@ -507,7 +503,7 @@ async def update_rule( name: str, query: str, update: UpdateLabelingRule, - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends( TextClassificationService.get_instance ), @@ -520,7 +516,7 @@ async def update_rule( name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) rule = service.update_labeling_rule( @@ -529,7 +525,7 @@ async def update_rule( labels=update.labels, description=update.description, ) - return rule + return LabelingRule.parse_obj(rule) text_classification_dataset_settings.configure_router(router) diff --git a/src/rubrix/server/apis/v0/handlers/text_classification_dataset_settings.py b/src/rubrix/server/apis/v0/handlers/text_classification_dataset_settings.py index 88d0f6323a..a2d0b840f8 100644 --- a/src/rubrix/server/apis/v0/handlers/text_classification_dataset_settings.py +++ b/src/rubrix/server/apis/v0/handlers/text_classification_dataset_settings.py @@ -2,18 +2,20 @@ from fastapi import APIRouter, Body, Depends, Security -from rubrix.server.apis.v0.models.commons.model import TaskType -from rubrix.server.apis.v0.models.commons.params import DATASET_NAME_PATH_PARAM -from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams +from rubrix.server.apis.v0.models.commons.params import ( + DATASET_NAME_PATH_PARAM, + CommonTaskHandlerDependencies, +) from rubrix.server.apis.v0.models.dataset_settings import TextClassificationSettings from rubrix.server.apis.v0.validators.text_classification import DatasetValidator +from rubrix.server.commons.models import TaskType from rubrix.server.security import auth from rubrix.server.security.model import User -from rubrix.server.services.datasets import DatasetsService, SVCDatasetSettings +from rubrix.server.services.datasets import DatasetsService, ServiceBaseDatasetSettings -__svc_settings_class__: Type[SVCDatasetSettings] = type( +__svc_settings_class__: Type[ServiceBaseDatasetSettings] = type( f"{TaskType.text_classification}_DatasetSettings", - (SVCDatasetSettings, TextClassificationSettings), + (ServiceBaseDatasetSettings, TextClassificationSettings), {}, ) @@ -33,7 +35,7 @@ def configure_router(router: APIRouter): ) async def get_dataset_settings( name: str = DATASET_NAME_PATH_PARAM, - ws_params: CommonTaskQueryParams = Depends(), + ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), user: User = Security(auth.get_user, scopes=["read:dataset.settings"]), ) -> TextClassificationSettings: @@ -63,7 +65,7 @@ async def save_settings( ..., description=f"The {task} dataset settings" ), name: str = DATASET_NAME_PATH_PARAM, - ws_params: CommonTaskQueryParams = Depends(), + ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), user: User = Security(auth.get_user, scopes=["write:dataset.settings"]), @@ -93,7 +95,7 @@ async def save_settings( ) async def delete_settings( name: str = DATASET_NAME_PATH_PARAM, - ws_params: CommonTaskQueryParams = Depends(), + ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), user: User = Security(auth.get_user, scopes=["delete:dataset.settings"]), ) -> None: diff --git a/src/rubrix/server/apis/v0/handlers/token_classification.py b/src/rubrix/server/apis/v0/handlers/token_classification.py index ed56beff43..363b638c56 100644 --- a/src/rubrix/server/apis/v0/handlers/token_classification.py +++ b/src/rubrix/server/apis/v0/handlers/token_classification.py @@ -19,36 +19,57 @@ from fastapi import APIRouter, Depends, Query, Security from fastapi.responses import StreamingResponse -from rubrix.server.apis.v0.config.tasks_factory import TaskFactory from rubrix.server.apis.v0.handlers import token_classification_dataset_settings -from rubrix.server.apis.v0.helpers import takeuntil -from rubrix.server.apis.v0.models.commons.model import ( - BulkResponse, - PaginationParams, - TaskType, +from rubrix.server.apis.v0.models.commons.model import BulkResponse +from rubrix.server.apis.v0.models.commons.params import ( + CommonTaskHandlerDependencies, + RequestPagination, +) +from rubrix.server.apis.v0.models.text_classification import ( + TextClassificationDataset, + TextClassificationQuery, ) -from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams from rubrix.server.apis.v0.models.token_classification import ( - TokenClassificationBulkData, + TokenClassificationAggregations, + TokenClassificationBulkRequest, TokenClassificationQuery, TokenClassificationRecord, TokenClassificationSearchRequest, TokenClassificationSearchResults, ) from rubrix.server.apis.v0.validators.token_classification import DatasetValidator +from rubrix.server.commons.config import TasksFactory +from rubrix.server.commons.models import TaskType from rubrix.server.errors import EntityNotFoundError +from rubrix.server.helpers import takeuntil from rubrix.server.responses import StreamingResponseWithErrorHandling from rubrix.server.security import auth from rubrix.server.security.model import User from rubrix.server.services.datasets import DatasetsService -from rubrix.server.services.token_classification import ( - TokenClassificationService, - token_classification_service, +from rubrix.server.services.tasks.text_classification.metrics import ( + TextClassificationMetrics, +) +from rubrix.server.services.tasks.text_classification.model import ( + ServiceTextClassificationRecord, +) +from rubrix.server.services.tasks.token_classification import TokenClassificationService +from rubrix.server.services.tasks.token_classification.model import ( + ServiceTokenClassificationQuery, + ServiceTokenClassificationRecord, ) TASK_TYPE = TaskType.token_classification BASE_ENDPOINT = "/{name}/" + TASK_TYPE +TasksFactory.register_task( + task_type=TaskType.text_classification, + dataset_class=TextClassificationDataset, + query_request=TextClassificationQuery, + record_class=ServiceTextClassificationRecord, + metrics=TextClassificationMetrics, +) + + router = APIRouter(tags=[TASK_TYPE], prefix="/datasets") @@ -60,38 +81,17 @@ ) async def bulk_records( name: str, - bulk: TokenClassificationBulkData, - common_params: CommonTaskQueryParams = Depends(), - service: TokenClassificationService = Depends(token_classification_service), + bulk: TokenClassificationBulkRequest, + common_params: CommonTaskHandlerDependencies = Depends(), + service: TokenClassificationService = Depends( + TokenClassificationService.get_instance + ), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> BulkResponse: - """ - Includes a chunk of record data with provided dataset bulk information - - Parameters - ---------- - name: - The dataset name - bulk: - The bulk data - common_params: - Common query params - service: - the Service - datasets: - The dataset service - current_user: - Current request user - - Returns - ------- - Bulk response data - """ task = TASK_TYPE - task_mappings = TaskFactory.get_task_mappings(TASK_TYPE) owner = current_user.check_workspace(common_params.workspace) try: dataset = datasets.find_by_name( @@ -99,7 +99,7 @@ async def bulk_records( name=name, task=task, workspace=owner, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) datasets.update( user=current_user, @@ -108,24 +108,22 @@ async def bulk_records( metadata=bulk.metadata, ) except EntityNotFoundError: - dataset_class = TaskFactory.get_task_dataset(task) + dataset_class = TasksFactory.get_task_dataset(task) dataset = dataset_class.parse_obj({**bulk.dict(), "name": name}) dataset.owner = owner - datasets.create_dataset( - user=current_user, dataset=dataset, mappings=task_mappings - ) + datasets.create_dataset(user=current_user, dataset=dataset) + records = [ServiceTokenClassificationRecord.parse_obj(r) for r in bulk.records] + # TODO(@frascuchon): validator can be applied in service layer await validator.validate_dataset_records( user=current_user, dataset=dataset, - records=bulk.records, + records=records, ) result = service.add_records( dataset=dataset, - mappings=task_mappings, - records=bulk.records, - metrics=TaskFactory.get_task_metrics(TASK_TYPE), + records=records, ) return BulkResponse( dataset=name, @@ -143,42 +141,17 @@ async def bulk_records( def search_records( name: str, search: TokenClassificationSearchRequest = None, - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), include_metrics: bool = Query( False, description="If enabled, return related record metrics" ), - pagination: PaginationParams = Depends(), - service: TokenClassificationService = Depends(token_classification_service), + pagination: RequestPagination = Depends(), + service: TokenClassificationService = Depends( + TokenClassificationService.get_instance + ), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> TokenClassificationSearchResults: - """ - Searches data from dataset - - Parameters - ---------- - name: - The dataset name - search: - THe search query request - common_params: - Common query params - include_metrics: - include metrics flag - pagination: - The pagination params - service: - The dataset records service - datasets: - The dataset service - current_user: - The current request user - - Returns - ------- - The search results data - - """ search = search or TokenClassificationSearchRequest() query = search.query or TokenClassificationQuery() @@ -188,34 +161,24 @@ def search_records( name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), ) - result = service.search( + results = service.search( dataset=dataset, - query=query, + query=ServiceTokenClassificationQuery.parse_obj(query), sort_by=search.sort, record_from=pagination.from_, size=pagination.limit, exclude_metrics=not include_metrics, - metrics=TaskFactory.find_task_metrics( - TASK_TYPE, - metric_ids={ - "words_cloud", - "predicted_by", - "predicted_as", - "annotated_by", - "annotated_as", - "error_distribution", - "predicted_mentions_distribution", - "annotated_mentions_distribution", - "status_distribution", - "metadata", - "score", - }, - ), ) - return result + return TokenClassificationSearchResults( + total=results.total, + records=[TokenClassificationRecord.parse_obj(r) for r in results.records], + aggregations=TokenClassificationAggregations.parse_obj(results.metrics) + if results.metrics + else None, + ) def scan_data_response( @@ -258,15 +221,17 @@ def grouper(n, iterable, fillvalue=None): async def stream_data( name: str, query: Optional[TokenClassificationQuery] = None, - common_params: CommonTaskQueryParams = Depends(), + common_params: CommonTaskHandlerDependencies = Depends(), limit: Optional[int] = Query(None, description="Limit loaded records", gt=0), - service: TokenClassificationService = Depends(token_classification_service), + service: TokenClassificationService = Depends( + TokenClassificationService.get_instance + ), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), id_from: Optional[str] = None, ) -> StreamingResponse: """ - Creates a data stream over dataset records + Creates a data stream over dataset records Parameters ---------- @@ -294,9 +259,14 @@ async def stream_data( name=name, task=TASK_TYPE, workspace=common_params.workspace, - as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE), + as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE), + ) + data_stream = map( + TokenClassificationRecord.parse_obj, + service.read_dataset( + dataset=dataset, query=ServiceTokenClassificationQuery.parse_obj(query), id_from=id_from, limit=limit + ), ) - data_stream = service.read_dataset(dataset=dataset, query=query, id_from=id_from, limit=limit) return scan_data_response( data_stream=data_stream, diff --git a/src/rubrix/server/apis/v0/handlers/token_classification_dataset_settings.py b/src/rubrix/server/apis/v0/handlers/token_classification_dataset_settings.py index 955f1aa41a..14cfb7d572 100644 --- a/src/rubrix/server/apis/v0/handlers/token_classification_dataset_settings.py +++ b/src/rubrix/server/apis/v0/handlers/token_classification_dataset_settings.py @@ -2,18 +2,20 @@ from fastapi import APIRouter, Body, Depends, Security -from rubrix.server.apis.v0.models.commons.model import TaskType -from rubrix.server.apis.v0.models.commons.params import DATASET_NAME_PATH_PARAM -from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams +from rubrix.server.apis.v0.models.commons.params import ( + DATASET_NAME_PATH_PARAM, + CommonTaskHandlerDependencies, +) from rubrix.server.apis.v0.models.dataset_settings import TokenClassificationSettings from rubrix.server.apis.v0.validators.token_classification import DatasetValidator +from rubrix.server.commons.models import TaskType from rubrix.server.security import auth from rubrix.server.security.model import User -from rubrix.server.services.datasets import DatasetsService, SVCDatasetSettings +from rubrix.server.services.datasets import DatasetsService, ServiceBaseDatasetSettings -__svc_settings_class__: Type[SVCDatasetSettings] = type( +__svc_settings_class__: Type[ServiceBaseDatasetSettings] = type( f"{TaskType.token_classification}_DatasetSettings", - (SVCDatasetSettings, TokenClassificationSettings), + (ServiceBaseDatasetSettings, TokenClassificationSettings), {}, ) @@ -32,7 +34,7 @@ def configure_router(router: APIRouter): ) async def get_dataset_settings( name: str = DATASET_NAME_PATH_PARAM, - ws_params: CommonTaskQueryParams = Depends(), + ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), user: User = Security(auth.get_user, scopes=["read:dataset.settings"]), ) -> TokenClassificationSettings: @@ -62,7 +64,7 @@ async def save_settings( ..., description=f"The {task} dataset settings" ), name: str = DATASET_NAME_PATH_PARAM, - ws_params: CommonTaskQueryParams = Depends(), + ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), user: User = Security(auth.get_user, scopes=["write:dataset.settings"]), @@ -92,7 +94,7 @@ async def save_settings( ) async def delete_settings( name: str = DATASET_NAME_PATH_PARAM, - ws_params: CommonTaskQueryParams = Depends(), + ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), user: User = Security(auth.get_user, scopes=["delete:dataset.settings"]), ) -> None: diff --git a/src/rubrix/server/apis/v0/models/commons/model.py b/src/rubrix/server/apis/v0/models/commons/model.py index d3ce747ae8..5eb98a3a80 100644 --- a/src/rubrix/server/apis/v0/models/commons/model.py +++ b/src/rubrix/server/apis/v0/models/commons/model.py @@ -17,98 +17,59 @@ Common model for task definitions """ -from dataclasses import dataclass -from typing import Any, Dict, Generic, TypeVar +from typing import Any, Dict, Generic, List, TypeVar -from fastapi import Query -from pydantic import validator +from pydantic import BaseModel, Field from pydantic.generics import GenericModel -from rubrix._constants import MAX_KEYWORD_LENGTH -from rubrix.server.apis.v0.helpers import flatten_dict from rubrix.server.services.search.model import ( - BaseSearchResults, - BaseSearchResultsAggregations, - QueryRange, - SortableField, + ServiceQueryRange, + ServiceSearchResultsAggregations, + ServiceSortableField, ) from rubrix.server.services.tasks.commons import ( - Annotation, - BaseAnnotation, - BaseRecordDB, - BulkResponse, - EsRecordDataFieldNames, - PredictionStatus, - TaskStatus, - TaskType, + ServiceBaseAnnotation, + ServiceBaseRecord, + ServiceBaseRecordInputs, ) -from rubrix.utils import limit_value_length -@dataclass -class PaginationParams: - """Query pagination params""" +class SortableField(ServiceSortableField): + pass - limit: int = Query(50, gte=0, le=1000, description="Response records limit") - from_: int = Query( - 0, ge=0, le=10000, alias="from", description="Record sequence from" - ) +class BaseAnnotation(ServiceBaseAnnotation): + pass -class BaseRecord(BaseRecordDB, GenericModel, Generic[Annotation]): - """ - Minimal dataset record information - Attributes: - ----------- +Annotation = TypeVar("Annotation", bound=BaseAnnotation) - id: - The record id - metadata: - The metadata related to record - event_timestamp: - The timestamp when record event was triggered - """ +class BaseRecordInputs(ServiceBaseRecordInputs[Annotation], Generic[Annotation]): + def extended_fields(self) -> Dict[str, Any]: + return {} - @validator("metadata", pre=True) - def flatten_metadata(cls, metadata: Dict[str, Any]): - """ - A fastapi validator for flatten metadata dictionary - Parameters - ---------- - metadata: - The metadata dictionary +class BaseRecord(ServiceBaseRecord[Annotation], Generic[Annotation]): + pass - Returns - ------- - A flatten version of metadata dictionary - """ - if metadata: - metadata = flatten_dict(metadata, drop_empty=True) - metadata = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH) - return metadata +class ScoreRange(ServiceQueryRange): + pass -Record = TypeVar("Record", bound=BaseRecord) +_Record = TypeVar("_Record", bound=BaseRecord) -class ScoreRange(QueryRange): - pass +class BulkResponse(BaseModel): + dataset: str + processed: int + failed: int = 0 -__ALL__ = [ - QueryRange, - SortableField, - BaseSearchResults, - BaseSearchResultsAggregations, - Annotation, - TaskStatus, - TaskType, - EsRecordDataFieldNames, - BaseAnnotation, - PredictionStatus, - BulkResponse, -] +class BaseSearchResults( + GenericModel, Generic[_Record, ServiceSearchResultsAggregations] +): + total: int = 0 + records: List[_Record] = Field(default_factory=list) + aggregations: ServiceSearchResultsAggregations = None diff --git a/src/rubrix/server/apis/v0/models/commons/params.py b/src/rubrix/server/apis/v0/models/commons/params.py index b224d92fa5..7700e8f134 100644 --- a/src/rubrix/server/apis/v0/models/commons/params.py +++ b/src/rubrix/server/apis/v0/models/commons/params.py @@ -1,7 +1,46 @@ -from fastapi import Path +from dataclasses import dataclass -from rubrix._constants import DATASET_NAME_REGEX_PATTERN +from fastapi import Header, Path, Query + +from rubrix._constants import DATASET_NAME_REGEX_PATTERN, RUBRIX_WORKSPACE_HEADER_NAME +from rubrix.server.security.model import WORKSPACE_NAME_PATTERN DATASET_NAME_PATH_PARAM = Path( ..., regex=DATASET_NAME_REGEX_PATTERN, description="The dataset name" ) + + +@dataclass +class RequestPagination: + """Query pagination params""" + + limit: int = Query(50, gte=0, le=1000, description="Response records limit") + from_: int = Query( + 0, ge=0, le=10000, alias="from", description="Record sequence from" + ) + + +@dataclass +class CommonTaskHandlerDependencies: + """Common task query dependencies""" + + # TODO(@frascuchon): we could include the request user and parametrize the action scopes + # Depends(CommonTaskHandlerDependencies.create(scopes=[...]) + + __workspace_header__: str = Header(None, alias=RUBRIX_WORKSPACE_HEADER_NAME) + __workspace_param__: str = Query( + None, + alias="workspace", + description="The workspace where dataset belongs to. If not provided default user team will be used", + ) + + @property + def workspace(self) -> str: + """Return read workspace. Query param prior to header param""" + workspace = self.__workspace_param__ or self.__workspace_header__ + if workspace: + assert WORKSPACE_NAME_PATTERN.match(workspace), ( + "Wrong workspace format. " + f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}" + ) + return workspace diff --git a/src/rubrix/server/apis/v0/models/commons/workspace.py b/src/rubrix/server/apis/v0/models/commons/workspace.py deleted file mode 100644 index d2abf609f6..0000000000 --- a/src/rubrix/server/apis/v0/models/commons/workspace.py +++ /dev/null @@ -1,30 +0,0 @@ -from dataclasses import dataclass - -from fastapi import Header, Query - -from rubrix._constants import RUBRIX_WORKSPACE_HEADER_NAME -from rubrix.server.security.model import WORKSPACE_NAME_PATTERN - - -@dataclass -class CommonTaskQueryParams: - """Common task query params""" - - __workspace_param__: str = Query( - None, - alias="workspace", - description="The workspace where dataset belongs to. If not provided default user team will be used", - ) - - __workspace_header__: str = Header(None, alias=RUBRIX_WORKSPACE_HEADER_NAME) - - @property - def workspace(self) -> str: - """Return read workspace. Query param prior to header param""" - workspace = self.__workspace_param__ or self.__workspace_header__ - if workspace: - assert WORKSPACE_NAME_PATTERN.match(workspace), ( - "Wrong workspace format. " - f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}" - ) - return workspace diff --git a/src/rubrix/server/apis/v0/models/datasets.py b/src/rubrix/server/apis/v0/models/datasets.py index f9f5586e4d..427210059f 100644 --- a/src/rubrix/server/apis/v0/models/datasets.py +++ b/src/rubrix/server/apis/v0/models/datasets.py @@ -17,14 +17,13 @@ Dataset models definition """ -from datetime import datetime from typing import Any, Dict, Optional from pydantic import BaseModel, Field from rubrix._constants import DATASET_NAME_REGEX_PATTERN -from rubrix.server.apis.v0.models.commons.model import TaskType -from rubrix.server.services.datasets import DatasetDB as SVCDataset +from rubrix.server.commons.models import TaskType +from rubrix.server.services.datasets import ServiceBaseDataset class UpdateDatasetRequest(BaseModel): @@ -43,15 +42,15 @@ class UpdateDatasetRequest(BaseModel): metadata: Dict[str, Any] = Field(default_factory=dict) -class CreationDatasetRequest(UpdateDatasetRequest): +class _BaseDatasetRequest(UpdateDatasetRequest): name: str = Field(regex=DATASET_NAME_REGEX_PATTERN, description="The dataset name") -class DatasetCreate(CreationDatasetRequest): +class CreateDatasetRequest(_BaseDatasetRequest): task: TaskType = Field(description="The dataset task") -class CopyDatasetRequest(CreationDatasetRequest): +class CopyDatasetRequest(_BaseDatasetRequest): """ Request body for copy dataset operation """ @@ -59,7 +58,7 @@ class CopyDatasetRequest(CreationDatasetRequest): target_workspace: Optional[str] = None -class BaseDatasetDB(CreationDatasetRequest, SVCDataset): +class Dataset(_BaseDatasetRequest, ServiceBaseDataset): """ Low level dataset data model @@ -76,13 +75,3 @@ class BaseDatasetDB(CreationDatasetRequest, SVCDataset): """ task: TaskType - - -class DatasetDB(BaseDatasetDB): - pass - - -class Dataset(BaseDatasetDB): - """Dataset used for response output""" - - pass diff --git a/src/rubrix/server/apis/v0/models/info.py b/src/rubrix/server/apis/v0/models/info.py deleted file mode 100644 index fb3a7485f9..0000000000 --- a/src/rubrix/server/apis/v0/models/info.py +++ /dev/null @@ -1,31 +0,0 @@ -# coding=utf-8 -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict - -from pydantic import BaseModel - - -class ApiInfo(BaseModel): - """Basic api info""" - - rubrix_version: str - - -class ApiStatus(ApiInfo): - """The Rubrix api status model""" - - elasticsearch: Dict[str, Any] - mem_info: Dict[str, Any] diff --git a/src/rubrix/server/apis/v0/models/metrics/base.py b/src/rubrix/server/apis/v0/models/metrics/base.py deleted file mode 100644 index 8bae2ab238..0000000000 --- a/src/rubrix/server/apis/v0/models/metrics/base.py +++ /dev/null @@ -1,304 +0,0 @@ -from typing import ( - Any, - ClassVar, - Dict, - Generic, - Iterable, - List, - Optional, - TypeVar, - Union, -) - -from pydantic import BaseModel, root_validator - -from rubrix.server._helpers import unflatten_dict -from rubrix.server.apis.v0.models.commons.model import BaseRecord -from rubrix.server.apis.v0.models.datasets import Dataset -from rubrix.server.daos.records import DatasetRecordsDAO -from rubrix.server.elasticseach.query_helpers import aggregations - -GenericRecord = TypeVar("GenericRecord", bound=BaseRecord) - - -class BaseMetric(BaseModel): - """ - Base model for rubrix dataset metrics summaries - """ - - id: str - name: str - description: str = None - - -class PythonMetric(BaseMetric, Generic[GenericRecord]): - """ - A metric definition which will be calculated using raw queried data - """ - - def apply(self, records: Iterable[GenericRecord]) -> Dict[str, Any]: - """ - Metric calculation method. - - Parameters - ---------- - records: - The matched records - - Returns - ------- - The metric result - """ - raise NotImplementedError() - - -class ElasticsearchMetric(BaseMetric): - """ - A metric summarized by using one or several elasticsearch aggregations - """ - - def aggregation_request( - self, *args, **kwargs - ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: - """ - Configures the summary es aggregation definition - """ - raise NotImplementedError() - - def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: - """ - Parse the es aggregation result. Override this method - for result customization - - Parameters - ---------- - aggregation_result: - Retrieved es aggregation result - - """ - return aggregation_result.get(self.id, aggregation_result) - - -class NestedPathElasticsearchMetric(ElasticsearchMetric): - """ - A ``ElasticsearchMetric`` which need nested fields for summary calculation. - - Aggregations for nested fields need some extra configuration and this class - encapsulate these common logic. - - Attributes: - ----------- - nested_path: - The nested - """ - - nested_path: str - - def inner_aggregation(self, *args, **kwargs) -> Dict[str, Any]: - """The specific aggregation definition""" - raise NotImplementedError() - - def aggregation_request(self, *args, **kwargs) -> Dict[str, Any]: - """Implements the common mechanism to define aggregations with nested fields""" - return { - self.id: aggregations.nested_aggregation( - nested_path=self.nested_path, - inner_aggregation=self.inner_aggregation(*args, **kwargs), - ) - } - - def compound_nested_field(self, inner_field: str) -> str: - return f"{self.nested_path}.{inner_field}" - - -class BaseTaskMetrics(BaseModel): - """ - Base class encapsulating related task metrics - - Attributes: - ----------- - - metrics: - A list of configured metrics for task - """ - - metrics: ClassVar[List[BaseMetric]] - - @classmethod - def configure_es_index(cls): - """ - If some metrics require specific es field mapping definitions, - include them here. - - """ - pass - - @classmethod - def find_metric(cls, id: str) -> Optional[BaseMetric]: - """ - Finds a metric by id - - Parameters - ---------- - id: - The metric id - - Returns - ------- - Found metric if any, ``None`` otherwise - - """ - for metric in cls.metrics: - if metric.id == id: - return metric - - @classmethod - def record_metrics(cls, record: GenericRecord) -> Dict[str, Any]: - """ - Use this method is some configured metric requires additional - records fields. - - Generated records will be persisted under ``metrics`` record path. - For example, if you define a field called ``sentence_length`` like - - >>> def record_metrics(cls, record)-> Dict[str, Any]: - ... return { "sentence_length" : len(record.text) } - - The new field will be stored in elasticsearch in ``metrics.sentence_length`` - - Parameters - ---------- - record: - The record used for calculate metrics fields - - Returns - ------- - A dict with calculated metrics fields - """ - return {} - - -class HistogramAggregation(ElasticsearchMetric): - """ - Base elasticsearch histogram aggregation metric - - Attributes - ---------- - field: - The histogram field - script: - If provided, it will be used as scripted field - for aggregation - fixed_interval: - If provided, it will used ALWAYS as the histogram - aggregation interval - """ - - field: str - script: Optional[Union[str, Dict[str, Any]]] = None - fixed_interval: Optional[float] = None - - def aggregation_request(self, interval: Optional[float] = None) -> Dict[str, Any]: - if self.fixed_interval: - interval = self.fixed_interval - return { - self.id: aggregations.histogram_aggregation( - field_name=self.field, script=self.script, interval=interval - ) - } - - -class TermsAggregation(ElasticsearchMetric): - """ - The base elasticsearch terms aggregation metric - - Attributes - ---------- - - field: - The term field - script: - If provided, it will be used as scripted field - for aggregation - fixed_size: - If provided, the size will use for terms aggregation - missing: - If provided, will use the value for docs results with missing value for field - - """ - - field: str = None - script: Union[str, Dict[str, Any]] = None - fixed_size: Optional[int] = None - missing: Optional[str] = None - - def aggregation_request(self, size: int = None) -> Dict[str, Any]: - if self.fixed_size: - size = self.fixed_size - return { - self.id: aggregations.terms_aggregation( - self.field, script=self.script, size=size, missing=self.missing - ) - } - - -class NestedTermsAggregation(NestedPathElasticsearchMetric): - terms: TermsAggregation - - @root_validator - def normalize_terms_field(cls, values): - terms = values["terms"] - nested_path = values["nested_path"] - terms.field = f"{nested_path}.{terms.field}" - - return values - - def inner_aggregation(self, size: int) -> Dict[str, Any]: - return self.terms.aggregation_request(size) - - -class NestedHistogramAggregation(NestedPathElasticsearchMetric): - histogram: HistogramAggregation - - @root_validator - def normalize_terms_field(cls, values): - histogram = values["histogram"] - nested_path = values["nested_path"] - histogram.field = f"{nested_path}.{histogram.field}" - - return values - - def inner_aggregation(self, interval: float) -> Dict[str, Any]: - return self.histogram.aggregation_request(interval) - - -class WordCloudAggregation(ElasticsearchMetric): - default_field: str - - def aggregation_request( - self, text_field: str = None, size: int = None - ) -> Dict[str, Any]: - field = text_field or self.default_field - return TermsAggregation( - id=f"{self.id}_{field}" if text_field else self.id, - name=f"Words cloud for field {field}", - field=field, - ).aggregation_request(size=size) - - -class MetadataAggregations(ElasticsearchMetric): - def aggregation_request( - self, - dataset: Dataset, - dao: DatasetRecordsDAO, - size: int = None, - ) -> List[Dict[str, Any]]: - - metadata_aggs = aggregations.custom_fields( - fields_definitions=dao.get_metadata_schema(dataset), size=size - ) - return [{key: value} for key, value in metadata_aggs.items()] - - def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: - data = unflatten_dict(aggregation_result, stop_keys=["metadata"]) - return data.get("metadata", {}) diff --git a/src/rubrix/server/apis/v0/models/metrics/commons.py b/src/rubrix/server/apis/v0/models/metrics/commons.py deleted file mode 100644 index 214bf5e15c..0000000000 --- a/src/rubrix/server/apis/v0/models/metrics/commons.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Any, ClassVar, Dict, Generic, List - -from rubrix.server.apis.v0.models.commons.model import ( - EsRecordDataFieldNames, - TaskStatus, -) -from rubrix.server.apis.v0.models.metrics.base import ( - BaseMetric, - BaseTaskMetrics, - GenericRecord, - HistogramAggregation, - MetadataAggregations, - TermsAggregation, - WordCloudAggregation, -) - - -class CommonTasksMetrics(BaseTaskMetrics, Generic[GenericRecord]): - """Common task metrics""" - - @classmethod - def record_metrics(cls, record: GenericRecord) -> Dict[str, Any]: - """Record metrics will persist the text_length""" - return {"text_length": len(record.all_text())} - - metrics: ClassVar[List[BaseMetric]] = [ - HistogramAggregation( - id="text_length", - name="Text length distribution", - description="Computes the input text length distribution", - field="metrics.text_length", - script="params._source.text.length()", - fixed_interval=1, - ), - TermsAggregation( - id="error_distribution", - name="Error distribution", - description="Computes the dataset error distribution. It's mean, records " - "with correct predictions vs records with incorrect prediction " - "vs records with unknown prediction result", - field=EsRecordDataFieldNames.predicted, - missing="unknown", - fixed_size=3, - ), - TermsAggregation( - id="status_distribution", - name="Record status distribution", - description="The dataset record status distribution", - field=EsRecordDataFieldNames.status, - fixed_size=len(TaskStatus), - ), - WordCloudAggregation( - id="words_cloud", - name="Inputs words cloud", - description="The words cloud for dataset inputs", - default_field="text.wordcloud", - ), - MetadataAggregations(id="metadata", name="Metadata fields stats"), - TermsAggregation( - id="predicted_by", - name="Predicted by distribution", - field="predicted_by", - ), - TermsAggregation( - id="annotated_by", - name="Annotated by distribution", - field="annotated_by", - ), - HistogramAggregation( - id="score", - name="Score record distribution", - field="score", - fixed_interval=0.001, - ), - ] diff --git a/src/rubrix/server/apis/v0/models/metrics/token_classification.py b/src/rubrix/server/apis/v0/models/metrics/token_classification.py deleted file mode 100644 index 80c9abdee6..0000000000 --- a/src/rubrix/server/apis/v0/models/metrics/token_classification.py +++ /dev/null @@ -1,643 +0,0 @@ -from typing import Any, ClassVar, Dict, Iterable, List, Optional, Set, Tuple - -from pydantic import BaseModel, Field - -from rubrix.server.apis.v0.models.metrics.base import ( - BaseMetric, - ElasticsearchMetric, - HistogramAggregation, - NestedHistogramAggregation, - NestedPathElasticsearchMetric, - NestedTermsAggregation, - PythonMetric, - TermsAggregation, -) -from rubrix.server.apis.v0.models.metrics.commons import CommonTasksMetrics -from rubrix.server.apis.v0.models.token_classification import ( - EntitySpan, - TokenClassificationRecord, -) -from rubrix.server.elasticseach.query_helpers import aggregations - - -class TokensLength(ElasticsearchMetric): - """ - Summarizes the tokens length metric into an histogram - - Attributes: - ----------- - length_field: - The elasticsearch field where tokens length is stored - """ - - length_field: str - - def aggregation_request(self, interval: int) -> Dict[str, Any]: - return { - self.id: aggregations.histogram_aggregation( - self.length_field, interval=interval or 1 - ) - } - - -_DEFAULT_MAX_ENTITY_BUCKET = 1000 - - -class EntityLabels(NestedPathElasticsearchMetric): - """ - Computes the entity labels distribution - - Attributes: - ----------- - labels_field: - The elasticsearch field where tags are stored - """ - - labels_field: str - - def inner_aggregation(self, size: int) -> Dict[str, Any]: - return { - "labels": aggregations.terms_aggregation( - self.compound_nested_field(self.labels_field), - size=size or _DEFAULT_MAX_ENTITY_BUCKET, - ) - } - - -class EntityDensity(NestedPathElasticsearchMetric): - """Summarizes the entity density metric into an histogram""" - - density_field: str - - def inner_aggregation(self, interval: float) -> Dict[str, Any]: - return { - "density": aggregations.histogram_aggregation( - field_name=self.compound_nested_field(self.density_field), - interval=interval or 0.01, - ) - } - - -class MentionLength(NestedPathElasticsearchMetric): - """Summarizes the mention length into an histogram""" - - length_field: str - - def inner_aggregation(self, interval: int) -> Dict[str, Any]: - return { - "mention_length": aggregations.histogram_aggregation( - self.compound_nested_field(self.length_field), interval=interval or 1 - ) - } - - -class EntityCapitalness(NestedPathElasticsearchMetric): - """Computes the mention capitalness distribution""" - - capitalness_field: str - - def inner_aggregation(self) -> Dict[str, Any]: - return { - "capitalness": aggregations.terms_aggregation( - self.compound_nested_field(self.capitalness_field), - size=4, # The number of capitalness choices - ) - } - - -class MentionsByEntityDistribution(NestedPathElasticsearchMetric): - def inner_aggregation(self): - return { - self.id: aggregations.bidimentional_terms_aggregations( - field_name_x=f"{self.nested_path}.label", - field_name_y=f"{self.nested_path}.value", - ) - } - - -class EntityConsistency(NestedPathElasticsearchMetric): - """Computes the entity consistency distribution""" - - mention_field: str - labels_field: str - - def inner_aggregation( - self, - size: int, - interval: int = 2, - entity_size: int = _DEFAULT_MAX_ENTITY_BUCKET, - ) -> Dict[str, Any]: - size = size or 50 - interval = int(max(interval or 2, 2)) - return { - "consistency": { - **aggregations.terms_aggregation( - self.compound_nested_field(self.mention_field), size=size - ), - "aggs": { - "entities": aggregations.terms_aggregation( - self.compound_nested_field(self.labels_field), size=entity_size - ), - "count": { - "cardinality": { - "field": self.compound_nested_field(self.labels_field) - } - }, - "entities_variability_filter": { - "bucket_selector": { - "buckets_path": {"numLabels": "count"}, - "script": f"params.numLabels >= {interval}", - } - }, - "sortby_entities_count": { - "bucket_sort": { - "sort": [{"count": {"order": "desc"}}], - "size": size, - } - }, - }, - } - } - - def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: - """Simplifies the aggregation result sorting by worst mention consistency""" - result = [ - { - "mention": mention, - "entities": [ - {"label": entity, "count": count} - for entity, count in mention_aggs["entities"].items() - ], - } - for mention, mention_aggs in aggregation_result.items() - ] - # TODO: filter by entities threshold - result.sort(key=lambda m: len(m["entities"]), reverse=True) - return {"mentions": result} - - -class F1Metric(PythonMetric[TokenClassificationRecord]): - """The F1 metric based on entity-level. - - We follow the convention of `CoNLL 2003 `_, where: - `"precision is the percentage of named entities found by the learning system that are correct. - Recall is the percentage of named entities present in the corpus that are found by the system. - A named entity is correct only if it is an exact match (...).”` - """ - - def apply(self, records: Iterable[TokenClassificationRecord]) -> Dict[str, Any]: - # store entities per label in dicts - predicted_entities = {} - annotated_entities = {} - - # extract entities per label to dicts - for rec in records: - if rec.prediction: - self._add_entities_to_dict(rec.prediction.entities, predicted_entities) - if rec.annotation: - self._add_entities_to_dict(rec.annotation.entities, annotated_entities) - - # store precision, recall, and f1 per label - per_label_metrics = {} - - annotated_total, predicted_total, correct_total = 0, 0, 0 - precision_macro, recall_macro = 0, 0 - for label, annotated in annotated_entities.items(): - predicted = predicted_entities.get(label, set()) - correct = len(annotated & predicted) - - # safe divides are used to cover the 0/0 cases - precision = self._safe_divide(correct, len(predicted)) - recall = self._safe_divide(correct, len(annotated)) - per_label_metrics.update( - { - f"{label}_precision": precision, - f"{label}_recall": recall, - f"{label}_f1": self._safe_divide( - 2 * precision * recall, precision + recall - ), - } - ) - - annotated_total += len(annotated) - predicted_total += len(predicted) - correct_total += correct - - precision_macro += precision / len(annotated_entities) - recall_macro += recall / len(annotated_entities) - - # store macro and micro averaged precision, recall and f1 - averaged_metrics = { - "precision_macro": precision_macro, - "recall_macro": recall_macro, - "f1_macro": self._safe_divide( - 2 * precision_macro * recall_macro, precision_macro + recall_macro - ), - } - - precision_micro = self._safe_divide(correct_total, predicted_total) - recall_micro = self._safe_divide(correct_total, annotated_total) - averaged_metrics.update( - { - "precision_micro": precision_micro, - "recall_micro": recall_micro, - "f1_micro": self._safe_divide( - 2 * precision_micro * recall_micro, precision_micro + recall_micro - ), - } - ) - - return {**averaged_metrics, **per_label_metrics} - - @staticmethod - def _add_entities_to_dict( - entities: List[EntitySpan], dictionary: Dict[str, Set[Tuple[int, int]]] - ): - """Helper function for the apply method.""" - for ent in entities: - try: - dictionary[ent.label].add((ent.start, ent.end)) - except KeyError: - dictionary[ent.label] = {(ent.start, ent.end)} - - @staticmethod - def _safe_divide(numerator, denominator): - """Helper function for the apply method.""" - try: - return numerator / denominator - except ZeroDivisionError: - return 0 - - -class DatasetLabels(PythonMetric): - id: str = Field("dataset_labels", const=True) - name: str = Field("The dataset entity labels", const=True) - max_processed_records: int = 10000 - - def apply(self, records: Iterable[TokenClassificationRecord]) -> Dict[str, Any]: - ds_labels = set() - - for _ in range( - 0, self.max_processed_records - ): # Only a few of records will be parsed - record: TokenClassificationRecord = next(records, None) - if record is None: - break - - if record.annotation: - ds_labels.update( - [entity.label for entity in record.annotation.entities] - ) - if record.prediction: - ds_labels.update( - [entity.label for entity in record.prediction.entities] - ) - return {"labels": ds_labels or []} - - -class MentionMetrics(BaseModel): - """Mention metrics model""" - - value: str - label: str - score: float = Field(ge=0.0) - capitalness: Optional[str] = Field(None) - density: float = Field(ge=0.0) - tokens_length: int = Field(g=0) - chars_length: int = Field(g=0) - - -class TokenTagMetrics(BaseModel): - value: str - tag: str - - -class TokenMetrics(BaseModel): - """ - Token metrics stored in elasticsearch for token classification - - Attributes - idx: The token index in sentence - value: The token textual value - char_start: The token character start position in sentence - char_end: The token character end position in sentence - score: Token score info - tag: Token tag info. Deprecated: Use metrics.predicted.tags or metrics.annotated.tags instead - custom: extra token level info - """ - - idx: int - value: str - char_start: int - char_end: int - length: int - capitalness: Optional[str] = None - score: Optional[float] = None - tag: Optional[str] = None # TODO: remove! - custom: Dict[str, Any] = None - - -class TokenClassificationMetrics(CommonTasksMetrics[TokenClassificationRecord]): - """Configured metrics for token classification""" - - _PREDICTED_NAMESPACE = "metrics.predicted" - _ANNOTATED_NAMESPACE = "metrics.annotated" - - _PREDICTED_MENTIONS_NAMESPACE = f"{_PREDICTED_NAMESPACE}.mentions" - _ANNOTATED_MENTIONS_NAMESPACE = f"{_ANNOTATED_NAMESPACE}.mentions" - - _PREDICTED_TAGS_NAMESPACE = f"{_PREDICTED_NAMESPACE}.tags" - _ANNOTATED_TAGS_NAMESPACE = f"{_ANNOTATED_NAMESPACE}.tags" - - _TOKENS_NAMESPACE = "metrics.tokens" - - @staticmethod - def density(value: int, sentence_length: int) -> float: - """Compute the string density over a sentence""" - return value / sentence_length - - @staticmethod - def capitalness(value: str) -> Optional[str]: - """Compute capitalness for a string value""" - value = value.strip() - if not value: - return None - if value.isupper(): - return "UPPER" - if value.islower(): - return "LOWER" - if value[0].isupper(): - return "FIRST" - if any([c.isupper() for c in value[1:]]): - return "MIDDLE" - return None - - @staticmethod - def mentions_metrics( - record: TokenClassificationRecord, mentions: List[Tuple[str, EntitySpan]] - ): - def mention_tokens_length(entity: EntitySpan) -> int: - """Compute mention tokens length""" - return len( - set( - [ - token_idx - for i in range(entity.start, entity.end) - for token_idx in [record.char_id2token_id(i)] - if token_idx is not None - ] - ) - ) - - return [ - MentionMetrics( - value=mention, - label=entity.label, - score=entity.score, - capitalness=TokenClassificationMetrics.capitalness(mention), - density=TokenClassificationMetrics.density( - _tokens_length, sentence_length=len(record.tokens) - ), - tokens_length=_tokens_length, - chars_length=len(mention), - ) - for mention, entity in mentions - for _tokens_length in [ - mention_tokens_length(entity), - ] - ] - - @classmethod - def build_tokens_metrics( - cls, record: TokenClassificationRecord, tags: Optional[List[str]] = None - ) -> List[TokenMetrics]: - - return [ - TokenMetrics( - idx=token_idx, - value=token_value, - char_start=char_start, - char_end=char_end, - capitalness=cls.capitalness(token_value), - length=1 + (char_end - char_start), - tag=tags[token_idx] if tags else None, - ) - for token_idx, token_value in enumerate(record.tokens) - for char_start, char_end in [record.token_span(token_idx)] - ] - - @classmethod - def record_metrics(cls, record: TokenClassificationRecord) -> Dict[str, Any]: - """Compute metrics at record level""" - base_metrics = super(TokenClassificationMetrics, cls).record_metrics(record) - - annotated_tags = record.annotated_iob_tags() or [] - predicted_tags = record.predicted_iob_tags() or [] - - tokens_metrics = cls.build_tokens_metrics( - record, predicted_tags or annotated_tags - ) - return { - **base_metrics, - "tokens": tokens_metrics, - "tokens_length": len(record.tokens), - "predicted": { - "mentions": cls.mentions_metrics(record, record.predicted_mentions()), - "tags": [ - TokenTagMetrics(tag=tag, value=token) - for tag, token in zip(predicted_tags, record.tokens) - ], - }, - "annotated": { - "mentions": cls.mentions_metrics(record, record.annotated_mentions()), - "tags": [ - TokenTagMetrics(tag=tag, value=token) - for tag, token in zip(annotated_tags, record.tokens) - ], - }, - } - - _TOKENS_METRICS = [ - TokensLength( - id="tokens_length", - name="Tokens length", - description="Computes the text length distribution measured in number of tokens", - length_field="metrics.tokens_length", - ), - NestedTermsAggregation( - id="token_frequency", - name="Tokens frequency distribution", - nested_path=_TOKENS_NAMESPACE, - terms=TermsAggregation( - id="frequency", - field="value", - name="", - ), - ), - NestedHistogramAggregation( - id="token_length", - name="Token length distribution", - nested_path=_TOKENS_NAMESPACE, - description="Computes token length distribution in number of characters", - histogram=HistogramAggregation( - id="length", - field="length", - name="", - fixed_interval=1, - ), - ), - NestedTermsAggregation( - id="token_capitalness", - name="Token capitalness distribution", - description="Computes capitalization information of tokens", - nested_path=_TOKENS_NAMESPACE, - terms=TermsAggregation( - id="capitalness", - field="capitalness", - name="", - # missing="OTHER", - ), - ), - ] - _PREDICTED_METRICS = [ - EntityDensity( - id="predicted_entity_density", - name="Mention entity density for predictions", - description="Computes the ratio between the number of all entity tokens and tokens in the text", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - density_field="density", - ), - EntityLabels( - id="predicted_entity_labels", - name="Predicted entity labels", - description="Predicted entity labels distribution", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - labels_field="label", - ), - EntityCapitalness( - id="predicted_entity_capitalness", - name="Mention entity capitalness for predictions", - description="Computes capitalization information of predicted entity mentions", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - capitalness_field="capitalness", - ), - MentionLength( - id="predicted_mention_token_length", - name="Predicted mention tokens length", - description="Computes the length of the predicted entity mention measured in number of tokens", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - length_field="tokens_length", - ), - MentionLength( - id="predicted_mention_char_length", - name="Predicted mention characters length", - description="Computes the length of the predicted entity mention measured in number of tokens", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - length_field="chars_length", - ), - MentionsByEntityDistribution( - id="predicted_mentions_distribution", - name="Predicted mentions distribution by entity", - description="Computes predicted mentions distribution against its labels", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - ), - EntityConsistency( - id="predicted_entity_consistency", - name="Entity label consistency for predictions", - description="Computes entity label variability for top-k predicted entity mentions", - nested_path=_PREDICTED_MENTIONS_NAMESPACE, - mention_field="value", - labels_field="label", - ), - EntityConsistency( - id="predicted_tag_consistency", - name="Token tag consistency for predictions", - description="Computes token tag variability for top-k predicted tags", - nested_path=_PREDICTED_TAGS_NAMESPACE, - mention_field="value", - labels_field="tag", - ), - ] - - _ANNOTATED_METRICS = [ - EntityDensity( - id="annotated_entity_density", - name="Mention entity density for annotations", - description="Computes the ratio between the number of all entity tokens and tokens in the text", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - density_field="density", - ), - EntityLabels( - id="annotated_entity_labels", - name="Annotated entity labels", - description="Annotated Entity labels distribution", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - labels_field="label", - ), - EntityCapitalness( - id="annotated_entity_capitalness", - name="Mention entity capitalness for annotations", - description="Compute capitalization information of annotated entity mentions", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - capitalness_field="capitalness", - ), - MentionLength( - id="annotated_mention_token_length", - name="Annotated mention tokens length", - description="Computes the length of the entity mention measured in number of tokens", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - length_field="tokens_length", - ), - MentionLength( - id="annotated_mention_char_length", - name="Annotated mention characters length", - description="Computes the length of the entity mention measured in number of tokens", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - length_field="chars_length", - ), - MentionsByEntityDistribution( - id="annotated_mentions_distribution", - name="Annotated mentions distribution by entity", - description="Computes annotated mentions distribution against its labels", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - ), - EntityConsistency( - id="annotated_entity_consistency", - name="Entity label consistency for annotations", - description="Computes entity label variability for top-k annotated entity mentions", - nested_path=_ANNOTATED_MENTIONS_NAMESPACE, - mention_field="value", - labels_field="label", - ), - EntityConsistency( - id="annotated_tag_consistency", - name="Token tag consistency for annotations", - description="Computes token tag variability for top-k annotated tags", - nested_path=_ANNOTATED_TAGS_NAMESPACE, - mention_field="value", - labels_field="tag", - ), - ] - - metrics: ClassVar[List[BaseMetric]] = CommonTasksMetrics.metrics + [ - TermsAggregation( - id="predicted_as", - name="Predicted labels distribution", - field="predicted_as", - ), - TermsAggregation( - id="annotated_as", - name="Annotated labels distribution", - field="annotated_as", - ), - *_TOKENS_METRICS, - *_PREDICTED_METRICS, - *_ANNOTATED_METRICS, - DatasetLabels(), - F1Metric( - id="F1", - name="F1 Metric based on entity-level", - description="F1 metrics based on entity-level (averaged and per label), " - "where only exact matches count (CoNNL 2003).", - ), - ] diff --git a/src/rubrix/server/apis/v0/models/text2text.py b/src/rubrix/server/apis/v0/models/text2text.py index 1761e9d1ef..5497c2c4eb 100644 --- a/src/rubrix/server/apis/v0/models/text2text.py +++ b/src/rubrix/server/apis/v0/models/text2text.py @@ -14,51 +14,34 @@ # limitations under the License. from datetime import datetime -from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional from pydantic import BaseModel, Field, validator from rubrix.server.apis.v0.models.commons.model import ( BaseAnnotation, BaseRecord, + BaseRecordInputs, BaseSearchResults, - BaseSearchResultsAggregations, - EsRecordDataFieldNames, - PredictionStatus, ScoreRange, SortableField, - TaskType, ) -from rubrix.server.apis.v0.models.datasets import DatasetDB, UpdateDatasetRequest -from rubrix.server.apis.v0.models.metrics.commons import CommonTasksMetrics -from rubrix.server.services.search.model import BaseSearchQuery - - -class ExtendedEsRecordDataFieldNames(str, Enum): - text_predicted = "text_predicted" - text_annotated = "text_annotated" +from rubrix.server.apis.v0.models.datasets import UpdateDatasetRequest +from rubrix.server.commons.models import PredictionStatus, TaskType +from rubrix.server.services.metrics.models import CommonTasksMetrics +from rubrix.server.services.search.model import ( + ServiceBaseRecordsQuery, + ServiceBaseSearchResultsAggregations, +) +from rubrix.server.services.tasks.text2text.models import ServiceText2TextDataset class Text2TextPrediction(BaseModel): - """Represents a text prediction/annotation and its score""" - text: str score: float = Field(default=1.0, ge=0.0, le=1.0) class Text2TextAnnotation(BaseAnnotation): - """ - Annotation class for text2text tasks - - Attributes: - ----------- - - sentences: str - List of sentence predictions/annotations - - """ - @validator("sentences") def sort_sentences_by_score(cls, sentences: List[Text2TextPrediction]): """Sort provided sentences by score desc""" @@ -67,168 +50,25 @@ def sort_sentences_by_score(cls, sentences: List[Text2TextPrediction]): sentences: List[Text2TextPrediction] -class CreationText2TextRecord(BaseRecord[Text2TextAnnotation]): - """ - Text2Text record - - Attributes: - ----------- - - text: - The input data text - """ +class Text2TextRecordInputs(BaseRecordInputs[Text2TextAnnotation]): text: str - @classmethod - def task(cls) -> TaskType: - """The task type""" - return TaskType.text2text - - def all_text(self) -> str: - return self.text - - @property - def predicted_as(self) -> Optional[List[str]]: - return ( - [sentence.text for sentence in self.prediction.sentences] - if self.prediction - else None - ) - - @property - def annotated_as(self) -> Optional[List[str]]: - return ( - [sentence.text for sentence in self.annotation.sentences] - if self.annotation - else None - ) - - @property - def scores(self) -> List[float]: - """Values of prediction scores""" - if not self.prediction: - return [] - return [sentence.score for sentence in self.prediction.sentences] - - @validator("text") - def validate_text(cls, text: Dict[str, Any]): - """Applies validation over input text""" - assert len(text) > 0, "No text provided" - return text - - -class Text2TextRecord(CreationText2TextRecord): - """ - The output text2text task record - - Attributes: - ----------- - - last_updated: datetime - Last record update (read only) - predicted: Optional[PredictionStatus] - The record prediction status. Optional - """ - - last_updated: datetime = None - _predicted: Optional[PredictionStatus] = Field(alias="predicted") - - def extended_fields(self): - return {} - - -class Text2TextRecordDB(Text2TextRecord): - """The db text2text task record""" - - def extended_fields(self) -> Dict[str, Any]: - return { - EsRecordDataFieldNames.annotated_as: self.annotated_as, - EsRecordDataFieldNames.predicted_as: self.predicted_as, - EsRecordDataFieldNames.annotated_by: self.annotated_by, - EsRecordDataFieldNames.predicted_by: self.predicted_by, - EsRecordDataFieldNames.score: self.scores, - EsRecordDataFieldNames.words: self.all_text(), - } - - -class Text2TextBulkData(UpdateDatasetRequest): - """ - API bulk data for text2text - - Attributes: - ----------- - records: List[CreationText2TextRecord] - The text2text record list - - """ - - records: List[CreationText2TextRecord] - - -class Text2TextQuery(BaseSearchQuery): - """ - API Filters for text2text - - Attributes: - ----------- - ids: Optional[List[Union[str, int]]] - Record ids list - - query_text: str - Text query over input text - - annotated_by: List[str] - List of annotation agents - predicted_by: List[str] - List of predicted agents - - status: List[TaskStatus] - List of task status +class Text2TextRecord(Text2TextRecordInputs, BaseRecord[Text2TextAnnotation]): + pass - metadata: Optional[Dict[str, Union[str, List[str]]]] - Text query over metadata fields. Default=None - predicted: Optional[PredictionStatus] - The task prediction status +class Text2TextBulkRequest(UpdateDatasetRequest): + records: List[Text2TextRecordInputs] - """ +class Text2TextQuery(ServiceBaseRecordsQuery): score: Optional[ScoreRange] = Field(default=None) predicted: Optional[PredictionStatus] = Field(default=None, nullable=True) -class Text2TextSearchRequest(BaseModel): - """ - API SearchRequest request - - Attributes: - ----------- - - query: Text2TextQuery - The search query configuration - - sort: - The sort order list - """ - - query: Text2TextQuery = Field(default_factory=Text2TextQuery) - sort: List[SortableField] = Field(default_factory=list) - - -class Text2TextSearchAggregations(BaseSearchResultsAggregations): - """ - Extends base aggregation with predicted and annotated text - - Attributes: - ----------- - predicted_text: Dict[str, int] - The word cloud aggregations for predicted text - annotated_text: Dict[str, int] - The word cloud aggregations for annotated text - """ - +class Text2TextSearchAggregations(ServiceBaseSearchResultsAggregations): predicted_text: Dict[str, int] = Field(default_factory=dict) annotated_text: Dict[str, int] = Field(default_factory=dict) @@ -239,14 +79,14 @@ class Text2TextSearchResults( pass -class Text2TextDatasetDB(DatasetDB): - task: TaskType = Field(default=TaskType.text2text, const=True) +class Text2TextDataset(ServiceText2TextDataset): pass class Text2TextMetrics(CommonTasksMetrics[Text2TextRecord]): - """ - Configured metrics for text2text task - """ - pass + + +class Text2TextSearchRequest(BaseModel): + query: Text2TextQuery = Field(default_factory=Text2TextQuery) + sort: List[SortableField] = Field(default_factory=list) diff --git a/src/rubrix/server/apis/v0/models/text_classification.py b/src/rubrix/server/apis/v0/models/text_classification.py index e62d87474f..1c2d388a70 100644 --- a/src/rubrix/server/apis/v0/models/text_classification.py +++ b/src/rubrix/server/apis/v0/models/text_classification.py @@ -14,26 +14,39 @@ # limitations under the License. from datetime import datetime -from typing import Any, ClassVar, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field, root_validator, validator -from rubrix._constants import MAX_KEYWORD_LENGTH -from rubrix.server.apis.v0.helpers import flatten_dict from rubrix.server.apis.v0.models.commons.model import ( - BaseAnnotation, BaseRecord, + BaseRecordInputs, BaseSearchResults, - BaseSearchResultsAggregations, - EsRecordDataFieldNames, - PredictionStatus, ScoreRange, SortableField, - TaskStatus, - TaskType, ) -from rubrix.server.apis.v0.models.datasets import DatasetDB, UpdateDatasetRequest -from rubrix.server.services.search.model import BaseSearchQuery +from rubrix.server.apis.v0.models.datasets import UpdateDatasetRequest +from rubrix.server.commons.models import PredictionStatus +from rubrix.server.services.search.model import ( + ServiceBaseRecordsQuery, + ServiceBaseSearchResultsAggregations, +) +from rubrix.server.services.tasks.text_classification.model import ( + DatasetLabelingRulesMetricsSummary as _DatasetLabelingRulesMetricsSummary, +) +from rubrix.server.services.tasks.text_classification.model import ( + LabelingRuleMetricsSummary as _LabelingRuleMetricsSummary, +) +from rubrix.server.services.tasks.text_classification.model import ( + ServiceTextClassificationDataset, +) +from rubrix.server.services.tasks.text_classification.model import ( + ServiceTextClassificationQuery as _TextClassificationQuery, +) +from rubrix.server.services.tasks.text_classification.model import ( + TextClassificationAnnotation as _TextClassificationAnnotation, +) +from rubrix.server.services.tasks.text_classification.model import TokenAttributions class UpdateLabelingRule(BaseModel): @@ -63,22 +76,6 @@ def initialize_labels(cls, values): class CreateLabelingRule(UpdateLabelingRule): - """ - Data model for labeling rules creation - - Attributes: - ----------- - - query: - The ES query of the rule - - label: str - The label associated with the rule - - description: - A brief description of the rule - - """ query: str = Field(description="The es rule query") @@ -89,387 +86,44 @@ def strip_query(cls, query: str) -> str: class LabelingRule(CreateLabelingRule): - """ - Adds read-only attributes to the labeling rule - - Attributes: - ----------- - - author: - Who created the rule - - created_at: - When was the rule created - - """ - author: str = Field(description="User who created the rule") created_at: Optional[datetime] = Field( default_factory=datetime.utcnow, description="Rule creation timestamp" ) -class LabelingRuleMetricsSummary(BaseModel): - """Metrics generated for a labeling rule""" - - coverage: Optional[float] = None - coverage_annotated: Optional[float] = None - correct: Optional[float] = None - incorrect: Optional[float] = None - precision: Optional[float] = None - - total_records: int - annotated_records: int - - -class DatasetLabelingRulesMetricsSummary(BaseModel): - coverage: Optional[float] = None - coverage_annotated: Optional[float] = None - - total_records: int - annotated_records: int - - -class TextClassificationDatasetDB(DatasetDB): - """ - A dataset class specialized for text classification task - - Attributes: - ----------- - - rules: - A list of dataset labeling rules - """ - - task: TaskType = Field(default=TaskType.text_classification, const=True) - - rules: List[LabelingRule] = Field(default_factory=list) - - -class ClassPrediction(BaseModel): - """ - Single class prediction - - Attributes: - ----------- - - class_label: Union[str, int] - the predicted class - - score: float - the predicted class score. For human-supervised annotations, - this probability should be 1.0 - """ - - class_label: Union[str, int] = Field(alias="class") - score: float = Field(default=1.0, ge=0.0, le=1.0) - - @validator("class_label") - def check_label_length(cls, class_label): - if isinstance(class_label, str): - assert 1 <= len(class_label) <= MAX_KEYWORD_LENGTH, ( - f"Class name '{class_label}' exceeds max length of {MAX_KEYWORD_LENGTH}" - if len(class_label) > MAX_KEYWORD_LENGTH - else f"Class name must not be empty" - ) - return class_label - - # See - class Config: - allow_population_by_field_name = True - - -class TextClassificationAnnotation(BaseAnnotation): - """ - Annotation class for text classification tasks - - Attributes: - ----------- - - labels: List[LabelPrediction] - list of annotated labels with score - """ - - labels: List[ClassPrediction] - - @validator("labels") - def sort_labels(cls, labels: List[ClassPrediction]): - """Sort provided labels by score""" - return sorted(labels, key=lambda x: x.score, reverse=True) - - -class TokenAttributions(BaseModel): - """ - The token attributions explaining predicted labels - - Attributes: - ----------- - - token: str - The input token - attributions: Dict[str, float] - A dictionary containing label class-attribution pairs +class LabelingRuleMetricsSummary(_LabelingRuleMetricsSummary): + pass - """ - token: str - attributions: Dict[str, float] = Field(default_factory=dict) +class DatasetLabelingRulesMetricsSummary(_DatasetLabelingRulesMetricsSummary): + pass -class CreationTextClassificationRecord(BaseRecord[TextClassificationAnnotation]): - """ - Text classification record +class TextClassificationDataset(ServiceTextClassificationDataset): + pass - Attributes: - ----------- - inputs: Dict[str, Union[str, List[str]]] - The input data text +class TextClassificationAnnotation(_TextClassificationAnnotation): + pass - multi_label: bool - Enable text classification with multiple predicted/annotated labels. - Default=False - explanation: Dict[str, List[TokenAttributions]] - Token attribution list explaining predicted classes per token input. - The dictionary key must be aligned with provided record text. Optional - """ +class TextClassificationRecordInputs(BaseRecordInputs[TextClassificationAnnotation]): inputs: Dict[str, Union[str, List[str]]] multi_label: bool = False explanation: Optional[Dict[str, List[TokenAttributions]]] = None - _SCORE_DEVIATION_ERROR: ClassVar[float] = 0.001 - - @root_validator - def validate_record(cls, values): - """fastapi validator method""" - prediction = values.get("prediction", None) - annotation = values.get("annotation", None) - status = values.get("status") - multi_label = values.get("multi_label", False) - - cls._check_score_integrity(prediction, multi_label) - cls._check_annotation_integrity(annotation, multi_label, status) - - return values - - @classmethod - def _check_annotation_integrity( - cls, - annotation: TextClassificationAnnotation, - multi_label: bool, - status: TaskStatus, - ): - if status == TaskStatus.validated and not multi_label: - assert ( - annotation and len(annotation.labels) > 0 - ), "Annotation must include some label for validated records" - - if not multi_label and annotation: - assert ( - len(annotation.labels) == 1 - ), "Single label record must include only one annotation label" - - @classmethod - def _check_score_integrity( - cls, prediction: TextClassificationAnnotation, multi_label: bool - ): - """ - Checks the score value integrity - - Parameters - ---------- - prediction: - The prediction annotation - multi_label: - If multi label - - """ - if prediction and not multi_label: - assert sum([label.score for label in prediction.labels]) <= ( - 1.0 + cls._SCORE_DEVIATION_ERROR - ), f"Wrong score distributions: {prediction.labels}" - - @classmethod - def task(cls) -> TaskType: - """The task type""" - return TaskType.text_classification - - @property - def predicted(self) -> Optional[PredictionStatus]: - if self.predicted_by and self.annotated_by: - return ( - PredictionStatus.OK - if set(self.predicted_as) == set(self.annotated_as) - else PredictionStatus.KO - ) - return None - - @property - def predicted_as(self) -> List[str]: - return self._labels_from_annotation( - self.prediction, multi_label=self.multi_label - ) - - @property - def annotated_as(self) -> List[str]: - return self._labels_from_annotation( - self.annotation, multi_label=self.multi_label - ) - - @property - def scores(self) -> List[float]: - """Values of prediction scores""" - if not self.prediction: - return [] - return ( - [label.score for label in self.prediction.labels] - if self.multi_label - else [ - prediction_class.score - for prediction_class in [ - self._max_class_prediction( - self.prediction, multi_label=self.multi_label - ) - ] - if prediction_class - ] - ) - - def all_text(self) -> str: - sentences = [] - for v in self.inputs.values(): - if isinstance(v, list): - sentences.extend(v) - else: - sentences.append(v) - return "\n".join(sentences) - - @validator("inputs") - def validate_inputs(cls, text: Dict[str, Any]): - """Applies validation over input text""" - assert len(text) > 0, "No inputs provided" - - for t in text.values(): - assert t is not None, "Cannot include None fields" - - return text - - @validator("inputs") - def flatten_text(cls, text: Dict[str, Any]): - """Normalizes input text to dict of strings""" - flat_dict = flatten_dict(text) - return flat_dict - - @classmethod - def _labels_from_annotation( - cls, annotation: TextClassificationAnnotation, multi_label: bool - ) -> Union[List[str], List[int]]: - """ - Extracts labels values from annotation - - Parameters - ---------- - annotation: - The annotation - multi_label - Enable/Disable multi label model - - Returns - ------- - Label values for a given annotation - - """ - if not annotation: - return [] - - if multi_label: - return [ - label.class_label for label in annotation.labels if label.score > 0.5 - ] - - class_prediction = cls._max_class_prediction( - annotation, multi_label=multi_label - ) - if class_prediction is None: - return [] - - return [class_prediction.class_label] - - @staticmethod - def _max_class_prediction( - p: TextClassificationAnnotation, multi_label: bool - ) -> Optional[ClassPrediction]: - """ - Gets the max class prediction for annotation - - Parameters - ---------- - p: - The annotation - multi_label: - Enable/Disable multi_label mode - - Returns - ------- - - The max class prediction in terms of prediction score if - prediction has labels and no multi label is enabled. None, otherwise - """ - if multi_label or p is None or not p.labels: - return None - return p.labels[0] - - class Config: - allow_population_by_field_name = True - - -class TextClassificationRecordDB(CreationTextClassificationRecord): - """ - The main text classification task record - - Attributes: - ----------- - - last_updated: datetime - Last record update (read only) - predicted: Optional[PredictionStatus] - The record prediction status. Optional - """ - - last_updated: datetime = None - _predicted: Optional[PredictionStatus] = Field(alias="predicted") - - def extended_fields(self) -> Dict[str, Any]: - words = self.all_text() - return { - **super().extended_fields(), - EsRecordDataFieldNames.words: words, - # This allow query by text:.... or text.exact:.... - # Once words is remove we can normalize at record level - "text": words, - } +class TextClassificationRecord( + TextClassificationRecordInputs, BaseRecord[TextClassificationAnnotation] +): + pass -class TextClassificationRecord(TextClassificationRecordDB): - def extended_fields(self) -> Dict[str, Any]: - return {} - - -class TextClassificationBulkData(UpdateDatasetRequest): - """ - API bulk data for text classification - - Attributes: - ----------- - - records: List[CreationTextClassificationRecord] - The text classification record list - """ +class TextClassificationBulkRequest(UpdateDatasetRequest): - records: List[CreationTextClassificationRecord] + records: List[TextClassificationRecordInputs] @validator("records") def check_multi_label_integrity(cls, records: List[TextClassificationRecord]): @@ -483,26 +137,7 @@ def check_multi_label_integrity(cls, records: List[TextClassificationRecord]): return records -class TextClassificationQuery(BaseSearchQuery): - """ - API Filters for text classification - - Attributes: - ----------- - - predicted_as: List[str] - List of predicted terms - - annotated_as: List[str] - List of annotated terms - - predicted: Optional[PredictionStatus] - The task prediction status - - uncovered_by_rules: - Only return records that are NOT covered by these rules. - - """ +class TextClassificationQuery(ServiceBaseRecordsQuery): predicted_as: List[str] = Field(default_factory=list) annotated_as: List[str] = Field(default_factory=list) @@ -515,25 +150,7 @@ class TextClassificationQuery(BaseSearchQuery): ) -class TextClassificationSearchRequest(BaseModel): - """ - API SearchRequest request - - Attributes: - ----------- - - query: TextClassificationQuery - The search query configuration - - sort: - The sort order list - """ - - query: TextClassificationQuery = Field(default_factory=TextClassificationQuery) - sort: List[SortableField] = Field(default_factory=list) - - -class TextClassificationSearchAggregations(BaseSearchResultsAggregations): +class TextClassificationSearchAggregations(ServiceBaseSearchResultsAggregations): pass @@ -541,3 +158,8 @@ class TextClassificationSearchResults( BaseSearchResults[TextClassificationRecord, TextClassificationSearchAggregations] ): pass + + +class TextClassificationSearchRequest(BaseModel): + query: TextClassificationQuery = Field(default_factory=TextClassificationQuery) + sort: List[SortableField] = Field(default_factory=list) diff --git a/src/rubrix/server/apis/v0/models/token_classification.py b/src/rubrix/server/apis/v0/models/token_classification.py index 0dac97527d..948fc559ab 100644 --- a/src/rubrix/server/apis/v0/models/token_classification.py +++ b/src/rubrix/server/apis/v0/models/token_classification.py @@ -12,104 +12,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict -from datetime import datetime -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, root_validator, validator -from rubrix._constants import MAX_KEYWORD_LENGTH from rubrix.server.apis.v0.models.commons.model import ( - BaseAnnotation, BaseRecord, + BaseRecordInputs, BaseSearchResults, - BaseSearchResultsAggregations, - EsRecordDataFieldNames, - PredictionStatus, ScoreRange, - SortableField, - TaskType, ) -from rubrix.server.apis.v0.models.datasets import DatasetDB, UpdateDatasetRequest -from rubrix.server.services.search.model import BaseSearchQuery - -PREDICTED_MENTIONS_ES_FIELD_NAME = "predicted_mentions" -MENTIONS_ES_FIELD_NAME = "mentions" - - -class EntitySpan(BaseModel): - """ - The tokens span for a labeled text. - - Entity spans will be defined between from start to end - 1 - - Attributes: - ----------- - - start: int - character start position - end: int - character end position, must be higher than the starting character. - label: str - the label related to tokens that conforms the entity span - score: - A higher score means, the model/annotator is more confident about its predicted/annotated entity. - """ - - start: int - end: int - label: str = Field(min_length=1, max_length=MAX_KEYWORD_LENGTH) - score: float = Field(default=1.0, ge=0.0, le=1.0) - - @validator("end") - def check_span_offset(cls, end: int, values): - """Validates span offset""" - assert ( - end > values["start"] - ), "End character cannot be placed before the starting character, it must be at least one character after." - return end - - def __hash__(self): - return hash(type(self)) + hash(self.__dict__.values()) - - -class TokenClassificationAnnotation(BaseAnnotation): - """Annotation class for the Token classification task. - - Attributes: - ----------- - entities: List[EntitiesSpan] - a list of detected entities spans in tokenized text, if any. - score: float - score related to annotated entities. The higher is score value, the - more likely is that entities were properly annotated. - """ - - entities: List[EntitySpan] = Field(default_factory=list) - score: Optional[float] = None - +from rubrix.server.apis.v0.models.datasets import UpdateDatasetRequest +from rubrix.server.commons.models import PredictionStatus +from rubrix.server.daos.backend.search.model import SortableField +from rubrix.server.services.search.model import ( + ServiceBaseRecordsQuery, + ServiceBaseSearchResultsAggregations, +) +from rubrix.server.services.tasks.token_classification.model import ( + ServiceTokenClassificationAnnotation as _TokenClassificationAnnotation, +) +from rubrix.server.services.tasks.token_classification.model import ( + ServiceTokenClassificationDataset, +) -class CreationTokenClassificationRecord(BaseRecord[TokenClassificationAnnotation]): - """ - Dataset record for token classification task - Attributes: - ----------- +class TokenClassificationAnnotation(_TokenClassificationAnnotation): + pass - tokens: List[str] - The input tokens - text: str - Textual representation of token list - """ +class TokenClassificationRecordInputs(BaseRecordInputs[TokenClassificationAnnotation]): - tokens: List[str] = Field(min_items=1) text: str = Field() + tokens: List[str] = Field(min_items=1) + # TODO(@frascuchon): Delete this field and all related logic _raw_text: Optional[str] = Field(alias="raw_text") - __chars2tokens__: Dict[int, int] = None - __tokens2chars__: Dict[int, Tuple[int, int]] = None - @root_validator(pre=True) def accept_old_fashion_text_field(cls, values): text, raw_text = values.get("text"), values.get("raw_text") @@ -118,290 +56,23 @@ def accept_old_fashion_text_field(cls, values): return values - def __init__(self, **data): - super().__init__(**data) - - self.__chars2tokens__, self.__tokens2chars__ = self.__build_indices_map__() - - self.check_annotation(self.prediction) - self.check_annotation(self.annotation) - - def char_id2token_id(self, char_idx: int) -> Optional[int]: - return self.__chars2tokens__.get(char_idx) - - def token_span(self, token_idx: int) -> Tuple[int, int]: - if token_idx not in self.__tokens2chars__: - raise IndexError(f"Token id {token_idx} out of bounds") - return self.__tokens2chars__[token_idx] - @validator("text") def check_text_content(cls, text: str): assert text and text.strip(), "No text or empty text provided" return text - def __build_indices_map__( - self, - ) -> Tuple[Dict[int, int], Dict[int, Tuple[int, int]]]: - """ - Build the indices mapping between text characters and tokens where belongs to, - and vice versa. - - chars2tokens index contains is the token idx where i char is contained (if any). - - Out-of-token characters won't be included in this map, - so access should be using ``chars2tokens_map.get(i)`` - instead of ``chars2tokens_map[i]``. - - """ - - def chars2tokens_index(): - def is_space_after_token(char, idx: int, chars_map) -> str: - return char == " " and idx - 1 in chars_map - - chars_map = {} - current_token = 0 - current_token_char_start = 0 - for idx, char in enumerate(self.text): - if is_space_after_token(char, idx, chars_map): - continue - relative_idx = idx - current_token_char_start - if ( - relative_idx < len(self.tokens[current_token]) - and char == self.tokens[current_token][relative_idx] - ): - chars_map[idx] = current_token - elif ( - current_token + 1 < len(self.tokens) - and relative_idx >= len(self.tokens[current_token]) - and char == self.tokens[current_token + 1][0] - ): - current_token += 1 - current_token_char_start += relative_idx - chars_map[idx] = current_token - - return chars_map - - def tokens2chars_index( - chars2tokens: Dict[int, int] - ) -> Dict[int, Tuple[int, int]]: - tokens2chars_map = defaultdict(list) - for c, t in chars2tokens.items(): - tokens2chars_map[t].append(c) - - return { - token_idx: (min(chars), max(chars)) - for token_idx, chars in tokens2chars_map.items() - } - - chars2tokens_idx = chars2tokens_index() - return chars2tokens_idx, tokens2chars_index(chars2tokens_idx) - - def check_annotation( - self, - annotation: Optional[TokenClassificationAnnotation], - ): - """Validates entities in terms of offset spans""" - - def adjust_span_bounds(start, end): - if start < 0: - start = 0 - if entity.end > len(self.text): - end = len(self.text) - while start <= len(self.text) and not self.text[start].strip(): - start += 1 - while not self.text[end - 1].strip(): - end -= 1 - return start, end - - if annotation: - for entity in annotation.entities: - entity.start, entity.end = adjust_span_bounds(entity.start, entity.end) - mention = self.text[entity.start : entity.end] - assert len(mention) > 0, f"Empty offset defined for entity {entity}" - - token_start = self.char_id2token_id(entity.start) - token_end = self.char_id2token_id(entity.end - 1) - - assert not ( - token_start is None or token_end is None - ), f"Provided entity span {self.text[entity.start: entity.end]} is not aligned with provided tokens." - "Some entity chars could be reference characters out of tokens" - - span_start, _ = self.token_span(token_start) - _, span_end = self.token_span(token_end) - - assert ( - self.text[span_start : span_end + 1] == mention - ), f"Defined offset [{self.text[entity.start: entity.end]}] is a misaligned entity mention" - - def task(cls) -> TaskType: - """The record task type""" - return TaskType.token_classification - - @property - def predicted(self) -> Optional[PredictionStatus]: - if self.annotation and self.prediction: - return ( - PredictionStatus.OK - if self.annotation.entities == self.prediction.entities - else PredictionStatus.KO - ) - return None - - @property - def predicted_as(self) -> List[str]: - return [ent.label for ent in self.predicted_entities()] - - @property - def annotated_as(self) -> List[str]: - return [ent.label for ent in self.annotated_entities()] - - @property - def scores(self) -> List[float]: - if not self.prediction: - return [] - if self.prediction.score is not None: - return [self.prediction.score] - return [e.score for e in self.prediction.entities] - - def all_text(self) -> str: - return self.text - - def predicted_iob_tags(self) -> Optional[List[str]]: - if self.prediction is None: - return None - return self.spans2iob(self.prediction.entities) - - def annotated_iob_tags(self) -> Optional[List[str]]: - if self.annotation is None: - return None - return self.spans2iob(self.annotation.entities) - - def spans2iob(self, spans: List[EntitySpan]) -> Optional[List[str]]: - if spans is None: - return None - tags = ["O"] * len(self.tokens) - for entity in spans: - token_start = self.char_id2token_id(entity.start) - token_end = self.char_id2token_id(entity.end - 1) - tags[token_start] = f"B-{entity.label}" - for idx in range(token_start + 1, token_end + 1): - tags[idx] = f"I-{entity.label}" - - return tags - - def predicted_mentions(self) -> List[Tuple[str, EntitySpan]]: - return [ - (mention, entity) - for mention, entity in self.__mentions_from_entities__( - self.predicted_entities() - ).items() - ] - - def annotated_mentions(self) -> List[Tuple[str, EntitySpan]]: - return [ - (mention, entity) - for mention, entity in self.__mentions_from_entities__( - self.annotated_entities() - ).items() - ] - - def annotated_entities(self) -> Set[EntitySpan]: - """Shortcut for real annotated entities, if provided""" - if self.annotation is None: - return set() - return set(self.annotation.entities) - - def predicted_entities(self) -> Set[EntitySpan]: - """Predicted entities""" - if self.prediction is None: - return set() - return set(self.prediction.entities) - - def __mentions_from_entities__( - self, entities: Set[EntitySpan] - ) -> Dict[str, EntitySpan]: - return { - mention: entity - for entity in entities - for mention in [self.text[entity.start : entity.end]] - } - - class Config: - allow_population_by_field_name = True - underscore_attrs_are_private = True - - -class TokenClassificationRecordDB(CreationTokenClassificationRecord): - """ - The main token classification task record - - Attributes: - ----------- - - last_updated: datetime - Last record update (read only) - predicted: Optional[PredictionStatus] - The record prediction status. Optional - """ - - last_updated: datetime = None - _predicted: Optional[PredictionStatus] = Field(alias="predicted") - - def extended_fields(self) -> Dict[str, Any]: - - return { - **super().extended_fields(), - # See ../service/service.py - PREDICTED_MENTIONS_ES_FIELD_NAME: [ - {"mention": mention, "entity": entity.label, "score": entity.score} - for mention, entity in self.predicted_mentions() - ], - MENTIONS_ES_FIELD_NAME: [ - {"mention": mention, "entity": entity.label} - for mention, entity in self.annotated_mentions() - ], - EsRecordDataFieldNames.words: self.all_text(), - } - - -class TokenClassificationRecord(TokenClassificationRecordDB): - def extended_fields(self) -> Dict[str, Any]: - return { - "raw_text": self.text, # Maintain results compatibility - } - - -class TokenClassificationBulkData(UpdateDatasetRequest): - """ - API bulk data for text classification - - Attributes: - ----------- - - records: List[TextClassificationRecord] - The text classification record list - - """ - - records: List[CreationTokenClassificationRecord] +class TokenClassificationRecord( + TokenClassificationRecordInputs, BaseRecord[TokenClassificationAnnotation] +): + pass -class TokenClassificationQuery(BaseSearchQuery): - """ - API Filters for text classification - Attributes: - ----------- +class TokenClassificationBulkRequest(UpdateDatasetRequest): + records: List[TokenClassificationRecordInputs] - predicted_as: List[str] - List of predicted terms - annotated_as: List[str] - List of annotated terms - predicted: Optional[PredictionStatus] - The task prediction status - """ +class TokenClassificationQuery(ServiceBaseRecordsQuery): predicted_as: List[str] = Field(default_factory=list) annotated_as: List[str] = Field(default_factory=list) @@ -410,35 +81,11 @@ class TokenClassificationQuery(BaseSearchQuery): class TokenClassificationSearchRequest(BaseModel): - - """ - API SearchRequest request - - Attributes: - ----------- - - query: TokenClassificationQuery - The search query configuration - sort: - The sort by order in search results - """ - query: TokenClassificationQuery = Field(default_factory=TokenClassificationQuery) sort: List[SortableField] = Field(default_factory=list) -class TokenClassificationAggregations(BaseSearchResultsAggregations): - """ - Extends base aggregation with mentions - - Attributes: - ----------- - mentions: Dict[str,Dict[str,int]] - The annotated entity spans - predicted_mentions: Dict[str,Dict[str,int]] - The prediction entity spans - """ - +class TokenClassificationAggregations(ServiceBaseSearchResultsAggregations): predicted_mentions: Dict[str, Dict[str, int]] = Field(default_factory=dict) mentions: Dict[str, Dict[str, int]] = Field(default_factory=dict) @@ -449,6 +96,5 @@ class TokenClassificationSearchResults( pass -class TokenClassificationDatasetDB(DatasetDB): - task: TaskType = Field(default=TaskType.token_classification, const=True) +class TokenClassificationDataset(ServiceTokenClassificationDataset): pass diff --git a/src/rubrix/server/apis/v0/validators/text_classification.py b/src/rubrix/server/apis/v0/validators/text_classification.py index f2bee61bee..0bb4f5db97 100644 --- a/src/rubrix/server/apis/v0/validators/text_classification.py +++ b/src/rubrix/server/apis/v0/validators/text_classification.py @@ -2,27 +2,27 @@ from fastapi import Depends -from rubrix.server.apis.v0.models.commons.model import TaskType from rubrix.server.apis.v0.models.dataset_settings import TextClassificationSettings from rubrix.server.apis.v0.models.datasets import Dataset -from rubrix.server.apis.v0.models.metrics.text_classification import DatasetLabels -from rubrix.server.apis.v0.models.text_classification import ( - CreationTextClassificationRecord, - TextClassificationRecord, -) +from rubrix.server.commons.models import TaskType from rubrix.server.errors import BadRequestError, EntityNotFoundError from rubrix.server.security.model import User -from rubrix.server.services.datasets import DatasetsService, SVCDatasetSettings +from rubrix.server.services.datasets import DatasetsService, ServiceBaseDatasetSettings +from rubrix.server.services.tasks.text_classification.metrics import DatasetLabels -__svc_settings_class__: Type[SVCDatasetSettings] = type( +__svc_settings_class__: Type[ServiceBaseDatasetSettings] = type( f"{TaskType.text_classification}_DatasetSettings", - (SVCDatasetSettings, TextClassificationSettings), + (ServiceBaseDatasetSettings, TextClassificationSettings), {}, ) from rubrix.server.services.metrics import MetricsService +from rubrix.server.services.tasks.text_classification.model import ( + ServiceTextClassificationRecord, +) +# TODO(@frascuchon): Move validator and its models to the service layer class DatasetValidator: _INSTANCE = None @@ -48,7 +48,7 @@ async def validate_dataset_settings( results = self.__metrics__.summarize_metric( dataset=dataset, metric=DatasetLabels(), - record_class=TextClassificationRecord, + record_class=ServiceTextClassificationRecord, query=None, ) if results: @@ -67,7 +67,7 @@ async def validate_dataset_records( self, user: User, dataset: Dataset, - records: Optional[List[CreationTextClassificationRecord]] = None, + records: Optional[List[ServiceTextClassificationRecord]] = None, ): try: settings: TextClassificationSettings = await self.__datasets__.get_settings( diff --git a/src/rubrix/server/apis/v0/validators/token_classification.py b/src/rubrix/server/apis/v0/validators/token_classification.py index 947c2ce1dc..5e49f68f3e 100644 --- a/src/rubrix/server/apis/v0/validators/token_classification.py +++ b/src/rubrix/server/apis/v0/validators/token_classification.py @@ -2,28 +2,28 @@ from fastapi import Depends -from rubrix.server.apis.v0.models.commons.model import TaskType from rubrix.server.apis.v0.models.dataset_settings import TokenClassificationSettings from rubrix.server.apis.v0.models.datasets import Dataset -from rubrix.server.apis.v0.models.metrics.token_classification import DatasetLabels -from rubrix.server.apis.v0.models.token_classification import ( - CreationTokenClassificationRecord, - TokenClassificationAnnotation, - TokenClassificationRecord, -) +from rubrix.server.commons.models import TaskType from rubrix.server.errors import BadRequestError, EntityNotFoundError from rubrix.server.security.model import User -from rubrix.server.services.datasets import DatasetsService, SVCDatasetSettings +from rubrix.server.services.datasets import DatasetsService, ServiceBaseDatasetSettings +from rubrix.server.services.tasks.token_classification.metrics import DatasetLabels -__svc_settings_class__: Type[SVCDatasetSettings] = type( +__svc_settings_class__: Type[ServiceBaseDatasetSettings] = type( f"{TaskType.token_classification}_DatasetSettings", - (SVCDatasetSettings, TokenClassificationSettings), + (ServiceBaseDatasetSettings, TokenClassificationSettings), {}, ) from rubrix.server.services.metrics import MetricsService +from rubrix.server.services.tasks.token_classification.model import ( + ServiceTokenClassificationAnnotation, + ServiceTokenClassificationRecord, +) +# TODO(@frascuchon): Move validator and its models to the service layer class DatasetValidator: _INSTANCE = None @@ -48,7 +48,7 @@ async def validate_dataset_settings( results = self.__metrics__.summarize_metric( dataset=dataset, metric=DatasetLabels(), - record_class=TokenClassificationRecord, + record_class=ServiceTokenClassificationRecord, query=None, ) if results: @@ -67,7 +67,7 @@ async def validate_dataset_records( self, user: User, dataset: Dataset, - records: List[CreationTokenClassificationRecord], + records: List[ServiceTokenClassificationRecord], ): try: settings: TokenClassificationSettings = ( @@ -90,7 +90,7 @@ async def validate_dataset_records( @staticmethod def __check_label_entities__( - label_schema: Set[str], annotation: TokenClassificationAnnotation + label_schema: Set[str], annotation: ServiceTokenClassificationAnnotation ): if not annotation: return diff --git a/src/rubrix/server/apis/v0/config/__init__.py b/src/rubrix/server/commons/__init__.py similarity index 100% rename from src/rubrix/server/apis/v0/config/__init__.py rename to src/rubrix/server/commons/__init__.py diff --git a/src/rubrix/server/commons/config.py b/src/rubrix/server/commons/config.py new file mode 100644 index 0000000000..6de562d06b --- /dev/null +++ b/src/rubrix/server/commons/config.py @@ -0,0 +1,96 @@ +from typing import List, Optional, Set, Type + +from pydantic import BaseModel + +from rubrix.server.commons.models import TaskType +from rubrix.server.errors import EntityNotFoundError, WrongTaskError +from rubrix.server.services.datasets import ServiceDataset +from rubrix.server.services.metrics import ServiceBaseMetric +from rubrix.server.services.metrics.models import ServiceBaseTaskMetrics +from rubrix.server.services.search.model import ServiceRecordsQuery +from rubrix.server.services.tasks.commons import ServiceRecord + + +class TaskConfig(BaseModel): + task: TaskType + query: Type[ServiceRecordsQuery] + dataset: Type[ServiceDataset] + record: Type[ServiceRecord] + metrics: Optional[Type[ServiceBaseTaskMetrics]] + + +class TasksFactory: + + __REGISTERED_TASKS__ = dict() + + @classmethod + def register_task( + cls, + task_type: TaskType, + dataset_class: Type[ServiceDataset], + query_request: Type[ServiceRecordsQuery], + record_class: Type[ServiceRecord], + metrics: Optional[Type[ServiceBaseTaskMetrics]] = None, + ): + + cls.__REGISTERED_TASKS__[task_type] = TaskConfig( + task=task_type, + dataset=dataset_class, + query=query_request, + record=record_class, + metrics=metrics, + ) + + @classmethod + def get_all_configs(cls) -> List[TaskConfig]: + return [cfg for cfg in cls.__REGISTERED_TASKS__.values()] + + @classmethod + def get_task_by_task_type(cls, task_type: TaskType) -> Optional[TaskConfig]: + return cls.__REGISTERED_TASKS__.get(task_type) + + @classmethod + def get_task_metrics(cls, task: TaskType) -> Optional[Type[ServiceBaseTaskMetrics]]: + config = cls.get_task_by_task_type(task) + if config: + return config.metrics + + @classmethod + def get_task_dataset(cls, task: TaskType) -> Type[ServiceDataset]: + config = cls.__get_task_config__(task) + return config.dataset + + @classmethod + def get_task_record(cls, task: TaskType) -> Type[ServiceRecord]: + config = cls.__get_task_config__(task) + return config.record + + @classmethod + def __get_task_config__(cls, task): + config = cls.get_task_by_task_type(task) + if not config: + raise WrongTaskError(f"No configuration found for task {task}") + return config + + @classmethod + def find_task_metric( + cls, task: TaskType, metric_id: str + ) -> Optional[ServiceBaseMetric]: + metrics = cls.find_task_metrics(task, {metric_id}) + if metrics: + return metrics[0] + raise EntityNotFoundError(name=metric_id, type=ServiceBaseMetric) + + @classmethod + def find_task_metrics( + cls, task: TaskType, metric_ids: Set[str] + ) -> List[ServiceBaseMetric]: + + if not metric_ids: + return [] + + metrics = [] + for metric in cls.get_task_metrics(task).metrics: + if metric.id in metric_ids: + metrics.append(metric) + return metrics diff --git a/src/rubrix/server/commons/models.py b/src/rubrix/server/commons/models.py new file mode 100644 index 0000000000..31e2444863 --- /dev/null +++ b/src/rubrix/server/commons/models.py @@ -0,0 +1,21 @@ +from enum import Enum + + +class TaskStatus(str, Enum): + default = "Default" + edited = "Edited" # TODO: DEPRECATE + discarded = "Discarded" + validated = "Validated" + + +class TaskType(str, Enum): + + text_classification = "TextClassification" + token_classification = "TokenClassification" + text2text = "Text2Text" + multi_task_text_token_classification = "MultitaskTextTokenClassification" + + +class PredictionStatus(str, Enum): + OK = "ok" + KO = "ko" diff --git a/src/rubrix/server/apis/v0/models/metrics/__init__.py b/src/rubrix/server/daos/backend/__init__.py similarity index 100% rename from src/rubrix/server/apis/v0/models/metrics/__init__.py rename to src/rubrix/server/daos/backend/__init__.py diff --git a/src/rubrix/server/elasticseach/__init__.py b/src/rubrix/server/daos/backend/mappings/__init__.py similarity index 100% rename from src/rubrix/server/elasticseach/__init__.py rename to src/rubrix/server/daos/backend/mappings/__init__.py diff --git a/src/rubrix/server/elasticseach/mappings/datasets.py b/src/rubrix/server/daos/backend/mappings/datasets.py similarity index 87% rename from src/rubrix/server/elasticseach/mappings/datasets.py rename to src/rubrix/server/daos/backend/mappings/datasets.py index 125343f0fe..63fbc0ecda 100644 --- a/src/rubrix/server/elasticseach/mappings/datasets.py +++ b/src/rubrix/server/daos/backend/mappings/datasets.py @@ -1,9 +1,9 @@ from rubrix.server.settings import settings DATASETS_INDEX_NAME = settings.dataset_index_name -DATASETS_RECORDS_INDEX_NAME = settings.dataset_records_index_name - +# TODO(@frascuchon): Define an mapping definition instead and +# use it when datasets index is created DATASETS_INDEX_TEMPLATE = { "index_patterns": [DATASETS_INDEX_NAME], "settings": {"number_of_shards": 1}, diff --git a/src/rubrix/server/elasticseach/mappings/helpers.py b/src/rubrix/server/daos/backend/mappings/helpers.py similarity index 97% rename from src/rubrix/server/elasticseach/mappings/helpers.py rename to src/rubrix/server/daos/backend/mappings/helpers.py index 4a0715cab2..0a7f75bc20 100644 --- a/src/rubrix/server/elasticseach/mappings/helpers.py +++ b/src/rubrix/server/daos/backend/mappings/helpers.py @@ -164,6 +164,8 @@ def tasks_common_mappings(): "id": mappings.keyword_field(), "words": mappings.words_text_field(), "text": mappings.text_field(), + # TODO(@frascuchon): Enable prediction and annotation + # so we can build extra metrics based on these fields "prediction": {"type": "object", "enabled": False}, "annotation": {"type": "object", "enabled": False}, "status": mappings.keyword_field(), diff --git a/src/rubrix/server/elasticseach/mappings/text2text.py b/src/rubrix/server/daos/backend/mappings/text2text.py similarity index 69% rename from src/rubrix/server/elasticseach/mappings/text2text.py rename to src/rubrix/server/daos/backend/mappings/text2text.py index 8f2ec4ffcf..91bfc6880d 100644 --- a/src/rubrix/server/elasticseach/mappings/text2text.py +++ b/src/rubrix/server/daos/backend/mappings/text2text.py @@ -1,4 +1,4 @@ -from rubrix.server.elasticseach.mappings.helpers import mappings +from rubrix.server.daos.backend.mappings.helpers import mappings def text2text_mappings(): @@ -16,10 +16,6 @@ def text2text_mappings(): ] ), "properties": { - # TODO: we will include this breaking changes 2 releases after - # PR https://github.com/recognai/rubrix/pull/1018 - # "annotated_as": mappings.text_field(), - # "predicted_as": mappings.text_field(), "annotated_as": mappings.keyword_field(), "predicted_as": mappings.keyword_field(), "text_predicted": mappings.words_text_field(), diff --git a/src/rubrix/server/elasticseach/mappings/text_classification.py b/src/rubrix/server/daos/backend/mappings/text_classification.py similarity index 95% rename from src/rubrix/server/elasticseach/mappings/text_classification.py rename to src/rubrix/server/daos/backend/mappings/text_classification.py index 95dc19942e..bcce2023e1 100644 --- a/src/rubrix/server/elasticseach/mappings/text_classification.py +++ b/src/rubrix/server/daos/backend/mappings/text_classification.py @@ -1,4 +1,4 @@ -from rubrix.server.elasticseach.mappings.helpers import mappings +from rubrix.server.daos.backend.mappings.helpers import mappings def text_classification_mappings(): diff --git a/src/rubrix/server/elasticseach/mappings/token_classification.py b/src/rubrix/server/daos/backend/mappings/token_classification.py similarity index 68% rename from src/rubrix/server/elasticseach/mappings/token_classification.py rename to src/rubrix/server/daos/backend/mappings/token_classification.py index d72302ec3e..664c317e31 100644 --- a/src/rubrix/server/elasticseach/mappings/token_classification.py +++ b/src/rubrix/server/daos/backend/mappings/token_classification.py @@ -1,10 +1,39 @@ -from rubrix.server.apis.v0.models.metrics.token_classification import ( - MentionMetrics, - TokenMetrics, - TokenTagMetrics, -) -from rubrix.server.elasticseach.mappings.helpers import mappings -from rubrix.server.elasticseach.query_helpers import nested_mappings_from_base_model +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + +from rubrix.server.daos.backend.mappings.helpers import mappings +from rubrix.server.daos.backend.query_helpers import nested_mappings_from_base_model + + +class MentionMetrics(BaseModel): + """Mention metrics model""" + + value: str + label: str + score: float = Field(ge=0.0) + capitalness: Optional[str] = Field(None) + density: float = Field(ge=0.0) + tokens_length: int = Field(g=0) + chars_length: int = Field(g=0) + + +class TokenTagMetrics(BaseModel): + value: str + tag: str + + +class TokenMetrics(BaseModel): + + idx: int + value: str + char_start: int + char_end: int + length: int + capitalness: Optional[str] = None + score: Optional[float] = None + tag: Optional[str] = None # TODO: remove! + custom: Dict[str, Any] = None def mentions_mappings(): diff --git a/src/rubrix/server/daos/backend/metrics/__init__.py b/src/rubrix/server/daos/backend/metrics/__init__.py new file mode 100644 index 0000000000..3b02974875 --- /dev/null +++ b/src/rubrix/server/daos/backend/metrics/__init__.py @@ -0,0 +1,5 @@ +from .commons import METRICS as COMMON_METRICS +from .text_classification import METRICS as TEXT_CLASSIFICATION +from .token_classification import METRICS as TOKEN_CLASSIFICATION + +ALL_METRICS = {**COMMON_METRICS, **TEXT_CLASSIFICATION, **TOKEN_CLASSIFICATION} diff --git a/src/rubrix/server/daos/backend/metrics/base.py b/src/rubrix/server/daos/backend/metrics/base.py new file mode 100644 index 0000000000..bb469755ad --- /dev/null +++ b/src/rubrix/server/daos/backend/metrics/base.py @@ -0,0 +1,213 @@ +import dataclasses +from typing import Any, Dict, List, Optional, Union + +from rubrix.server.daos.backend.query_helpers import aggregations +from rubrix.server.helpers import unflatten_dict + + +@dataclasses.dataclass +class ElasticsearchMetric: + id: str + + @property + def metric_arg_names(self): + return self.__args__ + + def __post_init__(self): + self.__args__ = self.get_function_arg_names(self._build_aggregation) + + @staticmethod + def get_function_arg_names(func): + return func.__code__.co_varnames + + def aggregation_request( + self, *args, **kwargs + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + """ + Configures the summary es aggregation definition + """ + return {self.id: self._build_aggregation(*args, **kwargs)} + + def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: + """ + Parse the es aggregation result. Override this method + for result customization + + Parameters + ---------- + aggregation_result: + Retrieved es aggregation result + + """ + return aggregation_result.get(self.id, aggregation_result) + + def _build_aggregation(self, *args, **kwargs) -> Dict[str, Any]: + raise NotImplementedError() + + +@dataclasses.dataclass +class NestedPathElasticsearchMetric(ElasticsearchMetric): + """ + A ``ElasticsearchMetric`` which need nested fields for summary calculation. + + Aggregations for nested fields need some extra configuration and this class + encapsulate these common logic. + + Attributes: + ----------- + nested_path: + The nested + """ + + nested_path: str + + def __post_init__(self): + self.__args__ = self.get_function_arg_names(self._inner_aggregation) + + def _inner_aggregation(self, *args, **kwargs) -> Dict[str, Any]: + """The specific aggregation definition""" + raise NotImplementedError() + + def _build_aggregation(self, *args, **kwargs) -> Dict[str, Any]: + """Implements the common mechanism to define aggregations with nested fields""" + return aggregations.nested_aggregation( + nested_path=self.nested_path, + inner_aggregation=self._inner_aggregation(*args, **kwargs), + ) + + def compound_nested_field(self, inner_field: str) -> str: + return f"{self.nested_path}.{inner_field}" + + +@dataclasses.dataclass +class HistogramAggregation(ElasticsearchMetric): + """ + Base elasticsearch histogram aggregation metric + + Attributes + ---------- + field: + The histogram field + script: + If provided, it will be used as scripted field + for aggregation + fixed_interval: + If provided, it will used ALWAYS as the histogram + aggregation interval + """ + + field: str + script: Optional[Union[str, Dict[str, Any]]] = None + fixed_interval: Optional[float] = None + + def _build_aggregation(self, interval: Optional[float] = None) -> Dict[str, Any]: + if self.fixed_interval: + interval = self.fixed_interval + + return aggregations.histogram_aggregation( + field_name=self.field, script=self.script, interval=interval + ) + + +@dataclasses.dataclass +class TermsAggregation(ElasticsearchMetric): + + field: str = None + script: Union[str, Dict[str, Any]] = None + fixed_size: Optional[int] = None + default_size: Optional[int] = None + missing: Optional[str] = None + + def _build_aggregation(self, size: int = None) -> Dict[str, Any]: + if self.fixed_size: + size = self.fixed_size + return aggregations.terms_aggregation( + self.field, + script=self.script, + size=size or self.default_size, + missing=self.missing, + ) + + +@dataclasses.dataclass +class BidimensionalTermsAggregation(ElasticsearchMetric): + field_x: str + field_y: str + + def _build_aggregation(self, size: int = None) -> Dict[str, Any]: + return aggregations.bidimentional_terms_aggregations( + field_name_x=self.field_x, + field_name_y=self.field_y, + size=size, + ) + + +@dataclasses.dataclass +class NestedTermsAggregation(NestedPathElasticsearchMetric): + terms: TermsAggregation + + def __post_init__(self): + super().__post_init__() + self.terms.field = f"{self.nested_path}.{self.terms.field}" + + def _inner_aggregation(self, size: int) -> Dict[str, Any]: + return self.terms.aggregation_request(size) + + +@dataclasses.dataclass +class NestedBidimensionalTermsAggregation(NestedPathElasticsearchMetric): + biterms: BidimensionalTermsAggregation + + def __post_init__(self): + super().__post_init__() + self.biterms.field_x = f"{self.nested_path}.{self.biterms.field_x}" + self.biterms.field_y = f"{self.nested_path}.{self.biterms.field_y}" + + def _inner_aggregation(self, size: int = None) -> Dict[str, Any]: + return self.biterms.aggregation_request(size) + + +@dataclasses.dataclass +class NestedHistogramAggregation(NestedPathElasticsearchMetric): + histogram: HistogramAggregation + + def __post_init__(self): + super().__post_init__() + self.histogram.field = f"{self.nested_path}.{self.histogram.field}" + + def _inner_aggregation(self, interval: float) -> Dict[str, Any]: + return self.histogram.aggregation_request(interval) + + +@dataclasses.dataclass +class WordCloudAggregation(ElasticsearchMetric): + default_field: str + + def _build_aggregation( + self, text_field: str = None, size: int = None + ) -> Dict[str, Any]: + field = text_field or self.default_field + terms_id = f"{self.id}_{field}" if text_field else self.id + return TermsAggregation(id=terms_id, field=field,).aggregation_request( + size=size + )[terms_id] + + +@dataclasses.dataclass +class MetadataAggregations(ElasticsearchMetric): + def __post_init__(self): + super().__post_init__() + self.__args__ = self.get_function_arg_names(self.aggregation_request) + + def aggregation_request( + self, + schema: Dict[str, Any], + size: int = None, + ) -> List[Dict[str, Any]]: + + metadata_aggs = aggregations.custom_fields(fields_definitions=schema, size=size) + return [{key: value} for key, value in metadata_aggs.items()] + + def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: + data = unflatten_dict(aggregation_result, stop_keys=["metadata"]) + return data.get("metadata", {}) diff --git a/src/rubrix/server/daos/backend/metrics/commons.py b/src/rubrix/server/daos/backend/metrics/commons.py new file mode 100644 index 0000000000..b91c712a94 --- /dev/null +++ b/src/rubrix/server/daos/backend/metrics/commons.py @@ -0,0 +1,47 @@ +from rubrix.server.commons.models import TaskStatus +from rubrix.server.daos.backend.metrics.base import ( + HistogramAggregation, + MetadataAggregations, + TermsAggregation, + WordCloudAggregation, +) + +METRICS = { + "text_length": HistogramAggregation( + id="text_length", + field="metrics.text_length", + script="params._source.text.length()", + fixed_interval=1, + ), + "error_distribution": TermsAggregation( + id="error_distribution", + field="predicted", + missing="unknown", + fixed_size=3, + ), + "status_distribution": TermsAggregation( + id="status_distribution", + field="status", + fixed_size=len(TaskStatus), + ), + "words_cloud": WordCloudAggregation( + id="words_cloud", + default_field="text.wordcloud", + ), + "metadata": MetadataAggregations( + id="metadata", + ), + "predicted_by": TermsAggregation( + id="predicted_by", + field="predicted_by", + ), + "annotated_by": TermsAggregation( + id="annotated_by", + field="annotated_by", + ), + "score": HistogramAggregation( + id="score", + field="score", + fixed_interval=0.001, + ), +} diff --git a/src/rubrix/server/daos/backend/metrics/datasets.py b/src/rubrix/server/daos/backend/metrics/datasets.py new file mode 100644 index 0000000000..a48bc73d9b --- /dev/null +++ b/src/rubrix/server/daos/backend/metrics/datasets.py @@ -0,0 +1,8 @@ +# All metrics related to the datasets index +from rubrix.server.daos.backend.metrics.base import TermsAggregation + +METRICS = { + "all_rubrix_workspaces": TermsAggregation( + id="all_rubrix_workspaces", field="owner.keyword" + ) +} diff --git a/src/rubrix/server/daos/backend/metrics/text_classification.py b/src/rubrix/server/daos/backend/metrics/text_classification.py new file mode 100644 index 0000000000..085400a342 --- /dev/null +++ b/src/rubrix/server/daos/backend/metrics/text_classification.py @@ -0,0 +1,130 @@ +import dataclasses +from typing import Any, Dict, List, Optional + +from rubrix.server.daos.backend.metrics.base import ( + ElasticsearchMetric, + TermsAggregation, +) +from rubrix.server.daos.backend.query_helpers import aggregations, filters +from rubrix.server.helpers import unflatten_dict + + +@dataclasses.dataclass +class DatasetLabelingRulesMetric(ElasticsearchMetric): + def _build_aggregation(self, queries: List[str]) -> Dict[str, Any]: + rules_filters = [filters.text_query(rule_query) for rule_query in queries] + return aggregations.filters_aggregation( + filters={ + "covered_records": filters.boolean_filter( + should_filters=rules_filters, minimum_should_match=1 + ), + "annotated_covered_records": filters.boolean_filter( + filter_query=filters.exists_field("annotated_as"), + should_filters=rules_filters, + minimum_should_match=1, + ), + } + ) + + +@dataclasses.dataclass +class LabelingRulesMetric(ElasticsearchMetric): + id: str + + def _build_aggregation( + self, rule_query: str, labels: Optional[List[str]] + ) -> Dict[str, Any]: + + annotated_records_filter = filters.exists_field("annotated_as") + rule_query_filter = filters.text_query(rule_query) + aggr_filters = { + "covered_records": rule_query_filter, + "annotated_covered_records": filters.boolean_filter( + filter_query=annotated_records_filter, + should_filters=[rule_query_filter], + ), + } + + if labels is not None: + for label in labels: + rule_label_annotated_filter = filters.term_filter( + "annotated_as", value=label + ) + encoded_label = self._encode_label_name(label) + aggr_filters.update( + { + f"{encoded_label}.correct_records": filters.boolean_filter( + filter_query=annotated_records_filter, + should_filters=[ + rule_query_filter, + rule_label_annotated_filter, + ], + minimum_should_match=2, + ), + f"{encoded_label}.incorrect_records": filters.boolean_filter( + filter_query=annotated_records_filter, + must_query=rule_query_filter, + must_not_query=rule_label_annotated_filter, + ), + } + ) + + return aggregations.filters_aggregation(aggr_filters) + + @staticmethod + def _encode_label_name(label: str) -> str: + return label.replace(".", "@@@") + + @staticmethod + def _decode_label_name(label: str) -> str: + return label.replace("@@@", ".") + + def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: + if self.id in aggregation_result: + aggregation_result = aggregation_result[self.id] + + aggregation_result = unflatten_dict(aggregation_result) + results = { + "covered_records": aggregation_result.pop("covered_records"), + "annotated_covered_records": aggregation_result.pop( + "annotated_covered_records" + ), + } + + all_correct = [] + all_incorrect = [] + all_precision = [] + for label, metrics in aggregation_result.items(): + correct = metrics.get("correct_records", 0) + incorrect = metrics.get("incorrect_records", 0) + annotated = correct + incorrect + metrics["annotated"] = annotated + if annotated > 0: + precision = correct / annotated + metrics["precision"] = precision + all_precision.append(precision) + + all_correct.append(correct) + all_incorrect.append(incorrect) + results[self._decode_label_name(label)] = metrics + + results["correct_records"] = sum(all_correct) + results["incorrect_records"] = sum(all_incorrect) + if len(all_precision) > 0: + results["precision"] = sum(all_precision) / len(all_precision) + + return results + + +METRICS = { + "predicted_as": TermsAggregation( + id="predicted_as", + field="predicted_as", + ), + "annotated_as": TermsAggregation( + id="annotated_as", + field="annotated_as", + ), + "labeling_rule": LabelingRulesMetric(id="labeling_rule"), + "dataset_labeling_rules": DatasetLabelingRulesMetric(id="dataset_labeling_rules"), +} diff --git a/src/rubrix/server/daos/backend/metrics/token_classification.py b/src/rubrix/server/daos/backend/metrics/token_classification.py new file mode 100644 index 0000000000..166c03b969 --- /dev/null +++ b/src/rubrix/server/daos/backend/metrics/token_classification.py @@ -0,0 +1,214 @@ +import dataclasses +from typing import Any, Dict + +from rubrix.server.daos.backend.metrics.base import ( + BidimensionalTermsAggregation, + HistogramAggregation, + NestedBidimensionalTermsAggregation, + NestedHistogramAggregation, + NestedPathElasticsearchMetric, + NestedTermsAggregation, + TermsAggregation, +) +from rubrix.server.daos.backend.query_helpers import aggregations + +_DEFAULT_MAX_ENTITY_BUCKET = 1000 + + +@dataclasses.dataclass +class EntityConsistency(NestedPathElasticsearchMetric): + """Computes the entity consistency distribution""" + + mention_field: str + labels_field: str + + def _inner_aggregation( + self, + size: int, + interval: int = 2, + entity_size: int = _DEFAULT_MAX_ENTITY_BUCKET, + ) -> Dict[str, Any]: + size = size or 50 + interval = int(max(interval or 2, 2)) + return { + "consistency": { + **aggregations.terms_aggregation( + self.compound_nested_field(self.mention_field), size=size + ), + "aggs": { + "entities": aggregations.terms_aggregation( + self.compound_nested_field(self.labels_field), size=entity_size + ), + "count": { + "cardinality": { + "field": self.compound_nested_field(self.labels_field) + } + }, + "entities_variability_filter": { + "bucket_selector": { + "buckets_path": {"numLabels": "count"}, + "script": f"params.numLabels >= {interval}", + } + }, + "sortby_entities_count": { + "bucket_sort": { + "sort": [{"count": {"order": "desc"}}], + "size": size, + } + }, + }, + } + } + + def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: + """Simplifies the aggregation result sorting by worst mention consistency""" + result = [ + { + "mention": mention, + "entities": [ + {"label": entity, "count": count} + for entity, count in mention_aggs["entities"].items() + ], + } + for mention, mention_aggs in aggregation_result.items() + ] + # TODO: filter by entities threshold + result.sort(key=lambda m: len(m["entities"]), reverse=True) + return {"mentions": result} + + +METRICS = { + "tokens_length": HistogramAggregation( + "tokens_length", + field="metrics.tokens_length", + ), + "token_frequency": NestedTermsAggregation( + id="token_frequency", + nested_path="metrics.tokens", + terms=TermsAggregation( + id="frequency", + field="value", + ), + ), + "token_length": NestedHistogramAggregation( + id="token_length", + nested_path="metrics.tokens", + histogram=HistogramAggregation( + id="length", + field="length", + fixed_interval=1, + ), + ), + "token_capitalness": NestedTermsAggregation( + id="token_capitalness", + nested_path="metrics.tokens", + terms=TermsAggregation( + id="capitalness", + field="capitalness", + ), + ), + "predicted_mention_char_length": NestedHistogramAggregation( + id="predicted_mention_char_length", + nested_path="metrics.predicted.mentions", + histogram=HistogramAggregation( + id="length", + field="chars_length", + fixed_interval=1, + ), + ), + "annotated_mention_char_length": NestedHistogramAggregation( + id="annotated_mention_char_length", + nested_path="metrics.annotated.mentions", + histogram=HistogramAggregation( + id="length", + field="chars_length", + fixed_interval=1, + ), + ), + "predicted_entity_labels": NestedTermsAggregation( + id="predicted_entity_labels", + nested_path="metrics.predicted.mentions", + terms=TermsAggregation( + id="terms", + field="label", + default_size=_DEFAULT_MAX_ENTITY_BUCKET, + ), + ), + "annotated_entity_labels": NestedTermsAggregation( + id="annotated_entity_labels", + nested_path="metrics.annotated.mentions", + terms=TermsAggregation( + id="terms", + field="label", + default_size=_DEFAULT_MAX_ENTITY_BUCKET, + ), + ), + "predicted_entity_density": NestedHistogramAggregation( + id="predicted_entity_density", + nested_path="metrics.predicted.mentions", + histogram=HistogramAggregation(id="histogram", field="density"), + ), + "annotated_entity_density": NestedHistogramAggregation( + id="annotated_entity_density", + nested_path="metrics.annotated.mentions", + histogram=HistogramAggregation(id="histogram", field="density"), + ), + "predicted_mention_token_length": NestedHistogramAggregation( + id="predicted_mention_token_length", + nested_path="metrics.predicted.mentions", + histogram=HistogramAggregation(id="histogram", field="tokens_length"), + ), + "annotated_mention_token_length": NestedHistogramAggregation( + id="annotated_mention_token_length", + nested_path="metrics.annotated.mentions", + histogram=HistogramAggregation(id="histogram", field="tokens_length"), + ), + "predicted_entity_capitalness": NestedTermsAggregation( + id="predicted_entity_capitalness", + nested_path="metrics.predicted.mentions", + terms=TermsAggregation(id="terms", field="capitalness"), + ), + "annotated_entity_capitalness": NestedTermsAggregation( + id="annotated_entity_capitalness", + nested_path="metrics.annotated.mentions", + terms=TermsAggregation(id="terms", field="capitalness"), + ), + "predicted_mentions_distribution": NestedBidimensionalTermsAggregation( + id="predicted_mentions_distribution", + nested_path="metrics.predicted.mentions", + biterms=BidimensionalTermsAggregation( + id="bi-dimensional", field_x="label", field_y="value" + ), + ), + "annotated_mentions_distribution": NestedBidimensionalTermsAggregation( + id="predicted_mentions_distribution", + nested_path="metrics.annotated.mentions", + biterms=BidimensionalTermsAggregation( + id="bi-dimensional", field_x="label", field_y="value" + ), + ), + "predicted_entity_consistency": EntityConsistency( + id="predicted_entity_consistency", + nested_path="metrics.predicted.mentions", + mention_field="value", + labels_field="label", + ), + "annotated_entity_consistency": EntityConsistency( + id="annotated_entity_consistency", + nested_path="metrics.annotated.mentions", + mention_field="value", + labels_field="label", + ), + "predicted_tag_consistency": EntityConsistency( + id="predicted_tag_consistency", + nested_path="metrics.predicted.tags", + mention_field="value", + labels_field="tag", + ), + "annotated_tag_consistency": EntityConsistency( + id="annotated_tag_consistency", + nested_path="metrics.annotated.tags", + mention_field="value", + labels_field="tag", + ), +} diff --git a/src/rubrix/server/elasticseach/query_helpers.py b/src/rubrix/server/daos/backend/query_helpers.py similarity index 90% rename from src/rubrix/server/elasticseach/query_helpers.py rename to src/rubrix/server/daos/backend/query_helpers.py index 0830cf51c5..3386c299bb 100644 --- a/src/rubrix/server/elasticseach/query_helpers.py +++ b/src/rubrix/server/daos/backend/query_helpers.py @@ -13,17 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Any, Dict, List, Optional, Type, Union from pydantic import BaseModel -from rubrix.server.apis.v0.models.commons.model import ( - EsRecordDataFieldNames, - SortableField, - TaskStatus, -) -from rubrix.server.elasticseach.mappings.helpers import mappings +from rubrix.server.commons.models import TaskStatus +from rubrix.server.daos.backend.mappings.helpers import mappings def nested_mappings_from_base_model(model_class: Type[BaseModel]) -> Dict[str, Any]: @@ -45,21 +40,6 @@ def resolve_mapping(info) -> Dict[str, Any]: } -def sort_by2elasticsearch( - sort: List[SortableField], valid_fields: Optional[List[str]] = None -) -> List[Dict[str, Any]]: - valid_fields = valid_fields or [] - result = [] - for sortable_field in sort: - if valid_fields: - assert sortable_field.id.split(".")[0] in valid_fields, ( - f"Wrong sort id {sortable_field.id}. Valid values are: " - f"{[str(v) for v in valid_fields]}" - ) - result.append({sortable_field.id: {"order": sortable_field.order}}) - return result - - def parse_aggregations( es_aggregations: Dict[str, Any] = None ) -> Optional[Dict[str, Any]]: @@ -126,10 +106,6 @@ def parse_buckets(buckets: List[Dict[str, Any]]) -> Dict[str, Any]: return result -def decode_field_name(field: EsRecordDataFieldNames) -> str: - return field.value - - class filters: """Group of functions related to elasticsearch filters""" @@ -172,29 +148,21 @@ def predicted_by(predicted_by: List[str] = None) -> Optional[Dict[str, Any]]: if not predicted_by: return None - return { - "terms": { - decode_field_name(EsRecordDataFieldNames.predicted_by): predicted_by - } - } + return {"terms": {"predicted_by": predicted_by}} @staticmethod def annotated_by(annotated_by: List[str] = None) -> Optional[Dict[str, Any]]: """Filter records with given predicted by terms""" if not annotated_by: return None - return { - "terms": { - decode_field_name(EsRecordDataFieldNames.annotated_by): annotated_by - } - } + return {"terms": {"annotated_by": annotated_by}} @staticmethod def status(status: List[TaskStatus] = None) -> Optional[Dict[str, Any]]: """Filter records by status""" if not status: return None - return {"terms": {decode_field_name(EsRecordDataFieldNames.status): status}} + return {"terms": {"status": status}} @staticmethod def metadata(metadata: Dict[str, Union[str, List[str]]]) -> List[Dict[str, Any]]: @@ -269,7 +237,9 @@ class aggregations: MAX_AGGREGATION_SIZE = 5000 # TODO: improve by setting env var @staticmethod - def nested_aggregation(nested_path: str, inner_aggregation: Dict[str, Any]): + def nested_aggregation( + nested_path: str, inner_aggregation: Dict[str, Any] + ) -> Dict[str, Any]: inner_meta = list(inner_aggregation.values())[0].get("meta", {}) return { "meta": { @@ -374,6 +344,10 @@ def __resolve_aggregation_for_field_type( if aggregation } + @staticmethod + def filters_aggregation(filters: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: + return {"filters": {"filters": filters}} + def find_nested_field_path( field_name: str, mapping_definition: Dict[str, Any] diff --git a/src/rubrix/server/elasticseach/mappings/__init__.py b/src/rubrix/server/daos/backend/search/__init__.py similarity index 100% rename from src/rubrix/server/elasticseach/mappings/__init__.py rename to src/rubrix/server/daos/backend/search/__init__.py diff --git a/src/rubrix/server/daos/backend/search/model.py b/src/rubrix/server/daos/backend/search/model.py new file mode 100644 index 0000000000..ff2598a1c0 --- /dev/null +++ b/src/rubrix/server/daos/backend/search/model.py @@ -0,0 +1,67 @@ +from enum import Enum +from typing import Dict, List, Optional, TypeVar, Union + +from pydantic import BaseModel, Field + +from rubrix.server.commons.models import TaskStatus + + +class SortOrder(str, Enum): + asc = "asc" + desc = "desc" + + +class QueryRange(BaseModel): + + range_from: float = Field(default=0.0, alias="from") + range_to: float = Field(default=None, alias="to") + + class Config: + allow_population_by_field_name = True + + +class SortableField(BaseModel): + """Sortable field structure""" + + id: str + order: SortOrder = SortOrder.asc + + +class SortConfig(BaseModel): + shuffle: bool = False + + sort_by: List[SortableField] = Field(default_factory=list) + valid_fields: List[str] = Field(default_factory=list) + + +class BaseQuery(BaseModel): + pass + + +class BaseDatasetsQuery(BaseQuery): + tasks: Optional[List[str]] = None + owners: Optional[List[str]] = None + include_no_owner: bool = None + name: Optional[str] = None + + +class BaseRecordsQuery(BaseQuery): + + query_text: Optional[str] = None + advanced_query_dsl: bool = False + + ids: Optional[List[Union[str, int]]] + + annotated_by: List[str] = Field(default_factory=list) + predicted_by: List[str] = Field(default_factory=list) + + status: List[TaskStatus] = Field(default_factory=list) + metadata: Optional[Dict[str, Union[str, List[str]]]] = None + + has_annotation: Optional[bool] = None + has_prediction: Optional[bool] = None + + +BackendQuery = TypeVar("BackendQuery", bound=BaseQuery) +BackendRecordsQuery = TypeVar("BackendRecordsQuery", bound=BaseRecordsQuery) +BackendDatasetsQuery = TypeVar("BackendDatasetsQuery", bound=BaseDatasetsQuery) diff --git a/src/rubrix/server/daos/backend/search/query_builder.py b/src/rubrix/server/daos/backend/search/query_builder.py new file mode 100644 index 0000000000..d63c2a4a30 --- /dev/null +++ b/src/rubrix/server/daos/backend/search/query_builder.py @@ -0,0 +1,227 @@ +import logging +from enum import Enum +from typing import Any, Dict, List, Optional + +from luqum.elasticsearch import ElasticsearchQueryBuilder, SchemaAnalyzer +from luqum.parser import parser + +from rubrix.server.daos.backend.query_helpers import filters +from rubrix.server.daos.backend.search.model import ( + BackendDatasetsQuery, + BackendQuery, + BackendRecordsQuery, + BaseDatasetsQuery, + QueryRange, + SortableField, + SortConfig, +) + + +class EsQueryBuilder: + _INSTANCE: "EsQueryBuilder" = None + _LOGGER = logging.getLogger(__name__) + + @classmethod + def get_instance(cls): + if not cls._INSTANCE: + cls._INSTANCE = cls() + return cls._INSTANCE + + def _datasets_to_es_query( + self, query: Optional[BackendDatasetsQuery] = None + ) -> Dict[str, Any]: + if not query: + return filters.match_all() + + query_filters = [] + if query.owners: + owners_filter = filters.terms_filter("owner.keyword", query.owners) + if query.include_no_owner: + query_filters.append( + filters.boolean_filter( + minimum_should_match=1, # OR Condition + should_filters=[ + owners_filter, + filters.boolean_filter( + must_not_query=filters.exists_field("owner") + ), + ], + ) + ) + else: + query_filters.append(owners_filter) + + if query.tasks: + query_filters.append( + filters.terms_filter(field="task.keyword", values=query.tasks) + ) + if query.name: + query_filters.append( + filters.term_filter(field="name.keyword", value=query.name) + ) + if not query_filters: + return filters.match_all() + return filters.boolean_filter( + should_filters=query_filters, minimum_should_match=len(query_filters) + ) + + def _search_to_es_query( + self, + schema: Optional[Dict[str, Any]] = None, + query: Optional[BackendRecordsQuery] = None, + sort: Optional[SortConfig] = None, + ): + if not query: + return filters.match_all() + + if not query.advanced_query_dsl or not query.query_text: + return self._to_es_query(query) + + text_search = query.query_text + new_query = query.copy(update={"query_text": None}) + + schema = SchemaAnalyzer(schema) + es_query_builder = ElasticsearchQueryBuilder( + **{ + **schema.query_builder_options(), + "default_field": "text", + } + ) + + query_tree = parser.parse(text_search) + query_text = es_query_builder(query_tree) + + return filters.boolean_filter( + filter_query=self._to_es_query(new_query), must_query=query_text + ) + + def map_2_es_query( + self, + schema: Optional[Dict[str, Any]] = None, + query: Optional[BackendQuery] = None, + sort: Optional[SortConfig] = None, + id_from: Optional[str] = None, + ) -> Dict[str, Any]: + es_query: Dict[str, Any] = ( + {"query": self._datasets_to_es_query(query)} + if isinstance(query, BaseDatasetsQuery) + else {"query": self._search_to_es_query(schema, query)} + ) + + if id_from: + es_query["search_after"] = [id_from] + sort = SortConfig() # sort by id as default + + es_sort = self.map_2_es_sort_configuration(schema=schema, sort=sort) + if es_sort: + es_query["sort"] = es_sort + + return es_query + + def map_2_es_sort_configuration( + self, schema: Optional[Dict[str, Any]] = None, sort: Optional[SortConfig] = None + ) -> Optional[List[Dict[str, Any]]]: + + if not sort: + return None + + # TODO(@frascuchon): compute valid list from the schema + valid_fields = sort.valid_fields or [ + "id", + "metadata", + "score", + "predicted", + "predicted_as", + "predicted_by", + "annotated_as", + "annotated_by", + "status", + "last_updated", + "event_timestamp", + ] + + id_field = "id" + id_keyword_field = "id.keyword" + schema = schema or {} + mappings = self._clean_mappings(schema.get("mappings", {})) + use_id_keyword = "text" == mappings.get("id") + + es_sort = [] + for sortable_field in sort.sort_by or [SortableField(id="id")]: + if valid_fields: + if not sortable_field.id.split(".")[0] in valid_fields: + raise AssertionError( + f"Wrong sort id {sortable_field.id}. Valid values are: " + f"{[str(v) for v in valid_fields]}" + ) + field = sortable_field.id + if field == id_field and use_id_keyword: + field = id_keyword_field + es_sort.append({field: {"order": sortable_field.order}}) + + return es_sort + + @classmethod + def _to_es_query(cls, query: BackendRecordsQuery) -> Dict[str, Any]: + if query.ids: + return filters.ids_filter(query.ids) + + query_text = filters.text_query(query.query_text) + all_filters = filters.metadata(query.metadata) + if query.has_annotation: + all_filters.append(filters.exists_field("annotated_by")) + if query.has_prediction: + all_filters.append(filters.exists_field("predicted_by")) + + query_data = query.dict( + exclude={ + "advanced_query_dsl", + "query_text", + "metadata", + "uncovered_by_rules", + "has_annotation", + "has_prediction", + } + ) + for key, value in query_data.items(): + if value is None: + continue + key_filter = None + if isinstance(value, dict): + value = getattr(query, key) # check the original field type + if isinstance(value, list): + key_filter = filters.terms_filter(key, value) + elif isinstance(value, (str, Enum)): + key_filter = filters.term_filter(key, value) + elif isinstance(value, QueryRange): + key_filter = filters.range_filter( + field=key, value_from=value.range_from, value_to=value.range_to + ) + + else: + cls._LOGGER.warning(f"Cannot parse query value {value} for key {key}") + if key_filter: + all_filters.append(key_filter) + + return filters.boolean_filter( + must_query=query_text or filters.match_all(), + filter_query=filters.boolean_filter( + should_filters=all_filters, minimum_should_match=len(all_filters) + ) + if all_filters + else None, + must_not_query=filters.boolean_filter( + should_filters=[filters.text_query(q) for q in query.uncovered_by_rules] + ) + if hasattr(query, "uncovered_by_rules") and query.uncovered_by_rules + else None, + ) + + def _clean_mappings(self, mappings: Dict[str, Any]): + if not mappings: + return {} + + return { + key: definition.get("type") or self._clean_mappings(definition) + for key, definition in mappings["properties"].items() + } diff --git a/src/rubrix/server/daos/datasets.py b/src/rubrix/server/daos/datasets.py index eaab6919cc..d476edd1eb 100644 --- a/src/rubrix/server/daos/datasets.py +++ b/src/rubrix/server/daos/datasets.py @@ -12,24 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json from typing import Any, Dict, List, Optional, Type from fastapi import Depends -from rubrix.server.daos.models.datasets import BaseDatasetDB, DatasetDB, SettingsDB -from rubrix.server.daos.records import DatasetRecordsDAO, dataset_records_index -from rubrix.server.elasticseach import query_helpers -from rubrix.server.elasticseach.client_wrapper import ElasticsearchWrapper -from rubrix.server.elasticseach.mappings.datasets import ( - DATASETS_INDEX_NAME, - DATASETS_INDEX_TEMPLATE, +from rubrix.server.daos.backend.elasticsearch import ElasticsearchBackend +from rubrix.server.daos.backend.search.model import BaseDatasetsQuery +from rubrix.server.daos.models.datasets import ( + BaseDatasetDB, + BaseDatasetSettingsDB, + DatasetDB, + DatasetSettingsDB, ) +from rubrix.server.daos.records import DatasetRecordsDAO from rubrix.server.errors import WrongTaskError NO_WORKSPACE = "" -MAX_NUMBER_OF_LISTED_DATASETS = 2500 class DatasetsDAO: @@ -40,7 +39,7 @@ class DatasetsDAO: @classmethod def get_instance( cls, - es: ElasticsearchWrapper = Depends(ElasticsearchWrapper.get_instance), + es: ElasticsearchBackend = Depends(ElasticsearchBackend.get_instance), records_dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance), ) -> "DatasetsDAO": """ @@ -63,207 +62,91 @@ def get_instance( cls._INSTANCE = cls(es, records_dao) return cls._INSTANCE - def __init__(self, es: ElasticsearchWrapper, records_dao: DatasetRecordsDAO): + def __init__(self, es: ElasticsearchBackend, records_dao: DatasetRecordsDAO): self._es = es self.__records_dao__ = records_dao self.init() def init(self): """Initializes dataset dao. Used on app startup""" - self._es.create_index_template( - name=DATASETS_INDEX_NAME, - template=DATASETS_INDEX_TEMPLATE, - force_recreate=True, - ) - self._es.create_index(DATASETS_INDEX_NAME) + self._es.create_datasets_index(force_recreate=True) def list_datasets( self, owner_list: List[str] = None, - task2dataset_map: Dict[str, Type[BaseDatasetDB]] = None, - ) -> List[BaseDatasetDB]: - filters = [] - if owner_list: - owners_filter = query_helpers.filters.terms_filter( - "owner.keyword", owner_list - ) - if NO_WORKSPACE in owner_list: - filters.append( - query_helpers.filters.boolean_filter( - minimum_should_match=1, # OR Condition - should_filters=[ - query_helpers.filters.boolean_filter( - must_not_query=query_helpers.filters.exists_field( - "owner" - ) - ), - owners_filter, - ], - ) - ) - else: - filters.append(owners_filter) - - if task2dataset_map: - filters.append( - query_helpers.filters.terms_filter( - field="task.keyword", values=[task for task in task2dataset_map] - ) - ) - - docs = self._es.list_documents( - index=DATASETS_INDEX_NAME, - fetch_once=True, - # TODO(@frascuchon): include id as part of the document as keyword to enable sorting by id - size=MAX_NUMBER_OF_LISTED_DATASETS, - query={ - "query": query_helpers.filters.boolean_filter( - should_filters=filters, minimum_should_match=len(filters) - ) - } - if filters - else None, + task2dataset_map: Dict[str, Type[DatasetDB]] = None, + name: Optional[str] = None, + ) -> List[DatasetDB]: + owner_list = owner_list or [] + query = BaseDatasetsQuery( + owners=owner_list, + include_no_owner=NO_WORKSPACE in owner_list, + tasks=[task for task in task2dataset_map] if task2dataset_map else None, + name=name, ) + docs = self._es.list_datasets(query) task2dataset_map = task2dataset_map or {} return [ self._es_doc_to_instance( - doc, ds_class=task2dataset_map.get(task, DatasetDB) + doc, ds_class=task2dataset_map.get(task, BaseDatasetDB) ) for doc in docs for task in [self.__get_doc_field__(doc, "task")] ] - def create_dataset( - self, dataset: BaseDatasetDB, mappings: Dict[str, Any] - ) -> BaseDatasetDB: - """ - Stores a dataset in elasticsearch and creates corresponding dataset records index - - Parameters - ---------- - dataset: - The dataset - - Returns - ------- - Created dataset - """ - - self._es.add_document( - index=DATASETS_INDEX_NAME, - doc_id=dataset.id, - document=self._dataset_to_es_doc(dataset), + def create_dataset(self, dataset: DatasetDB) -> DatasetDB: + self._es.add_dataset_document( + id=dataset.id, document=self._dataset_to_es_doc(dataset) ) - self.__records_dao__.create_dataset_index( - dataset, mappings=mappings, force_recreate=True + self._es.create_dataset_index( + id=dataset.id, + task=dataset.task, + force_recreate=True, ) return dataset def update_dataset( self, - dataset: BaseDatasetDB, - ) -> BaseDatasetDB: - """ - Updates an stored dataset - - Parameters - ---------- - dataset: - The dataset - - Returns - ------- - The updated dataset - - """ - dataset_id = dataset.id + dataset: DatasetDB, + ) -> DatasetDB: - self._es.update_document( - index=DATASETS_INDEX_NAME, - doc_id=dataset_id, - document=self._dataset_to_es_doc(dataset), - partial_update=True, + self._es.update_dataset_document( + id=dataset.id, document=self._dataset_to_es_doc(dataset) ) return dataset - def delete_dataset(self, dataset: BaseDatasetDB): - """ - Deletes indices related to provided dataset - - Parameters - ---------- - dataset: - The dataset - - """ - try: - self._es.delete_index(dataset_records_index(dataset.id)) - finally: - self._es.delete_document(index=DATASETS_INDEX_NAME, doc_id=dataset.id) + def delete_dataset(self, dataset: DatasetDB): + self._es.delete(dataset.id) def find_by_name( self, name: str, owner: Optional[str], - as_dataset_class: Type[BaseDatasetDB] = DatasetDB, + as_dataset_class: Type[DatasetDB] = BaseDatasetDB, task: Optional[str] = None, - ) -> Optional[BaseDatasetDB]: - """ - Finds a dataset by name + ) -> Optional[DatasetDB]: - Parameters - ---------- - name: The dataset name - owner: The dataset owner - as_dataset_class: The dataset class used to return data - task: The dataset task string definition - - Returns - ------- - The found dataset if any. None otherwise - """ - dataset_id = DatasetDB.build_dataset_id( + dataset_id = BaseDatasetDB.build_dataset_id( name=name, owner=owner, ) - document = self._es.get_document_by_id( - index=DATASETS_INDEX_NAME, doc_id=dataset_id - ) - if not document and owner is None: - # We must search by name since we have no owner - results = self._es.list_documents( - index=DATASETS_INDEX_NAME, - query={"query": {"term": {"name.keyword": name}}}, - fetch_once=True, - ) - results = list(results) - if len(results) == 0: - return None - - if len(results) > 1: - raise ValueError( - f"Ambiguous dataset info found for name {name}. Please provide a valid owner" - ) - - document = results[0] - + document = self._es.find_dataset(id=dataset_id, name=name, owner=owner) if document is None: return None - base_ds = self._es_doc_to_instance(document) if task and task != base_ds.task: raise WrongTaskError( detail=f"Provided task {task} cannot be applied to dataset" ) - dataset_type = as_dataset_class or DatasetDB + dataset_type = as_dataset_class or BaseDatasetDB return self._es_doc_to_instance(document, ds_class=dataset_type) @staticmethod def _es_doc_to_instance( - doc: Dict[str, Any], ds_class: Type[BaseDatasetDB] = DatasetDB - ) -> BaseDatasetDB: - """Transforms a stored elasticsearch document into a `DatasetDB`""" + doc: Dict[str, Any], ds_class: Type[DatasetDB] = BaseDatasetDB + ) -> DatasetDB: + """Transforms a stored elasticsearch document into a `BaseDatasetDB`""" def __key_value_list_to_dict__( key_value_list: List[Dict[str, Any]] @@ -300,67 +183,47 @@ def __dict_to_key_value_list__(data: Dict[str, Any]) -> List[Dict[str, Any]]: } def copy(self, source: DatasetDB, target: DatasetDB): - source_doc = self._es.get_document_by_id( - index=DATASETS_INDEX_NAME, doc_id=source.id - ) - self._es.add_document( - index=DATASETS_INDEX_NAME, - doc_id=target.id, + source_doc = self._es.find_dataset(id=source.id) + self._es.add_dataset_document( + id=target.id, document={ **source_doc["_source"], # we copy extended fields from source document **self._dataset_to_es_doc(target), }, ) - index_from = dataset_records_index(source.id) - index_to = dataset_records_index(target.id) - self._es.clone_index(index=index_from, clone_to=index_to) + self._es.copy(id_from=source.id, id_to=target.id) def close(self, dataset: DatasetDB): """Close a dataset. It's mean that release all related resources, like elasticsearch index""" - self._es.close_index(dataset_records_index(dataset.id)) + self._es.close(dataset.id) def open(self, dataset: DatasetDB): """Make available a dataset""" - self._es.open_index(dataset_records_index(dataset.id)) + self._es.open(dataset.id) def get_all_workspaces(self) -> List[str]: """Get all datasets (Only for super users)""" - - workspaces_dict = self._es.aggregate( - index=DATASETS_INDEX_NAME, - aggregation=query_helpers.aggregations.terms_aggregation( - "owner.keyword", - missing=NO_WORKSPACE, - size=500, # TODO: A max number of workspaces env var could be leveraged by this. - ), - ) - - return [k for k in workspaces_dict] - - def save_settings(self, dataset: DatasetDB, settings: SettingsDB) -> SettingsDB: - self._es.update_document( - index=DATASETS_INDEX_NAME, - doc_id=dataset.id, - document={"settings": settings.dict(exclude_none=True)}, - partial_update=True, + metric_data = self._es.compute_rubrix_metric(metric_id="all_rubrix_workspaces") + return [k for k in metric_data] + + def save_settings( + self, dataset: DatasetDB, settings: DatasetSettingsDB + ) -> BaseDatasetSettingsDB: + self._es.update_dataset_document( + id=dataset.id, document={"settings": settings.dict(exclude_none=True)} ) return settings def load_settings( - self, dataset: DatasetDB, as_class: Type[SettingsDB] - ) -> Optional[SettingsDB]: - doc = self._es.get_document_by_id(index=DATASETS_INDEX_NAME, doc_id=dataset.id) + self, dataset: DatasetDB, as_class: Type[DatasetSettingsDB] + ) -> Optional[DatasetSettingsDB]: + doc = self._es.find_dataset(id=dataset.id) if doc: settings = self.__get_doc_field__(doc, field="settings") return as_class.parse_obj(settings) if settings else None def delete_settings(self, dataset: DatasetDB): - self._es.update_document( - index=DATASETS_INDEX_NAME, - doc_id=dataset.id, - script='ctx._source.remove("settings")', - partial_update=True, - ) + self._es.remove_dataset_field(dataset.id, field="settings") def __get_doc_field__(self, doc: Dict[str, Any], field: str) -> Optional[Any]: return doc["_source"].get(field) diff --git a/src/rubrix/server/daos/models/datasets.py b/src/rubrix/server/daos/models/datasets.py index 9d4a1e9f8c..9d06f9f8e4 100644 --- a/src/rubrix/server/daos/models/datasets.py +++ b/src/rubrix/server/daos/models/datasets.py @@ -3,10 +3,13 @@ from pydantic import BaseModel, Field +from rubrix._constants import DATASET_NAME_REGEX_PATTERN +from rubrix.server.commons.models import TaskType -class DatasetDB(BaseModel): - name: str - task: str + +class BaseDatasetDB(BaseModel): + name: str = Field(regex=DATASET_NAME_REGEX_PATTERN) + task: TaskType owner: Optional[str] = None tags: Dict[str, str] = Field(default_factory=dict) metadata: Dict[str, Any] = Field(default_factory=dict) @@ -29,8 +32,9 @@ def id(self) -> str: return self.build_dataset_id(self.name, self.owner) -class SettingsDB(BaseModel): +class BaseDatasetSettingsDB(BaseModel): pass -BaseDatasetDB = TypeVar("BaseDatasetDB", bound=DatasetDB) +DatasetDB = TypeVar("DatasetDB", bound=BaseDatasetDB) +DatasetSettingsDB = TypeVar("DatasetSettingsDB", bound=BaseDatasetSettingsDB) diff --git a/src/rubrix/server/daos/models/records.py b/src/rubrix/server/daos/models/records.py index b221588b62..8bff581261 100644 --- a/src/rubrix/server/daos/models/records.py +++ b/src/rubrix/server/daos/models/records.py @@ -12,56 +12,159 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from datetime import datetime +from typing import Any, Dict, Generic, List, Optional, TypeVar, Union +from uuid import uuid4 -from typing import Any, Dict, List, Optional +from pydantic import BaseModel, Field, validator +from pydantic.generics import GenericModel -from pydantic import BaseModel, Field +from rubrix._constants import MAX_KEYWORD_LENGTH +from rubrix.server.commons.models import PredictionStatus, TaskStatus, TaskType +from rubrix.server.daos.backend.search.model import BackendRecordsQuery, SortConfig +from rubrix.server.helpers import flatten_dict +from rubrix.utils import limit_value_length -class RecordSearch(BaseModel): - """ - Dao search +class DaoRecordsSearch(BaseModel): - Attributes: - ----------- + query: Optional[BackendRecordsQuery] = None + sort: SortConfig = Field(default_factory=SortConfig) - query: - The elasticsearch search query portion - sort: - The elasticsearch sort order - aggregations: - The elasticsearch search aggregations - """ +class DaoRecordsSearchResults(BaseModel): + total: int + records: List[Dict[str, Any]] - query: Optional[Dict[str, Any]] = None - sort: List[Dict[str, Any]] = Field(default_factory=list) - aggregations: Optional[Dict[str, Any]] +class BaseAnnotationDB(BaseModel): + agent: str = Field(max_length=64) -class RecordSearchResults(BaseModel): - """ - Dao search results - Attributes: - ----------- +AnnotationDB = TypeVar("AnnotationDB", bound=BaseAnnotationDB) - total: int - The total of query results - records: List[T] - List of records retrieved for the pagination configuration - aggregations: Optional[Dict[str, Dict[str, Any]]] - The query aggregations grouped by task. Optional - words: Optional[Dict[str, int]] - The words cloud aggregations - metadata: Optional[Dict[str, int]] - Metadata fields aggregations - metrics: Optional[List[DatasetMetricResults]] - Calculated metrics for search - """ - total: int - records: List[Dict[str, Any]] - aggregations: Optional[Dict[str, Dict[str, Any]]] = Field(default_factory=dict) - words: Optional[Dict[str, int]] = None - metadata: Optional[Dict[str, int]] = None +class BaseRecordInDB(GenericModel, Generic[AnnotationDB]): + id: Optional[Union[int, str]] = Field(default=None) + metadata: Dict[str, Any] = Field(default=None) + event_timestamp: Optional[datetime] = None + status: Optional[TaskStatus] = None + prediction: Optional[AnnotationDB] = None + annotation: Optional[AnnotationDB] = None + + @validator("id", always=True, pre=True) + def default_id_if_none_provided(cls, id: Optional[str]) -> str: + """Validates id info and sets a random uuid if not provided""" + if id is None: + return str(uuid4()) + return id + + @validator("status", always=True) + def fill_default_value(cls, status: TaskStatus): + """Fastapi validator for set default task status""" + return TaskStatus.default if status is None else status + + @validator("metadata", pre=True) + def flatten_metadata(cls, metadata: Dict[str, Any]): + """ + A fastapi validator for flatten metadata dictionary + + Parameters + ---------- + metadata: + The metadata dictionary + + Returns + ------- + A flatten version of metadata dictionary + + """ + if metadata: + metadata = flatten_dict(metadata, drop_empty=True) + metadata = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH) + return metadata + + @classmethod + def task(cls) -> TaskType: + """The task type related to this task info""" + raise NotImplementedError + + @property + def predicted(self) -> Optional[PredictionStatus]: + """The task record prediction status (if any)""" + return None + + @property + def predicted_as(self) -> Optional[List[str]]: + """Predictions strings representation""" + return None + + @property + def annotated_as(self) -> Optional[List[str]]: + """Annotations strings representation""" + return None + + @property + def scores(self) -> Optional[List[float]]: + """Prediction scores""" + return None + + def all_text(self) -> str: + """All textual information related to record""" + raise NotImplementedError + + @property + def predicted_by(self) -> List[str]: + """The prediction agents""" + if self.prediction: + return [self.prediction.agent] + return [] + + @property + def annotated_by(self) -> List[str]: + """The annotation agents""" + if self.annotation: + return [self.annotation.agent] + return [] + + def extended_fields(self) -> Dict[str, Any]: + """ + Used for extends fields to store in db. Tasks that would include extra + properties than commons (predicted, annotated_as,....) could implement + this method. + """ + return { + "predicted": self.predicted, + "annotated_as": self.annotated_as, + "predicted_as": self.predicted_as, + "annotated_by": self.annotated_by, + "predicted_by": self.predicted_by, + "score": self.scores, + } + + def dict(self, *args, **kwargs) -> "DictStrAny": + """ + Extends base component dict extending object properties + and user defined extended fields + """ + return { + **super().dict(*args, **kwargs), + **self.extended_fields(), + } + + +class BaseRecordDB(BaseRecordInDB, Generic[AnnotationDB]): + + # Read only ones + metrics: Dict[str, Any] = Field(default_factory=dict) + search_keywords: Optional[List[str]] = None + last_updated: datetime = None + + @validator("search_keywords") + def remove_duplicated_keywords(cls, value) -> List[str]: + """Remove duplicated keywords""" + if value: + return list(set(value)) + + +RecordDB = TypeVar("RecordDB", bound=BaseRecordDB) diff --git a/src/rubrix/server/daos/records.py b/src/rubrix/server/daos/records.py index d7ae4d0666..1793f83a76 100644 --- a/src/rubrix/server/daos/records.py +++ b/src/rubrix/server/daos/records.py @@ -13,66 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dataclasses import datetime -import re -from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar +from typing import Any, Dict, Iterable, List, Optional, Type -import deprecated from fastapi import Depends -from rubrix.server.apis.v0.models.commons.model import BaseRecord, TaskType -from rubrix.server.daos.models.datasets import BaseDatasetDB -from rubrix.server.daos.models.records import RecordSearch, RecordSearchResults -from rubrix.server.elasticseach.client_wrapper import ( +from rubrix.server.daos.backend.elasticsearch import ( ClosedIndexError, - ElasticsearchWrapper, + ElasticsearchBackend, IndexNotFoundError, - create_es_wrapper, ) -from rubrix.server.elasticseach.mappings.datasets import DATASETS_RECORDS_INDEX_NAME -from rubrix.server.elasticseach.mappings.helpers import ( - mappings, - tasks_common_mappings, - tasks_common_settings, +from rubrix.server.daos.backend.search.model import BackendRecordsQuery +from rubrix.server.daos.models.datasets import DatasetDB +from rubrix.server.daos.models.records import ( + DaoRecordsSearch, + DaoRecordsSearchResults, + RecordDB, ) -from rubrix.server.elasticseach.query_helpers import parse_aggregations from rubrix.server.errors import ClosedDatasetError, MissingDatasetRecordsError -from rubrix.server.errors.task_errors import MetadataLimitExceededError -from rubrix.server.settings import settings - -DBRecord = TypeVar("DBRecord", bound=BaseRecord) - - -@dataclasses.dataclass -class _IndexTemplateExtensions: - - analyzers: List[Dict[str, Any]] = dataclasses.field(default_factory=list) - properties: List[Dict[str, Any]] = dataclasses.field(default_factory=list) - dynamic_templates: List[Dict[str, Any]] = dataclasses.field(default_factory=list) - - -def dataset_records_index(dataset_id: str) -> str: - """ - Returns dataset records index for a given dataset id - - The dataset info is stored in two elasticsearch indices. The main - index where all datasets definition are stored and - an specific dataset index for data records. - - This function calculates the corresponding dataset records index - for a given dataset id. - - Parameters - ---------- - dataset_id - - Returns - ------- - The dataset records index name - - """ - return DATASETS_RECORDS_INDEX_NAME.format(dataset_id) class DatasetRecordsDAO: @@ -80,24 +38,10 @@ class DatasetRecordsDAO: _INSTANCE = None - # Keep info about elasticsearch mappings per task - # This info must be provided by each task using dao.register_task_mappings method - _MAPPINGS_BY_TASKS = {} - - __HIGHLIGHT_PRE_TAG__ = "<@@-rb-key>" - __HIGHLIGHT_POST_TAG__ = "" - __HIGHLIGHT_VALUES_REGEX__ = re.compile( - rf"{__HIGHLIGHT_PRE_TAG__}(.+?){__HIGHLIGHT_POST_TAG__}" - ) - - __HIGHLIGHT_PHRASE_PRE_PARSER_REGEX__ = re.compile( - rf"{__HIGHLIGHT_POST_TAG__}\s+{__HIGHLIGHT_PRE_TAG__}" - ) - @classmethod def get_instance( cls, - es: ElasticsearchWrapper = Depends(ElasticsearchWrapper.get_instance), + es: ElasticsearchBackend = Depends(ElasticsearchBackend.get_instance), ) -> "DatasetRecordsDAO": """ Creates a dataset records dao instance @@ -112,9 +56,8 @@ def get_instance( cls._INSTANCE = cls(es) return cls._INSTANCE - def __init__(self, es: ElasticsearchWrapper): + def __init__(self, es: ElasticsearchBackend): self._es = es - self.init() def init(self): """Initializes dataset records dao. Used on app startup""" @@ -122,10 +65,9 @@ def init(self): def add_records( self, - dataset: BaseDatasetDB, - mappings: Dict[str, Any], - records: List[DBRecord], - record_class: Type[DBRecord], + dataset: DatasetDB, + records: List[RecordDB], + record_class: Type[RecordDB], ) -> int: """ Add records to dataset @@ -160,71 +102,60 @@ def add_records( db_record.dict(exclude_none=False, exclude={"search_keywords"}) ) - index_name = self.create_dataset_index(dataset, mappings=mappings) - self._configure_metadata_fields(index_name, metadata_values) + self._es.create_dataset_index( + dataset.id, + task=dataset.task, + metadata_values=metadata_values, + ) + return self._es.add_documents( - index=index_name, + id=dataset.id, documents=documents, - doc_id=lambda _record: _record.get("id"), ) - def get_metadata_schema(self, dataset: BaseDatasetDB) -> Dict[str, str]: + def get_metadata_schema(self, dataset: DatasetDB) -> Dict[str, str]: """Get metadata fields schema for provided dataset""" - records_index = dataset_records_index(dataset.id) - return self._es.get_field_mapping(index=records_index, field_name="metadata.*") + + return self._es.get_metadata_mappings(id=dataset.id) + + def compute_metric( + self, + dataset: DatasetDB, + metric_id: str, + metric_params: Dict[str, Any] = None, + query: Optional[BackendRecordsQuery] = None, + ): + + return self._es.compute_metric( + id=dataset.id, + metric_id=metric_id, + query=query, + params=metric_params, + ) def search_records( self, - dataset: BaseDatasetDB, - search: Optional[RecordSearch] = None, + dataset: DatasetDB, + search: Optional[DaoRecordsSearch] = None, size: int = 100, record_from: int = 0, exclude_fields: List[str] = None, highligth_results: bool = True, - ) -> RecordSearchResults: - """ - SearchRequest records under a dataset given a search parameters. - - Parameters - ---------- - dataset: - The dataset - search: - The search params - size: - Number of records to retrieve (for pagination) - record_from: - Record from which to retrieve the records (for pagination) - exclude_fields: - a list of fields to exclude from the result source. Wildcards are accepted - Returns - ------- - The search result - - """ - search = search or RecordSearch() - records_index = dataset_records_index(dataset.id) - compute_aggregations = record_from == 0 - aggregation_requests = ( - {**(search.aggregations or {})} if compute_aggregations else {} - ) - - sort_config = self.__normalize_sort_config__(records_index, sort=search.sort) - - es_query = { - "_source": {"excludes": exclude_fields or []}, - "from": record_from, - "query": search.query or {"match_all": {}}, - "sort": sort_config, - "aggs": aggregation_requests, - } - if highligth_results: - es_query["highlight"] = self.__configure_query_highlight__( - task=dataset.task - ) + ) -> DaoRecordsSearchResults: try: - results = self._es.search(index=records_index, query=es_query, size=size) + search = search or DaoRecordsSearch() + + total, records = self._es.search_records( + id=dataset.id, + query=search.query, + sort=search.sort, + record_from=record_from, + size=size, + exclude_fields=exclude_fields, + enable_highlight=highligth_results, + ) + return DaoRecordsSearchResults(total=total, records=records) except ClosedIndexError: raise ClosedDatasetError(dataset.name) except IndexNotFoundError: @@ -232,43 +163,11 @@ def search_records( f"No records index found for dataset {dataset.name}" ) - hits = results["hits"] - total = hits["total"] - docs = hits["hits"] - search_aggregations = results.get("aggregations", {}) - - result = RecordSearchResults( - total=total, - records=list(map(self.__esdoc2record__, docs)), - ) - if search_aggregations: - parsed_aggregations = parse_aggregations(search_aggregations) - result.aggregations = parsed_aggregations - - return result - - def __normalize_sort_config__( - self, index: str, sort: Optional[List[Dict[str, Any]]] = None - ) -> List[Dict[str, Any]]: - id_field = "id" - id_keyword_field = "id.keyword" - sort_config = [] - - for sort_field in sort or [{id_field: {"order": "asc"}}]: - for field in sort_field: - if field == id_field and self._es.get_field_mapping( - index=index, field_name=id_keyword_field - ): - sort_config.append({id_keyword_field: sort_field[field]}) - else: - sort_config.append(sort_field) - return sort_config - def scan_dataset( self, - dataset: BaseDatasetDB, + dataset: DatasetDB, + search: Optional[DaoRecordsSearch] = None, limit: int = 1000, - search: Optional[RecordSearch] = None, id_from: Optional[str] = None, ) -> Iterable[Dict[str, Any]]: """ @@ -289,178 +188,12 @@ def scan_dataset( ------- An iterable over found dataset records """ - index = dataset_records_index(dataset.id) - search = search or RecordSearch() - - sort_cfg = self.__normalize_sort_config__( - index=index, sort=[{"id": {"order": "asc"}}] - ) - es_query = { - "query": search.query or {"match_all": {}}, - "highlight": self.__configure_query_highlight__(task=dataset.task), - "sort": sort_cfg, # Sort the search so the consistency is maintained in every search - } - - if id_from: - # Scroll method does not accept read_after, thus, this case is handled as a search - es_query["search_after"] = [id_from] - results = self._es.search(index=index, query=es_query, size=limit) - hits = results["hits"] - docs = hits["hits"] - - else: - docs = self._es.list_documents( - index, - query=es_query, - sort_cfg=sort_cfg, - ) - for doc in docs: - yield self.__esdoc2record__(doc) - - def __esdoc2record__( - self, - doc: Dict[str, Any], - query: Optional[str] = None, - is_phrase_query: bool = True, - ): - return { - **doc["_source"], - "id": doc["_id"], - "search_keywords": self.__parse_highlight_results__( - doc, query=query, is_phrase_query=is_phrase_query - ), - } - - @classmethod - def __parse_highlight_results__( - cls, - doc: Dict[str, Any], - query: Optional[str] = None, - is_phrase_query: bool = False, - ) -> Optional[List[str]]: - highlight_info = doc.get("highlight") - if not highlight_info: - return None - - search_keywords = [] - for content in highlight_info.values(): - if not isinstance(content, list): - content = [content] - text = " ".join(content) - - if is_phrase_query: - text = re.sub(cls.__HIGHLIGHT_PHRASE_PRE_PARSER_REGEX__, " ", text) - search_keywords.extend(re.findall(cls.__HIGHLIGHT_VALUES_REGEX__, text)) - return list(set(search_keywords)) - - def _configure_metadata_fields(self, index: str, metadata_values: Dict[str, Any]): - def check_metadata_length(metadata_length: int = 0): - if metadata_length > settings.metadata_fields_limit: - raise MetadataLimitExceededError( - length=metadata_length, limit=settings.metadata_fields_limit - ) - - def detect_nested_type(v: Any) -> bool: - """Returns True if value match as nested value""" - return isinstance(v, list) and isinstance(v[0], dict) - - check_metadata_length(len(metadata_values)) - check_metadata_length( - len( - { - *self._es.get_field_mapping( - index, "metadata.*", exclude_subfields=True - ), - *[f"metadata.{k}" for k in metadata_values.keys()], - } - ) - ) - for field, value in metadata_values.items(): - if detect_nested_type(value): - self._es.create_field_mapping( - index, - field_name=f"metadata.{field}", - mapping=mappings.nested_field(), - ) - - def create_dataset_index( - self, - dataset: BaseDatasetDB, - mappings: Dict[str, Any], - force_recreate: bool = False, - ) -> str: - """ - Creates a dataset records elasticsearch index based on dataset task type - - Args: - dataset: - The dataset - force_recreate: - If True, the index will be deleted and recreated - - Returns: - The generated index name. - """ - _mappings = tasks_common_mappings() - task_mappings = mappings.copy() - for k in task_mappings: - if isinstance(task_mappings[k], list): - _mappings[k] = [*_mappings.get(k, []), *task_mappings[k]] - else: - _mappings[k] = {**_mappings.get(k, {}), **task_mappings[k]} - - index_name = dataset_records_index(dataset.id) - self._es.create_index( - index=index_name, - settings=tasks_common_settings(), - mappings={**tasks_common_mappings(), **_mappings}, - force_recreate=force_recreate, + search = search or DaoRecordsSearch() + return self._es.scan_records( + id=dataset.id, query=search.query, limit=limit, id_from=id_from ) - return index_name - def get_dataset_schema(self, dataset: BaseDatasetDB) -> Dict[str, Any]: + def get_dataset_schema(self, dataset: DatasetDB) -> Dict[str, Any]: """Return inner elasticsearch index configuration""" - index_name = dataset_records_index(dataset.id) - response = self._es.__client__.indices.get_mapping(index=index_name) - - if index_name in response: - response = response.get(index_name) - - return response - - @classmethod - def __configure_query_highlight__(cls, task: TaskType): - - return { - "pre_tags": [cls.__HIGHLIGHT_PRE_TAG__], - "post_tags": [cls.__HIGHLIGHT_POST_TAG__], - "require_field_match": True, - "fields": { - "text": {}, - "text.*": {}, - # TODO(@frascuchon): `words` will be removed in version 0.16.0 - **({"inputs.*": {}} if task == TaskType.text_classification else {}), - }, - } - - -_instance: Optional[DatasetRecordsDAO] = None - - -@deprecated.deprecated(reason="Use `DatasetRecordsDAO.get_instance` instead") -def dataset_records_dao( - es: ElasticsearchWrapper = Depends(create_es_wrapper), -) -> DatasetRecordsDAO: - """ - Creates a dataset records dao instance - - Parameters - ---------- - es: - The elasticserach wrapper dependency - - """ - global _instance - if not _instance: - _instance = DatasetRecordsDAO(es) - return _instance + schema = self._es.get_mappings(id=dataset.id) + return schema diff --git a/src/rubrix/server/elasticseach/client_wrapper.py b/src/rubrix/server/elasticseach/client_wrapper.py deleted file mode 100644 index fcad32ab29..0000000000 --- a/src/rubrix/server/elasticseach/client_wrapper.py +++ /dev/null @@ -1,650 +0,0 @@ -# coding=utf-8 -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Callable, Dict, Iterable, List, Optional - -import deprecated -from opensearchpy import NotFoundError, OpenSearch, OpenSearchException, RequestError -from opensearchpy.helpers import bulk as es_bulk -from opensearchpy.helpers import scan as es_scan - -from rubrix.logging import LoggingMixin -from rubrix.server.elasticseach import query_helpers -from rubrix.server.errors import InvalidTextSearchError - -try: - import ujson as json -except ModuleNotFoundError: - import json - -from rubrix.server.settings import settings - - -class ClosedIndexError(Exception): - pass - - -class IndexNotFoundError(Exception): - pass - - -class GenericSearchError(Exception): - def __init__(self, origin_error: Exception): - self.origin_error = origin_error - - -class ElasticsearchWrapper(LoggingMixin): - """A simple elasticsearch client wrapper for atomize some repetitive operations""" - - _INSTANCE = None - - @classmethod - def get_instance(cls) -> "ElasticsearchWrapper": - """ - Creates an instance of ElasticsearchWrapper. - - This function is used in fastapi for resolve component dependencies. - - See - - Returns - ------- - - """ - - if cls._INSTANCE is None: - es_client = OpenSearch( - hosts=settings.elasticsearch, - verify_certs=settings.elasticsearch_ssl_verify, - ca_certs=settings.elasticsearch_ca_path, - # Extra args to es configuration -> TODO: extensible by settings - retry_on_timeout=True, - max_retries=5, - ) - cls._INSTANCE = cls(es_client) - - return cls._INSTANCE - - def __init__(self, es_client: OpenSearch): - self.__client__ = es_client - - @property - def client(self): - """The elasticsearch client""" - return self.__client__ - - def list_documents( - self, - index: str, - query: Dict[str, Any] = None, - sort_cfg: Optional[List[Dict[str, Any]]] = None, - size: Optional[int] = None, - fetch_once: bool = False, - ) -> Iterable[Dict[str, Any]]: - """ - List ALL documents of an elasticsearch index - Parameters - ---------- - index: - The index name - sor_id: - The sort id configuration - query: - The es query for filter results. Default: None - sort_cfg: - Customized configuration for sort-by id - size: - Amount of samples to retrieve per iteration, 1000 by default - fetch_once: - If enabled, will return only the `size` first records found. Default to: ``False`` - - Returns - ------- - A sequence of documents resulting from applying the query on the index - - """ - size = size or 1000 - query = query.copy() or {} - if sort_cfg: - query["sort"] = sort_cfg - query["track_total_hits"] = False # Speedup pagination - response = self.__client__.search(index=index, body=query, size=size) - while response["hits"]["hits"]: - for hit in response["hits"]["hits"]: - yield hit - if fetch_once: - break - - last_id = hit["_id"] - query["search_after"] = [last_id] - response = self.__client__.search(index=index, body=query, size=size) - - def index_exists(self, index: str) -> bool: - """ - Checks if provided index exists - - Parameters - ---------- - index: - The index name - - Returns - ------- - True if index exists. False otherwise - """ - return self.__client__.indices.exists(index) - - def search( - self, - index: str, - routing: str = None, - size: int = 100, - query: Dict[str, Any] = None, - ) -> Dict[str, Any]: - """ - Apply a search over an index. - See - - Parameters - ---------- - index: - The index name - routing: - The routing key. Optional - size: - Number of results to return. Default=100 - query: - The elasticsearch query. Optional - - Returns - ------- - - """ - try: - return self.__client__.search( - index=index, - body=query or {}, - routing=routing, - track_total_hits=True, - rest_total_hits_as_int=True, - size=size, - ) - except RequestError as rex: - - if rex.error == "search_phase_execution_exception": - detail = rex.info["error"] - detail = detail.get("root_cause") - detail = detail[0].get("reason") if detail else rex.info["error"] - - raise InvalidTextSearchError(detail) - - if rex.error == "index_closed_exception": - raise ClosedIndexError(index) - raise GenericSearchError(rex) - except NotFoundError as nex: - raise IndexNotFoundError(nex) - except OpenSearchException as ex: - raise GenericSearchError(ex) - - def create_index( - self, - index: str, - force_recreate: bool = False, - settings: Dict[str, Any] = None, - mappings: Dict[str, Any] = None, - ): - """ - Applies a index creation with provided mapping configuration. - - See - - Parameters - ---------- - index: - The index name - force_recreate: - If True, the index will be recreated (if exists). Default=False - settings: - The index settings configuration - mappings: - The mapping configuration. Optional. - - """ - if force_recreate: - self.delete_index(index) - if not self.index_exists(index): - self.__client__.indices.create( - index=index, - body={"settings": settings or {}, "mappings": mappings or {}}, - ignore=400, - ) - - def create_index_template( - self, name: str, template: Dict[str, Any], force_recreate: bool = False - ): - """ - Applies a index template creation with provided template definition. - - Parameters - ---------- - name: - The template index name - template: - The template definition - force_recreate: - If True, the template will be recreated (if exists). Default=False - - """ - if force_recreate or not self.__client__.indices.exists_template(name): - self.__client__.indices.put_template(name=name, body=template) - - def delete_index_template(self, index_template: str): - """Deletes an index template""" - if self.__client__.indices.exists_index_template(index_template): - self.__client__.indices.delete_template( - name=index_template, ignore=[400, 404] - ) - - def delete_index(self, index: str): - """Deletes an elasticsearch index""" - if self.index_exists(index): - self.__client__.indices.delete(index, ignore=[400, 404]) - - def add_document(self, index: str, doc_id: str, document: Dict[str, Any]): - """ - Creates/updates a document in an index - - See - - Parameters - ---------- - index: - The index name - doc_id: - The document id - document: - The document source - - """ - self.__client__.index(index=index, body=document, id=doc_id, refresh="wait_for") - - def get_document_by_id(self, index: str, doc_id: str) -> Optional[Dict[str, Any]]: - """ - Get a document by its id - - See - - Parameters - ---------- - index: - The index name - doc_id: - The document id - - Returns - ------- - The elasticsearch document if found, None otherwise - """ - - try: - if self.__client__.exists(index=index, id=doc_id): - return self.__client__.get(index=index, id=doc_id) - except NotFoundError: - return None - - def delete_document(self, index: str, doc_id: str): - """ - Deletes a document from an index. - - See - - Parameters - ---------- - index: - The index name - doc_id: - The document id - - Returns - ------- - - """ - if self.__client__.exists(index=index, id=doc_id): - self.__client__.delete(index=index, id=doc_id, refresh=True) - - def add_documents( - self, - index: str, - documents: List[Dict[str, Any]], - routing: Callable[[Dict[str, Any]], str] = None, - doc_id: Callable[[Dict[str, Any]], str] = None, - ) -> int: - """ - Adds or updated a set of documents to an index. Documents can contains - partial information of document. - - See - - Parameters - ---------- - index: - The index name - documents: - The set of documents - routing: - The routing key - doc_id - - Returns - ------- - The number of failed documents - """ - - def map_doc_2_action(doc: Dict[str, Any]) -> Dict[str, Any]: - """Configures bulk action""" - data = { - "_op_type": "index", - "_index": index, - "_routing": routing(doc) if routing else None, - **doc, - } - - _id = doc_id(doc) if doc_id else None - if _id is not None: - data["_id"] = _id - - return data - - success, failed = es_bulk( - self.__client__, - index=index, - actions=map(map_doc_2_action, documents), - raise_on_error=True, - refresh="wait_for", - ) - return len(failed) - - def get_mapping(self, index: str) -> Dict[str, Any]: - """ - Return the configured index mapping - - See `` - - """ - try: - response = self.__client__.indices.get_mapping( - index=index, - ignore_unavailable=False, - include_type_name=True, - ) - return list(response[index]["mappings"].values())[0]["properties"] - except NotFoundError: - return {} - - def get_field_mapping( - self, - index: str, - field_name: Optional[str] = None, - exclude_subfields: bool = False, - ) -> Dict[str, str]: - """ - Returns the mapping for a given field name (can be as wildcard notation). The result - consist on a dictionary with full field name as key and its type as value - - See - - Parameters - ---------- - index: - The index name - field_name: - The field name pattern - exclude_subfields: - If True, exclude extra subfields from mappings definition - - Returns - ------- - A dictionary with full field name as key and its type as value - """ - try: - response = self.__client__.indices.get_field_mapping( - fields=field_name or "*", - index=index, - ignore_unavailable=False, - ) - data = { - key: list(definition["mapping"].values())[0]["type"] - for key, definition in response[index]["mappings"].items() - } - - if exclude_subfields: - # Remove `text`, `exact` and `wordcloud` fields - def is_subfield(key: str): - for suffix in ["exact", "text", "wordcloud"]: - if suffix in key: - return True - return False - - data = { - key: value for key, value in data.items() if not is_subfield(key) - } - - return data - except NotFoundError: - # No mapping data - return {} - - def update_document( - self, - index: str, - doc_id: str, - document: Optional[Dict[str, Any]] = None, - script: Optional[str] = None, - partial_update: bool = False, - ): - """ - Updates a document in a given index - - Parameters - ---------- - index: - The index name - doc_id: - The document id - document: - The document data. Could be partial document info - partial_update: - If True, document contains partial info, and will be - merged with stored document. If false, the stored document - will be overwritten. Default=False - - Returns - ------- - - """ - # TODO: validate either doc or script are provided - if partial_update: - body = {"script": script} if script else {"doc": document} - - self.__client__.update( - index=index, - id=doc_id, - body=body, - refresh=True, - retry_on_conflict=500, # TODO: configurable - ) - else: - self.__client__.index(index=index, id=doc_id, body=document, refresh=True) - - def open_index(self, index: str): - """ - Open an elasticsearch index. If index is already open, this operation will do nothing - - See ``_ - - Parameters - ---------- - index: - The index name - """ - self.__client__.indices.open( - index=index, wait_for_active_shards=settings.es_records_index_shards - ) - - def close_index(self, index: str): - """ - Closes an elasticsearch index. If index is already closed, this operation will do nothing. - - See ``_ - - Parameters - ---------- - index: - The index name - """ - self.__client__.indices.close( - index=index, - ignore_unavailable=True, - wait_for_active_shards=settings.es_records_index_shards, - ) - - def clone_index(self, index: str, clone_to: str, override: bool = True): - """ - Clone an existing index. During index clone, source must be setup as read-only index. Then, changes can be - applied - - See ``_ - - Parameters - ---------- - index: - The source index name - clone_to: - The destination index name - override: - If True, destination index will be removed if exists - """ - index_read_only = self.is_index_read_only(index) - try: - if not index_read_only: - self.index_read_only(index, read_only=True) - if override: - self.delete_index(clone_to) - self.__client__.indices.clone( - index=index, - target=clone_to, - wait_for_active_shards=settings.es_records_index_shards, - ) - finally: - self.index_read_only(index, read_only=index_read_only) - self.index_read_only(clone_to, read_only=index_read_only) - - def is_index_read_only(self, index: str) -> bool: - """ - Fetch info about read-only configuration index - - Parameters - ---------- - index: - The index name - - Returns - ------- - True if queried index is read-only, False otherwise - - """ - response = self.__client__.indices.get_settings( - index=index, - name="index.blocks.write", - allow_no_indices=True, - flat_settings=True, - ) - return ( - response[index]["settings"]["index.blocks.write"] == "true" - if response - else False - ) - - def index_read_only(self, index: str, read_only: bool): - """ - Enable/disable index read only - - Parameters - ---------- - index: - The index name - read_only: - True for enable read-only, False otherwise - - """ - self.__client__.indices.put_settings( - index=index, - body={"settings": {"index.blocks.write": read_only}}, - ignore=404, - ) - - def create_field_mapping( - self, - index: str, - field_name: str, - mapping: Dict[str, Any], - ): - """Creates or updates an index field mapping configuration""" - self.__client__.indices.put_mapping( - index=index, - body={"properties": {field_name: mapping}}, - ) - - def get_cluster_info(self) -> Dict[str, Any]: - """Returns basic about es cluster""" - try: - return self.__client__.info() - except OpenSearchException as ex: - return {"error": ex} - - def aggregate(self, index: str, aggregation: Dict[str, Any]) -> Dict[str, Any]: - """Apply an aggregation over the index returning ONLY the agg results""" - aggregation_name = "aggregation" - results = self.search( - index=index, size=0, query={"aggs": {aggregation_name: aggregation}} - ) - - return query_helpers.parse_aggregations(results["aggregations"]).get( - aggregation_name - ) - - -_instance = None # The singleton instance - - -@deprecated.deprecated(reason="Use `ElasticsearchWrapper.get_instance` instead") -def create_es_wrapper() -> ElasticsearchWrapper: - """ - Creates an instance of ElasticsearchWrapper. - - This function is used in fastapi for resolve component dependencies. - - See - - Returns - ------- - - """ - - global _instance - if _instance is None: - _instance = ElasticsearchWrapper.get_instance() - - return _instance diff --git a/src/rubrix/server/apis/v0/helpers.py b/src/rubrix/server/helpers.py similarity index 68% rename from src/rubrix/server/apis/v0/helpers.py rename to src/rubrix/server/helpers.py index 2a17f0bef0..959b4ab103 100644 --- a/src/rubrix/server/apis/v0/helpers.py +++ b/src/rubrix/server/helpers.py @@ -12,31 +12,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +""" +Common helper functions +""" +from typing import Any, Dict, List, Optional -def takeuntil(iterable, limit: int): + +def unflatten_dict( + data: Dict[str, Any], sep: str = ".", stop_keys: Optional[List[str]] = None +) -> Dict[str, Any]: """ - Iterate over inner iterable until a count limit + Given a flat dictionary keys, build a hierarchical version by grouping keys Parameters ---------- - iterable: - The inner iterable - limit: - The limit + data: + The data dictionary + sep: + The key separator. Default "." + stop_keys + List of dictionary first level keys where hierarchy will stop Returns ------- """ - count = 0 - for e in iterable: - if count < limit: - yield e - count += 1 - else: - break + resultDict = {} + stop_keys = stop_keys or [] + for key, value in data.items(): + if key is not None: + parts = key.split(sep) + if parts[0] in stop_keys: + parts = [parts[0], sep.join(parts[1:])] + d = resultDict + for part in parts[:-1]: + if part not in d: + d[part] = {} + d = d[part] + d[parts[-1]] = value + return resultDict def flatten_dict( @@ -81,3 +96,27 @@ def _flatten_internal_(_data: Dict[str, Any], _parent_key="", _sep="."): return dict(items) return _flatten_internal_(data, _sep=sep) + + +def takeuntil(iterable, limit: int): + """ + Iterate over inner iterable until a count limit + + Parameters + ---------- + iterable: + The inner iterable + limit: + The limit + + Returns + ------- + + """ + count = 0 + for e in iterable: + if count < limit: + yield e + count += 1 + else: + break diff --git a/src/rubrix/server/server.py b/src/rubrix/server/server.py index c2cfcf3f26..d5f316cfc0 100644 --- a/src/rubrix/server/server.py +++ b/src/rubrix/server/server.py @@ -27,9 +27,9 @@ from rubrix import __version__ as rubrix_version from rubrix.logging import configure_logging +from rubrix.server.daos.backend.elasticsearch import ElasticsearchBackend from rubrix.server.daos.datasets import DatasetsDAO from rubrix.server.daos.records import DatasetRecordsDAO -from rubrix.server.elasticseach.client_wrapper import create_es_wrapper from rubrix.server.errors import APIErrorHandler, EntityNotFoundError from rubrix.server.routes import api_router from rubrix.server.security import auth @@ -89,7 +89,7 @@ async def configure_elasticsearch(): import opensearchpy try: - es_wrapper = create_es_wrapper() + es_wrapper = ElasticsearchBackend.get_instance() dataset_records: DatasetRecordsDAO = DatasetRecordsDAO(es_wrapper) datasets: DatasetsDAO = DatasetsDAO.get_instance( es_wrapper, records_dao=dataset_records diff --git a/src/rubrix/server/services/datasets.py b/src/rubrix/server/services/datasets.py index 51b09fd371..5d2d67264e 100644 --- a/src/rubrix/server/services/datasets.py +++ b/src/rubrix/server/services/datasets.py @@ -18,8 +18,8 @@ from fastapi import Depends -from rubrix.server.daos.datasets import BaseDatasetDB, DatasetsDAO, SettingsDB -from rubrix.server.daos.models.datasets import DatasetDB +from rubrix.server.daos.datasets import BaseDatasetSettingsDB, DatasetsDAO +from rubrix.server.daos.models.datasets import BaseDatasetDB from rubrix.server.errors import ( EntityAlreadyExistsError, EntityNotFoundError, @@ -28,13 +28,21 @@ ) from rubrix.server.security.model import User -Dataset = TypeVar("Dataset", bound=BaseDatasetDB) + +class ServiceBaseDataset(BaseDatasetDB): + pass -class SVCDatasetSettings(SettingsDB): +class ServiceBaseDatasetSettings(BaseDatasetSettingsDB): pass +ServiceDataset = TypeVar("ServiceDataset", bound=ServiceBaseDataset) +ServiceDatasetSettings = TypeVar( + "ServiceDatasetSettings", bound=ServiceBaseDatasetSettings +) + + class DatasetsService: _INSTANCE: "DatasetsService" = None @@ -50,9 +58,7 @@ def get_instance( def __init__(self, dao: DatasetsDAO): self.__dao__ = dao - def create_dataset( - self, user: User, dataset: Dataset, mappings: Dict[str, Any] - ) -> Dataset: + def create_dataset(self, user: User, dataset: ServiceDataset) -> ServiceDataset: user.check_workspace(dataset.owner) try: @@ -60,11 +66,11 @@ def create_dataset( user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner ) raise EntityAlreadyExistsError( - name=dataset.name, type=Dataset, workspace=dataset.owner + name=dataset.name, type=ServiceDataset, workspace=dataset.owner ) except WrongTaskError: # Found a dataset with same name but different task raise EntityAlreadyExistsError( - name=dataset.name, type=Dataset, workspace=dataset.owner + name=dataset.name, type=ServiceDataset, workspace=dataset.owner ) except EntityNotFoundError: # The dataset does not exist -> create it ! @@ -72,16 +78,16 @@ def create_dataset( dataset.created_by = user.username dataset.created_at = date_now dataset.last_updated = date_now - return self.__dao__.create_dataset(dataset, mappings=mappings) + return self.__dao__.create_dataset(dataset) def find_by_name( self, user: User, name: str, - as_dataset_class: Type[BaseDatasetDB] = DatasetDB, + as_dataset_class: Type[ServiceDataset] = ServiceBaseDataset, task: Optional[str] = None, workspace: Optional[str] = None, - ) -> Dataset: + ) -> ServiceDataset: owner = user.check_workspace(workspace) if task is None: @@ -96,20 +102,20 @@ def find_by_name( ) if found_ds is None: - raise EntityNotFoundError(name=name, type=Dataset) + raise EntityNotFoundError(name=name, type=ServiceDataset) if found_ds.owner and owner and found_ds.owner != owner: raise EntityNotFoundError( - name=name, type=Dataset + name=name, type=ServiceDataset ) if user.is_superuser() else ForbiddenOperationError() - return cast(Dataset, found_ds) + return cast(ServiceDataset, found_ds) def __find_by_name_with_superuser_fallback__( self, user: User, name: str, owner: Optional[str], - as_dataset_class: Optional[Type[DatasetDB]], + as_dataset_class: Optional[Type[ServiceDataset]], task: Optional[str] = None, ): found_ds = self.__dao__.find_by_name( @@ -125,7 +131,7 @@ def __find_by_name_with_superuser_fallback__( pass return found_ds - def delete(self, user: User, dataset: Dataset): + def delete(self, user: User, dataset: ServiceDataset): user.check_workspace(dataset.owner) found = self.__find_by_name_with_superuser_fallback__( user=user, @@ -140,10 +146,10 @@ def delete(self, user: User, dataset: Dataset): def update( self, user: User, - dataset: Dataset, + dataset: ServiceDataset, tags: Dict[str, str], metadata: Dict[str, Any], - ) -> Dataset: + ) -> ServiceDataset: found = self.find_by_name( user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner ) @@ -159,30 +165,30 @@ def list( self, user: User, workspaces: Optional[List[str]], - task2dataset_map: Dict[str, Type[BaseDatasetDB]] = None, - ) -> List[Dataset]: + task2dataset_map: Dict[str, Type[ServiceDataset]] = None, + ) -> List[ServiceDataset]: owners = user.check_workspaces(workspaces) return self.__dao__.list_datasets( owner_list=owners, task2dataset_map=task2dataset_map ) - def close(self, user: User, dataset: Dataset): + def close(self, user: User, dataset: ServiceDataset): found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.owner) self.__dao__.close(found) - def open(self, user: User, dataset: Dataset): + def open(self, user: User, dataset: ServiceDataset): found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.owner) self.__dao__.open(found) def copy_dataset( self, user: User, - dataset: Dataset, + dataset: ServiceDataset, copy_name: str, copy_workspace: Optional[str] = None, copy_tags: Dict[str, Any] = None, copy_metadata: Dict[str, Any] = None, - ) -> Dataset: + ) -> ServiceDataset: dataset_workspace = copy_workspace or dataset.owner dataset_workspace = user.check_workspace(dataset_workspace) @@ -221,58 +227,23 @@ def copy_dataset( return copy_dataset - def all_workspaces(self) -> List[str]: - """Retrieve all dataset workspaces""" - - workspaces = self.__dao__.get_all_workspaces() - # include the non-workspace workspace? - return workspaces - async def get_settings( - self, user: User, dataset: Dataset, class_type: Type[SVCDatasetSettings] - ) -> SVCDatasetSettings: - """ - Get the configured settings for dataset - - Args: - user: the connected user - dataset: the target dataset - class_type: the settings class - - Returns: - An instance of class_type settings configured for provided dataset - - """ + self, + user: User, + dataset: ServiceDataset, + class_type: Type[ServiceDatasetSettings], + ) -> ServiceDatasetSettings: settings = self.__dao__.load_settings(dataset=dataset, as_class=class_type) if not settings: raise EntityNotFoundError(name=dataset.name, type=class_type) return class_type.parse_obj(settings.dict()) async def save_settings( - self, user: User, dataset: Dataset, settings: SVCDatasetSettings - ) -> SVCDatasetSettings: - """ - Save a set of settings for a dataset + self, user: User, dataset: ServiceDataset, settings: ServiceDatasetSettings + ) -> ServiceDatasetSettings: - Args: - user: The user executing the command - dataset: The dataset - settings: The dataset settings - - Returns: - Stored dataset settings - - """ self.__dao__.save_settings(dataset=dataset, settings=settings) return settings - async def delete_settings(self, user: User, dataset: Dataset) -> None: - """ - Deletes the dataset settings - - Args: - user: The user executing the command - dataset: The dataset - - """ + async def delete_settings(self, user: User, dataset: ServiceDataset) -> None: self.__dao__.delete_settings(dataset=dataset) diff --git a/src/rubrix/server/services/info.py b/src/rubrix/server/services/info.py index 2fe96d51ba..716b8d7a4e 100644 --- a/src/rubrix/server/services/info.py +++ b/src/rubrix/server/services/info.py @@ -14,15 +14,52 @@ # limitations under the License. import os -from typing import Any, Dict, Optional +from typing import Any, Dict import psutil from fastapi import Depends -from hurry.filesize import size +from pydantic import BaseModel from rubrix import __version__ as rubrix_version -from rubrix.server.apis.v0.models.info import ApiStatus -from rubrix.server.elasticseach.client_wrapper import ElasticsearchWrapper +from rubrix.server.daos.backend.elasticsearch import ElasticsearchBackend + + +def size(bytes): + system = [ + (1024 ** 5, "P"), + (1024 ** 4, "T"), + (1024 ** 3, "G"), + (1024 ** 2, "M"), + (1024 ** 1, "K"), + (1024 ** 0, "B"), + ] + + factor, suffix = None, None + for factor, suffix in system: + if bytes >= factor: + break + + amount = int(bytes / factor) + if isinstance(suffix, tuple): + singular, multiple = suffix + if amount == 1: + suffix = singular + else: + suffix = multiple + return str(amount) + suffix + + +class ApiInfo(BaseModel): + """Basic api info""" + + rubrix_version: str + + +class ApiStatus(ApiInfo): + """The Rubrix api status model""" + + elasticsearch: Dict[str, Any] + mem_info: Dict[str, Any] class ApiInfoService: @@ -30,7 +67,22 @@ class ApiInfoService: The api info service """ - def __init__(self, es: ElasticsearchWrapper): + _INSTANCE = None + + @classmethod + def get_instance( + cls, + backend: ElasticsearchBackend = Depends(ElasticsearchBackend.get_instance), + ) -> "ApiInfoService": + """ + Creates an api info service + """ + + if not cls._INSTANCE: + cls._INSTANCE = ApiInfoService(backend) + return cls._INSTANCE + + def __init__(self, es: ElasticsearchBackend): self.__es__ = es def api_status(self) -> ApiStatus: @@ -50,19 +102,3 @@ def _api_memory_info() -> Dict[str, Any]: """Fetch the api process memory usage""" process = psutil.Process(os.getpid()) return {k: size(v) for k, v in process.memory_info()._asdict().items()} - - -_instance: Optional[ApiInfoService] = None - - -def create_info_service( - es_wrapper: ElasticsearchWrapper = Depends(ElasticsearchWrapper.get_instance), -) -> ApiInfoService: - """ - Creates an api info service - """ - - global _instance - if not _instance: - _instance = ApiInfoService(es_wrapper) - return _instance diff --git a/src/rubrix/server/services/metrics.py b/src/rubrix/server/services/metrics.py deleted file mode 100644 index 4b69dd60ca..0000000000 --- a/src/rubrix/server/services/metrics.py +++ /dev/null @@ -1,204 +0,0 @@ -from typing import Callable, Optional, Type, TypeVar, Union - -from fastapi import Depends - -from rubrix.server.apis.v0.models.metrics.base import ( - ElasticsearchMetric, - NestedPathElasticsearchMetric, - PythonMetric, -) -from rubrix.server.apis.v0.models.metrics.commons import * -from rubrix.server.daos.models.records import RecordSearch -from rubrix.server.daos.records import DatasetRecordsDAO, dataset_records_dao -from rubrix.server.errors import WrongInputParamError -from rubrix.server.services.datasets import Dataset -from rubrix.server.services.search.query_builder import EsQueryBuilder -from rubrix.server.services.tasks.commons.record import BaseRecordDB - -GenericQuery = TypeVar("GenericQuery") - - -class MetricsService: - """The dataset metrics service singleton""" - - _INSTANCE = None - - @classmethod - def get_instance( - cls, - dao: DatasetRecordsDAO = Depends(dataset_records_dao), - query_builder: EsQueryBuilder = Depends(EsQueryBuilder.get_instance), - ) -> "MetricsService": - """ - Creates the service instance. - - Parameters - ---------- - dao: - The dataset records dao - - Returns - ------- - The metrics service instance - - """ - if not cls._INSTANCE: - cls._INSTANCE = cls(dao, query_builder=query_builder) - return cls._INSTANCE - - def __init__(self, dao: DatasetRecordsDAO, query_builder: EsQueryBuilder): - """ - Creates a service instance - - Parameters - ---------- - dao: - The dataset records dao - """ - self.__dao__ = dao - self.__query_builder__ = query_builder - - def summarize_metric( - self, - dataset: Dataset, - metric: BaseMetric, - record_class: Optional[Type[BaseRecordDB]] = None, - query: Optional[GenericQuery] = None, - **metric_params, - ) -> Dict[str, Any]: - """ - Applies a metric summarization. - - Parameters - ---------- - dataset: - The records dataset - metric: - The selected metric - query: - An optional query passed for records filtering - metric_params: - Related metrics parameters - - Returns - ------- - The metric summarization info - """ - - if isinstance(metric, ElasticsearchMetric): - return self._handle_elasticsearch_metric( - metric, metric_params, dataset=dataset, query=query - ) - elif isinstance(metric, PythonMetric): - records = self.__dao__.scan_dataset( - dataset, - search=RecordSearch(query=self.__query_builder__(dataset, query=query)), - ) - return metric.apply(map(record_class.parse_obj, records)) - - raise WrongInputParamError(f"Cannot process {metric} of type {type(metric)}") - - def _handle_elasticsearch_metric( - self, - metric: ElasticsearchMetric, - metric_params: Dict[str, Any], - dataset: Dataset, - query: GenericQuery, - ) -> Dict[str, Any]: - """ - Parameters - ---------- - metric: - The elasticsearch metric summary - metric_params: - The summary params - dataset: - The records dataset - query: - The filter to apply to dataset - - Returns - ------- - The metric summary result - - """ - params = self.__compute_metric_params__( - dataset=dataset, metric=metric, query=query, provided_params=metric_params - ) - results = self.__metric_results__( - dataset=dataset, - query=query, - metric_aggregation=metric.aggregation_request(**params), - ) - return metric.aggregation_result( - aggregation_result=results.get(metric.id, results) - ) - - def __compute_metric_params__( - self, - dataset: Dataset, - metric: ElasticsearchMetric, - query: Optional[GenericQuery], - provided_params: Dict[str, Any], - ) -> Dict[str, Any]: - - return self._filter_metric_params( - metric=metric, - function=metric.aggregation_request, - metric_params={ - **provided_params, # in case of param conflict, provided metric params will be preserved - "dataset": dataset, - "dao": self.__dao__, - }, - ) - - def __metric_results__( - self, - dataset: Dataset, - query: Optional[GenericQuery], - metric_aggregation: Union[Dict[str, Any], List[Dict[str, Any]]], - ) -> Dict[str, Any]: - - if not metric_aggregation: - return {} - - if not isinstance(metric_aggregation, list): - metric_aggregation = [metric_aggregation] - - results = {} - for agg in metric_aggregation: - results_ = self.__dao__.search_records( - dataset, - size=0, # No records at all - search=RecordSearch( - query=self.__query_builder__(dataset, query=query), - aggregations=agg, - ), - ) - results.update(results_.aggregations) - return results - - @staticmethod - def _filter_metric_params( - metric: ElasticsearchMetric, function: Callable, metric_params: Dict[str, Any] - ): - """ - Select from provided metric parameter those who can be applied to given metric - - Parameters - ---------- - metric: - The target metric - metric_params: - A dict of metric parameters - - """ - - if isinstance(metric, NestedPathElasticsearchMetric): - function = metric.inner_aggregation - - return { - argument: metric_params[argument] - for argument in function.__code__.co_varnames - if argument in metric_params - } diff --git a/src/rubrix/server/services/metrics/__init__.py b/src/rubrix/server/services/metrics/__init__.py new file mode 100644 index 0000000000..341f097d55 --- /dev/null +++ b/src/rubrix/server/services/metrics/__init__.py @@ -0,0 +1,2 @@ +from .models import ServiceBaseMetric, ServicePythonMetric +from .service import MetricsService diff --git a/src/rubrix/server/services/metrics/models.py b/src/rubrix/server/services/metrics/models.py new file mode 100644 index 0000000000..e282cb5919 --- /dev/null +++ b/src/rubrix/server/services/metrics/models.py @@ -0,0 +1,156 @@ +from typing import ( + Any, + ClassVar, + Dict, + Generic, + Iterable, + List, + Optional, + TypeVar, + Union, +) + +from pydantic import BaseModel + +from rubrix.server.services.tasks.commons import ServiceRecord + + +class ServiceBaseMetric(BaseModel): + """ + Base model for rubrix dataset metrics summaries + """ + + id: str + name: str + description: str = None + + +class ServicePythonMetric(ServiceBaseMetric, Generic[ServiceRecord]): + """ + A metric definition which will be calculated using raw queried data + """ + + def apply(self, records: Iterable[ServiceRecord]) -> Dict[str, Any]: + """ + ServiceBaseMetric calculation method. + + Parameters + ---------- + records: + The matched records + + Returns + ------- + The metric result + """ + raise NotImplementedError() + + +ServiceMetric = TypeVar("ServiceMetric", bound=ServiceBaseMetric) + + +class ServiceBaseTaskMetrics(BaseModel): + """ + Base class encapsulating related task metrics + + Attributes: + ----------- + + metrics: + A list of configured metrics for task + """ + + metrics: ClassVar[List[Union[ServicePythonMetric, str]]] + + @classmethod + def find_metric(cls, id: str) -> Optional[Union[ServicePythonMetric, str]]: + """ + Finds a metric by id + + Parameters + ---------- + id: + The metric id + + Returns + ------- + Found metric if any, ``None`` otherwise + + """ + for metric in cls.metrics: + if isinstance(metric, str) and metric == id: + return metric + if metric.id == id: + return metric + + @classmethod + def record_metrics(cls, record: ServiceRecord) -> Dict[str, Any]: + """ + Use this method is some configured metric requires additional + records fields. + + Generated records will be persisted under ``metrics`` record path. + For example, if you define a field called ``sentence_length`` like + + >>> def record_metrics(cls, record)-> Dict[str, Any]: + ... return { "sentence_length" : len(record.text) } + + The new field will be stored in elasticsearch in ``metrics.sentence_length`` + + Parameters + ---------- + record: + The record used for calculate metrics fields + + Returns + ------- + A dict with calculated metrics fields + """ + return {} + + +class CommonTasksMetrics(ServiceBaseTaskMetrics, Generic[ServiceRecord]): + """Common task metrics""" + + @classmethod + def record_metrics(cls, record: ServiceRecord) -> Dict[str, Any]: + """Record metrics will persist the text_length""" + return {"text_length": len(record.all_text())} + + metrics: ClassVar[List[ServiceBaseMetric]] = [ + ServiceBaseMetric( + id="text_length", + name="Text length distribution", + description="Computes the input text length distribution", + ), + ServiceBaseMetric( + id="error_distribution", + name="Error distribution", + description="Computes the dataset error distribution. It's mean, records " + "with correct predictions vs records with incorrect prediction " + "vs records with unknown prediction result", + ), + ServiceBaseMetric( + id="status_distribution", + name="Record status distribution", + description="The dataset record status distribution", + ), + ServiceBaseMetric( + id="words_cloud", + name="Inputs words cloud", + description="The words cloud for dataset inputs", + ), + ServiceBaseMetric(id="metadata", name="Metadata fields stats"), + ServiceBaseMetric( + id="predicted_by", + name="Predicted by distribution", + ), + ServiceBaseMetric( + id="annotated_by", + name="Annotated by distribution", + ), + ServiceBaseMetric( + id="score", + name="Score record distribution", + ), + ] diff --git a/src/rubrix/server/services/metrics/service.py b/src/rubrix/server/services/metrics/service.py new file mode 100644 index 0000000000..23b544dc4d --- /dev/null +++ b/src/rubrix/server/services/metrics/service.py @@ -0,0 +1,91 @@ +from typing import Any, Dict, Optional, Type + +from fastapi import Depends + +from rubrix.server.daos.models.records import DaoRecordsSearch +from rubrix.server.daos.records import DatasetRecordsDAO +from rubrix.server.services.datasets import ServiceDataset +from rubrix.server.services.metrics.models import ServiceMetric, ServicePythonMetric +from rubrix.server.services.search.model import ServiceRecordsQuery +from rubrix.server.services.tasks.commons import ServiceRecord + + +class MetricsService: + """The dataset metrics service singleton""" + + _INSTANCE = None + + @classmethod + def get_instance( + cls, + dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance), + ) -> "MetricsService": + """ + Creates the service instance. + + Parameters + ---------- + dao: + The dataset records dao + + Returns + ------- + The metrics service instance + + """ + if not cls._INSTANCE: + cls._INSTANCE = cls(dao) + return cls._INSTANCE + + def __init__(self, dao: DatasetRecordsDAO): + """ + Creates a service instance + + Parameters + ---------- + dao: + The dataset records dao + """ + self.__dao__ = dao + + def summarize_metric( + self, + dataset: ServiceDataset, + metric: ServiceMetric, + record_class: Optional[Type[ServiceRecord]] = None, + query: Optional[ServiceRecordsQuery] = None, + **metric_params, + ) -> Dict[str, Any]: + """ + Applies a metric summarization. + + Parameters + ---------- + dataset: + The records dataset + metric: + The selected metric + record_class: + The record class type for python metrics computation + query: + An optional query passed for records filtering + metric_params: + Related metrics parameters + + Returns + ------- + The metric summarization info + """ + + if isinstance(metric, ServicePythonMetric): + records = self.__dao__.scan_dataset( + dataset, search=DaoRecordsSearch(query=query) + ) + return metric.apply(map(record_class.parse_obj, records)) + + return self.__dao__.compute_metric( + metric_id=metric.id, + metric_params=metric_params, + dataset=dataset, + query=query, + ) diff --git a/src/rubrix/server/services/search/model.py b/src/rubrix/server/services/search/model.py index f206809849..3878d437d1 100644 --- a/src/rubrix/server/services/search/model.py +++ b/src/rubrix/server/services/search/model.py @@ -1,55 +1,39 @@ -from enum import Enum -from typing import Any, Dict, Generic, List, Optional, TypeVar, Union +from typing import Any, Dict, List, TypeVar from pydantic import BaseModel, Field -from pydantic.generics import GenericModel -from rubrix.server.services.tasks.commons.record import Record, TaskStatus +from rubrix.server.daos.backend.search.model import ( + BaseRecordsQuery, + QueryRange, + SortableField, + SortConfig, +) +from rubrix.server.services.tasks.commons import ServiceRecord -class SortOrder(str, Enum): - asc = "asc" - desc = "desc" +class ServiceBaseRecordsQuery(BaseRecordsQuery): + pass -class SortableField(BaseModel): - """Sortable field structure""" - - id: str - order: SortOrder = SortOrder.asc - - -class BaseSearchQuery(BaseModel): - - query_text: Optional[str] = None - advanced_query_dsl: bool = False - - ids: Optional[List[Union[str, int]]] - - annotated_by: List[str] = Field(default_factory=list) - predicted_by: List[str] = Field(default_factory=list) - - status: List[TaskStatus] = Field(default_factory=list) - metadata: Optional[Dict[str, Union[str, List[str]]]] = None +class ServiceSortConfig(SortConfig): + pass -class QueryRange(BaseModel): +class ServiceSortableField(SortableField): + """Sortable field structure""" - range_from: float = Field(default=0.0, alias="from") - range_to: float = Field(default=None, alias="to") + pass - class Config: - allow_population_by_field_name = True +class ServiceQueryRange(QueryRange): + pass -class SortConfig(BaseModel): - shuffle: bool = False - sort_by: List[SortableField] = Field(default_factory=list) - valid_fields: List[str] = Field(default_factory=list) +class ServiceScoreRange(ServiceQueryRange): + pass -class BaseSearchResultsAggregations(BaseModel): +class ServiceBaseSearchResultsAggregations(BaseModel): predicted_as: Dict[str, int] = Field(default_factory=dict) annotated_as: Dict[str, int] = Field(default_factory=dict) @@ -62,30 +46,15 @@ class BaseSearchResultsAggregations(BaseModel): metadata: Dict[str, Dict[str, Any]] = Field(default_factory=dict) -Aggregations = TypeVar("Aggregations", bound=BaseSearchResultsAggregations) - - -class BaseSearchResults(GenericModel, Generic[Record, Aggregations]): - """ - API search results - - Attributes: - ----------- +ServiceSearchResultsAggregations = TypeVar( + "ServiceSearchResultsAggregations", bound=ServiceBaseSearchResultsAggregations +) - total: - The total number of records - records: - The selected records to return - aggregations: - Requested aggregations - """ - total: int = 0 - records: List[Record] = Field(default_factory=list) - aggregations: Aggregations = None - - -class SearchResults(BaseModel): +class ServiceSearchResults(BaseModel): total: int - records: List[Record] + records: List[ServiceRecord] metrics: Dict[str, Any] = Field(default_factory=dict) + + +ServiceRecordsQuery = TypeVar("ServiceRecordsQuery", bound=ServiceBaseRecordsQuery) diff --git a/src/rubrix/server/services/search/query_builder.py b/src/rubrix/server/services/search/query_builder.py deleted file mode 100644 index 2f1a982e2c..0000000000 --- a/src/rubrix/server/services/search/query_builder.py +++ /dev/null @@ -1,108 +0,0 @@ -import logging -from enum import Enum -from typing import Any, Dict, Optional, TypeVar - -from fastapi import Depends -from luqum.elasticsearch import ElasticsearchQueryBuilder, SchemaAnalyzer -from luqum.parser import parser - -from rubrix.server.daos.models.datasets import BaseDatasetDB -from rubrix.server.daos.records import DatasetRecordsDAO -from rubrix.server.elasticseach.query_helpers import filters -from rubrix.server.services.search.model import BaseSearchQuery, QueryRange - -SearchQuery = TypeVar("SearchQuery", bound=BaseSearchQuery) - - -class EsQueryBuilder: - _INSTANCE: "EsQueryBuilder" = None - _LOGGER = logging.getLogger(__name__) - - @classmethod - def get_instance( - cls, dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance) - ): - if not cls._INSTANCE: - cls._INSTANCE = cls(dao=dao) - return cls._INSTANCE - - def __init__(self, dao: DatasetRecordsDAO): - self.__dao__ = dao - - def __call__( - self, dataset: BaseDatasetDB, query: Optional[SearchQuery] = None - ) -> Dict[str, Any]: - - if not query: - return filters.match_all() - - if not query.advanced_query_dsl or not query.query_text: - return self.to_es_query(query) - - text_search = query.query_text - new_query = query.copy(update={"query_text": None}) - - schema = self.__dao__.get_dataset_schema(dataset) - schema = SchemaAnalyzer(schema) - es_query_builder = ElasticsearchQueryBuilder( - **{ - **schema.query_builder_options(), - "default_field": "text", - } # TODO: This will change - ) - - query_tree = parser.parse(text_search) - query_text = es_query_builder(query_tree) - - return filters.boolean_filter( - filter_query=self.to_es_query(new_query), must_query=query_text - ) - - @classmethod - def to_es_query(cls, query: BaseSearchQuery) -> Dict[str, Any]: - if query.ids: - return filters.ids_filter(query.ids) - - query_text = filters.text_query(query.query_text) - all_filters = filters.metadata(query.metadata) - query_data = query.dict( - exclude={ - "advanced_query_dsl", - "query_text", - "metadata", - "uncovered_by_rules", - } - ) - for key, value in query_data.items(): - if value is None: - continue - key_filter = None - if isinstance(value, dict): - value = getattr(query, key) # check the original field type - if isinstance(value, list): - key_filter = filters.terms_filter(key, value) - elif isinstance(value, (str, Enum)): - key_filter = filters.term_filter(key, value) - elif isinstance(value, QueryRange): - key_filter = filters.range_filter( - field=key, value_from=value.range_from, value_to=value.range_to - ) - else: - cls._LOGGER.warning(f"Cannot parse query value {value} for key {key}") - - if key_filter: - all_filters.append(key_filter) - - return filters.boolean_filter( - must_query=query_text or filters.match_all(), - filter_query=filters.boolean_filter( - should_filters=all_filters, minimum_should_match=len(all_filters) - ) - if all_filters - else None, - must_not_query=filters.boolean_filter( - should_filters=[filters.text_query(q) for q in query.uncovered_by_rules] - ) - if hasattr(query, "uncovered_by_rules") and query.uncovered_by_rules - else None, - ) diff --git a/src/rubrix/server/services/search/service.py b/src/rubrix/server/services/search/service.py index 891f2d4bfd..862150ed89 100644 --- a/src/rubrix/server/services/search/service.py +++ b/src/rubrix/server/services/search/service.py @@ -3,22 +3,16 @@ from fastapi import Depends -from rubrix.server.apis.v0.models.metrics.base import BaseMetric -from rubrix.server.daos.models.records import RecordSearch +from rubrix.server.daos.models.records import DaoRecordsSearch from rubrix.server.daos.records import DatasetRecordsDAO -from rubrix.server.elasticseach.query_helpers import sort_by2elasticsearch -from rubrix.server.services.datasets import Dataset +from rubrix.server.services.datasets import ServiceDataset from rubrix.server.services.metrics import MetricsService +from rubrix.server.services.metrics.models import ServiceMetric from rubrix.server.services.search.model import ( - BaseSearchQuery, - Record, - SearchResults, - SortConfig, -) -from rubrix.server.services.search.query_builder import EsQueryBuilder -from rubrix.server.services.tasks.commons.record import ( - BaseRecordDB, - EsRecordDataFieldNames, + ServiceRecord, + ServiceRecordsQuery, + ServiceSearchResults, + ServiceSortConfig, ) @@ -34,59 +28,39 @@ def get_instance( cls, dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance), metrics: MetricsService = Depends(MetricsService.get_instance), - query_builder: EsQueryBuilder = Depends(EsQueryBuilder.get_instance), ): if not cls._INSTANCE: - cls._INSTANCE = cls(dao=dao, metrics=metrics, query_builder=query_builder) + cls._INSTANCE = cls(dao=dao, metrics=metrics) return cls._INSTANCE def __init__( self, dao: DatasetRecordsDAO, metrics: MetricsService, - query_builder: EsQueryBuilder, ): self.__dao__ = dao self.__metrics__ = metrics - self.__query_builder__ = query_builder def search( self, - dataset: Dataset, - record_type: Type[BaseRecordDB], - query: Optional[BaseSearchQuery] = None, - sort_config: Optional[SortConfig] = None, + dataset: ServiceDataset, + record_type: Type[ServiceRecord], + query: Optional[ServiceRecordsQuery] = None, + sort_config: Optional[ServiceSortConfig] = None, record_from: int = 0, size: int = 100, exclude_metrics: bool = True, - metrics: Optional[List[BaseMetric]] = None, - ) -> SearchResults: + metrics: Optional[List[ServiceMetric]] = None, + ) -> ServiceSearchResults: if record_from > 0: metrics = None - sort_config = sort_config or SortConfig() + sort_config = sort_config or ServiceSortConfig() exclude_fields = ["metrics.*"] if exclude_metrics else None results = self.__dao__.search_records( dataset, - search=RecordSearch( - query=self.__query_builder__(dataset, query), - sort=sort_by2elasticsearch( - sort_config.sort_by, - valid_fields=[ - "metadata", - EsRecordDataFieldNames.last_updated, - EsRecordDataFieldNames.score, - EsRecordDataFieldNames.predicted, - EsRecordDataFieldNames.predicted_as, - EsRecordDataFieldNames.predicted_by, - EsRecordDataFieldNames.annotated_as, - EsRecordDataFieldNames.annotated_by, - EsRecordDataFieldNames.status, - EsRecordDataFieldNames.event_timestamp, - ], - ), - ), + search=DaoRecordsSearch(query=query, sort=sort_config), size=size, record_from=record_from, exclude_fields=exclude_fields, @@ -110,7 +84,7 @@ def search( ) metrics_results[metric.id] = {} - return SearchResults( + return ServiceSearchResults( total=results.total, records=[record_type.parse_obj(r) for r in results.records], metrics=metrics_results if metrics_results else {}, @@ -118,14 +92,15 @@ def search( def scan_records( self, - dataset: Dataset, - record_type: Type[BaseRecordDB], - query: Optional[BaseSearchQuery] = None, + dataset: ServiceDataset, + record_type: Type[ServiceRecord], + query: Optional[ServiceRecordsQuery] = None, id_from: Optional[str] = None, limit: int = 1000 - ) -> Iterable[Record]: + ) -> Iterable[ServiceRecord]: """Scan records for a queried""" + search = DaoRecordsSearch(query=query) for doc in self.__dao__.scan_dataset( - dataset, search=RecordSearch(query=self.__query_builder__(dataset, query)), id_from=id_from, limit=limit + dataset, id_from=id_from, limit=limit, search=search ): yield record_type.parse_obj(doc) diff --git a/src/rubrix/server/services/storage/service.py b/src/rubrix/server/services/storage/service.py index 1d89d20d1c..72007fcc56 100644 --- a/src/rubrix/server/services/storage/service.py +++ b/src/rubrix/server/services/storage/service.py @@ -1,11 +1,11 @@ -from typing import Any, Dict, List, Optional, Type +from typing import List, Type from fastapi import Depends -from rubrix.server.apis.v0.models.commons.model import Record -from rubrix.server.apis.v0.models.datasets import BaseDatasetDB -from rubrix.server.apis.v0.models.metrics.base import BaseTaskMetrics +from rubrix.server.commons.config import TasksFactory from rubrix.server.daos.records import DatasetRecordsDAO +from rubrix.server.services.datasets import ServiceDataset +from rubrix.server.services.tasks.commons import ServiceRecord class RecordsStorageService: @@ -26,20 +26,18 @@ def __init__(self, dao: DatasetRecordsDAO): def store_records( self, - dataset: BaseDatasetDB, - mappings: Dict[str, Any], - records: List[Record], - record_type: Type[Record], - metrics: Optional[Type[BaseTaskMetrics]] = None, + dataset: ServiceDataset, + records: List[ServiceRecord], + record_type: Type[ServiceRecord], ) -> int: """Store a set of records""" + metrics = TasksFactory.get_task_metrics(dataset.task) if metrics: for record in records: record.metrics = metrics.record_metrics(record) return self.__dao__.add_records( dataset=dataset, - mappings=mappings, records=records, record_class=record_type, ) diff --git a/src/rubrix/server/services/tasks/commons/__init__.py b/src/rubrix/server/services/tasks/commons/__init__.py index 4958b62db9..aed4fa323c 100644 --- a/src/rubrix/server/services/tasks/commons/__init__.py +++ b/src/rubrix/server/services/tasks/commons/__init__.py @@ -1,2 +1 @@ -from .logging import * -from .record import * +from .models import * diff --git a/src/rubrix/server/services/tasks/commons/logging.py b/src/rubrix/server/services/tasks/commons/logging.py deleted file mode 100644 index 71bbc8b19c..0000000000 --- a/src/rubrix/server/services/tasks/commons/logging.py +++ /dev/null @@ -1,21 +0,0 @@ -from pydantic import BaseModel - - -class BulkResponse(BaseModel): - """ - Data info for bulk results - - Attributes - ---------- - - dataset: - The dataset name - processed: - Number of records in bulk - failed: - Number of failed records - """ - - dataset: str - processed: int - failed: int = 0 diff --git a/src/rubrix/server/services/tasks/commons/models.py b/src/rubrix/server/services/tasks/commons/models.py new file mode 100644 index 0000000000..da691433c6 --- /dev/null +++ b/src/rubrix/server/services/tasks/commons/models.py @@ -0,0 +1,35 @@ +from typing import Generic, TypeVar + +from pydantic import BaseModel + +from rubrix.server.daos.models.records import ( + BaseAnnotationDB, + BaseRecordDB, + BaseRecordInDB, +) + + +class ServiceBaseAnnotation(BaseAnnotationDB): + pass + + +class BulkResponse(BaseModel): + dataset: str + processed: int + failed: int = 0 + + +ServiceAnnotation = TypeVar("ServiceAnnotation", bound=ServiceBaseAnnotation) + + +class ServiceBaseRecordInputs( + BaseRecordInDB[ServiceAnnotation], Generic[ServiceAnnotation] +): + pass + + +class ServiceBaseRecord(BaseRecordDB[ServiceAnnotation], Generic[ServiceAnnotation]): + pass + + +ServiceRecord = TypeVar("ServiceRecord", bound=ServiceBaseRecord) diff --git a/src/rubrix/server/services/tasks/commons/record.py b/src/rubrix/server/services/tasks/commons/record.py deleted file mode 100644 index 04d2d91108..0000000000 --- a/src/rubrix/server/services/tasks/commons/record.py +++ /dev/null @@ -1,152 +0,0 @@ -from datetime import datetime -from enum import Enum -from typing import Any, Dict, Generic, List, Optional, TypeVar, Union -from uuid import uuid4 - -from pydantic import BaseModel, Field, validator -from pydantic.generics import GenericModel - - -class EsRecordDataFieldNames(str, Enum): - - predicted_as = "predicted_as" - annotated_as = "annotated_as" - annotated_by = "annotated_by" - predicted_by = "predicted_by" - status = "status" - predicted = "predicted" - score = "score" - words = "words" - event_timestamp = "event_timestamp" - last_updated = "last_updated" - - def __str__(self): - return self.value - - -class BaseAnnotation(BaseModel): - agent: str = Field(max_length=64) - - -class TaskType(str, Enum): - - text_classification = "TextClassification" - token_classification = "TokenClassification" - text2text = "Text2Text" - multi_task_text_token_classification = "MultitaskTextTokenClassification" - - -class TaskStatus(str, Enum): - default = "Default" - edited = "Edited" # TODO: DEPRECATE - discarded = "Discarded" - validated = "Validated" - - -class PredictionStatus(str, Enum): - OK = "ok" - KO = "ko" - - -Annotation = TypeVar("Annotation", bound=BaseAnnotation) - - -class BaseRecordDB(GenericModel, Generic[Annotation]): - - id: Optional[Union[int, str]] = Field(default=None) - metadata: Dict[str, Any] = Field(default=None) - event_timestamp: Optional[datetime] = None - status: Optional[TaskStatus] = None - prediction: Optional[Annotation] = None - annotation: Optional[Annotation] = None - metrics: Dict[str, Any] = Field(default_factory=dict) - search_keywords: Optional[List[str]] = None - - @validator("id", always=True, pre=True) - def default_id_if_none_provided(cls, id: Optional[str]) -> str: - """Validates id info and sets a random uuid if not provided""" - if id is None: - return str(uuid4()) - return id - - @validator("status", always=True) - def fill_default_value(cls, status: TaskStatus): - """Fastapi validator for set default task status""" - return TaskStatus.default if status is None else status - - @validator("search_keywords") - def remove_duplicated_keywords(cls, value) -> List[str]: - """Remove duplicated keywords""" - if value: - return list(set(value)) - - @classmethod - def task(cls) -> TaskType: - """The task type related to this task info""" - raise NotImplementedError - - @property - def predicted(self) -> Optional[PredictionStatus]: - """The task record prediction status (if any)""" - return None - - @property - def predicted_as(self) -> Optional[List[str]]: - """Predictions strings representation""" - return None - - @property - def annotated_as(self) -> Optional[List[str]]: - """Annotations strings representation""" - return None - - @property - def scores(self) -> Optional[List[float]]: - """Prediction scores""" - return None - - def all_text(self) -> str: - """All textual information related to record""" - raise NotImplementedError - - @property - def predicted_by(self) -> List[str]: - """The prediction agents""" - if self.prediction: - return [self.prediction.agent] - return [] - - @property - def annotated_by(self) -> List[str]: - """The annotation agents""" - if self.annotation: - return [self.annotation.agent] - return [] - - def extended_fields(self) -> Dict[str, Any]: - """ - Used for extends fields to store in db. Tasks that would include extra - properties than commons (predicted, annotated_as,....) could implement - this method. - """ - return { - EsRecordDataFieldNames.predicted: self.predicted, - EsRecordDataFieldNames.annotated_as: self.annotated_as, - EsRecordDataFieldNames.predicted_as: self.predicted_as, - EsRecordDataFieldNames.annotated_by: self.annotated_by, - EsRecordDataFieldNames.predicted_by: self.predicted_by, - EsRecordDataFieldNames.score: self.scores, - } - - def dict(self, *args, **kwargs) -> "DictStrAny": - """ - Extends base component dict extending object properties - and user defined extended fields - """ - return { - **super().dict(*args, **kwargs), - **self.extended_fields(), - } - - -Record = TypeVar("Record", bound=BaseRecordDB) diff --git a/src/rubrix/server/services/tasks/text2text/__init__.py b/src/rubrix/server/services/tasks/text2text/__init__.py new file mode 100644 index 0000000000..79de1ed5a3 --- /dev/null +++ b/src/rubrix/server/services/tasks/text2text/__init__.py @@ -0,0 +1 @@ +from .service import Text2TextService diff --git a/src/rubrix/server/services/tasks/text2text/models.py b/src/rubrix/server/services/tasks/text2text/models.py new file mode 100644 index 0000000000..8b21787008 --- /dev/null +++ b/src/rubrix/server/services/tasks/text2text/models.py @@ -0,0 +1,81 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + +from rubrix.server.commons.models import PredictionStatus, TaskType +from rubrix.server.services.datasets import ServiceBaseDataset +from rubrix.server.services.search.model import ( + ServiceBaseRecordsQuery, + ServiceBaseSearchResultsAggregations, + ServiceScoreRange, + ServiceSearchResults, +) +from rubrix.server.services.tasks.commons import ( + ServiceBaseAnnotation, + ServiceBaseRecord, +) + + +class ServiceText2TextPrediction(BaseModel): + text: str + score: float + + +class ServiceText2TextAnnotation(ServiceBaseAnnotation): + sentences: List[ServiceText2TextPrediction] + + +class ServiceText2TextRecord(ServiceBaseRecord[ServiceText2TextAnnotation]): + text: str + + @classmethod + def task(cls) -> TaskType: + """The task type""" + return TaskType.text2text + + def all_text(self) -> str: + return self.text + + @property + def predicted_as(self) -> Optional[List[str]]: + return ( + [sentence.text for sentence in self.prediction.sentences] + if self.prediction + else None + ) + + @property + def annotated_as(self) -> Optional[List[str]]: + return ( + [sentence.text for sentence in self.annotation.sentences] + if self.annotation + else None + ) + + @property + def scores(self) -> List[float]: + """Values of prediction scores""" + if not self.prediction: + return [] + return [sentence.score for sentence in self.prediction.sentences] + + def extended_fields(self) -> Dict[str, Any]: + return { + "annotated_as": self.annotated_as, + "predicted_as": self.predicted_as, + "annotated_by": self.annotated_by, + "predicted_by": self.predicted_by, + "score": self.scores, + "words": self.all_text(), + } + + +class ServiceText2TextQuery(ServiceBaseRecordsQuery): + score: Optional[ServiceScoreRange] = Field(default=None) + predicted: Optional[PredictionStatus] = Field(default=None, nullable=True) + + +class ServiceText2TextDataset(ServiceBaseDataset): + task: TaskType = Field(default=TaskType.text2text, const=True) + pass diff --git a/src/rubrix/server/services/text2text.py b/src/rubrix/server/services/tasks/text2text/service.py similarity index 52% rename from src/rubrix/server/services/text2text.py rename to src/rubrix/server/services/tasks/text2text/service.py index f19e3e926e..0c8f552367 100644 --- a/src/rubrix/server/services/text2text.py +++ b/src/rubrix/server/services/tasks/text2text/service.py @@ -13,28 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterable, List, Optional, Type +from typing import Iterable, List, Optional, Type from fastapi import Depends -from rubrix.server.apis.v0.models.commons.model import ( - BulkResponse, - EsRecordDataFieldNames, - SortableField, +from rubrix.server.commons.config import TasksFactory +from rubrix.server.services.metrics.models import ServiceBaseTaskMetrics +from rubrix.server.services.search.model import ( + ServiceSearchResults, + ServiceSortableField, + ServiceSortConfig, ) -from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics -from rubrix.server.apis.v0.models.text2text import ( - CreationText2TextRecord, - Text2TextDatasetDB, - Text2TextQuery, - Text2TextRecord, - Text2TextRecordDB, - Text2TextSearchAggregations, - Text2TextSearchResults, -) -from rubrix.server.services.search.model import SortConfig from rubrix.server.services.search.service import SearchRecordsService from rubrix.server.services.storage.service import RecordsStorageService +from rubrix.server.services.tasks.commons import BulkResponse +from rubrix.server.services.tasks.text2text.models import ( + ServiceText2TextDataset, + ServiceText2TextQuery, + ServiceText2TextRecord, +) class Text2TextService: @@ -65,70 +62,46 @@ def __init__( def add_records( self, - dataset: Text2TextDatasetDB, - mappings: Dict[str, Any], - records: List[CreationText2TextRecord], - metrics: Type[BaseTaskMetrics], + dataset: ServiceText2TextDataset, + records: List[ServiceText2TextRecord], ): failed = self.__storage__.store_records( dataset=dataset, - mappings=mappings, records=records, - record_type=Text2TextRecordDB, - metrics=metrics, + record_type=ServiceText2TextRecord, ) return BulkResponse(dataset=dataset.name, processed=len(records), failed=failed) def search( self, - dataset: Text2TextDatasetDB, - query: Text2TextQuery, - sort_by: List[SortableField], + dataset: ServiceText2TextDataset, + query: ServiceText2TextQuery, + sort_by: List[ServiceSortableField], record_from: int = 0, size: int = 100, exclude_metrics: bool = True, - metrics: Optional[List[BaseMetric]] = None, - ) -> Text2TextSearchResults: - """ - Run a search in a dataset - - Parameters - ---------- - dataset: - The records dataset - query: - The search parameters - sort_by: - The sort by list - record_from: - The record from return results - size: - The max number of records to return - - Returns - ------- - The matched records with aggregation info for specified task_meta.py - - """ + ) -> ServiceSearchResults: + + metrics = TasksFactory.find_task_metrics( + dataset.task, + metric_ids={ + "words_cloud", + "predicted_by", + "annotated_by", + "status_distribution", + "metadata", + "score", + }, + ) results = self.__search__.search( dataset, query=query, size=size, record_from=record_from, - record_type=Text2TextRecord, - sort_config=SortConfig( + record_type=ServiceText2TextRecord, + sort_config=ServiceSortConfig( sort_by=sort_by, - valid_fields=[ - "metadata", - EsRecordDataFieldNames.predicted_as, - EsRecordDataFieldNames.annotated_as, - EsRecordDataFieldNames.predicted_by, - EsRecordDataFieldNames.annotated_by, - EsRecordDataFieldNames.status, - EsRecordDataFieldNames.last_updated, - EsRecordDataFieldNames.event_timestamp, - ], ), exclude_metrics=exclude_metrics, metrics=metrics, @@ -138,21 +111,15 @@ def search( results.metrics["words"] = results.metrics["words_cloud"] results.metrics["status"] = results.metrics["status_distribution"] - return Text2TextSearchResults( - total=results.total, - records=results.records, - aggregations=Text2TextSearchAggregations.parse_obj(results.metrics) - if results.metrics - else None, - ) + return results def read_dataset( self, - dataset: Text2TextDatasetDB, - query: Optional[Text2TextQuery] = None, + dataset: ServiceText2TextDataset, + query: Optional[ServiceText2TextQuery] = None, id_from: Optional[str] = None, limit: int = 1000 - ) -> Iterable[Text2TextRecord]: + ) -> Iterable[ServiceText2TextRecord]: """ Scan a dataset records @@ -170,8 +137,5 @@ def read_dataset( """ yield from self.__search__.scan_records( - dataset, query=query, record_type=Text2TextRecord, id_from=id_from, limit=limit, + dataset, query=query, record_type=ServiceText2TextRecord,id_from=id_from, limit=limit ) - - -text2text_service = Text2TextService.get_instance diff --git a/src/rubrix/server/services/tasks/text_classification/__init__.py b/src/rubrix/server/services/tasks/text_classification/__init__.py new file mode 100644 index 0000000000..dbc9d953c5 --- /dev/null +++ b/src/rubrix/server/services/tasks/text_classification/__init__.py @@ -0,0 +1,2 @@ +from .labeling_rules_service import LabelingService +from .service import TextClassificationService diff --git a/src/rubrix/server/services/tasks/text_classification/labeling_rules_service.py b/src/rubrix/server/services/tasks/text_classification/labeling_rules_service.py new file mode 100644 index 0000000000..0f5cc3a0ee --- /dev/null +++ b/src/rubrix/server/services/tasks/text_classification/labeling_rules_service.py @@ -0,0 +1,138 @@ +from typing import List, Optional, Tuple + +from fastapi import Depends +from pydantic import BaseModel, Field + +from rubrix.server.daos.datasets import DatasetsDAO +from rubrix.server.daos.models.records import DaoRecordsSearch +from rubrix.server.daos.records import DatasetRecordsDAO +from rubrix.server.errors import EntityAlreadyExistsError, EntityNotFoundError +from rubrix.server.services.search.model import ServiceBaseRecordsQuery +from rubrix.server.services.tasks.text_classification.model import ( + ServiceLabelingRule, + ServiceTextClassificationDataset, +) + + +class DatasetLabelingRulesSummary(BaseModel): + covered_records: int + annotated_covered_records: int + + +class LabelingRuleSummary(BaseModel): + covered_records: int + annotated_covered_records: int + correct_records: int = Field(default=0) + incorrect_records: int = Field(default=0) + precision: Optional[float] = None + + +class LabelingService: + + _INSTANCE = None + + @classmethod + def get_instance( + cls, + datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), + records: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance), + ): + if cls._INSTANCE is None: + cls._INSTANCE = cls(datasets, records) + return cls._INSTANCE + + def __init__(self, datasets: DatasetsDAO, records: DatasetRecordsDAO): + self.__datasets__ = datasets + self.__records__ = records + + # TODO(@frascuchon): Move all rules management methods to the common datasets service like settings + def list_rules( + self, dataset: ServiceTextClassificationDataset + ) -> List[ServiceLabelingRule]: + """List a set of rules for a given dataset""" + return dataset.rules + + def delete_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str): + """Delete a rule from a dataset by its defined query string""" + new_rules_set = [r for r in dataset.rules if r.query != rule_query] + if len(dataset.rules) != new_rules_set: + dataset.rules = new_rules_set + self.__datasets__.update_dataset(dataset) + + def add_rule( + self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule + ) -> ServiceLabelingRule: + """Adds a rule to a dataset""" + for r in dataset.rules: + if r.query == rule.query: + raise EntityAlreadyExistsError(rule.query, type=ServiceLabelingRule) + dataset.rules.append(rule) + self.__datasets__.update_dataset(dataset) + return rule + + def compute_rule_metrics( + self, + dataset: ServiceTextClassificationDataset, + rule_query: str, + labels: Optional[List[str]] = None, + ) -> Tuple[int, int, LabelingRuleSummary]: + """Computes metrics for given rule query and optional label against a set of rules""" + + annotated_records = self._count_annotated_records(dataset) + dataset_records = self.__records__.search_records(dataset, size=0).total + metric_data = self.__records__.compute_metric( + dataset=dataset, + metric_id="labeling_rule", + metric_params=dict(rule_query=rule_query, labels=labels), + ) + + return ( + dataset_records, + annotated_records, + LabelingRuleSummary.parse_obj(metric_data), + ) + + def _count_annotated_records( + self, dataset: ServiceTextClassificationDataset + ) -> int: + results = self.__records__.search_records( + dataset, + size=0, + search=DaoRecordsSearch(query=ServiceBaseRecordsQuery(has_annotation=True)), + ) + return results.total + + def all_rules_metrics( + self, dataset: ServiceTextClassificationDataset + ) -> Tuple[int, int, DatasetLabelingRulesSummary]: + annotated_records = self._count_annotated_records(dataset) + dataset_records = self.__records__.search_records(dataset, size=0).total + metric_data = self.__records__.compute_metric( + dataset=dataset, + metric_id="dataset_labeling_rules", + metric_params=dict(queries=[r.query for r in dataset.rules]), + ) + + return ( + dataset_records, + annotated_records, + DatasetLabelingRulesSummary.parse_obj(metric_data), + ) + + def find_rule_by_query( + self, dataset: ServiceTextClassificationDataset, rule_query: str + ) -> ServiceLabelingRule: + rule_query = rule_query.strip() + for rule in dataset.rules: + if rule.query == rule_query: + return rule + raise EntityNotFoundError(rule_query, type=ServiceLabelingRule) + + def replace_rule( + self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule + ): + for idx, r in enumerate(dataset.rules): + if r.query == rule.query: + dataset.rules[idx] = rule + break + self.__datasets__.update_dataset(dataset) diff --git a/src/rubrix/server/apis/v0/models/metrics/text_classification.py b/src/rubrix/server/services/tasks/text_classification/metrics.py similarity index 69% rename from src/rubrix/server/apis/v0/models/metrics/text_classification.py rename to src/rubrix/server/services/tasks/text_classification/metrics.py index 266fae6ce5..5153cda301 100644 --- a/src/rubrix/server/apis/v0/models/metrics/text_classification.py +++ b/src/rubrix/server/services/tasks/text_classification/metrics.py @@ -1,19 +1,17 @@ -from typing import Any, ClassVar, Dict, Iterable, List, Optional, Set +from typing import Any, ClassVar, Dict, Iterable, List from pydantic import Field from sklearn.metrics import precision_recall_fscore_support from sklearn.preprocessing import MultiLabelBinarizer -from rubrix.server.apis.v0.models.metrics.base import ( - BaseMetric, - PythonMetric, - TermsAggregation, +from rubrix.server.services.metrics import ServiceBaseMetric, ServicePythonMetric +from rubrix.server.services.metrics.models import CommonTasksMetrics +from rubrix.server.services.tasks.text_classification.model import ( + ServiceTextClassificationRecord, ) -from rubrix.server.apis.v0.models.metrics.commons import CommonTasksMetrics -from rubrix.server.apis.v0.models.text_classification import TextClassificationRecord -class F1Metric(PythonMetric): +class F1Metric(ServicePythonMetric): """ A basic f1 computation for text classification @@ -25,7 +23,7 @@ class F1Metric(PythonMetric): multi_label: bool = False - def apply(self, records: Iterable[TextClassificationRecord]) -> Any: + def apply(self, records: Iterable[ServiceTextClassificationRecord]) -> Any: filtered_records = list(filter(lambda r: r.predicted is not None, records)) # TODO: This must be precomputed with using a global dataset metric ds_labels = { @@ -92,12 +90,14 @@ def apply(self, records: Iterable[TextClassificationRecord]) -> Any: } -class DatasetLabels(PythonMetric): +class DatasetLabels(ServicePythonMetric): id: str = Field("dataset_labels", const=True) name: str = Field("The dataset labels", const=True) max_processed_records: int = 10000 - def apply(self, records: Iterable[TextClassificationRecord]) -> Dict[str, Any]: + def apply( + self, records: Iterable[ServiceTextClassificationRecord] + ) -> Dict[str, Any]: ds_labels = set() for _ in range( 0, self.max_processed_records @@ -117,30 +117,33 @@ def apply(self, records: Iterable[TextClassificationRecord]) -> Dict[str, Any]: return {"labels": ds_labels or []} -class TextClassificationMetrics(CommonTasksMetrics[TextClassificationRecord]): +class TextClassificationMetrics(CommonTasksMetrics[ServiceTextClassificationRecord]): """Configured metrics for text classification task""" - metrics: ClassVar[List[BaseMetric]] = CommonTasksMetrics.metrics + [ - TermsAggregation( - id="predicted_as", - name="Predicted labels distribution", - field="predicted_as", - ), - TermsAggregation( - id="annotated_as", - name="Annotated labels distribution", - field="annotated_as", - ), - F1Metric( - id="F1", - name="F1 Metrics for single-label", - description="F1 Metrics for single-label (averaged and per label)", - ), - F1Metric( - id="MultiLabelF1", - name="F1 Metrics for multi-label", - description="F1 Metrics for multi-label (averaged and per label)", - multi_label=True, - ), - DatasetLabels(), - ] + metrics: ClassVar[List[ServiceBaseMetric]] = ( + CommonTasksMetrics.metrics + + [ + F1Metric( + id="F1", + name="F1 Metrics for single-label", + description="F1 Metrics for single-label (averaged and per label)", + ), + F1Metric( + id="MultiLabelF1", + name="F1 Metrics for multi-label", + description="F1 Metrics for multi-label (averaged and per label)", + multi_label=True, + ), + DatasetLabels(), + ] + + [ + ServiceBaseMetric( + id="predicted_as", + name="Predicted labels distribution", + ), + ServiceBaseMetric( + id="annotated_as", + name="Annotated labels distribution", + ), + ] + ) diff --git a/src/rubrix/server/services/tasks/text_classification/model.py b/src/rubrix/server/services/tasks/text_classification/model.py new file mode 100644 index 0000000000..3bd8ff3927 --- /dev/null +++ b/src/rubrix/server/services/tasks/text_classification/model.py @@ -0,0 +1,344 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from typing import Any, ClassVar, Dict, List, Optional, Union + +from pydantic import BaseModel, Field, root_validator, validator + +from rubrix._constants import MAX_KEYWORD_LENGTH +from rubrix.server.commons.models import PredictionStatus, TaskStatus, TaskType +from rubrix.server.helpers import flatten_dict +from rubrix.server.services.datasets import ServiceBaseDataset +from rubrix.server.services.search.model import ( + ServiceBaseRecordsQuery, + ServiceScoreRange, +) +from rubrix.server.services.tasks.commons import ( + ServiceBaseAnnotation, + ServiceBaseRecord, +) + + +class ServiceLabelingRule(BaseModel): + query: str = Field(description="The es rule query") + + author: str = Field(description="User who created the rule") + created_at: Optional[datetime] = Field( + default_factory=datetime.utcnow, description="Rule creation timestamp" + ) + + label: Optional[str] = Field( + default=None, description="@Deprecated::The label associated with the rule." + ) + labels: List[str] = Field( + default_factory=list, + description="For multi label problems, a list of labels. " + "It will replace the `label` field", + ) + description: Optional[str] = Field( + None, description="A brief description of the rule" + ) + + @root_validator + def initialize_labels(cls, values): + label = values.get("label", None) + labels = values.get("labels", []) + + if label: + labels.append(label) + values["labels"] = list(set(labels)) + + assert len(labels) >= 1, f"No labels was provided in rule {values}" + return values + + @validator("query") + def strip_query(cls, query: str) -> str: + """Remove blank spaces for query""" + return query.strip() + + +class ServiceTextClassificationDataset(ServiceBaseDataset): + + task: TaskType = Field(default=TaskType.text_classification, const=True) + rules: List[ServiceLabelingRule] = Field(default_factory=list) + + +class ClassPrediction(BaseModel): + """ + Single class prediction + + Attributes: + ----------- + + class_label: Union[str, int] + the predicted class + + score: float + the predicted class score. For human-supervised annotations, + this probability should be 1.0 + """ + + class_label: Union[str, int] = Field(alias="class") + score: float = Field(default=1.0, ge=0.0, le=1.0) + + @validator("class_label") + def check_label_length(cls, class_label): + if isinstance(class_label, str): + assert 1 <= len(class_label) <= MAX_KEYWORD_LENGTH, ( + f"Class name '{class_label}' exceeds max length of {MAX_KEYWORD_LENGTH}" + if len(class_label) > MAX_KEYWORD_LENGTH + else f"Class name must not be empty" + ) + return class_label + + # See + class Config: + allow_population_by_field_name = True + + +class LabelingRuleMetricsSummary(BaseModel): + """Metrics generated for a labeling rule""" + + coverage: Optional[float] = None + coverage_annotated: Optional[float] = None + correct: Optional[float] = None + incorrect: Optional[float] = None + precision: Optional[float] = None + + total_records: int + annotated_records: int + + +class DatasetLabelingRulesMetricsSummary(BaseModel): + coverage: Optional[float] = None + coverage_annotated: Optional[float] = None + + total_records: int + annotated_records: int + + +class TextClassificationAnnotation(ServiceBaseAnnotation): + """ + Annotation class for text classification tasks + + Attributes: + ----------- + + labels: List[LabelPrediction] + list of annotated labels with score + """ + + # TODO(@frascuchon): labels must be a dict (to avoid repeat labels) + labels: List[ClassPrediction] + + @validator("labels") + def sort_labels(cls, labels: List[ClassPrediction]): + """Sort provided labels by score""" + return sorted(labels, key=lambda x: x.score, reverse=True) + + +class TokenAttributions(BaseModel): + """ + The token attributions explaining predicted labels + + Attributes: + ----------- + + token: str + The input token + attributions: Dict[str, float] + A dictionary containing label class-attribution pairs + + """ + + token: str + attributions: Dict[str, float] = Field(default_factory=dict) + + +class ServiceTextClassificationRecord(ServiceBaseRecord[TextClassificationAnnotation]): + inputs: Dict[str, Union[str, List[str]]] + multi_label: bool = False + explanation: Optional[Dict[str, List[TokenAttributions]]] = None + + class Config: + allow_population_by_field_name = True + + _SCORE_DEVIATION_ERROR: ClassVar[float] = 0.001 + + @root_validator + def validate_record(cls, values): + """fastapi validator method""" + prediction = values.get("prediction", None) + annotation = values.get("annotation", None) + status = values.get("status") + multi_label = values.get("multi_label", False) + + cls._check_score_integrity(prediction, multi_label) + cls._check_annotation_integrity(annotation, multi_label, status) + + return values + + @classmethod + def _check_annotation_integrity( + cls, + annotation: TextClassificationAnnotation, + multi_label: bool, + status: TaskStatus, + ): + if status == TaskStatus.validated and not multi_label: + assert ( + annotation and len(annotation.labels) > 0 + ), "Annotation must include some label for validated records" + + if not multi_label and annotation: + assert ( + len(annotation.labels) == 1 + ), "Single label record must include only one annotation label" + + @classmethod + def _check_score_integrity( + cls, prediction: TextClassificationAnnotation, multi_label: bool + ): + """ + Checks the score value integrity + + Parameters + ---------- + prediction: + The prediction annotation + multi_label: + If multi label + + """ + if prediction and not multi_label: + assert sum([label.score for label in prediction.labels]) <= ( + 1.0 + cls._SCORE_DEVIATION_ERROR + ), f"Wrong score distributions: {prediction.labels}" + + @classmethod + def task(cls) -> TaskType: + """The task type""" + return TaskType.text_classification + + @property + def predicted(self) -> Optional[PredictionStatus]: + if self.predicted_by and self.annotated_by: + return ( + PredictionStatus.OK + if set(self.predicted_as) == set(self.annotated_as) + else PredictionStatus.KO + ) + return None + + @property + def predicted_as(self) -> List[str]: + return self._labels_from_annotation( + self.prediction, multi_label=self.multi_label + ) + + @property + def annotated_as(self) -> List[str]: + return self._labels_from_annotation( + self.annotation, multi_label=self.multi_label + ) + + @property + def scores(self) -> List[float]: + if not self.prediction: + return [] + return ( + [label.score for label in self.prediction.labels] + if self.multi_label + else [ + prediction_class.score + for prediction_class in [ + self._max_class_prediction( + self.prediction, multi_label=self.multi_label + ) + ] + if prediction_class + ] + ) + + def all_text(self) -> str: + sentences = [] + for v in self.inputs.values(): + if isinstance(v, list): + sentences.extend(v) + else: + sentences.append(v) + return "\n".join(sentences) + + @validator("inputs") + def validate_inputs(cls, text: Dict[str, Any]): + assert len(text) > 0, "No inputs provided" + + for t in text.values(): + assert t is not None, "Cannot include None fields" + + return text + + @validator("inputs") + def flatten_text(cls, text: Dict[str, Any]): + flat_dict = flatten_dict(text) + return flat_dict + + @classmethod + def _labels_from_annotation( + cls, annotation: TextClassificationAnnotation, multi_label: bool + ) -> Union[List[str], List[int]]: + + if not annotation: + return [] + + if multi_label: + return [ + label.class_label for label in annotation.labels if label.score > 0.5 + ] + + class_prediction = cls._max_class_prediction( + annotation, multi_label=multi_label + ) + if class_prediction is None: + return [] + + return [class_prediction.class_label] + + @staticmethod + def _max_class_prediction( + p: TextClassificationAnnotation, multi_label: bool + ) -> Optional[ClassPrediction]: + if multi_label or p is None or not p.labels: + return None + return p.labels[0] + + def extended_fields(self) -> Dict[str, Any]: + words = self.all_text() + return { + **super().extended_fields(), + "words": words, + "text": words, + } + + +class ServiceTextClassificationQuery(ServiceBaseRecordsQuery): + + predicted_as: List[str] = Field(default_factory=list) + annotated_as: List[str] = Field(default_factory=list) + score: Optional[ServiceScoreRange] = Field(default=None) + predicted: Optional[PredictionStatus] = Field(default=None, nullable=True) + + uncovered_by_rules: List[str] = Field(default_factory=list) diff --git a/src/rubrix/server/services/text_classification.py b/src/rubrix/server/services/tasks/text_classification/service.py similarity index 67% rename from src/rubrix/server/services/text_classification.py rename to src/rubrix/server/services/tasks/text_classification/service.py index 1c45ff244e..3a372ae443 100644 --- a/src/rubrix/server/services/text_classification.py +++ b/src/rubrix/server/services/tasks/text_classification/service.py @@ -13,33 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterable, List, Optional, Type +from typing import Iterable, List, Optional from fastapi import Depends -from rubrix.server.apis.v0.models.commons.model import ( - BulkResponse, - EsRecordDataFieldNames, - SortableField, +from rubrix.server.commons.config import TasksFactory +from rubrix.server.errors.base_errors import MissingDatasetRecordsError +from rubrix.server.services.search.model import ( + ServiceSearchResults, + ServiceSortableField, + ServiceSortConfig, ) -from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics -from rubrix.server.apis.v0.models.text_classification import ( - CreationTextClassificationRecord, +from rubrix.server.services.search.service import SearchRecordsService +from rubrix.server.services.storage.service import RecordsStorageService +from rubrix.server.services.tasks.commons import BulkResponse +from rubrix.server.services.tasks.text_classification import LabelingService +from rubrix.server.services.tasks.text_classification.model import ( DatasetLabelingRulesMetricsSummary, - LabelingRule, LabelingRuleMetricsSummary, - TextClassificationDatasetDB, - TextClassificationQuery, - TextClassificationRecord, - TextClassificationRecordDB, - TextClassificationSearchAggregations, - TextClassificationSearchResults, + ServiceLabelingRule, + ServiceTextClassificationDataset, + ServiceTextClassificationQuery, + ServiceTextClassificationRecord, ) -from rubrix.server.errors.base_errors import MissingDatasetRecordsError -from rubrix.server.services.search.model import SortConfig -from rubrix.server.services.search.service import SearchRecordsService -from rubrix.server.services.storage.service import RecordsStorageService -from rubrix.server.services.text_classification_labelling_rules import LabelingService class TextClassificationService: @@ -73,33 +69,28 @@ def __init__( def add_records( self, - dataset: TextClassificationDatasetDB, - mappings: Dict[str, Any], - records: List[CreationTextClassificationRecord], - metrics: Type[BaseTaskMetrics], + dataset: ServiceTextClassificationDataset, + records: List[ServiceTextClassificationRecord], ): # TODO(@frascuchon): This will moved to dataset settings validation once DatasetSettings join the game! self._check_multi_label_integrity(dataset, records) failed = self.__storage__.store_records( dataset=dataset, - mappings=mappings, records=records, - record_type=TextClassificationRecordDB, - metrics=metrics, + record_type=ServiceTextClassificationRecord, ) return BulkResponse(dataset=dataset.name, processed=len(records), failed=failed) def search( self, - dataset: TextClassificationDatasetDB, - query: TextClassificationQuery, - sort_by: List[SortableField], + dataset: ServiceTextClassificationDataset, + query: ServiceTextClassificationQuery, + sort_by: List[ServiceSortableField], record_from: int = 0, size: int = 100, exclude_metrics: bool = True, - metrics: Optional[List[BaseMetric]] = None, - ) -> TextClassificationSearchResults: + ) -> ServiceSearchResults: """ Run a search in a dataset @@ -121,28 +112,31 @@ def search( """ + metrics = TasksFactory.find_task_metrics( + dataset.task, + metric_ids={ + "words_cloud", + "predicted_by", + "predicted_as", + "annotated_by", + "annotated_as", + "error_distribution", + "status_distribution", + "metadata", + "score", + }, + ) + results = self.__search__.search( dataset, query=query, - record_type=TextClassificationRecord, + record_type=ServiceTextClassificationRecord, record_from=record_from, size=size, exclude_metrics=exclude_metrics, metrics=metrics, - sort_config=SortConfig( + sort_config=ServiceSortConfig( sort_by=sort_by, - valid_fields=[ - "metadata", - EsRecordDataFieldNames.last_updated, - EsRecordDataFieldNames.score, - EsRecordDataFieldNames.predicted, - EsRecordDataFieldNames.predicted_as, - EsRecordDataFieldNames.predicted_by, - EsRecordDataFieldNames.annotated_as, - EsRecordDataFieldNames.annotated_by, - EsRecordDataFieldNames.status, - EsRecordDataFieldNames.event_timestamp, - ], ), ) @@ -152,21 +146,15 @@ def search( results.metrics["predicted"] = results.metrics["error_distribution"] results.metrics["predicted"].pop("unknown", None) - return TextClassificationSearchResults( - total=results.total, - records=results.records, - aggregations=TextClassificationSearchAggregations.parse_obj(results.metrics) - if results.metrics - else None, - ) + return results def read_dataset( self, - dataset: TextClassificationDatasetDB, - query: Optional[TextClassificationQuery] = None, + dataset: ServiceTextClassificationDataset, + query: Optional[ServiceTextClassificationQuery] = None, id_from: Optional[str] = None, limit: int = 1000 - ) -> Iterable[TextClassificationRecord]: + ) -> Iterable[ServiceTextClassificationRecord]: """ Scan a dataset records @@ -184,13 +172,13 @@ def read_dataset( """ yield from self.__search__.scan_records( - dataset, query=query, record_type=TextClassificationRecord, id_from=id_from, limit=limit + dataset, query=query, record_type=ServiceTextClassificationRecord, id_from=id_from, limit=limit ) def _check_multi_label_integrity( self, - dataset: TextClassificationDatasetDB, - records: List[CreationTextClassificationRecord], + dataset: ServiceTextClassificationDataset, + records: List[ServiceTextClassificationRecord], ): is_multi_label_dataset = self._is_dataset_multi_label(dataset) if is_multi_label_dataset is not None: @@ -203,12 +191,12 @@ def _check_multi_label_integrity( ) def _is_dataset_multi_label( - self, dataset: TextClassificationDatasetDB + self, dataset: ServiceTextClassificationDataset ) -> Optional[bool]: try: results = self.__search__.search( dataset, - record_type=TextClassificationRecord, + record_type=ServiceTextClassificationRecord, size=1, ) except MissingDatasetRecordsError: # No records index yet @@ -217,25 +205,13 @@ def _is_dataset_multi_label( return results.records[0].multi_label def get_labeling_rules( - self, dataset: TextClassificationDatasetDB - ) -> Iterable[LabelingRule]: - """ - Gets rules for a given dataset + self, dataset: ServiceTextClassificationDataset + ) -> Iterable[ServiceLabelingRule]: - Parameters - ---------- - dataset: - The dataset - - Returns - ------- - A list of labeling rules for a given dataset - - """ return self.__labeling__.list_rules(dataset) def add_labeling_rule( - self, dataset: TextClassificationDatasetDB, rule: LabelingRule + self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule ) -> None: """ Adds a labeling rule @@ -253,24 +229,11 @@ def add_labeling_rule( def update_labeling_rule( self, - dataset: TextClassificationDatasetDB, + dataset: ServiceTextClassificationDataset, rule_query: str, labels: List[str], description: Optional[str] = None, - ) -> LabelingRule: - """ - Update a labeling rule. Updatable fields are label and/or description - - Args: - dataset: The dataset - rule_query: The labeling rule - label: The new rule label - description: If provided, the new rule description - - Returns: - Updated labeling rule - - """ + ) -> ServiceLabelingRule: found_rule = self.__labeling__.find_rule_by_query(dataset, rule_query) found_rule.labels = labels @@ -282,44 +245,20 @@ def update_labeling_rule( self.__labeling__.replace_rule(dataset, found_rule) return found_rule - def find_labeling_rule(self, dataset: TextClassificationDatasetDB, rule_query: str): - """ - Find a labeling rule given a rule query string - - Args: - dataset: The dataset - rule_query: The query string - - Returns: - Found labeling rule. - If rule was not found EntityNotFoundError is raised - """ + def find_labeling_rule( + self, dataset: ServiceTextClassificationDataset, rule_query: str + ) -> ServiceLabelingRule: return self.__labeling__.find_rule_by_query(dataset, rule_query=rule_query) def delete_labeling_rule( - self, dataset: TextClassificationDatasetDB, rule_query: str + self, dataset: ServiceTextClassificationDataset, rule_query: str ): - """ - Deletes a rule from a dataset. - - Nothing happens if the rule does not exist in dataset. - - Parameters - ---------- - - dataset: - The dataset - - rule_query: - The rule query - - """ if rule_query.strip(): return self.__labeling__.delete_rule(dataset, rule_query) def compute_rule_metrics( self, - dataset: TextClassificationDatasetDB, + dataset: ServiceTextClassificationDataset, rule_query: str, labels: Optional[List[str]] = None, ) -> LabelingRuleMetricsSummary: @@ -366,6 +305,7 @@ def compute_rule_metrics( coverage_annotated = ( metrics.annotated_covered_records / annotated if annotated > 0 else None ) + return LabelingRuleMetricsSummary( total_records=total, annotated_records=annotated, @@ -376,7 +316,7 @@ def compute_rule_metrics( precision=metrics.precision if annotated > 0 else None, ) - def compute_overall_rules_metrics(self, dataset: TextClassificationDatasetDB): + def compute_overall_rules_metrics(self, dataset: ServiceTextClassificationDataset): total, annotated, metrics = self.__labeling__.all_rules_metrics(dataset) coverage = metrics.covered_records / total if total else None coverage_annotated = ( @@ -390,7 +330,7 @@ def compute_overall_rules_metrics(self, dataset: TextClassificationDatasetDB): ) @staticmethod - def __normalized_rule__(rule: LabelingRule) -> LabelingRule: + def __normalized_rule__(rule: ServiceLabelingRule) -> ServiceLabelingRule: if rule.labels and len(rule.labels) == 1: rule.label = rule.labels[0] elif rule.label and not rule.labels: diff --git a/src/rubrix/server/services/tasks/token_classification/__init__.py b/src/rubrix/server/services/tasks/token_classification/__init__.py new file mode 100644 index 0000000000..db4b4a5510 --- /dev/null +++ b/src/rubrix/server/services/tasks/token_classification/__init__.py @@ -0,0 +1 @@ +from .service import TokenClassificationService diff --git a/src/rubrix/server/services/tasks/token_classification/metrics.py b/src/rubrix/server/services/tasks/token_classification/metrics.py new file mode 100644 index 0000000000..f574cde4c5 --- /dev/null +++ b/src/rubrix/server/services/tasks/token_classification/metrics.py @@ -0,0 +1,407 @@ +from typing import Any, ClassVar, Dict, Iterable, List, Optional, Set, Tuple + +from pydantic import BaseModel, Field + +from rubrix.server.services.metrics import ServiceBaseMetric, ServicePythonMetric +from rubrix.server.services.metrics.models import CommonTasksMetrics +from rubrix.server.services.tasks.token_classification.model import ( + EntitySpan, + ServiceTokenClassificationRecord, +) + + +class F1Metric(ServicePythonMetric[ServiceTokenClassificationRecord]): + """The F1 metric based on entity-level. + + We follow the convention of `CoNLL 2003 `_, where: + `"precision is the percentage of named entities found by the learning system that are correct. + Recall is the percentage of named entities present in the corpus that are found by the system. + A named entity is correct only if it is an exact match (...).”` + """ + + def apply( + self, records: Iterable[ServiceTokenClassificationRecord] + ) -> Dict[str, Any]: + # store entities per label in dicts + predicted_entities = {} + annotated_entities = {} + + # extract entities per label to dicts + for rec in records: + if rec.prediction: + self._add_entities_to_dict(rec.prediction.entities, predicted_entities) + if rec.annotation: + self._add_entities_to_dict(rec.annotation.entities, annotated_entities) + + # store precision, recall, and f1 per label + per_label_metrics = {} + + annotated_total, predicted_total, correct_total = 0, 0, 0 + precision_macro, recall_macro = 0, 0 + for label, annotated in annotated_entities.items(): + predicted = predicted_entities.get(label, set()) + correct = len(annotated & predicted) + + # safe divides are used to cover the 0/0 cases + precision = self._safe_divide(correct, len(predicted)) + recall = self._safe_divide(correct, len(annotated)) + per_label_metrics.update( + { + f"{label}_precision": precision, + f"{label}_recall": recall, + f"{label}_f1": self._safe_divide( + 2 * precision * recall, precision + recall + ), + } + ) + + annotated_total += len(annotated) + predicted_total += len(predicted) + correct_total += correct + + precision_macro += precision / len(annotated_entities) + recall_macro += recall / len(annotated_entities) + + # store macro and micro averaged precision, recall and f1 + averaged_metrics = { + "precision_macro": precision_macro, + "recall_macro": recall_macro, + "f1_macro": self._safe_divide( + 2 * precision_macro * recall_macro, precision_macro + recall_macro + ), + } + + precision_micro = self._safe_divide(correct_total, predicted_total) + recall_micro = self._safe_divide(correct_total, annotated_total) + averaged_metrics.update( + { + "precision_micro": precision_micro, + "recall_micro": recall_micro, + "f1_micro": self._safe_divide( + 2 * precision_micro * recall_micro, precision_micro + recall_micro + ), + } + ) + + return {**averaged_metrics, **per_label_metrics} + + @staticmethod + def _add_entities_to_dict( + entities: List[EntitySpan], dictionary: Dict[str, Set[Tuple[int, int]]] + ): + """Helper function for the apply method.""" + for ent in entities: + try: + dictionary[ent.label].add((ent.start, ent.end)) + except KeyError: + dictionary[ent.label] = {(ent.start, ent.end)} + + @staticmethod + def _safe_divide(numerator, denominator): + """Helper function for the apply method.""" + try: + return numerator / denominator + except ZeroDivisionError: + return 0 + + +class DatasetLabels(ServicePythonMetric): + id: str = Field("dataset_labels", const=True) + name: str = Field("The dataset entity labels", const=True) + max_processed_records: int = 10000 + + def apply( + self, records: Iterable[ServiceTokenClassificationRecord] + ) -> Dict[str, Any]: + ds_labels = set() + + for _ in range( + 0, self.max_processed_records + ): # Only a few of records will be parsed + record: ServiceTokenClassificationRecord = next(records, None) + if record is None: + break + + if record.annotation: + ds_labels.update( + [entity.label for entity in record.annotation.entities] + ) + if record.prediction: + ds_labels.update( + [entity.label for entity in record.prediction.entities] + ) + return {"labels": ds_labels or []} + + +class MentionMetrics(BaseModel): + """Mention metrics model""" + + value: str + label: str + score: float = Field(ge=0.0) + capitalness: Optional[str] = Field(None) + density: float = Field(ge=0.0) + tokens_length: int = Field(g=0) + chars_length: int = Field(g=0) + + +class TokenTagMetrics(BaseModel): + value: str + tag: str + + +class TokenMetrics(BaseModel): + """ + Token metrics stored in elasticsearch for token classification + + Attributes + idx: The token index in sentence + value: The token textual value + char_start: The token character start position in sentence + char_end: The token character end position in sentence + score: Token score info + tag: Token tag info. Deprecated: Use metrics.predicted.tags or metrics.annotated.tags instead + custom: extra token level info + """ + + idx: int + value: str + char_start: int + char_end: int + length: int + capitalness: Optional[str] = None + score: Optional[float] = None + tag: Optional[str] = None # TODO: remove! + custom: Dict[str, Any] = None + + +class TokenClassificationMetrics(CommonTasksMetrics[ServiceTokenClassificationRecord]): + """Configured metrics for token classification""" + + @staticmethod + def density(value: int, sentence_length: int) -> float: + """Compute the string density over a sentence""" + return value / sentence_length + + @staticmethod + def capitalness(value: str) -> Optional[str]: + """Compute capitalness for a string value""" + value = value.strip() + if not value: + return None + if value.isupper(): + return "UPPER" + if value.islower(): + return "LOWER" + if value[0].isupper(): + return "FIRST" + if any([c.isupper() for c in value[1:]]): + return "MIDDLE" + return None + + @staticmethod + def mentions_metrics( + record: ServiceTokenClassificationRecord, mentions: List[Tuple[str, EntitySpan]] + ): + def mention_tokens_length(entity: EntitySpan) -> int: + """Compute mention tokens length""" + return len( + set( + [ + token_idx + for i in range(entity.start, entity.end) + for token_idx in [record.char_id2token_id(i)] + if token_idx is not None + ] + ) + ) + + return [ + MentionMetrics( + value=mention, + label=entity.label, + score=entity.score, + capitalness=TokenClassificationMetrics.capitalness(mention), + density=TokenClassificationMetrics.density( + _tokens_length, sentence_length=len(record.tokens) + ), + tokens_length=_tokens_length, + chars_length=len(mention), + ) + for mention, entity in mentions + for _tokens_length in [ + mention_tokens_length(entity), + ] + ] + + @classmethod + def build_tokens_metrics( + cls, record: ServiceTokenClassificationRecord, tags: Optional[List[str]] = None + ) -> List[TokenMetrics]: + + return [ + TokenMetrics( + idx=token_idx, + value=token_value, + char_start=char_start, + char_end=char_end, + capitalness=cls.capitalness(token_value), + length=1 + (char_end - char_start), + tag=tags[token_idx] if tags else None, + ) + for token_idx, token_value in enumerate(record.tokens) + for char_start, char_end in [record.token_span(token_idx)] + ] + + @classmethod + def record_metrics(cls, record: ServiceTokenClassificationRecord) -> Dict[str, Any]: + """Compute metrics at record level""" + base_metrics = super(TokenClassificationMetrics, cls).record_metrics(record) + + annotated_tags = record.annotated_iob_tags() or [] + predicted_tags = record.predicted_iob_tags() or [] + + tokens_metrics = cls.build_tokens_metrics( + record, predicted_tags or annotated_tags + ) + return { + **base_metrics, + "tokens": tokens_metrics, + "tokens_length": len(record.tokens), + "predicted": { + "mentions": cls.mentions_metrics(record, record.predicted_mentions()), + "tags": [ + TokenTagMetrics(tag=tag, value=token) + for tag, token in zip(predicted_tags, record.tokens) + ], + }, + "annotated": { + "mentions": cls.mentions_metrics(record, record.annotated_mentions()), + "tags": [ + TokenTagMetrics(tag=tag, value=token) + for tag, token in zip(annotated_tags, record.tokens) + ], + }, + } + + metrics: ClassVar[List[ServiceBaseMetric]] = ( + CommonTasksMetrics.metrics + + [ + DatasetLabels(), + F1Metric( + id="F1", + name="F1 ServiceBaseMetric based on entity-level", + description="F1 metrics based on entity-level (averaged and per label), " + "where only exact matches count (CoNNL 2003).", + ), + ] + + [ + ServiceBaseMetric( + id="predicted_as", + name="Predicted labels distribution", + ), + ServiceBaseMetric( + id="annotated_as", + name="Annotated labels distribution", + ), + ServiceBaseMetric( + id="tokens_length", + name="Tokens length", + description="Computes the text length distribution measured in number of tokens", + ), + ServiceBaseMetric( + id="token_frequency", + name="Tokens frequency distribution", + ), + ServiceBaseMetric( + id="token_length", + name="Token length distribution", + description="Computes token length distribution in number of characters", + ), + ServiceBaseMetric( + id="token_capitalness", + name="Token capitalness distribution", + description="Computes capitalization information of tokens", + ), + ServiceBaseMetric( + id="predicted_entity_density", + name="Mention entity density for predictions", + description="Computes the ratio between the number of all entity tokens and tokens in the text", + ), + ServiceBaseMetric( + id="predicted_entity_labels", + name="Predicted entity labels", + description="Predicted entity labels distribution", + ), + ServiceBaseMetric( + id="predicted_entity_capitalness", + name="Mention entity capitalness for predictions", + description="Computes capitalization information of predicted entity mentions", + ), + ServiceBaseMetric( + id="predicted_mention_token_length", + name="Predicted mention tokens length", + description="Computes the length of the predicted entity mention measured in number of tokens", + ), + ServiceBaseMetric( + id="predicted_mention_char_length", + name="Predicted mention characters length", + description="Computes the length of the predicted entity mention measured in number of tokens", + ), + ServiceBaseMetric( + id="predicted_mentions_distribution", + name="Predicted mentions distribution by entity", + description="Computes predicted mentions distribution against its labels", + ), + ServiceBaseMetric( + id="predicted_entity_consistency", + name="Entity label consistency for predictions", + description="Computes entity label variability for top-k predicted entity mentions", + ), + ServiceBaseMetric( + id="predicted_tag_consistency", + name="Token tag consistency for predictions", + description="Computes token tag variability for top-k predicted tags", + ), + ServiceBaseMetric( + id="annotated_entity_density", + name="Mention entity density for annotations", + description="Computes the ratio between the number of all entity tokens and tokens in the text", + ), + ServiceBaseMetric( + id="annotated_entity_labels", + name="Annotated entity labels", + description="Annotated Entity labels distribution", + ), + ServiceBaseMetric( + id="annotated_entity_capitalness", + name="Mention entity capitalness for annotations", + description="Compute capitalization information of annotated entity mentions", + ), + ServiceBaseMetric( + id="annotated_mention_token_length", + name="Annotated mention tokens length", + description="Computes the length of the entity mention measured in number of tokens", + ), + ServiceBaseMetric( + id="annotated_mention_char_length", + name="Annotated mention characters length", + description="Computes the length of the entity mention measured in number of tokens", + ), + ServiceBaseMetric( + id="annotated_mentions_distribution", + name="Annotated mentions distribution by entity", + description="Computes annotated mentions distribution against its labels", + ), + ServiceBaseMetric( + id="annotated_entity_consistency", + name="Entity label consistency for annotations", + description="Computes entity label variability for top-k annotated entity mentions", + ), + ServiceBaseMetric( + id="annotated_tag_consistency", + name="Token tag consistency for annotations", + description="Computes token tag variability for top-k annotated tags", + ), + ] + ) diff --git a/src/rubrix/server/services/tasks/token_classification/model.py b/src/rubrix/server/services/tasks/token_classification/model.py new file mode 100644 index 0000000000..2f307f1348 --- /dev/null +++ b/src/rubrix/server/services/tasks/token_classification/model.py @@ -0,0 +1,327 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +from datetime import datetime +from typing import Any, Dict, List, Optional, Set, Tuple + +from pydantic import BaseModel, Field, validator + +from rubrix._constants import MAX_KEYWORD_LENGTH +from rubrix.server.commons.models import PredictionStatus, TaskType +from rubrix.server.services.datasets import ServiceBaseDataset +from rubrix.server.services.search.model import ( + ServiceBaseRecordsQuery, + ServiceScoreRange, +) +from rubrix.server.services.tasks.commons import ( + ServiceBaseAnnotation, + ServiceBaseRecord, +) + +PREDICTED_MENTIONS_ES_FIELD_NAME = "predicted_mentions" +MENTIONS_ES_FIELD_NAME = "mentions" + + +class EntitySpan(BaseModel): + """ + The tokens span for a labeled text. + + Entity spans will be defined between from start to end - 1 + + Attributes: + ----------- + + start: int + character start position + end: int + character end position, must be higher than the starting character. + label: str + the label related to tokens that conforms the entity span + score: + A higher score means, the model/annotator is more confident about its predicted/annotated entity. + """ + + start: int + end: int + label: str = Field(min_length=1, max_length=MAX_KEYWORD_LENGTH) + score: float = Field(default=1.0, ge=0.0, le=1.0) + + @validator("end") + def check_span_offset(cls, end: int, values): + """Validates span offset""" + assert ( + end > values["start"] + ), "End character cannot be placed before the starting character, it must be at least one character after." + return end + + def __hash__(self): + return hash(type(self)) + hash(self.__dict__.values()) + + +class ServiceTokenClassificationAnnotation(ServiceBaseAnnotation): + entities: List[EntitySpan] = Field(default_factory=list) + score: Optional[float] = None + + +class ServiceTokenClassificationRecord( + ServiceBaseRecord[ServiceTokenClassificationAnnotation] +): + + tokens: List[str] = Field(min_items=1) + text: str = Field() + _raw_text: Optional[str] = Field(alias="raw_text") + + __chars2tokens__: Dict[int, int] = None + __tokens2chars__: Dict[int, Tuple[int, int]] = None + + # TODO: review this. + _predicted: Optional[PredictionStatus] = Field(alias="predicted") + + def extended_fields(self) -> Dict[str, Any]: + + return { + **super().extended_fields(), + # See ../service/service.py + PREDICTED_MENTIONS_ES_FIELD_NAME: [ + {"mention": mention, "entity": entity.label, "score": entity.score} + for mention, entity in self.predicted_mentions() + ], + MENTIONS_ES_FIELD_NAME: [ + {"mention": mention, "entity": entity.label} + for mention, entity in self.annotated_mentions() + ], + "words": self.all_text(), + } + + def __init__(self, **data): + super().__init__(**data) + + self.__chars2tokens__, self.__tokens2chars__ = self.__build_indices_map__() + + self.check_annotation(self.prediction) + self.check_annotation(self.annotation) + + def char_id2token_id(self, char_idx: int) -> Optional[int]: + return self.__chars2tokens__.get(char_idx) + + def token_span(self, token_idx: int) -> Tuple[int, int]: + if token_idx not in self.__tokens2chars__: + raise IndexError(f"Token id {token_idx} out of bounds") + return self.__tokens2chars__[token_idx] + + def __build_indices_map__( + self, + ) -> Tuple[Dict[int, int], Dict[int, Tuple[int, int]]]: + """ + Build the indices mapping between text characters and tokens where belongs to, + and vice versa. + + chars2tokens index contains is the token idx where i char is contained (if any). + + Out-of-token characters won't be included in this map, + so access should be using ``chars2tokens_map.get(i)`` + instead of ``chars2tokens_map[i]``. + + """ + + def chars2tokens_index(): + def is_space_after_token(char, idx: int, chars_map) -> str: + return char == " " and idx - 1 in chars_map + + chars_map = {} + current_token = 0 + current_token_char_start = 0 + for idx, char in enumerate(self.text): + if is_space_after_token(char, idx, chars_map): + continue + relative_idx = idx - current_token_char_start + if ( + relative_idx < len(self.tokens[current_token]) + and char == self.tokens[current_token][relative_idx] + ): + chars_map[idx] = current_token + elif ( + current_token + 1 < len(self.tokens) + and relative_idx >= len(self.tokens[current_token]) + and char == self.tokens[current_token + 1][0] + ): + current_token += 1 + current_token_char_start += relative_idx + chars_map[idx] = current_token + + return chars_map + + def tokens2chars_index( + chars2tokens: Dict[int, int] + ) -> Dict[int, Tuple[int, int]]: + tokens2chars_map = defaultdict(list) + for c, t in chars2tokens.items(): + tokens2chars_map[t].append(c) + + return { + token_idx: (min(chars), max(chars)) + for token_idx, chars in tokens2chars_map.items() + } + + chars2tokens_idx = chars2tokens_index() + return chars2tokens_idx, tokens2chars_index(chars2tokens_idx) + + def check_annotation( + self, + annotation: Optional[ServiceTokenClassificationAnnotation], + ): + """Validates entities in terms of offset spans""" + + def adjust_span_bounds(start, end): + if start < 0: + start = 0 + if entity.end > len(self.text): + end = len(self.text) + while start <= len(self.text) and not self.text[start].strip(): + start += 1 + while not self.text[end - 1].strip(): + end -= 1 + return start, end + + if annotation: + for entity in annotation.entities: + entity.start, entity.end = adjust_span_bounds(entity.start, entity.end) + mention = self.text[entity.start : entity.end] + assert len(mention) > 0, f"Empty offset defined for entity {entity}" + + token_start = self.char_id2token_id(entity.start) + token_end = self.char_id2token_id(entity.end - 1) + + assert not ( + token_start is None or token_end is None + ), f"Provided entity span {self.text[entity.start: entity.end]} is not aligned with provided tokens." + "Some entity chars could be reference characters out of tokens" + + span_start, _ = self.token_span(token_start) + _, span_end = self.token_span(token_end) + + assert ( + self.text[span_start : span_end + 1] == mention + ), f"Defined offset [{self.text[entity.start: entity.end]}] is a misaligned entity mention" + + def task(cls) -> TaskType: + """The record task type""" + return TaskType.token_classification + + @property + def predicted(self) -> Optional[PredictionStatus]: + if self.annotation and self.prediction: + return ( + PredictionStatus.OK + if self.annotation.entities == self.prediction.entities + else PredictionStatus.KO + ) + return None + + @property + def predicted_as(self) -> List[str]: + return [ent.label for ent in self.predicted_entities()] + + @property + def annotated_as(self) -> List[str]: + return [ent.label for ent in self.annotated_entities()] + + @property + def scores(self) -> List[float]: + if not self.prediction: + return [] + if self.prediction.score is not None: + return [self.prediction.score] + return [e.score for e in self.prediction.entities] + + def all_text(self) -> str: + return self.text + + def predicted_iob_tags(self) -> Optional[List[str]]: + if self.prediction is None: + return None + return self.spans2iob(self.prediction.entities) + + def annotated_iob_tags(self) -> Optional[List[str]]: + if self.annotation is None: + return None + return self.spans2iob(self.annotation.entities) + + def spans2iob(self, spans: List[EntitySpan]) -> Optional[List[str]]: + if spans is None: + return None + tags = ["O"] * len(self.tokens) + for entity in spans: + token_start = self.char_id2token_id(entity.start) + token_end = self.char_id2token_id(entity.end - 1) + tags[token_start] = f"B-{entity.label}" + for idx in range(token_start + 1, token_end + 1): + tags[idx] = f"I-{entity.label}" + + return tags + + def predicted_mentions(self) -> List[Tuple[str, EntitySpan]]: + return [ + (mention, entity) + for mention, entity in self.__mentions_from_entities__( + self.predicted_entities() + ).items() + ] + + def annotated_mentions(self) -> List[Tuple[str, EntitySpan]]: + return [ + (mention, entity) + for mention, entity in self.__mentions_from_entities__( + self.annotated_entities() + ).items() + ] + + def annotated_entities(self) -> Set[EntitySpan]: + """Shortcut for real annotated entities, if provided""" + if self.annotation is None: + return set() + return set(self.annotation.entities) + + def predicted_entities(self) -> Set[EntitySpan]: + """Predicted entities""" + if self.prediction is None: + return set() + return set(self.prediction.entities) + + def __mentions_from_entities__( + self, entities: Set[EntitySpan] + ) -> Dict[str, EntitySpan]: + return { + mention: entity + for entity in entities + for mention in [self.text[entity.start : entity.end]] + } + + class Config: + allow_population_by_field_name = True + underscore_attrs_are_private = True + + +class ServiceTokenClassificationQuery(ServiceBaseRecordsQuery): + + predicted_as: List[str] = Field(default_factory=list) + annotated_as: List[str] = Field(default_factory=list) + score: Optional[ServiceScoreRange] = Field(default=None) + predicted: Optional[PredictionStatus] = Field(default=None, nullable=True) + + +class ServiceTokenClassificationDataset(ServiceBaseDataset): + task: TaskType = Field(default=TaskType.token_classification, const=True) + pass diff --git a/src/rubrix/server/services/token_classification.py b/src/rubrix/server/services/tasks/token_classification/service.py similarity index 53% rename from src/rubrix/server/services/token_classification.py rename to src/rubrix/server/services/tasks/token_classification/service.py index 9c40e852a0..2bdd63afdf 100644 --- a/src/rubrix/server/services/token_classification.py +++ b/src/rubrix/server/services/tasks/token_classification/service.py @@ -13,28 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterable, List, Optional, Type +from typing import Iterable, List, Optional, Type from fastapi import Depends -from rubrix.server.apis.v0.models.commons.model import ( - BulkResponse, - EsRecordDataFieldNames, - SortableField, -) -from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics -from rubrix.server.apis.v0.models.token_classification import ( - CreationTokenClassificationRecord, - TokenClassificationAggregations, - TokenClassificationDatasetDB, - TokenClassificationQuery, - TokenClassificationRecord, - TokenClassificationRecordDB, - TokenClassificationSearchResults, -) -from rubrix.server.services.search.model import SortConfig +from rubrix.server.commons.config import TasksFactory +from rubrix.server.daos.backend.search.model import SortableField +from rubrix.server.services.metrics.models import ServiceBaseTaskMetrics +from rubrix.server.services.search.model import ServiceSearchResults, ServiceSortConfig from rubrix.server.services.search.service import SearchRecordsService from rubrix.server.services.storage.service import RecordsStorageService +from rubrix.server.services.tasks.commons import BulkResponse +from rubrix.server.services.tasks.token_classification.model import ( + ServiceTokenClassificationDataset, + ServiceTokenClassificationQuery, + ServiceTokenClassificationRecord, +) class TokenClassificationService: @@ -65,74 +59,73 @@ def __init__( def add_records( self, - dataset: TokenClassificationDatasetDB, - mappings: Dict[str, Any], - records: List[CreationTokenClassificationRecord], - metrics: Type[BaseTaskMetrics], + dataset: ServiceTokenClassificationDataset, + records: List[ServiceTokenClassificationRecord], ): failed = self.__storage__.store_records( dataset=dataset, - mappings=mappings, records=records, - record_type=TokenClassificationRecordDB, - metrics=metrics, + record_type=ServiceTokenClassificationRecord, ) return BulkResponse(dataset=dataset.name, processed=len(records), failed=failed) def search( self, - dataset: TokenClassificationDatasetDB, - query: TokenClassificationQuery, + dataset: ServiceTokenClassificationDataset, + query: ServiceTokenClassificationQuery, sort_by: List[SortableField], record_from: int = 0, size: int = 100, exclude_metrics: bool = True, - metrics: Optional[List[BaseMetric]] = None, - ) -> TokenClassificationSearchResults: - """ - Run a search in a dataset - - Parameters - ---------- - dataset: - The records dataset - query: - The search parameters - sort_by: - The sort by order list - record_from: - The record from return results - size: - The max number of records to return - - Returns - ------- - The matched records with aggregation info for specified task_meta.py + ) -> ServiceSearchResults: """ + Run a search in a dataset + + Parameters + ---------- + dataset: + The records dataset + query: + The search parameters + sort_by: + The sort by order list + record_from: + The record from return results + size: + The max number of records to return + + Returns + ------- + The matched records with aggregation info for specified task_meta.py + + """ + metrics = TasksFactory.find_task_metrics( + dataset.task, + metric_ids={ + "words_cloud", + "predicted_by", + "predicted_as", + "annotated_by", + "annotated_as", + "error_distribution", + "predicted_mentions_distribution", + "annotated_mentions_distribution", + "status_distribution", + "metadata", + "score", + }, + ) + results = self.__search__.search( dataset, query=query, - record_type=TokenClassificationRecord, + record_type=ServiceTokenClassificationRecord, size=size, + metrics=metrics, record_from=record_from, exclude_metrics=exclude_metrics, - metrics=metrics, - sort_config=SortConfig( - sort_by=sort_by, - valid_fields=[ - "metadata", - EsRecordDataFieldNames.last_updated, - EsRecordDataFieldNames.score, - EsRecordDataFieldNames.predicted, - EsRecordDataFieldNames.predicted_as, - EsRecordDataFieldNames.predicted_by, - EsRecordDataFieldNames.annotated_as, - EsRecordDataFieldNames.annotated_by, - EsRecordDataFieldNames.status, - EsRecordDataFieldNames.event_timestamp, - ], - ), + sort_config=ServiceSortConfig(sort_by=sort_by), ) if results.metrics: @@ -147,21 +140,15 @@ def search( "predicted_mentions_distribution" ] - return TokenClassificationSearchResults( - total=results.total, - records=results.records, - aggregations=TokenClassificationAggregations.parse_obj(results.metrics) - if results.metrics - else None, - ) + return results def read_dataset( self, - dataset: TokenClassificationDatasetDB, - query: TokenClassificationQuery, + dataset: ServiceTokenClassificationDataset, + query: ServiceTokenClassificationQuery, id_from: Optional[str] = None, limit: int = 1000 - ) -> Iterable[TokenClassificationRecord]: + ) -> Iterable[ServiceTokenClassificationRecord]: """ Scan a dataset records @@ -181,8 +168,5 @@ def read_dataset( """ yield from self.__search__.scan_records( - dataset, query=query, record_type=TokenClassificationRecord, id_from=id_from, limit=limit + dataset, query=query, record_type=ServiceTokenClassificationRecord,id_from=id_from, limit=limit ) - - -token_classification_service = TokenClassificationService.get_instance diff --git a/src/rubrix/server/services/text_classification_labelling_rules.py b/src/rubrix/server/services/text_classification_labelling_rules.py deleted file mode 100644 index 1418def5bb..0000000000 --- a/src/rubrix/server/services/text_classification_labelling_rules.py +++ /dev/null @@ -1,271 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -from fastapi import Depends -from pydantic import BaseModel, Field - -from rubrix.server._helpers import unflatten_dict -from rubrix.server.apis.v0.models.commons.model import EsRecordDataFieldNames -from rubrix.server.apis.v0.models.metrics.base import ElasticsearchMetric -from rubrix.server.apis.v0.models.text_classification import ( - LabelingRule, - TextClassificationDatasetDB, -) -from rubrix.server.daos.datasets import DatasetsDAO -from rubrix.server.daos.models.records import RecordSearch -from rubrix.server.daos.records import DatasetRecordsDAO -from rubrix.server.elasticseach.query_helpers import filters -from rubrix.server.errors import EntityAlreadyExistsError, EntityNotFoundError - - -class DatasetLabelingRulesMetric(ElasticsearchMetric): - id: str = Field("dataset_labeling_rules", const=True) - name: str = Field( - "Computes overall metrics for defined rules in dataset", const=True - ) - - def aggregation_request(self, all_rules: List[LabelingRule]) -> Dict[str, Any]: - rules_filters = [filters.text_query(rule.query) for rule in all_rules] - return { - self.id: { - "filters": { - "filters": { - "covered_records": filters.boolean_filter( - should_filters=rules_filters, minimum_should_match=1 - ), - "annotated_covered_records": filters.boolean_filter( - filter_query=filters.exists_field( - EsRecordDataFieldNames.annotated_as - ), - should_filters=rules_filters, - minimum_should_match=1, - ), - } - } - } - } - - -class LabelingRulesMetric(ElasticsearchMetric): - id: str = Field("labeling_rule", const=True) - name: str = Field("Computes metrics for a labeling rule", const=True) - - def aggregation_request( - self, - rule_query: str, - labels: Optional[List[str]], - ) -> Dict[str, Any]: - - annotated_records_filter = filters.exists_field( - EsRecordDataFieldNames.annotated_as - ) - rule_query_filter = filters.text_query(rule_query) - aggr_filters = { - "covered_records": rule_query_filter, - "annotated_covered_records": filters.boolean_filter( - filter_query=annotated_records_filter, - should_filters=[rule_query_filter], - ), - } - - if labels is not None: - for label in labels: - rule_label_annotated_filter = filters.term_filter( - "annotated_as", value=label - ) - encoded_label = self._encode_label_name(label) - aggr_filters.update( - { - f"{encoded_label}.correct_records": filters.boolean_filter( - filter_query=annotated_records_filter, - should_filters=[ - rule_query_filter, - rule_label_annotated_filter, - ], - minimum_should_match=2, - ), - f"{encoded_label}.incorrect_records": filters.boolean_filter( - filter_query=annotated_records_filter, - must_query=rule_query_filter, - must_not_query=rule_label_annotated_filter, - ), - } - ) - - return {self.id: {"filters": {"filters": aggr_filters}}} - - @staticmethod - def _encode_label_name(label: str) -> str: - return label.replace(".", "@@@") - - @staticmethod - def _decode_label_name(label: str) -> str: - return label.replace("@@@", ".") - - def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: - if self.id in aggregation_result: - aggregation_result = aggregation_result[self.id] - - aggregation_result = unflatten_dict(aggregation_result) - results = { - "covered_records": aggregation_result.pop("covered_records"), - "annotated_covered_records": aggregation_result.pop( - "annotated_covered_records" - ), - } - - all_correct = [] - all_incorrect = [] - all_precision = [] - for label, metrics in aggregation_result.items(): - correct = metrics.get("correct_records", 0) - incorrect = metrics.get("incorrect_records", 0) - annotated = correct + incorrect - metrics["annotated"] = annotated - if annotated > 0: - precision = correct / annotated - metrics["precision"] = precision - all_precision.append(precision) - - all_correct.append(correct) - all_incorrect.append(incorrect) - results[self._decode_label_name(label)] = metrics - - results["correct_records"] = sum(all_correct) - results["incorrect_records"] = sum(all_incorrect) - if len(all_precision) > 0: - results["precision"] = sum(all_precision) / len(all_precision) - - return results - - -class DatasetLabelingRulesSummary(BaseModel): - covered_records: int - annotated_covered_records: int - - -class LabelingRuleSummary(BaseModel): - covered_records: int - annotated_covered_records: int - correct_records: int = Field(default=0) - incorrect_records: int = Field(default=0) - precision: Optional[float] = None - - -class LabelingService: - - _INSTANCE = None - - __rule_metrics__ = LabelingRulesMetric() - __dataset_rules_metrics__ = DatasetLabelingRulesMetric() - - @classmethod - def get_instance( - cls, - datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - records: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance), - ): - if cls._INSTANCE is None: - cls._INSTANCE = cls(datasets, records) - return cls._INSTANCE - - def __init__(self, datasets: DatasetsDAO, records: DatasetRecordsDAO): - self.__datasets__ = datasets - self.__records__ = records - - def list_rules(self, dataset: TextClassificationDatasetDB) -> List[LabelingRule]: - """List a set of rules for a given dataset""" - return dataset.rules - - def delete_rule(self, dataset: TextClassificationDatasetDB, rule_query: str): - """Delete a rule from a dataset by its defined query string""" - new_rules_set = [r for r in dataset.rules if r.query != rule_query] - if len(dataset.rules) != new_rules_set: - dataset.rules = new_rules_set - self.__datasets__.update_dataset(dataset) - - def add_rule( - self, dataset: TextClassificationDatasetDB, rule: LabelingRule - ) -> LabelingRule: - """Adds a rule to a dataset""" - for r in dataset.rules: - if r.query == rule.query: - raise EntityAlreadyExistsError(rule.query, type=LabelingRule) - dataset.rules.append(rule) - self.__datasets__.update_dataset(dataset) - return rule - - def compute_rule_metrics( - self, - dataset: TextClassificationDatasetDB, - rule_query: str, - labels: Optional[List[str]] = None, - ) -> Tuple[int, int, LabelingRuleSummary]: - """Computes metrics for given rule query and optional label against a set of rules""" - - annotated_records = self._count_annotated_records(dataset) - results = self.__records__.search_records( - dataset, - size=0, - search=RecordSearch( - aggregations=self.__rule_metrics__.aggregation_request( - rule_query=rule_query, labels=labels - ), - ), - ) - - rule_metrics_summary = self.__rule_metrics__.aggregation_result( - results.aggregations - ) - - metrics = LabelingRuleSummary.parse_obj(rule_metrics_summary) - return results.total, annotated_records, metrics - - def _count_annotated_records(self, dataset: TextClassificationDatasetDB) -> int: - results = self.__records__.search_records( - dataset, - size=0, - search=RecordSearch( - query=filters.exists_field(EsRecordDataFieldNames.annotated_as), - ), - ) - return results.total - - def all_rules_metrics( - self, dataset: TextClassificationDatasetDB - ) -> Tuple[int, int, DatasetLabelingRulesSummary]: - annotated_records = self._count_annotated_records(dataset) - results = self.__records__.search_records( - dataset, - size=0, - search=RecordSearch( - aggregations=self.__dataset_rules_metrics__.aggregation_request( - all_rules=dataset.rules - ), - ), - ) - - rule_metrics_summary = self.__dataset_rules_metrics__.aggregation_result( - results.aggregations - ) - - return ( - results.total, - annotated_records, - DatasetLabelingRulesSummary.parse_obj(rule_metrics_summary), - ) - - def find_rule_by_query( - self, dataset: TextClassificationDatasetDB, rule_query: str - ) -> LabelingRule: - rule_query = rule_query.strip() - for rule in dataset.rules: - if rule.query == rule_query: - return rule - raise EntityNotFoundError(rule_query, type=LabelingRule) - - def replace_rule(self, dataset: TextClassificationDatasetDB, rule: LabelingRule): - for idx, r in enumerate(dataset.rules): - if r.query == rule.query: - dataset.rules[idx] = rule - break - self.__datasets__.update_dataset(dataset) diff --git a/tests/client/sdk/conftest.py b/tests/client/sdk/conftest.py index a62e324337..8874094dcf 100644 --- a/tests/client/sdk/conftest.py +++ b/tests/client/sdk/conftest.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from datetime import datetime +from typing import Any, Dict, List import pytest @@ -32,6 +34,8 @@ TokenClassificationBulkData, ) +LOGGER = logging.getLogger(__name__) + class Helpers: def remove_key(self, schema: dict, key: str): @@ -50,6 +54,55 @@ def remove_description(self, schema: dict): def remove_pattern(self, schema: dict): return self.remove_key(schema, key="pattern") + def are_compatible_api_schemas(self, client_schema: dict, server_schema: dict): + def check_schema_props(client_props, server_props): + different_props = [] + for name, definition in client_props.items(): + if name not in server_props: + LOGGER.warning( + f"Client property {name} not found in server properties. " + "Make sure your API compatibility" + ) + different_props.append(name) + continue + elif definition != server_props[name]: + if not check_schema_props(definition, server_props[name]): + return False + return len(different_props) < len(client_props) / 2 + + client_props = self._expands_schema( + client_schema["properties"], client_schema["definitions"] + ) + server_props = self._expands_schema( + server_schema["properties"], server_schema["definitions"] + ) + + if client_props == server_props: + return True + return check_schema_props(client_props, server_props) + + def _expands_schema( + self, props: Dict[str, Any], definitions: List[Dict[str, Any]] + ) -> Dict[str, Any]: + new_schema = {} + for name, definition in props.items(): + if "$ref" in definition: + ref = definition["$ref"] + ref_def = definitions[ref.replace("#/definitions/", "")] + field_props = ref_def.get("properties", ref_def) + expanded_props = self._expands_schema(field_props, definitions) + new_schema[name] = expanded_props.get("properties", expanded_props) + elif "items" in definition and "$ref" in definition["items"]: + ref = definition["items"]["$ref"] + ref_def = definitions[ref.replace("#/definitions/", "")] + field_props = ref_def.get("properties", ref_def) + expanded_props = self._expands_schema(field_props, definitions) + definition["items"] = expanded_props.get("properties", expanded_props) + new_schema[name] = definition + else: + new_schema[name] = definition + return new_schema + @pytest.fixture(scope="session") def helpers(): diff --git a/tests/client/sdk/text2text/test_models.py b/tests/client/sdk/text2text/test_models.py index 236c025c98..b9a9fd4ac1 100644 --- a/tests/client/sdk/text2text/test_models.py +++ b/tests/client/sdk/text2text/test_models.py @@ -27,7 +27,7 @@ ) from rubrix.client.sdk.text2text.models import Text2TextRecord as SdkText2TextRecord from rubrix.server.apis.v0.models.text2text import ( - Text2TextBulkData as ServerText2TextBulkData, + Text2TextBulkRequest as ServerText2TextBulkData, ) from rubrix.server.apis.v0.models.text2text import ( Text2TextQuery as ServerText2TextQuery, @@ -37,19 +37,14 @@ def test_bulk_data_schema(helpers): client_schema = Text2TextBulkData.schema() server_schema = ServerText2TextBulkData.schema() - - assert helpers.remove_description(client_schema) == helpers.remove_description( - server_schema - ) + assert helpers.are_compatible_api_schemas(client_schema, server_schema) def test_query_schema(helpers): client_schema = Text2TextQuery.schema() server_schema = ServerText2TextQuery.schema() - assert helpers.remove_description(client_schema) == helpers.remove_description( - server_schema - ) + assert helpers.are_compatible_api_schemas(client_schema, server_schema) @pytest.mark.parametrize( diff --git a/tests/client/sdk/text_classification/test_models.py b/tests/client/sdk/text_classification/test_models.py index 16429a811e..d0d72260c1 100644 --- a/tests/client/sdk/text_classification/test_models.py +++ b/tests/client/sdk/text_classification/test_models.py @@ -37,7 +37,7 @@ LabelingRuleMetricsSummary as ServerLabelingRuleMetricsSummary, ) from rubrix.server.apis.v0.models.text_classification import ( - TextClassificationBulkData as ServerTextClassificationBulkData, + TextClassificationBulkRequest as ServerTextClassificationBulkData, ) from rubrix.server.apis.v0.models.text_classification import ( TextClassificationQuery as ServerTextClassificationQuery, @@ -48,18 +48,14 @@ def test_bulk_data_schema(helpers): client_schema = TextClassificationBulkData.schema() server_schema = ServerTextClassificationBulkData.schema() - assert helpers.remove_description(client_schema) == helpers.remove_description( - server_schema - ) + assert helpers.are_compatible_api_schemas(client_schema, server_schema) def test_query_schema(helpers): client_schema = TextClassificationQuery.schema() server_schema = ServerTextClassificationQuery.schema() - assert helpers.remove_description(client_schema) == helpers.remove_description( - server_schema - ) + assert helpers.are_compatible_api_schemas(client_schema, server_schema) def test_labeling_rule_schema(helpers): diff --git a/tests/client/sdk/token_classification/test_models.py b/tests/client/sdk/token_classification/test_models.py index 9dfe5e848c..6cd5e05e88 100644 --- a/tests/client/sdk/token_classification/test_models.py +++ b/tests/client/sdk/token_classification/test_models.py @@ -29,7 +29,7 @@ TokenClassificationRecord as SdkTokenClassificationRecord, ) from rubrix.server.apis.v0.models.token_classification import ( - TokenClassificationBulkData as ServerTokenClassificationBulkData, + TokenClassificationBulkRequest as ServerTokenClassificationBulkData, ) from rubrix.server.apis.v0.models.token_classification import ( TokenClassificationQuery as ServerTokenClassificationQuery, @@ -40,18 +40,14 @@ def test_bulk_data_schema(helpers): client_schema = TokenClassificationBulkData.schema() server_schema = ServerTokenClassificationBulkData.schema() - assert helpers.remove_description(client_schema) == helpers.remove_description( - server_schema - ) + assert helpers.are_compatible_api_schemas(client_schema, server_schema) def test_query_schema(helpers): client_schema = TokenClassificationQuery.schema() server_schema = ServerTokenClassificationQuery.schema() - assert helpers.remove_description(client_schema) == helpers.remove_description( - server_schema - ) + assert helpers.are_compatible_api_schemas(client_schema, server_schema) @pytest.mark.parametrize( diff --git a/tests/functional_tests/search/test_search_service.py b/tests/functional_tests/search/test_search_service.py index 7c3d3b8853..5d9872bf8a 100644 --- a/tests/functional_tests/search/test_search_service.py +++ b/tests/functional_tests/search/test_search_service.py @@ -1,69 +1,62 @@ import pytest import rubrix -from rubrix.server.apis.v0.models.commons.model import ScoreRange, TaskType +from rubrix.server.apis.v0.models.commons.model import ScoreRange from rubrix.server.apis.v0.models.datasets import Dataset -from rubrix.server.apis.v0.models.metrics.base import BaseMetric from rubrix.server.apis.v0.models.text_classification import ( TextClassificationQuery, TextClassificationRecord, ) from rubrix.server.apis.v0.models.token_classification import TokenClassificationQuery +from rubrix.server.commons.models import TaskType +from rubrix.server.daos.backend.elasticsearch import ElasticsearchBackend from rubrix.server.daos.records import DatasetRecordsDAO -from rubrix.server.elasticseach.client_wrapper import ElasticsearchWrapper -from rubrix.server.services.metrics import MetricsService -from rubrix.server.services.search.model import SortConfig -from rubrix.server.services.search.query_builder import EsQueryBuilder +from rubrix.server.services.metrics import MetricsService, ServicePythonMetric +from rubrix.server.services.search.model import ServiceSortConfig from rubrix.server.services.search.service import SearchRecordsService @pytest.fixture -def es_wrapper(): - return ElasticsearchWrapper.get_instance() +def backend(): + return ElasticsearchBackend.get_instance() @pytest.fixture -def dao(es_wrapper: ElasticsearchWrapper): - return DatasetRecordsDAO.get_instance(es=es_wrapper) +def dao(backend: ElasticsearchBackend): + return DatasetRecordsDAO.get_instance(es=backend) @pytest.fixture -def query_builder(dao: DatasetRecordsDAO): - return EsQueryBuilder.get_instance(dao=dao) +def metrics(dao: DatasetRecordsDAO): + return MetricsService.get_instance(dao=dao) @pytest.fixture -def metrics(dao: DatasetRecordsDAO, query_builder: EsQueryBuilder): - return MetricsService.get_instance(dao=dao, query_builder=query_builder) +def service(dao: DatasetRecordsDAO, metrics: MetricsService): + return SearchRecordsService.get_instance(dao=dao, metrics=metrics) -@pytest.fixture -def service( - dao: DatasetRecordsDAO, metrics: MetricsService, query_builder: EsQueryBuilder -): - return SearchRecordsService.get_instance( - dao=dao, metrics=metrics, query_builder=query_builder - ) - - -def test_query_builder_with_query_range(query_builder): - es_query = query_builder( - "ds", query=TextClassificationQuery(score=ScoreRange(range_from=10)) +def test_query_builder_with_query_range(backend: ElasticsearchBackend): + es_query = backend.query_builder.map_2_es_query( + schema=None, + query=TextClassificationQuery(score=ScoreRange(range_from=10)), ) assert es_query == { - "bool": { - "filter": { - "bool": { - "minimum_should_match": 1, - "should": [{"range": {"score": {"gte": 10.0}}}], - } - }, - "must": {"match_all": {}}, + "query": { + "bool": { + "filter": { + "bool": { + "minimum_should_match": 1, + "should": [{"range": {"score": {"gte": 10.0}}}], + } + }, + "must": {"match_all": {}}, + } } } -def test_query_builder_with_nested(query_builder, mocked_client): +def test_query_builder_with_nested(mocked_client, dao, backend: ElasticsearchBackend): dataset = Dataset( name="test_query_builder_with_nested", owner=rubrix.get_workspace(), @@ -79,8 +72,8 @@ def test_query_builder_with_nested(query_builder, mocked_client): ), ) - es_query = query_builder( - dataset=dataset, + es_query = backend.query_builder.map_2_es_query( + schema=dao.get_dataset_schema(dataset), query=TokenClassificationQuery( advanced_query_dsl=True, query_text="metrics.predicted.mentions:(label:NAME AND score:[* TO 0.1])", @@ -88,33 +81,35 @@ def test_query_builder_with_nested(query_builder, mocked_client): ) assert es_query == { - "bool": { - "filter": {"bool": {"must": {"match_all": {}}}}, - "must": { - "nested": { - "path": "metrics.predicted.mentions", - "query": { - "bool": { - "must": [ - { - "term": { - "metrics.predicted.mentions.label": { - "value": "NAME" + "query": { + "bool": { + "filter": {"bool": {"must": {"match_all": {}}}}, + "must": { + "nested": { + "path": "metrics.predicted.mentions", + "query": { + "bool": { + "must": [ + { + "term": { + "metrics.predicted.mentions.label": { + "value": "NAME" + } } - } - }, - { - "range": { - "metrics.predicted.mentions.score": { - "lte": "0.1" + }, + { + "range": { + "metrics.predicted.mentions.score": { + "lte": "0.1" + } } - } - }, - ] - } - }, - } - }, + }, + ] + } + }, + } + }, + } } } @@ -135,8 +130,8 @@ def test_failing_metrics(service, mocked_client): results = service.search( dataset=dataset, query=TextClassificationQuery(), - sort_config=SortConfig(), - metrics=[BaseMetric(id="missing-metric", name="Missing metric")], + sort_config=ServiceSortConfig(), + metrics=[ServicePythonMetric(id="missing-metric", name="Missing metric")], size=0, record_type=TextClassificationRecord, ) diff --git a/tests/labeling/text_classification/test_rule.py b/tests/labeling/text_classification/test_rule.py index da933b93bd..9c8c86c4cf 100644 --- a/tests/labeling/text_classification/test_rule.py +++ b/tests/labeling/text_classification/test_rule.py @@ -237,7 +237,7 @@ def test_rule_metrics_without_annotated( ) metrics = rule.metrics(log_dataset_without_annotations) - assert metrics == expected_metrics + assert expected_metrics == metrics def delete_rule_silently(client, dataset: str, rule: Rule): diff --git a/tests/metrics/test_text_classification.py b/tests/metrics/test_text_classification.py index 9ad3b91cee..f523707410 100644 --- a/tests/metrics/test_text_classification.py +++ b/tests/metrics/test_text_classification.py @@ -38,8 +38,7 @@ def test_metrics_for_text_classification(mocked_client): ) results = f1(dataset) - assert results - assert results.data == { + assert results and results.data == { "f1_macro": 1.0, "f1_micro": 1.0, "ham_f1": 1.0, @@ -58,8 +57,7 @@ def test_metrics_for_text_classification(mocked_client): results.visualize() results = f1_multilabel(dataset) - assert results - assert results.data == { + assert results and results.data == { "f1_macro": 1.0, "f1_micro": 1.0, "ham_f1": 1.0, diff --git a/tests/server/backend/__init__.py b/tests/server/backend/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/server/backend/test_query_builder.py b/tests/server/backend/test_query_builder.py new file mode 100644 index 0000000000..cfbe0f1a0c --- /dev/null +++ b/tests/server/backend/test_query_builder.py @@ -0,0 +1,66 @@ +import pytest + +from rubrix.server.daos.backend.search.model import SortableField, SortConfig, SortOrder +from rubrix.server.daos.backend.search.query_builder import EsQueryBuilder + + +@pytest.mark.parametrize( + ["index_schema", "sort_cfg", "expected_sort"], + [ + ( + { + "mappings": { + "properties": { + "id": {"type": "text"}, + } + } + }, + [SortableField(id="id")], + [{"id.keyword": {"order": SortOrder.asc}}], + ), + ( + { + "mappings": { + "properties": { + "id": {"type": "keyword"}, + } + } + }, + [SortableField(id="id")], + [{"id": {"order": SortOrder.asc}}], + ), + ( + { + "mappings": { + "properties": { + "id": {"type": "keyword"}, + } + } + }, + [SortableField(id="metadata.black", order=SortOrder.desc)], + [{"metadata.black": {"order": SortOrder.desc}}], + ), + ], +) +def test_build_sort_configuration(index_schema, sort_cfg, expected_sort): + + builder = EsQueryBuilder() + + es_sort = builder.map_2_es_sort_configuration( + sort=SortConfig(sort_by=sort_cfg), schema=index_schema + ) + assert es_sort == expected_sort + + +def test_build_sort_with_wrong_field_name(): + builder = EsQueryBuilder() + + with pytest.raises(Exception): + builder.map_2_es_sort_configuration( + sort=SortConfig(sort_by=[SortableField(id="wat?!")]) + ) + + +def test_build_sort_without_sort_config(): + builder = EsQueryBuilder() + assert builder.map_2_es_sort_configuration() is None diff --git a/tests/server/commons/test_records_dao.py b/tests/server/commons/test_records_dao.py index db8dc4b104..92de61cb8f 100644 --- a/tests/server/commons/test_records_dao.py +++ b/tests/server/commons/test_records_dao.py @@ -1,15 +1,17 @@ import pytest -from rubrix.server.apis.v0.models.commons.model import TaskType -from rubrix.server.apis.v0.models.datasets import DatasetDB +from rubrix.server.commons.models import TaskType +from rubrix.server.daos.backend.elasticsearch import ElasticsearchBackend +from rubrix.server.daos.models.datasets import BaseDatasetDB from rubrix.server.daos.records import DatasetRecordsDAO -from rubrix.server.elasticseach.client_wrapper import ElasticsearchWrapper from rubrix.server.errors import MissingDatasetRecordsError def test_raise_proper_error(): - dao = DatasetRecordsDAO.get_instance(ElasticsearchWrapper.get_instance()) + dao = DatasetRecordsDAO.get_instance(ElasticsearchBackend.get_instance()) with pytest.raises(MissingDatasetRecordsError): dao.search_records( - dataset=DatasetDB(name="mock-notfound", task=TaskType.text_classification) + dataset=BaseDatasetDB( + name="mock-notfound", task=TaskType.text_classification + ) ) diff --git a/tests/server/datasets/test_api.py b/tests/server/datasets/test_api.py index 37a04cdbd8..fd99debe8b 100644 --- a/tests/server/datasets/test_api.py +++ b/tests/server/datasets/test_api.py @@ -14,9 +14,11 @@ # limitations under the License. from typing import Optional -from rubrix.server.apis.v0.models.commons.model import TaskType from rubrix.server.apis.v0.models.datasets import Dataset -from rubrix.server.apis.v0.models.text_classification import TextClassificationBulkData +from rubrix.server.apis.v0.models.text_classification import ( + TextClassificationBulkRequest, +) +from rubrix.server.commons.models import TaskType from tests.helpers import SecuredClient @@ -31,7 +33,7 @@ def test_delete_dataset(mocked_client): assert response.json() == { "detail": { "code": "rubrix.api.errors::EntityNotFoundError", - "params": {"name": "test_delete_dataset", "type": "Dataset"}, + "params": {"name": "test_delete_dataset", "type": "ServiceDataset"}, } } @@ -108,7 +110,7 @@ def test_fetch_dataset_using_workspaces(mocked_client: SecuredClient): def test_dataset_naming_validation(mocked_client): - request = TextClassificationBulkData(records=[]) + request = TextClassificationBulkRequest(records=[]) dataset = "Wrong dataset name" response = mocked_client.post( @@ -128,7 +130,7 @@ def test_dataset_naming_validation(mocked_client): "type": "value_error.str.regex", } ], - "model": "TextClassificationDatasetDB", + "model": "TextClassificationDataset", }, } } @@ -150,7 +152,7 @@ def test_dataset_naming_validation(mocked_client): "type": "value_error.str.regex", } ], - "model": "TokenClassificationDatasetDB", + "model": "TokenClassificationDataset", }, } } @@ -221,7 +223,7 @@ def delete_dataset(client, dataset, workspace: Optional[str] = None): def create_mock_dataset(client, dataset): client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - json=TextClassificationBulkData( + json=TextClassificationBulkRequest( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=[], diff --git a/tests/server/datasets/test_dao.py b/tests/server/datasets/test_dao.py index 19c200e13f..5364eae1bd 100644 --- a/tests/server/datasets/test_dao.py +++ b/tests/server/datasets/test_dao.py @@ -15,27 +15,22 @@ import pytest -from rubrix.server.apis.v0.config.tasks_factory import TaskFactory -from rubrix.server.apis.v0.models.commons.model import TaskType -from rubrix.server.apis.v0.models.datasets import DatasetDB +from rubrix.server.commons.models import TaskType +from rubrix.server.daos.backend.elasticsearch import ElasticsearchBackend from rubrix.server.daos.datasets import DatasetsDAO -from rubrix.server.daos.records import dataset_records_dao -from rubrix.server.elasticseach.client_wrapper import create_es_wrapper -from rubrix.server.elasticseach.mappings.text_classification import ( - text_classification_mappings, -) +from rubrix.server.daos.models.datasets import BaseDatasetDB +from rubrix.server.daos.records import DatasetRecordsDAO from rubrix.server.errors import ClosedDatasetError -es_wrapper = create_es_wrapper() -records = dataset_records_dao(es_wrapper) +es_wrapper = ElasticsearchBackend.get_instance() +records = DatasetRecordsDAO.get_instance(es_wrapper) dao = DatasetsDAO.get_instance(es_wrapper, records) def test_retrieve_ownered_dataset_for_no_owner_user(): dataset = "test_retrieve_owned_dataset_for_no_owner_user" created = dao.create_dataset( - DatasetDB(name=dataset, owner="other", task=TaskType.text_classification), - mappings=TaskFactory.get_task_mappings(TaskType.text_classification), + BaseDatasetDB(name=dataset, owner="other", task=TaskType.text_classification), ) assert dao.find_by_name(created.name, owner=created.owner) == created assert dao.find_by_name(created.name, owner=None) == created @@ -46,8 +41,7 @@ def test_close_dataset(): dataset = "test_close_dataset" created = dao.create_dataset( - DatasetDB(name=dataset, owner="other", task=TaskType.text_classification), - mappings=TaskFactory.get_task_mappings(TaskType.text_classification), + BaseDatasetDB(name=dataset, owner="other", task=TaskType.text_classification), ) dao.close(created) diff --git a/tests/server/datasets/test_model.py b/tests/server/datasets/test_model.py index 2e2234a60c..ffa558001e 100644 --- a/tests/server/datasets/test_model.py +++ b/tests/server/datasets/test_model.py @@ -16,7 +16,8 @@ import pytest from pydantic import ValidationError -from rubrix.server.apis.v0.models.datasets import CreationDatasetRequest +from rubrix.server.apis.v0.models.datasets import CreateDatasetRequest +from rubrix.server.commons.models import TaskType @pytest.mark.parametrize( @@ -24,7 +25,7 @@ ["fine", "fine33", "fine_33", "fine-3-3"], ) def test_dataset_naming_ok(name): - request = CreationDatasetRequest(name=name) + request = CreateDatasetRequest(name=name, task=TaskType.token_classification) assert request.name == name @@ -42,4 +43,4 @@ def test_dataset_naming_ok(name): ) def test_dataset_naming_ko(name): with pytest.raises(ValidationError, match="string does not match regex"): - CreationDatasetRequest(name=name) + CreateDatasetRequest(name=name, task=TaskType.token_classification) diff --git a/tests/server/info/test_api.py b/tests/server/info/test_api.py index da099c457f..f21c40baed 100644 --- a/tests/server/info/test_api.py +++ b/tests/server/info/test_api.py @@ -14,7 +14,7 @@ # limitations under the License. from rubrix import __version__ as rubrix_version -from rubrix.server.apis.v0.models.info import ApiInfo, ApiStatus +from rubrix.server.services.info import ApiInfo, ApiStatus def test_api_info(mocked_client): diff --git a/tests/server/metrics/test_api.py b/tests/server/metrics/test_api.py index 14637defb5..c944c16757 100644 --- a/tests/server/metrics/test_api.py +++ b/tests/server/metrics/test_api.py @@ -13,19 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rubrix.server.apis.v0.models.metrics.commons import CommonTasksMetrics -from rubrix.server.apis.v0.models.metrics.token_classification import ( - TokenClassificationMetrics, -) -from rubrix.server.apis.v0.models.text2text import Text2TextBulkData, Text2TextRecord +from rubrix.server.apis.v0.models.text2text import Text2TextBulkRequest, Text2TextRecord from rubrix.server.apis.v0.models.text_classification import ( - TextClassificationBulkData, + TextClassificationBulkRequest, TextClassificationRecord, ) from rubrix.server.apis.v0.models.token_classification import ( - TokenClassificationBulkData, + TokenClassificationBulkRequest, TokenClassificationRecord, ) +from rubrix.server.services.metrics.models import CommonTasksMetrics +from rubrix.server.services.tasks.token_classification.metrics import ( + TokenClassificationMetrics, +) COMMON_METRICS_LENGTH = len(CommonTasksMetrics.metrics) @@ -41,7 +41,7 @@ def test_wrong_dataset_metrics(mocked_client): {"text": text}, ] ] - request = Text2TextBulkData(records=records) + request = Text2TextBulkRequest(records=records) dataset = "test_wrong_dataset_metrics" assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 @@ -91,7 +91,7 @@ def test_dataset_for_text2text(mocked_client): {"text": text}, ] ] - request = Text2TextBulkData(records=records) + request = Text2TextBulkRequest(records=records) dataset = "test_dataset_for_text2text" assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 @@ -121,7 +121,7 @@ def test_dataset_for_token_classification(mocked_client): ] ] - request = TokenClassificationBulkData(records=records) + request = TokenClassificationBulkRequest(records=records) dataset = "test_dataset_for_token_classification" assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 @@ -146,7 +146,7 @@ def test_dataset_for_token_classification(mocked_client): json={}, ) - assert response.status_code == 200, response.json() + assert response.status_code == 200, f"{metric} :: {response.json()}" summary = response.json() if not ("predicted" in metric_id or "annotated" in metric_id): @@ -171,7 +171,7 @@ def test_dataset_metrics(mocked_client): }, ] ] - request = TextClassificationBulkData(records=records) + request = TextClassificationBulkRequest(records=records) dataset = "test_get_dataset_metrics" assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 @@ -199,7 +199,7 @@ def test_dataset_metrics(mocked_client): assert response.json() == { "detail": { "code": "rubrix.api.errors::EntityNotFoundError", - "params": {"name": "missing_metric", "type": "BaseMetric"}, + "params": {"name": "missing_metric", "type": "ServiceBaseMetric"}, } } @@ -242,7 +242,7 @@ def test_dataset_labels_for_text_classification(mocked_client): }, ] ] - request = TextClassificationBulkData(records=records) + request = TextClassificationBulkRequest(records=records) dataset = "test_dataset_labels_for_text_classification" assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 diff --git a/tests/server/test_api.py b/tests/server/test_api.py index 5cbd00c0b7..972e3f7134 100644 --- a/tests/server/test_api.py +++ b/tests/server/test_api.py @@ -15,12 +15,11 @@ import os -from rubrix.server.apis.v0.models.commons.model import TaskStatus from rubrix.server.apis.v0.models.text_classification import ( - TaskType, - TextClassificationBulkData, + TextClassificationBulkRequest, TextClassificationRecord, ) +from rubrix.server.commons.models import TaskStatus, TaskType def create_some_data_for_text_classification(client, name: str, n: int): @@ -65,7 +64,7 @@ def create_some_data_for_text_classification(client, name: str, n: int): ] client.post( f"/api/datasets/{name}/{TaskType.text_classification}:bulk", - json=TextClassificationBulkData( + json=TextClassificationBulkRequest( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=records, diff --git a/tests/server/text2text/test_api.py b/tests/server/text2text/test_api.py index a02b44966b..40667143fd 100644 --- a/tests/server/text2text/test_api.py +++ b/tests/server/text2text/test_api.py @@ -1,7 +1,7 @@ from rubrix.client.sdk.text2text.models import Text2TextBulkData from rubrix.server.apis.v0.models.commons.model import BulkResponse from rubrix.server.apis.v0.models.text2text import ( - CreationText2TextRecord, + Text2TextRecordInputs, Text2TextSearchResults, ) @@ -11,7 +11,7 @@ def test_search_records(mocked_client): assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 records = [ - CreationText2TextRecord.parse_obj(data) + Text2TextRecordInputs.parse_obj(data) for data in [ { "id": 0, diff --git a/tests/server/text2text/test_model.py b/tests/server/text2text/test_model.py index 4e9ee2f106..d8e0f56498 100644 --- a/tests/server/text2text/test_model.py +++ b/tests/server/text2text/test_model.py @@ -4,7 +4,7 @@ Text2TextQuery, Text2TextRecord, ) -from rubrix.server.services.search.query_builder import EsQueryBuilder +from rubrix.server.daos.backend.search.query_builder import EsQueryBuilder def test_sentences_sorted_by_score(): @@ -57,4 +57,4 @@ def test_model_dict(): def test_query_as_elasticsearch(): query = Text2TextQuery(ids=[1, 2, 3]) - assert EsQueryBuilder.to_es_query(query) == {"ids": {"values": query.ids}} + assert EsQueryBuilder._to_es_query(query) == {"ids": {"values": query.ids}} diff --git a/tests/server/text_classification/test_api.py b/tests/server/text_classification/test_api.py index dc3381bf90..1e8a12248b 100644 --- a/tests/server/text_classification/test_api.py +++ b/tests/server/text_classification/test_api.py @@ -15,16 +15,17 @@ from datetime import datetime -from rubrix.server.apis.v0.models.commons.model import BulkResponse, PredictionStatus +from rubrix.server.apis.v0.models.commons.model import BulkResponse from rubrix.server.apis.v0.models.datasets import Dataset from rubrix.server.apis.v0.models.text_classification import ( TextClassificationAnnotation, - TextClassificationBulkData, + TextClassificationBulkRequest, TextClassificationQuery, TextClassificationRecord, TextClassificationSearchRequest, TextClassificationSearchResults, ) +from rubrix.server.commons.models import PredictionStatus def test_create_records_for_text_classification_with_multi_label(mocked_client): @@ -70,7 +71,7 @@ def test_create_records_for_text_classification_with_multi_label(mocked_client): ] response = mocked_client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - json=TextClassificationBulkData( + json=TextClassificationBulkRequest( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=records, @@ -85,7 +86,7 @@ def test_create_records_for_text_classification_with_multi_label(mocked_client): response = mocked_client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - json=TextClassificationBulkData( + json=TextClassificationBulkRequest( tags={"new": "tag"}, metadata={"new": {"metadata": "value"}}, records=records, @@ -123,7 +124,7 @@ def test_create_records_for_text_classification(mocked_client): assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 tags = {"env": "test", "class": "text classification"} metadata = {"config": {"the": "config"}} - classification_bulk = TextClassificationBulkData( + classification_bulk = TextClassificationBulkRequest( tags=tags, metadata=metadata, records=[ @@ -197,7 +198,7 @@ def test_partial_record_update(mocked_client): } ) - bulk = TextClassificationBulkData( + bulk = TextClassificationBulkRequest( records=[record], ) @@ -268,7 +269,7 @@ def test_sort_by_last_updated(mocked_client): for i in range(0, 10): mocked_client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - json=TextClassificationBulkData( + json=TextClassificationBulkRequest( records=[ TextClassificationRecord( **{ @@ -294,7 +295,7 @@ def test_sort_by_id_as_default(mocked_client): assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 response = mocked_client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - json=TextClassificationBulkData( + json=TextClassificationBulkRequest( records=[ TextClassificationRecord( **{ @@ -335,7 +336,7 @@ def test_some_sort_by(mocked_client): assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 mocked_client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - json=TextClassificationBulkData( + json=TextClassificationBulkRequest( records=[ TextClassificationRecord( **{ @@ -366,10 +367,10 @@ def test_some_sort_by(mocked_client): "code": "rubrix.api.errors::BadRequestError", "params": { "message": "Wrong sort id wrong_field. Valid values " - "are: ['metadata', 'last_updated', 'score', " + "are: ['id', 'metadata', 'score', " "'predicted', 'predicted_as', " "'predicted_by', 'annotated_as', " - "'annotated_by', 'status', " + "'annotated_by', 'status', 'last_updated', " "'event_timestamp']" }, } @@ -407,7 +408,7 @@ def test_disable_aggregations_when_scroll(mocked_client): response = mocked_client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - json=TextClassificationBulkData( + json=TextClassificationBulkRequest( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=[ @@ -447,7 +448,7 @@ def test_include_event_timestamp(mocked_client): response = mocked_client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - data=TextClassificationBulkData( + data=TextClassificationBulkRequest( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=[ @@ -488,7 +489,7 @@ def test_words_cloud(mocked_client): response = mocked_client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - data=TextClassificationBulkData( + data=TextClassificationBulkRequest( records=[ TextClassificationRecord( **{ @@ -528,7 +529,7 @@ def test_metadata_with_point_in_field_name(mocked_client): response = mocked_client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - data=TextClassificationBulkData( + data=TextClassificationBulkRequest( records=[ TextClassificationRecord( **{ @@ -565,7 +566,7 @@ def test_wrong_text_query(mocked_client): mocked_client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - data=TextClassificationBulkData( + data=TextClassificationBulkRequest( records=[ TextClassificationRecord( **{ @@ -599,7 +600,7 @@ def test_search_using_text(mocked_client): mocked_client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - data=TextClassificationBulkData( + data=TextClassificationBulkRequest( records=[ TextClassificationRecord( **{ diff --git a/tests/server/text_classification/test_api_rules.py b/tests/server/text_classification/test_api_rules.py index 5a24a1770e..ce16623f15 100644 --- a/tests/server/text_classification/test_api_rules.py +++ b/tests/server/text_classification/test_api_rules.py @@ -4,7 +4,7 @@ CreateLabelingRule, LabelingRule, LabelingRuleMetricsSummary, - TextClassificationBulkData, + TextClassificationBulkRequest, TextClassificationRecord, ) @@ -31,7 +31,7 @@ def log_some_records( response = client.post( f"/api/datasets/{dataset}/TextClassification:bulk", - data=TextClassificationBulkData( + data=TextClassificationBulkRequest( records=[ TextClassificationRecord(**record), ], @@ -202,7 +202,7 @@ def test_duplicated_dataset_rules(mocked_client): assert response.json() == { "detail": { "code": "rubrix.api.errors::EntityAlreadyExistsError", - "params": {"name": "a query", "type": "LabelingRule"}, + "params": {"name": "a query", "type": "ServiceLabelingRule"}, } } diff --git a/tests/server/text_classification/test_api_settings.py b/tests/server/text_classification/test_api_settings.py index 585183f0d1..4dd170896d 100644 --- a/tests/server/text_classification/test_api_settings.py +++ b/tests/server/text_classification/test_api_settings.py @@ -1,5 +1,5 @@ import rubrix as rb -from rubrix.server.apis.v0.models.commons.model import TaskType +from rubrix.server.commons.models import TaskType def create_dataset(client, name: str): diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py index c41bf0fbf4..a4f4edec73 100644 --- a/tests/server/text_classification/test_model.py +++ b/tests/server/text_classification/test_model.py @@ -16,15 +16,17 @@ from pydantic import ValidationError from rubrix._constants import MAX_KEYWORD_LENGTH -from rubrix.server.apis.v0.models.commons.model import TaskStatus from rubrix.server.apis.v0.models.text_classification import ( - ClassPrediction, - PredictionStatus, TextClassificationAnnotation, TextClassificationQuery, TextClassificationRecord, ) -from rubrix.server.services.search.query_builder import EsQueryBuilder +from rubrix.server.commons.models import PredictionStatus, TaskStatus +from rubrix.server.daos.backend.search.query_builder import EsQueryBuilder +from rubrix.server.services.tasks.text_classification.model import ( + ClassPrediction, + ServiceTextClassificationRecord, +) def test_flatten_metadata(): @@ -34,7 +36,7 @@ def test_flatten_metadata(): "mail": {"subject": "The mail subject", "body": "This is a large text body"} }, } - record = TextClassificationRecord.parse_obj(data) + record = ServiceTextClassificationRecord.parse_obj(data) assert list(record.metadata.keys()) == ["mail.subject", "mail.body"] @@ -48,7 +50,7 @@ def test_metadata_with_object_list(): ] }, } - record = TextClassificationRecord.parse_obj(data) + record = ServiceTextClassificationRecord.parse_obj(data) assert list(record.metadata.keys()) == ["mails"] @@ -87,7 +89,7 @@ def test_single_label_with_multiple_annotation(): ValidationError, match="Single label record must include only one annotation label", ): - TextClassificationRecord.parse_obj( + ServiceTextClassificationRecord.parse_obj( { "inputs": {"text": "This is a text"}, "annotation": { @@ -100,7 +102,7 @@ def test_single_label_with_multiple_annotation(): def test_too_long_metadata(): - record = TextClassificationRecord.parse_obj( + record = ServiceTextClassificationRecord.parse_obj( { "inputs": {"text": "bogh"}, "metadata": {"too_long": "a" * 1000}, @@ -112,7 +114,7 @@ def test_too_long_metadata(): def test_too_long_label(): with pytest.raises(ValidationError, match="exceeds max length"): - TextClassificationRecord.parse_obj( + ServiceTextClassificationRecord.parse_obj( { "inputs": {"text": "bogh"}, "prediction": { @@ -137,26 +139,26 @@ def test_score_integrity(): } try: - TextClassificationRecord.parse_obj(data) + ServiceTextClassificationRecord.parse_obj(data) except ValidationError as e: assert "Wrong score distributions" in e.json() data["multi_label"] = True - record = TextClassificationRecord.parse_obj(data) + record = ServiceTextClassificationRecord.parse_obj(data) assert record is not None data["multi_label"] = False data["prediction"]["labels"] = [ {"class": "B", "score": 0.9}, ] - record = TextClassificationRecord.parse_obj(data) + record = ServiceTextClassificationRecord.parse_obj(data) assert record is not None data["prediction"]["labels"] = [ {"class": "B", "score": 0.10000000012}, {"class": "B", "score": 0.90000000002}, ] - record = TextClassificationRecord.parse_obj(data) + record = ServiceTextClassificationRecord.parse_obj(data) assert record is not None @@ -174,7 +176,7 @@ def test_prediction_ok_cases(): }, } - record = TextClassificationRecord(**data) + record = ServiceTextClassificationRecord(**data) assert record.predicted is None record.annotation = TextClassificationAnnotation( **{ @@ -215,7 +217,7 @@ def test_predicted_as_with_no_labels(): "inputs": {"text": "The input text"}, "prediction": {"agent": "test", "labels": []}, } - record = TextClassificationRecord(**data) + record = ServiceTextClassificationRecord(**data) assert record.predicted_as == [] @@ -224,12 +226,12 @@ def test_created_record_with_default_status(): "inputs": {"data": "My cool data"}, } - record = TextClassificationRecord.parse_obj(data) + record = ServiceTextClassificationRecord.parse_obj(data) assert record.status == TaskStatus.default def test_predicted_ok_for_multilabel_unordered(): - record = TextClassificationRecord( + record = ServiceTextClassificationRecord( inputs={"text": "The text"}, prediction=TextClassificationAnnotation( agent="test", @@ -264,7 +266,7 @@ def test_validate_without_labels_for_single_label(annotation): ValidationError, match="Annotation must include some label for validated records", ): - TextClassificationRecord( + ServiceTextClassificationRecord( inputs={"text": "The text"}, status=TaskStatus.validated, prediction=TextClassificationAnnotation( @@ -281,7 +283,7 @@ def test_query_with_uncovered_by_rules(): query = TextClassificationQuery(uncovered_by_rules=["query", "other*"]) - assert EsQueryBuilder.to_es_query(query) == { + assert EsQueryBuilder._to_es_query(query) == { "bool": { "must": {"match_all": {}}, "must_not": { @@ -322,12 +324,12 @@ def test_empty_labels_for_no_multilabel(): ValidationError, match="Single label record must include only one annotation label", ): - TextClassificationRecord( + ServiceTextClassificationRecord( inputs={"text": "The input text"}, annotation=TextClassificationAnnotation(agent="ann.", labels=[]), ) - record = TextClassificationRecord( + record = ServiceTextClassificationRecord( inputs={"text": "The input text"}, prediction=TextClassificationAnnotation(agent="ann.", labels=[]), annotation=TextClassificationAnnotation( @@ -338,7 +340,7 @@ def test_empty_labels_for_no_multilabel(): def test_annotated_without_labels_for_multilabel(): - record = TextClassificationRecord( + record = ServiceTextClassificationRecord( inputs={"text": "The input text"}, multi_label=True, prediction=TextClassificationAnnotation(agent="pred.", labels=[]), diff --git a/tests/server/token_classification/test_api.py b/tests/server/token_classification/test_api.py index c5637848d5..20b75b6108 100644 --- a/tests/server/token_classification/test_api.py +++ b/tests/server/token_classification/test_api.py @@ -18,7 +18,7 @@ from rubrix.server.apis.v0.models.commons.model import BulkResponse, SortableField from rubrix.server.apis.v0.models.token_classification import ( - TokenClassificationBulkData, + TokenClassificationBulkRequest, TokenClassificationQuery, TokenClassificationRecord, TokenClassificationSearchRequest, @@ -41,7 +41,7 @@ def test_load_as_different_task(mocked_client): ] mocked_client.post( f"/api/datasets/{dataset}/TokenClassification:bulk", - json=TokenClassificationBulkData( + json=TokenClassificationBulkRequest( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=records, @@ -80,7 +80,7 @@ def test_search_special_characters(mocked_client): ] mocked_client.post( f"/api/datasets/{dataset}/TokenClassification:bulk", - json=TokenClassificationBulkData( + json=TokenClassificationBulkRequest( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=records, @@ -123,7 +123,7 @@ def test_some_sort(mocked_client): ] mocked_client.post( f"/api/datasets/{dataset}/TokenClassification:bulk", - json=TokenClassificationBulkData( + json=TokenClassificationBulkRequest( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=records, @@ -143,11 +143,10 @@ def test_some_sort(mocked_client): "code": "rubrix.api.errors::BadRequestError", "params": { "message": "Wrong sort id babba. Valid values are: " - "['metadata', 'last_updated', 'score', " - "'predicted', 'predicted_as', " - "'predicted_by', 'annotated_as', " - "'annotated_by', 'status', " - "'event_timestamp']" + "['id', 'metadata', 'score', 'predicted', " + "'predicted_as', 'predicted_by', " + "'annotated_as', 'annotated_by', 'status', " + "'last_updated', 'event_timestamp']" }, } } @@ -185,7 +184,7 @@ def test_create_records_for_token_classification( response = mocked_client.post( f"/api/datasets/{dataset}/TokenClassification:bulk", - json=TokenClassificationBulkData( + json=TokenClassificationBulkRequest( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=records, @@ -267,7 +266,7 @@ def test_multiple_mentions_in_same_record(mocked_client): ] response = mocked_client.post( f"/api/datasets/{dataset}/TokenClassification:bulk", - json=TokenClassificationBulkData( + json=TokenClassificationBulkRequest( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=records, @@ -295,7 +294,7 @@ def test_show_not_aggregable_metadata_fields(mocked_client): response = mocked_client.post( f"/api/datasets/{dataset}/TokenClassification:bulk", - json=TokenClassificationBulkData( + json=TokenClassificationBulkRequest( records=[ TokenClassificationRecord.parse_obj( { diff --git a/tests/server/token_classification/test_api_settings.py b/tests/server/token_classification/test_api_settings.py index 9e05ada666..f517a45e76 100644 --- a/tests/server/token_classification/test_api_settings.py +++ b/tests/server/token_classification/test_api_settings.py @@ -1,5 +1,5 @@ import rubrix as rb -from rubrix.server.apis.v0.models.commons.model import TaskType +from rubrix.server.commons.models import TaskType def create_dataset(client, name: str): diff --git a/tests/server/token_classification/test_model.py b/tests/server/token_classification/test_model.py index f40fba6d7e..3e8309fd3e 100644 --- a/tests/server/token_classification/test_model.py +++ b/tests/server/token_classification/test_model.py @@ -18,14 +18,16 @@ from rubrix._constants import MAX_KEYWORD_LENGTH from rubrix.server.apis.v0.models.token_classification import ( - CreationTokenClassificationRecord, - EntitySpan, - PredictionStatus, TokenClassificationAnnotation, TokenClassificationQuery, TokenClassificationRecord, ) -from rubrix.server.services.search.query_builder import EsQueryBuilder +from rubrix.server.commons.models import PredictionStatus +from rubrix.server.daos.backend.search.query_builder import EsQueryBuilder +from rubrix.server.services.tasks.token_classification.model import ( + EntitySpan, + ServiceTokenClassificationRecord, +) def test_char_position(): @@ -38,7 +40,7 @@ def test_char_position(): EntitySpan(start=1, end=1, label="label") text = "I am Maxi" - TokenClassificationRecord( + ServiceTokenClassificationRecord( text=text, tokens=text.split(), prediction=TokenClassificationAnnotation( @@ -53,7 +55,7 @@ def test_char_position(): def test_fix_substrings(): text = "On one ones o no" - TokenClassificationRecord( + ServiceTokenClassificationRecord( text=text, tokens=text.split(), prediction=TokenClassificationAnnotation( @@ -68,7 +70,7 @@ def test_fix_substrings(): def test_entities_with_spaces(): text = "This is a great space" - TokenClassificationRecord( + ServiceTokenClassificationRecord( text=text, tokens=["This", "is", " ", "a", " ", "great", " ", "space"], prediction=TokenClassificationAnnotation( @@ -102,7 +104,6 @@ def test_model_dict(): "agent": "test", "entities": [{"end": 24, "label": "test", "score": 1.0, "start": 9}], }, - "raw_text": text, "text": text, "tokens": tokens, "status": "Default", @@ -111,7 +112,7 @@ def test_model_dict(): def test_too_long_metadata(): text = "On one ones o no" - record = TokenClassificationRecord.parse_obj( + record = ServiceTokenClassificationRecord.parse_obj( { "text": text, "tokens": text.split(), @@ -127,7 +128,7 @@ def test_entity_label_too_long(): with pytest.raises( ValidationError, match="ensure this value has at most 128 character" ): - TokenClassificationRecord( + ServiceTokenClassificationRecord( text=text, tokens=text.split(), prediction=TokenClassificationAnnotation( @@ -145,11 +146,11 @@ def test_entity_label_too_long(): def test_to_es_query(): query = TokenClassificationQuery(ids=[1, 2, 3]) - assert EsQueryBuilder.to_es_query(query) == {"ids": {"values": query.ids}} + assert EsQueryBuilder._to_es_query(query) == {"ids": {"values": query.ids}} def test_misaligned_entity_mentions_with_spaces_left(): - assert TokenClassificationRecord( + assert ServiceTokenClassificationRecord( text="according to analysts.\n Dart Group Corp was not", tokens=[ "according", @@ -172,7 +173,7 @@ def test_misaligned_entity_mentions_with_spaces_left(): def test_misaligned_entity_mentions_with_spaces_right(): - assert TokenClassificationRecord( + assert ServiceTokenClassificationRecord( text="\nvs 9.91 billion\n Note\n REUTER\n", tokens=["\n", "vs", "9.91", "billion", "\n ", "Note", "\n ", "REUTER", "\n"], annotation=TokenClassificationAnnotation( @@ -184,7 +185,7 @@ def test_misaligned_entity_mentions_with_spaces_right(): def test_custom_tokens_splitting(): - TokenClassificationRecord( + ServiceTokenClassificationRecord( text="ThisisMr.Bean, a character playedby actor RowanAtkinson", tokens=[ "This", @@ -211,7 +212,7 @@ def test_custom_tokens_splitting(): def test_record_scores(): - record = TokenClassificationRecord( + record = ServiceTokenClassificationRecord( text="\nvs 9.91 billion\n Note\n REUTER\n", tokens=["\n", "vs", "9.91", "billion", "\n ", "Note", "\n ", "REUTER", "\n"], prediction=TokenClassificationAnnotation( @@ -228,7 +229,7 @@ def test_record_scores(): def test_annotated_without_entities(): text = "The text that i wrote" - record = TokenClassificationRecord( + record = ServiceTokenClassificationRecord( text=text, tokens=text.split(), prediction=TokenClassificationAnnotation( @@ -245,8 +246,7 @@ def test_annotated_without_entities(): def test_adjust_spans(): text = "A text with some empty spaces that could bring not cleany annotated spans" - - record = TokenClassificationRecord( + record = ServiceTokenClassificationRecord( text=text, tokens=text.split(), prediction=TokenClassificationAnnotation( @@ -275,6 +275,7 @@ def test_adjust_spans(): EntitySpan(start=70, end=85, label="DET"), ] + def test_whitespace_in_tokens(): from spacy import load @@ -291,7 +292,6 @@ def test_whitespace_in_tokens(): }, } - record = CreationTokenClassificationRecord.parse_obj(record) + record = ServiceTokenClassificationRecord.parse_obj(record) assert record assert record.tokens == ["every", "four", "(", "4", ")", " "] -