diff --git a/pyproject.toml b/pyproject.toml
index da1f7409a8..3e57161904 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -65,7 +65,6 @@ server = [
"python-jose[cryptography]~=3.2.0",
"passlib[bcrypt]~=1.7.4",
# Info status
- "hurry.filesize", # TODO: remove
"psutil ~= 5.8.0",
]
listeners = [
diff --git a/src/rubrix/server/_helpers.py b/src/rubrix/server/_helpers.py
deleted file mode 100644
index c088a1566f..0000000000
--- a/src/rubrix/server/_helpers.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# coding=utf-8
-# Copyright 2021-present, the Recognai S.L. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Common helper functions
-"""
-from typing import Any, Dict, List, Optional
-
-
-def unflatten_dict(
- data: Dict[str, Any], sep: str = ".", stop_keys: Optional[List[str]] = None
-) -> Dict[str, Any]:
- """
- Given a flat dictionary keys, build a hierarchical version by grouping keys
-
- Parameters
- ----------
- data:
- The data dictionary
- sep:
- The key separator. Default "."
- stop_keys
- List of dictionary first level keys where hierarchy will stop
-
- Returns
- -------
-
- """
- resultDict = {}
- stop_keys = stop_keys or []
- for key, value in data.items():
- if key is not None:
- parts = key.split(sep)
- if parts[0] in stop_keys:
- parts = [parts[0], sep.join(parts[1:])]
- d = resultDict
- for part in parts[:-1]:
- if part not in d:
- d[part] = {}
- d = d[part]
- d[parts[-1]] = value
- return resultDict
diff --git a/src/rubrix/server/apis/v0/config/tasks_factory.py b/src/rubrix/server/apis/v0/config/tasks_factory.py
deleted file mode 100644
index 5ad982bf4f..0000000000
--- a/src/rubrix/server/apis/v0/config/tasks_factory.py
+++ /dev/null
@@ -1,155 +0,0 @@
-from typing import Any, Dict, List, Optional, Set, Type
-
-from pydantic import BaseModel
-
-from rubrix.server.apis.v0.models.commons.model import BaseRecord, TaskType
-from rubrix.server.apis.v0.models.datasets import DatasetDB
-from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics
-from rubrix.server.apis.v0.models.metrics.text_classification import (
- TextClassificationMetrics,
-)
-from rubrix.server.apis.v0.models.metrics.token_classification import (
- TokenClassificationMetrics,
-)
-from rubrix.server.apis.v0.models.text2text import (
- Text2TextDatasetDB,
- Text2TextMetrics,
- Text2TextQuery,
- Text2TextRecord,
-)
-from rubrix.server.apis.v0.models.text_classification import (
- TextClassificationDatasetDB,
- TextClassificationQuery,
- TextClassificationRecord,
-)
-from rubrix.server.apis.v0.models.token_classification import (
- TokenClassificationDatasetDB,
- TokenClassificationQuery,
- TokenClassificationRecord,
-)
-from rubrix.server.elasticseach.mappings.text2text import text2text_mappings
-from rubrix.server.elasticseach.mappings.text_classification import (
- text_classification_mappings,
-)
-from rubrix.server.elasticseach.mappings.token_classification import (
- token_classification_mappings,
-)
-from rubrix.server.errors import EntityNotFoundError, WrongTaskError
-
-
-class TaskConfig(BaseModel):
- task: TaskType
- query: Any
- dataset: Type[DatasetDB]
- record: Type[BaseRecord]
- metrics: Optional[Type[BaseTaskMetrics]]
- es_mappings: Dict[str, Any]
-
-
-class TaskFactory:
-
- _REGISTERED_TASKS = dict()
-
- @classmethod
- def register_task(
- cls,
- task_type: TaskType,
- dataset_class: Type[DatasetDB],
- query_request: Type[Any],
- es_mappings: Dict[str, Any],
- record_class: Type[BaseRecord],
- metrics: Optional[Type[BaseTaskMetrics]] = None,
- ):
- cls._REGISTERED_TASKS[task_type] = TaskConfig(
- task=task_type,
- dataset=dataset_class,
- es_mappings=es_mappings,
- query=query_request,
- record=record_class,
- metrics=metrics,
- )
-
- @classmethod
- def get_all_configs(cls) -> List[TaskConfig]:
- return [cfg for cfg in cls._REGISTERED_TASKS.values()]
-
- @classmethod
- def get_task_by_task_type(cls, task_type: TaskType) -> Optional[TaskConfig]:
- return cls._REGISTERED_TASKS.get(task_type)
-
- @classmethod
- def get_task_metrics(cls, task: TaskType) -> Optional[Type[BaseTaskMetrics]]:
- config = cls.get_task_by_task_type(task)
- if config:
- return config.metrics
-
- @classmethod
- def get_task_dataset(cls, task: TaskType) -> Type[DatasetDB]:
- config = cls.__get_task_config__(task)
- return config.dataset
-
- @classmethod
- def get_task_record(cls, task: TaskType) -> Type[BaseRecord]:
- config = cls.__get_task_config__(task)
- return config.record
-
- @classmethod
- def get_task_mappings(cls, task: TaskType) -> Dict[str, Any]:
- config = cls.__get_task_config__(task)
- return config.es_mappings
-
- @classmethod
- def __get_task_config__(cls, task):
- config = cls.get_task_by_task_type(task)
- if not config:
- raise WrongTaskError(f"No configuration found for task {task}")
- return config
-
- @classmethod
- def find_task_metric(cls, task: TaskType, metric_id: str) -> Optional[BaseMetric]:
- metrics = cls.find_task_metrics(task, {metric_id})
- if metrics:
- return metrics[0]
- raise EntityNotFoundError(name=metric_id, type=BaseMetric)
-
- @classmethod
- def find_task_metrics(
- cls, task: TaskType, metric_ids: Set[str]
- ) -> List[BaseMetric]:
-
- if not metric_ids:
- return []
-
- metrics = []
- for metric in cls.get_task_metrics(task).metrics:
- if metric.id in metric_ids:
- metrics.append(metric)
- return metrics
-
-
-TaskFactory.register_task(
- task_type=TaskType.token_classification,
- dataset_class=TokenClassificationDatasetDB,
- query_request=TokenClassificationQuery,
- record_class=TokenClassificationRecord,
- metrics=TokenClassificationMetrics,
- es_mappings=token_classification_mappings(),
-)
-
-TaskFactory.register_task(
- task_type=TaskType.text_classification,
- dataset_class=TextClassificationDatasetDB,
- query_request=TextClassificationQuery,
- record_class=TextClassificationRecord,
- metrics=TextClassificationMetrics,
- es_mappings=text_classification_mappings(),
-)
-
-TaskFactory.register_task(
- task_type=TaskType.text2text,
- dataset_class=Text2TextDatasetDB,
- query_request=Text2TextQuery,
- record_class=Text2TextRecord,
- metrics=Text2TextMetrics,
- es_mappings=text2text_mappings(),
-)
diff --git a/src/rubrix/server/apis/v0/handlers/datasets.py b/src/rubrix/server/apis/v0/handlers/datasets.py
index 753dc9775b..301a3a3cdb 100644
--- a/src/rubrix/server/apis/v0/handlers/datasets.py
+++ b/src/rubrix/server/apis/v0/handlers/datasets.py
@@ -17,14 +17,14 @@
from fastapi import APIRouter, Body, Depends, Security
-from rubrix.server.apis.v0.config.tasks_factory import TaskFactory
-from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams
+from rubrix.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies
from rubrix.server.apis.v0.models.datasets import (
CopyDatasetRequest,
+ CreateDatasetRequest,
Dataset,
- DatasetCreate,
UpdateDatasetRequest,
)
+from rubrix.server.commons.config import TasksFactory
from rubrix.server.errors import EntityNotFoundError
from rubrix.server.security import auth
from rubrix.server.security.model import User
@@ -39,30 +39,16 @@
response_model_exclude_none=True,
operation_id="list_datasets",
)
-def list_datasets(
- ds_params: CommonTaskQueryParams = Depends(),
+async def list_datasets(
+ request_deps: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> List[Dataset]:
- """
- List accessible user datasets
-
- Parameters
- ----------
- ds_params:
- Common task query params
- service:
- The datasets service
- current_user:
- The request user
-
- Returns
- -------
- A list of datasets visible by current user
- """
return service.list(
user=current_user,
- workspaces=[ds_params.workspace] if ds_params.workspace is not None else None,
+ workspaces=[request_deps.workspace]
+ if request_deps.workspace is not None
+ else None,
)
@@ -75,24 +61,19 @@ def list_datasets(
description="Create a new dataset",
)
async def create_dataset(
- request: DatasetCreate = Body(..., description=f"The request dataset info"),
- ws_params: CommonTaskQueryParams = Depends(),
+ request: CreateDatasetRequest = Body(..., description=f"The request dataset info"),
+ ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["create:datasets"]),
) -> Dataset:
owner = user.check_workspace(ws_params.workspace)
- dataset_class = TaskFactory.get_task_dataset(request.task)
- task_mappings = TaskFactory.get_task_mappings(request.task)
-
+ dataset_class = TasksFactory.get_task_dataset(request.task)
dataset = dataset_class.parse_obj({**request.dict()})
dataset.owner = owner
- response = datasets.create_dataset(
- user=user, dataset=dataset, mappings=task_mappings
- )
-
+ response = datasets.create_dataset(user=user, dataset=dataset)
return Dataset.parse_obj(response)
@@ -104,34 +85,15 @@ async def create_dataset(
)
def get_dataset(
name: str,
- ds_params: CommonTaskQueryParams = Depends(),
+ ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> Dataset:
- """
- Find a dataset by name visible for current user
-
- Parameters
- ----------
- name:
- The dataset name
- ds_params:
- Common dataset query params
- service:
- Datasets service
- current_user:
- The current user
-
- Returns
- -------
- - The found dataset if accessible or exists.
- - EntityNotFoundError if not found.
- - NotAuthorizedError if user cannot access the found dataset
-
- """
return Dataset.parse_obj(
service.find_by_name(
- user=current_user, name=name, workspace=ds_params.workspace
+ user=current_user,
+ name=name,
+ workspace=ds_params.workspace,
)
)
@@ -144,35 +106,11 @@ def get_dataset(
)
def update_dataset(
name: str,
- update_request: UpdateDatasetRequest,
- ds_params: CommonTaskQueryParams = Depends(),
+ request: UpdateDatasetRequest,
+ ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> Dataset:
- """
- Update a set of parameters for a dataset
-
- Parameters
- ----------
- name:
- The dataset name
- update_request:
- The fields to update
- ds_params:
- Common dataset query params
- service:
- The datasets service
- current_user:
- The current user
-
- Returns
- -------
-
- - The updated dataset if exists and user has access.
- - EntityNotFoundError if not found.
- - NotAuthorizedError if user cannot access the found dataset
-
- """
found_ds = service.find_by_name(
user=current_user, name=name, workspace=ds_params.workspace
@@ -181,8 +119,8 @@ def update_dataset(
return service.update(
user=current_user,
dataset=found_ds,
- tags=update_request.tags,
- metadata=update_request.metadata,
+ tags=request.tags,
+ metadata=request.metadata,
)
@@ -192,25 +130,10 @@ def update_dataset(
)
def delete_dataset(
name: str,
- ds_params: CommonTaskQueryParams = Depends(),
+ ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
):
- """
- Deletes a dataset
-
- Parameters
- ----------
- name:
- The dataset name
- ds_params:
- Common dataset query params
- service:
- The datasets service
- current_user:
- The current user
-
- """
try:
found_ds = service.find_by_name(
user=current_user, name=name, workspace=ds_params.workspace
@@ -226,25 +149,10 @@ def delete_dataset(
)
def close_dataset(
name: str,
- ds_params: CommonTaskQueryParams = Depends(),
+ ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
):
- """
- Closes a dataset. This operation will releases backend resources
-
- Parameters
- ----------
- name:
- The dataset name
- ds_params:
- Common dataset query params
- service:
- The datasets service
- current_user:
- The current user
-
- """
found_ds = service.find_by_name(
user=current_user, name=name, workspace=ds_params.workspace
)
@@ -257,25 +165,10 @@ def close_dataset(
)
def open_dataset(
name: str,
- ds_params: CommonTaskQueryParams = Depends(),
+ ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
):
- """
- Closes a dataset. This operation will releases backend resources
-
- Parameters
- ----------
- name:
- The dataset name
- ds_params:
- Common dataset query params
- service:
- The datasets service
- current_user:
- The current user
-
- """
found_ds = service.find_by_name(
user=current_user, name=name, workspace=ds_params.workspace
)
@@ -291,27 +184,10 @@ def open_dataset(
def copy_dataset(
name: str,
copy_request: CopyDatasetRequest,
- ds_params: CommonTaskQueryParams = Depends(),
+ ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> Dataset:
- """
- Creates a dataset copy and its tags/metadata info
-
- Parameters
- ----------
- name:
- The dataset name
- copy_request:
- The copy request data
- ds_params:
- Common dataset query params
- service:
- The datasets service
- current_user:
- The current user
-
- """
found = service.find_by_name(
user=current_user, name=name, workspace=ds_params.workspace
)
diff --git a/src/rubrix/server/apis/v0/handlers/info.py b/src/rubrix/server/apis/v0/handlers/info.py
index 30fcfd8c28..69a7ddde75 100644
--- a/src/rubrix/server/apis/v0/handlers/info.py
+++ b/src/rubrix/server/apis/v0/handlers/info.py
@@ -13,11 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from fastapi import APIRouter, Depends, Security
+from fastapi import APIRouter, Depends
-from rubrix.server.apis.v0.models.info import ApiInfo, ApiStatus
-from rubrix.server.security import auth
-from rubrix.server.services.info import ApiInfoService, create_info_service
+from rubrix.server.services.info import ApiInfo, ApiInfoService, ApiStatus
router = APIRouter(tags=["status"])
@@ -26,29 +24,13 @@
"/_status",
operation_id="api_status",
response_model=ApiStatus,
- dependencies=[Security(auth.get_user, scopes=[])],
)
def api_status(
- service: ApiInfoService = Depends(create_info_service),
+ service: ApiInfoService = Depends(ApiInfoService.get_instance),
) -> ApiStatus:
- """
-
- Parameters
- ----------
- service:
- The Api info service
-
- Returns
- -------
-
- The detailed api status
-
- """
return service.api_status()
@router.get("/_info", operation_id="api_info", response_model=ApiInfo)
-def api_info(
- service: ApiInfoService = Depends(create_info_service),
-) -> ApiInfo:
+def api_info(service: ApiInfoService = Depends(ApiInfoService.get_instance)) -> ApiInfo:
return ApiInfo.parse_obj(service.api_status())
diff --git a/src/rubrix/server/apis/v0/handlers/metrics.py b/src/rubrix/server/apis/v0/handlers/metrics.py
index 8f4040add8..3de73bafdc 100644
--- a/src/rubrix/server/apis/v0/handlers/metrics.py
+++ b/src/rubrix/server/apis/v0/handlers/metrics.py
@@ -19,13 +19,13 @@
from fastapi import APIRouter, Depends, Query, Security
from pydantic import BaseModel, Field
-from rubrix.server.apis.v0.config.tasks_factory import TaskConfig, TaskFactory
from rubrix.server.apis.v0.handlers import (
text2text,
text_classification,
token_classification,
)
-from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams
+from rubrix.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies
+from rubrix.server.commons.config import TaskConfig, TasksFactory
from rubrix.server.security import auth
from rubrix.server.security.model import User
from rubrix.server.services.datasets import DatasetsService
@@ -33,7 +33,6 @@
class MetricInfo(BaseModel):
- """Metric info data model for retrieve dataset metrics information"""
id: str = Field(description="The metric id")
name: str = Field(description="The metric name")
@@ -44,19 +43,6 @@ class MetricInfo(BaseModel):
@dataclass
class MetricSummaryParams:
- """
- For metrics summary calculation, common summary parameters.
-
- Attributes:
- -----------
-
- interval:
- For histogram summaries, the bucket interval
-
- size:
- For terminological metrics, the number of terms to retrieve
-
- """
interval: Optional[float] = Query(
default=None,
@@ -71,17 +57,8 @@ class MetricSummaryParams:
def configure_metrics_endpoints(router: APIRouter, cfg: TaskConfig):
- """
- Configures an api router with the dataset task metrics endpoints.
- Parameters
- ----------
- router:
- The api router
- cfg:
- The task configuration model
-
- """
+ # TODO(@frascuchon): Use new api endpoint (/datasets/{name}/{task}/...
base_metrics_endpoint = f"/{cfg.task}/{{name}}/metrics"
@router.get(
@@ -91,41 +68,19 @@ def configure_metrics_endpoints(router: APIRouter, cfg: TaskConfig):
)
def get_dataset_metrics(
name: str,
- teams_query: CommonTaskQueryParams = Depends(),
+ request_deps: CommonTaskHandlerDependencies = Depends(),
current_user: User = Security(auth.get_user, scopes=[]),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
- metrics: MetricsService = Depends(MetricsService.get_instance),
) -> List[MetricInfo]:
- """
- List available metrics info for a given dataset
-
- Parameters
- ----------
- name:
- The dataset name
- teams_query:
- Team query param where dataset belongs to. Optional
- current_user:
- The current user
- datasets:
- The datasets service
- metrics:
- The metrics service
-
- Returns
- -------
- A list of metric info availables for given dataset
-
- """
dataset = datasets.find_by_name(
user=current_user,
name=name,
task=cfg.task,
- workspace=teams_query.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(cfg.task),
+ workspace=request_deps.workspace,
+ as_dataset_class=TasksFactory.get_task_dataset(cfg.task),
)
- metrics = TaskFactory.get_task_metrics(dataset.task)
+ metrics = TasksFactory.get_task_metrics(dataset.task)
metrics = metrics.metrics if metrics else []
return [MetricInfo.parse_obj(metric) for metric in metrics]
@@ -140,52 +95,25 @@ def metric_summary(
metric: str,
query: cfg.query,
metric_params: MetricSummaryParams = Depends(),
- teams_query: CommonTaskQueryParams = Depends(),
+ request_deps: CommonTaskHandlerDependencies = Depends(),
current_user: User = Security(auth.get_user, scopes=[]),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
metrics: MetricsService = Depends(MetricsService.get_instance),
):
- """
- Summarizes a given metric for a given dataset.
-
- Parameters
- ----------
- name:
- The dataset name
- metric:
- The metric id
- query:
- A query for records filtering. Optional
- metric_params:
- Metric parameters for result calculation
- teams_query:
- Team query param where dataset belongs to. Optional
- current_user:
- The current user
- datasets:
- The datasets service
- metrics:
- The metrics service
-
- Returns
- -------
- The metric summary for a given dataset
-
- """
dataset = datasets.find_by_name(
user=current_user,
name=name,
task=cfg.task,
- workspace=teams_query.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(cfg.task),
+ workspace=request_deps.workspace,
+ as_dataset_class=TasksFactory.get_task_dataset(cfg.task),
)
- metric_ = TaskFactory.find_task_metric(task=cfg.task, metric_id=metric)
- record_class = TaskFactory.get_task_record(cfg.task)
+ metric_ = TasksFactory.find_task_metric(task=cfg.task, metric_id=metric)
+ record_class = TasksFactory.get_task_record(cfg.task)
return metrics.summarize_metric(
dataset=dataset,
- owner=current_user.check_workspace(teams_query.workspace),
+ owner=current_user.check_workspace(request_deps.workspace),
metric=metric_,
record_class=record_class,
query=query,
@@ -196,8 +124,7 @@ def metric_summary(
router = APIRouter()
for task_api in [text_classification, token_classification, text2text]:
- cfg = TaskFactory.get_task_by_task_type(task_api.TASK_TYPE)
+ cfg = TasksFactory.get_task_by_task_type(task_api.TASK_TYPE)
if cfg:
configure_metrics_endpoints(task_api.router, cfg)
-
router.include_router(task_api.router)
diff --git a/src/rubrix/server/apis/v0/handlers/text2text.py b/src/rubrix/server/apis/v0/handlers/text2text.py
index a2edd5aaa7..29daec8e12 100644
--- a/src/rubrix/server/apis/v0/handlers/text2text.py
+++ b/src/rubrix/server/apis/v0/handlers/text2text.py
@@ -19,31 +19,47 @@
from fastapi import APIRouter, Depends, Query, Security
from fastapi.responses import StreamingResponse
-from rubrix.server.apis.v0.config.tasks_factory import TaskFactory
-from rubrix.server.apis.v0.helpers import takeuntil
-from rubrix.server.apis.v0.models.commons.model import (
- BulkResponse,
- PaginationParams,
- TaskType,
+from rubrix.server.apis.v0.models.commons.model import BulkResponse
+from rubrix.server.apis.v0.models.commons.params import (
+ CommonTaskHandlerDependencies,
+ RequestPagination,
)
-from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams
from rubrix.server.apis.v0.models.text2text import (
- Text2TextBulkData,
+ Text2TextBulkRequest,
+ Text2TextDataset,
+ Text2TextMetrics,
Text2TextQuery,
Text2TextRecord,
+ Text2TextSearchAggregations,
Text2TextSearchRequest,
Text2TextSearchResults,
)
+from rubrix.server.commons.config import TasksFactory
+from rubrix.server.commons.models import TaskType
from rubrix.server.errors import EntityNotFoundError
+from rubrix.server.helpers import takeuntil
from rubrix.server.responses import StreamingResponseWithErrorHandling
from rubrix.server.security import auth
from rubrix.server.security.model import User
from rubrix.server.services.datasets import DatasetsService
-from rubrix.server.services.text2text import Text2TextService, text2text_service
+from rubrix.server.services.tasks.text2text import Text2TextService
+from rubrix.server.services.tasks.text2text.models import (
+ ServiceText2TextQuery,
+ ServiceText2TextRecord,
+)
TASK_TYPE = TaskType.text2text
BASE_ENDPOINT = "/{name}/" + TASK_TYPE
+TasksFactory.register_task(
+ task_type=TaskType.text2text,
+ dataset_class=Text2TextDataset,
+ query_request=Text2TextQuery,
+ record_class=ServiceText2TextRecord,
+ metrics=Text2TextMetrics,
+)
+
+
router = APIRouter(tags=[TASK_TYPE], prefix="/datasets")
@@ -55,37 +71,14 @@
)
def bulk_records(
name: str,
- bulk: Text2TextBulkData,
- common_params: CommonTaskQueryParams = Depends(),
- service: Text2TextService = Depends(text2text_service),
+ bulk: Text2TextBulkRequest,
+ common_params: CommonTaskHandlerDependencies = Depends(),
+ service: Text2TextService = Depends(Text2TextService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> BulkResponse:
- """
- Includes a chunk of record data with provided dataset bulk information
-
- Parameters
- ----------
- name:
- The dataset name
- bulk:
- The bulk data
- common_params:
- Common task query params
- service:
- the Service
- datasets:
- The dataset service
- current_user:
- Current request user
-
- Returns
- -------
- Bulk response data
- """
task = TASK_TYPE
- task_mappings = TaskFactory.get_task_mappings(TASK_TYPE)
owner = current_user.check_workspace(common_params.workspace)
try:
dataset = datasets.find_by_name(
@@ -93,7 +86,7 @@ def bulk_records(
name=name,
task=task,
workspace=owner,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
datasets.update(
user=current_user,
@@ -102,19 +95,14 @@ def bulk_records(
metadata=bulk.metadata,
)
except EntityNotFoundError:
- dataset_class = TaskFactory.get_task_dataset(task)
+ dataset_class = TasksFactory.get_task_dataset(task)
dataset = dataset_class.parse_obj({**bulk.dict(), "name": name})
dataset.owner = owner
-
- datasets.create_dataset(
- user=current_user, dataset=dataset, mappings=task_mappings
- )
+ datasets.create_dataset(user=current_user, dataset=dataset)
result = service.add_records(
dataset=dataset,
- mappings=task_mappings,
- records=bulk.records,
- metrics=TaskFactory.get_task_metrics(TASK_TYPE),
+ records=[ServiceText2TextRecord.parse_obj(r) for r in bulk.records],
)
return BulkResponse(
dataset=name,
@@ -132,42 +120,15 @@ def bulk_records(
def search_records(
name: str,
search: Text2TextSearchRequest = None,
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
include_metrics: bool = Query(
False, description="If enabled, return related record metrics"
),
- pagination: PaginationParams = Depends(),
- service: Text2TextService = Depends(text2text_service),
+ pagination: RequestPagination = Depends(),
+ service: Text2TextService = Depends(Text2TextService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> Text2TextSearchResults:
- """
- Searches data from dataset
-
- Parameters
- ----------
- name:
- The dataset name
- common_params:
- The task common query params
- include_metrics:
- Flag to include metrics in results
- search:
- THe search query request
- pagination:
- The pagination params
- service:
- The dataset records service
- datasets:
- The dataset service
- current_user:
- The current request user
-
- Returns
- -------
- The search results data
-
- """
search = search or Text2TextSearchRequest()
query = search.query or Text2TextQuery()
@@ -176,29 +137,24 @@ def search_records(
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
result = service.search(
dataset=dataset,
- query=query,
+ query=ServiceText2TextQuery.parse_obj(query),
sort_by=search.sort,
record_from=pagination.from_,
size=pagination.limit,
exclude_metrics=not include_metrics,
- metrics=TaskFactory.find_task_metrics(
- TASK_TYPE,
- metric_ids={
- "words_cloud",
- "predicted_by",
- "annotated_by",
- "status_distribution",
- "metadata",
- "score",
- },
- ),
)
- return result
+ return Text2TextSearchResults(
+ total=result.total,
+ records=[Text2TextRecord.parse_obj(r) for r in result.records],
+ aggregations=Text2TextSearchAggregations.parse_obj(result.metrics)
+ if result.metrics
+ else None,
+ )
def scan_data_response(
@@ -241,15 +197,15 @@ def grouper(n, iterable, fillvalue=None):
async def stream_data(
name: str,
query: Optional[Text2TextQuery] = None,
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
limit: Optional[int] = Query(None, description="Limit loaded records", gt=0),
- service: Text2TextService = Depends(text2text_service),
+ service: Text2TextService = Depends(Text2TextService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
id_from: Optional[str] = None
) -> StreamingResponse:
"""
- Creates a data stream over dataset records
+ Creates a data stream over dataset records
Parameters
----------
@@ -277,10 +233,12 @@ async def stream_data(
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
+ )
+ data_stream = map(
+ Text2TextRecord.parse_obj,
+ service.read_dataset(dataset, query=ServiceText2TextQuery.parse_obj(query), id_from=id_from, limit=limit),
)
- data_stream = service.read_dataset(dataset, query=query, id_from=id_from, limit=limit)
-
return scan_data_response(
data_stream=data_stream,
limit=limit,
diff --git a/src/rubrix/server/apis/v0/handlers/text_classification.py b/src/rubrix/server/apis/v0/handlers/text_classification.py
index 976646c1a3..35b0a134ab 100644
--- a/src/rubrix/server/apis/v0/handlers/text_classification.py
+++ b/src/rubrix/server/apis/v0/handlers/text_classification.py
@@ -19,39 +19,64 @@
from fastapi import APIRouter, Depends, Query, Security
from fastapi.responses import StreamingResponse
-from rubrix.server.apis.v0.config.tasks_factory import TaskFactory
from rubrix.server.apis.v0.handlers import text_classification_dataset_settings
-from rubrix.server.apis.v0.helpers import takeuntil
-from rubrix.server.apis.v0.models.commons.model import (
- BulkResponse,
- PaginationParams,
- TaskType,
+from rubrix.server.apis.v0.models.commons.model import BulkResponse
+from rubrix.server.apis.v0.models.commons.params import (
+ CommonTaskHandlerDependencies,
+ RequestPagination,
)
-from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams
from rubrix.server.apis.v0.models.text_classification import (
CreateLabelingRule,
DatasetLabelingRulesMetricsSummary,
LabelingRule,
LabelingRuleMetricsSummary,
- TextClassificationBulkData,
+ TextClassificationBulkRequest,
TextClassificationQuery,
TextClassificationRecord,
+ TextClassificationSearchAggregations,
TextClassificationSearchRequest,
TextClassificationSearchResults,
UpdateLabelingRule,
)
+from rubrix.server.apis.v0.models.token_classification import (
+ TokenClassificationDataset,
+ TokenClassificationQuery,
+)
from rubrix.server.apis.v0.validators.text_classification import DatasetValidator
+from rubrix.server.commons.config import TasksFactory
+from rubrix.server.commons.models import TaskType
from rubrix.server.errors import EntityNotFoundError
+from rubrix.server.helpers import takeuntil
from rubrix.server.responses import StreamingResponseWithErrorHandling
from rubrix.server.security import auth
from rubrix.server.security.model import User
from rubrix.server.services.datasets import DatasetsService
-from rubrix.server.services.text_classification import TextClassificationService
+from rubrix.server.services.tasks.text_classification import TextClassificationService
+from rubrix.server.services.tasks.text_classification.model import (
+ ServiceLabelingRule,
+ ServiceTextClassificationQuery,
+ ServiceTextClassificationRecord,
+)
+from rubrix.server.services.tasks.token_classification.metrics import (
+ TokenClassificationMetrics,
+)
+from rubrix.server.services.tasks.token_classification.model import (
+ ServiceTokenClassificationRecord,
+)
TASK_TYPE = TaskType.text_classification
BASE_ENDPOINT = "/{name}/" + TASK_TYPE
NEW_BASE_ENDPOINT = f"/{TASK_TYPE}/{{name}}"
+TasksFactory.register_task(
+ task_type=TaskType.token_classification,
+ dataset_class=TokenClassificationDataset,
+ query_request=TokenClassificationQuery,
+ record_class=ServiceTokenClassificationRecord,
+ metrics=TokenClassificationMetrics,
+)
+
+
router = APIRouter(tags=[TASK_TYPE], prefix="/datasets")
@@ -63,8 +88,8 @@
)
async def bulk_records(
name: str,
- bulk: TextClassificationBulkData,
- common_params: CommonTaskQueryParams = Depends(),
+ bulk: TextClassificationBulkRequest,
+ common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(
TextClassificationService.get_instance
),
@@ -72,33 +97,8 @@ async def bulk_records(
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> BulkResponse:
- """
- Includes a chunk of record data with provided dataset bulk information
-
- Parameters
- ----------
- name:
- The dataset name
- bulk:
- The bulk data
- common_params:
- Common query params
- service:
- the Service
- datasets:
- The dataset service
- validator:
- The dataset validator component
- current_user:
- Current request user
-
- Returns
- -------
- Bulk response data
- """
task = TASK_TYPE
- task_mappings = TaskFactory.get_task_mappings(TASK_TYPE)
owner = current_user.check_workspace(common_params.workspace)
try:
dataset = datasets.find_by_name(
@@ -106,7 +106,7 @@ async def bulk_records(
name=name,
task=task,
workspace=owner,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
datasets.update(
user=current_user,
@@ -115,22 +115,20 @@ async def bulk_records(
metadata=bulk.metadata,
)
except EntityNotFoundError:
- dataset_class = TaskFactory.get_task_dataset(task)
+ dataset_class = TasksFactory.get_task_dataset(task)
dataset = dataset_class.parse_obj({**bulk.dict(), "name": name})
dataset.owner = owner
- datasets.create_dataset(
- user=current_user, dataset=dataset, mappings=task_mappings
- )
+ datasets.create_dataset(user=current_user, dataset=dataset)
+ # TODO(@frascuchon): Validator should be applied in the service layer
+ records = [ServiceTextClassificationRecord.parse_obj(r) for r in bulk.records]
await validator.validate_dataset_records(
- user=current_user, dataset=dataset, records=bulk.records
+ user=current_user, dataset=dataset, records=records
)
result = service.add_records(
dataset=dataset,
- mappings=task_mappings,
- records=bulk.records,
- metrics=TaskFactory.get_task_metrics(TASK_TYPE),
+ records=records,
)
return BulkResponse(
dataset=name,
@@ -148,11 +146,11 @@ async def bulk_records(
def search_records(
name: str,
search: TextClassificationSearchRequest = None,
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
include_metrics: bool = Query(
False, description="If enabled, return related record metrics"
),
- pagination: PaginationParams = Depends(),
+ pagination: RequestPagination = Depends(),
service: TextClassificationService = Depends(
TextClassificationService.get_instance
),
@@ -195,32 +193,24 @@ def search_records(
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
result = service.search(
dataset=dataset,
- query=query,
+ query=ServiceTextClassificationQuery.parse_obj(query),
sort_by=search.sort,
record_from=pagination.from_,
size=pagination.limit,
exclude_metrics=not include_metrics,
- metrics=TaskFactory.find_task_metrics(
- TASK_TYPE,
- metric_ids={
- "words_cloud",
- "predicted_by",
- "predicted_as",
- "annotated_by",
- "annotated_as",
- "error_distribution",
- "status_distribution",
- "metadata",
- "score",
- },
- ),
)
- return result
+ return TextClassificationSearchResults(
+ total=result.total,
+ records=result.records,
+ aggregations=TextClassificationSearchAggregations.parse_obj(result.metrics)
+ if result.metrics
+ else None,
+ )
def scan_data_response(
@@ -263,7 +253,7 @@ def grouper(n, iterable, fillvalue=None):
async def stream_data(
name: str,
query: Optional[TextClassificationQuery] = None,
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
id_from: Optional[str] = None,
limit: Optional[int] = Query(None, description="Limit loaded records", gt=0),
service: TextClassificationService = Depends(
@@ -273,7 +263,7 @@ async def stream_data(
current_user: User = Security(auth.get_user, scopes=[]),
) -> StreamingResponse:
"""
- Creates a data stream over dataset records
+ Creates a data stream over dataset records
Parameters
----------
@@ -302,10 +292,15 @@ async def stream_data(
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
- data_stream = service.read_dataset(dataset, query=query, id_from=id_from, limit=limit)
+ data_stream = map(
+ TextClassificationRecord.parse_obj,
+ service.read_dataset(
+ dataset, query=ServiceTextClassificationQuery.parse_obj(query), id_from=id_from, limit=limit
+ ),
+ )
return scan_data_response(
data_stream=data_stream,
limit=limit,
@@ -321,7 +316,7 @@ async def stream_data(
)
async def list_labeling_rules(
name: str,
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
service: TextClassificationService = Depends(
TextClassificationService.get_instance
@@ -334,10 +329,12 @@ async def list_labeling_rules(
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
- return list(service.get_labeling_rules(dataset))
+ return [
+ LabelingRule.parse_obj(rule) for rule in service.get_labeling_rules(dataset)
+ ]
@router.post(
@@ -350,7 +347,7 @@ async def list_labeling_rules(
async def create_rule(
name: str,
rule: CreateLabelingRule,
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(
TextClassificationService.get_instance
),
@@ -363,10 +360,10 @@ async def create_rule(
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
- rule = LabelingRule(
+ rule = ServiceLabelingRule(
**rule.dict(),
author=current_user.username,
)
@@ -374,8 +371,7 @@ async def create_rule(
dataset,
rule=rule,
)
-
- return rule
+ return LabelingRule.parse_obj(rule)
@router.get(
@@ -391,19 +387,20 @@ async def compute_rule_metrics(
labels: Optional[List[str]] = Query(
None, description="Label related to query rule", alias="label"
),
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(
TextClassificationService.get_instance
),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> LabelingRuleMetricsSummary:
+
dataset = datasets.find_by_name(
user=current_user,
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
return service.compute_rule_metrics(dataset, rule_query=query, labels=labels)
@@ -418,7 +415,7 @@ async def compute_rule_metrics(
)
async def compute_dataset_rules_metrics(
name: str,
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(
TextClassificationService.get_instance
),
@@ -430,10 +427,10 @@ async def compute_dataset_rules_metrics(
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
-
- return service.compute_overall_rules_metrics(dataset)
+ metrics = service.compute_overall_rules_metrics(dataset)
+ return DatasetLabelingRulesMetricsSummary.parse_obj(metrics)
@router.delete(
@@ -444,7 +441,7 @@ async def compute_dataset_rules_metrics(
async def delete_labeling_rule(
name: str,
query: str,
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(
TextClassificationService.get_instance
),
@@ -457,7 +454,7 @@ async def delete_labeling_rule(
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
service.delete_labeling_rule(dataset, rule_query=query)
@@ -473,7 +470,7 @@ async def delete_labeling_rule(
async def get_rule(
name: str,
query: str,
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(
TextClassificationService.get_instance
),
@@ -486,14 +483,13 @@ async def get_rule(
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
-
rule = service.find_labeling_rule(
dataset,
rule_query=query,
)
- return rule
+ return LabelingRule.parse_obj(rule)
@router.patch(
@@ -507,7 +503,7 @@ async def update_rule(
name: str,
query: str,
update: UpdateLabelingRule,
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(
TextClassificationService.get_instance
),
@@ -520,7 +516,7 @@ async def update_rule(
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
rule = service.update_labeling_rule(
@@ -529,7 +525,7 @@ async def update_rule(
labels=update.labels,
description=update.description,
)
- return rule
+ return LabelingRule.parse_obj(rule)
text_classification_dataset_settings.configure_router(router)
diff --git a/src/rubrix/server/apis/v0/handlers/text_classification_dataset_settings.py b/src/rubrix/server/apis/v0/handlers/text_classification_dataset_settings.py
index 88d0f6323a..a2d0b840f8 100644
--- a/src/rubrix/server/apis/v0/handlers/text_classification_dataset_settings.py
+++ b/src/rubrix/server/apis/v0/handlers/text_classification_dataset_settings.py
@@ -2,18 +2,20 @@
from fastapi import APIRouter, Body, Depends, Security
-from rubrix.server.apis.v0.models.commons.model import TaskType
-from rubrix.server.apis.v0.models.commons.params import DATASET_NAME_PATH_PARAM
-from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams
+from rubrix.server.apis.v0.models.commons.params import (
+ DATASET_NAME_PATH_PARAM,
+ CommonTaskHandlerDependencies,
+)
from rubrix.server.apis.v0.models.dataset_settings import TextClassificationSettings
from rubrix.server.apis.v0.validators.text_classification import DatasetValidator
+from rubrix.server.commons.models import TaskType
from rubrix.server.security import auth
from rubrix.server.security.model import User
-from rubrix.server.services.datasets import DatasetsService, SVCDatasetSettings
+from rubrix.server.services.datasets import DatasetsService, ServiceBaseDatasetSettings
-__svc_settings_class__: Type[SVCDatasetSettings] = type(
+__svc_settings_class__: Type[ServiceBaseDatasetSettings] = type(
f"{TaskType.text_classification}_DatasetSettings",
- (SVCDatasetSettings, TextClassificationSettings),
+ (ServiceBaseDatasetSettings, TextClassificationSettings),
{},
)
@@ -33,7 +35,7 @@ def configure_router(router: APIRouter):
)
async def get_dataset_settings(
name: str = DATASET_NAME_PATH_PARAM,
- ws_params: CommonTaskQueryParams = Depends(),
+ ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["read:dataset.settings"]),
) -> TextClassificationSettings:
@@ -63,7 +65,7 @@ async def save_settings(
..., description=f"The {task} dataset settings"
),
name: str = DATASET_NAME_PATH_PARAM,
- ws_params: CommonTaskQueryParams = Depends(),
+ ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
user: User = Security(auth.get_user, scopes=["write:dataset.settings"]),
@@ -93,7 +95,7 @@ async def save_settings(
)
async def delete_settings(
name: str = DATASET_NAME_PATH_PARAM,
- ws_params: CommonTaskQueryParams = Depends(),
+ ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["delete:dataset.settings"]),
) -> None:
diff --git a/src/rubrix/server/apis/v0/handlers/token_classification.py b/src/rubrix/server/apis/v0/handlers/token_classification.py
index ed56beff43..363b638c56 100644
--- a/src/rubrix/server/apis/v0/handlers/token_classification.py
+++ b/src/rubrix/server/apis/v0/handlers/token_classification.py
@@ -19,36 +19,57 @@
from fastapi import APIRouter, Depends, Query, Security
from fastapi.responses import StreamingResponse
-from rubrix.server.apis.v0.config.tasks_factory import TaskFactory
from rubrix.server.apis.v0.handlers import token_classification_dataset_settings
-from rubrix.server.apis.v0.helpers import takeuntil
-from rubrix.server.apis.v0.models.commons.model import (
- BulkResponse,
- PaginationParams,
- TaskType,
+from rubrix.server.apis.v0.models.commons.model import BulkResponse
+from rubrix.server.apis.v0.models.commons.params import (
+ CommonTaskHandlerDependencies,
+ RequestPagination,
+)
+from rubrix.server.apis.v0.models.text_classification import (
+ TextClassificationDataset,
+ TextClassificationQuery,
)
-from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams
from rubrix.server.apis.v0.models.token_classification import (
- TokenClassificationBulkData,
+ TokenClassificationAggregations,
+ TokenClassificationBulkRequest,
TokenClassificationQuery,
TokenClassificationRecord,
TokenClassificationSearchRequest,
TokenClassificationSearchResults,
)
from rubrix.server.apis.v0.validators.token_classification import DatasetValidator
+from rubrix.server.commons.config import TasksFactory
+from rubrix.server.commons.models import TaskType
from rubrix.server.errors import EntityNotFoundError
+from rubrix.server.helpers import takeuntil
from rubrix.server.responses import StreamingResponseWithErrorHandling
from rubrix.server.security import auth
from rubrix.server.security.model import User
from rubrix.server.services.datasets import DatasetsService
-from rubrix.server.services.token_classification import (
- TokenClassificationService,
- token_classification_service,
+from rubrix.server.services.tasks.text_classification.metrics import (
+ TextClassificationMetrics,
+)
+from rubrix.server.services.tasks.text_classification.model import (
+ ServiceTextClassificationRecord,
+)
+from rubrix.server.services.tasks.token_classification import TokenClassificationService
+from rubrix.server.services.tasks.token_classification.model import (
+ ServiceTokenClassificationQuery,
+ ServiceTokenClassificationRecord,
)
TASK_TYPE = TaskType.token_classification
BASE_ENDPOINT = "/{name}/" + TASK_TYPE
+TasksFactory.register_task(
+ task_type=TaskType.text_classification,
+ dataset_class=TextClassificationDataset,
+ query_request=TextClassificationQuery,
+ record_class=ServiceTextClassificationRecord,
+ metrics=TextClassificationMetrics,
+)
+
+
router = APIRouter(tags=[TASK_TYPE], prefix="/datasets")
@@ -60,38 +81,17 @@
)
async def bulk_records(
name: str,
- bulk: TokenClassificationBulkData,
- common_params: CommonTaskQueryParams = Depends(),
- service: TokenClassificationService = Depends(token_classification_service),
+ bulk: TokenClassificationBulkRequest,
+ common_params: CommonTaskHandlerDependencies = Depends(),
+ service: TokenClassificationService = Depends(
+ TokenClassificationService.get_instance
+ ),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> BulkResponse:
- """
- Includes a chunk of record data with provided dataset bulk information
-
- Parameters
- ----------
- name:
- The dataset name
- bulk:
- The bulk data
- common_params:
- Common query params
- service:
- the Service
- datasets:
- The dataset service
- current_user:
- Current request user
-
- Returns
- -------
- Bulk response data
- """
task = TASK_TYPE
- task_mappings = TaskFactory.get_task_mappings(TASK_TYPE)
owner = current_user.check_workspace(common_params.workspace)
try:
dataset = datasets.find_by_name(
@@ -99,7 +99,7 @@ async def bulk_records(
name=name,
task=task,
workspace=owner,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
datasets.update(
user=current_user,
@@ -108,24 +108,22 @@ async def bulk_records(
metadata=bulk.metadata,
)
except EntityNotFoundError:
- dataset_class = TaskFactory.get_task_dataset(task)
+ dataset_class = TasksFactory.get_task_dataset(task)
dataset = dataset_class.parse_obj({**bulk.dict(), "name": name})
dataset.owner = owner
- datasets.create_dataset(
- user=current_user, dataset=dataset, mappings=task_mappings
- )
+ datasets.create_dataset(user=current_user, dataset=dataset)
+ records = [ServiceTokenClassificationRecord.parse_obj(r) for r in bulk.records]
+ # TODO(@frascuchon): validator can be applied in service layer
await validator.validate_dataset_records(
user=current_user,
dataset=dataset,
- records=bulk.records,
+ records=records,
)
result = service.add_records(
dataset=dataset,
- mappings=task_mappings,
- records=bulk.records,
- metrics=TaskFactory.get_task_metrics(TASK_TYPE),
+ records=records,
)
return BulkResponse(
dataset=name,
@@ -143,42 +141,17 @@ async def bulk_records(
def search_records(
name: str,
search: TokenClassificationSearchRequest = None,
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
include_metrics: bool = Query(
False, description="If enabled, return related record metrics"
),
- pagination: PaginationParams = Depends(),
- service: TokenClassificationService = Depends(token_classification_service),
+ pagination: RequestPagination = Depends(),
+ service: TokenClassificationService = Depends(
+ TokenClassificationService.get_instance
+ ),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> TokenClassificationSearchResults:
- """
- Searches data from dataset
-
- Parameters
- ----------
- name:
- The dataset name
- search:
- THe search query request
- common_params:
- Common query params
- include_metrics:
- include metrics flag
- pagination:
- The pagination params
- service:
- The dataset records service
- datasets:
- The dataset service
- current_user:
- The current request user
-
- Returns
- -------
- The search results data
-
- """
search = search or TokenClassificationSearchRequest()
query = search.query or TokenClassificationQuery()
@@ -188,34 +161,24 @@ def search_records(
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
)
- result = service.search(
+ results = service.search(
dataset=dataset,
- query=query,
+ query=ServiceTokenClassificationQuery.parse_obj(query),
sort_by=search.sort,
record_from=pagination.from_,
size=pagination.limit,
exclude_metrics=not include_metrics,
- metrics=TaskFactory.find_task_metrics(
- TASK_TYPE,
- metric_ids={
- "words_cloud",
- "predicted_by",
- "predicted_as",
- "annotated_by",
- "annotated_as",
- "error_distribution",
- "predicted_mentions_distribution",
- "annotated_mentions_distribution",
- "status_distribution",
- "metadata",
- "score",
- },
- ),
)
- return result
+ return TokenClassificationSearchResults(
+ total=results.total,
+ records=[TokenClassificationRecord.parse_obj(r) for r in results.records],
+ aggregations=TokenClassificationAggregations.parse_obj(results.metrics)
+ if results.metrics
+ else None,
+ )
def scan_data_response(
@@ -258,15 +221,17 @@ def grouper(n, iterable, fillvalue=None):
async def stream_data(
name: str,
query: Optional[TokenClassificationQuery] = None,
- common_params: CommonTaskQueryParams = Depends(),
+ common_params: CommonTaskHandlerDependencies = Depends(),
limit: Optional[int] = Query(None, description="Limit loaded records", gt=0),
- service: TokenClassificationService = Depends(token_classification_service),
+ service: TokenClassificationService = Depends(
+ TokenClassificationService.get_instance
+ ),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
id_from: Optional[str] = None,
) -> StreamingResponse:
"""
- Creates a data stream over dataset records
+ Creates a data stream over dataset records
Parameters
----------
@@ -294,9 +259,14 @@ async def stream_data(
name=name,
task=TASK_TYPE,
workspace=common_params.workspace,
- as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
+ as_dataset_class=TasksFactory.get_task_dataset(TASK_TYPE),
+ )
+ data_stream = map(
+ TokenClassificationRecord.parse_obj,
+ service.read_dataset(
+ dataset=dataset, query=ServiceTokenClassificationQuery.parse_obj(query), id_from=id_from, limit=limit
+ ),
)
- data_stream = service.read_dataset(dataset=dataset, query=query, id_from=id_from, limit=limit)
return scan_data_response(
data_stream=data_stream,
diff --git a/src/rubrix/server/apis/v0/handlers/token_classification_dataset_settings.py b/src/rubrix/server/apis/v0/handlers/token_classification_dataset_settings.py
index 955f1aa41a..14cfb7d572 100644
--- a/src/rubrix/server/apis/v0/handlers/token_classification_dataset_settings.py
+++ b/src/rubrix/server/apis/v0/handlers/token_classification_dataset_settings.py
@@ -2,18 +2,20 @@
from fastapi import APIRouter, Body, Depends, Security
-from rubrix.server.apis.v0.models.commons.model import TaskType
-from rubrix.server.apis.v0.models.commons.params import DATASET_NAME_PATH_PARAM
-from rubrix.server.apis.v0.models.commons.workspace import CommonTaskQueryParams
+from rubrix.server.apis.v0.models.commons.params import (
+ DATASET_NAME_PATH_PARAM,
+ CommonTaskHandlerDependencies,
+)
from rubrix.server.apis.v0.models.dataset_settings import TokenClassificationSettings
from rubrix.server.apis.v0.validators.token_classification import DatasetValidator
+from rubrix.server.commons.models import TaskType
from rubrix.server.security import auth
from rubrix.server.security.model import User
-from rubrix.server.services.datasets import DatasetsService, SVCDatasetSettings
+from rubrix.server.services.datasets import DatasetsService, ServiceBaseDatasetSettings
-__svc_settings_class__: Type[SVCDatasetSettings] = type(
+__svc_settings_class__: Type[ServiceBaseDatasetSettings] = type(
f"{TaskType.token_classification}_DatasetSettings",
- (SVCDatasetSettings, TokenClassificationSettings),
+ (ServiceBaseDatasetSettings, TokenClassificationSettings),
{},
)
@@ -32,7 +34,7 @@ def configure_router(router: APIRouter):
)
async def get_dataset_settings(
name: str = DATASET_NAME_PATH_PARAM,
- ws_params: CommonTaskQueryParams = Depends(),
+ ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["read:dataset.settings"]),
) -> TokenClassificationSettings:
@@ -62,7 +64,7 @@ async def save_settings(
..., description=f"The {task} dataset settings"
),
name: str = DATASET_NAME_PATH_PARAM,
- ws_params: CommonTaskQueryParams = Depends(),
+ ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
user: User = Security(auth.get_user, scopes=["write:dataset.settings"]),
@@ -92,7 +94,7 @@ async def save_settings(
)
async def delete_settings(
name: str = DATASET_NAME_PATH_PARAM,
- ws_params: CommonTaskQueryParams = Depends(),
+ ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["delete:dataset.settings"]),
) -> None:
diff --git a/src/rubrix/server/apis/v0/models/commons/model.py b/src/rubrix/server/apis/v0/models/commons/model.py
index d3ce747ae8..5eb98a3a80 100644
--- a/src/rubrix/server/apis/v0/models/commons/model.py
+++ b/src/rubrix/server/apis/v0/models/commons/model.py
@@ -17,98 +17,59 @@
Common model for task definitions
"""
-from dataclasses import dataclass
-from typing import Any, Dict, Generic, TypeVar
+from typing import Any, Dict, Generic, List, TypeVar
-from fastapi import Query
-from pydantic import validator
+from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
-from rubrix._constants import MAX_KEYWORD_LENGTH
-from rubrix.server.apis.v0.helpers import flatten_dict
from rubrix.server.services.search.model import (
- BaseSearchResults,
- BaseSearchResultsAggregations,
- QueryRange,
- SortableField,
+ ServiceQueryRange,
+ ServiceSearchResultsAggregations,
+ ServiceSortableField,
)
from rubrix.server.services.tasks.commons import (
- Annotation,
- BaseAnnotation,
- BaseRecordDB,
- BulkResponse,
- EsRecordDataFieldNames,
- PredictionStatus,
- TaskStatus,
- TaskType,
+ ServiceBaseAnnotation,
+ ServiceBaseRecord,
+ ServiceBaseRecordInputs,
)
-from rubrix.utils import limit_value_length
-@dataclass
-class PaginationParams:
- """Query pagination params"""
+class SortableField(ServiceSortableField):
+ pass
- limit: int = Query(50, gte=0, le=1000, description="Response records limit")
- from_: int = Query(
- 0, ge=0, le=10000, alias="from", description="Record sequence from"
- )
+class BaseAnnotation(ServiceBaseAnnotation):
+ pass
-class BaseRecord(BaseRecordDB, GenericModel, Generic[Annotation]):
- """
- Minimal dataset record information
- Attributes:
- -----------
+Annotation = TypeVar("Annotation", bound=BaseAnnotation)
- id:
- The record id
- metadata:
- The metadata related to record
- event_timestamp:
- The timestamp when record event was triggered
- """
+class BaseRecordInputs(ServiceBaseRecordInputs[Annotation], Generic[Annotation]):
+ def extended_fields(self) -> Dict[str, Any]:
+ return {}
- @validator("metadata", pre=True)
- def flatten_metadata(cls, metadata: Dict[str, Any]):
- """
- A fastapi validator for flatten metadata dictionary
- Parameters
- ----------
- metadata:
- The metadata dictionary
+class BaseRecord(ServiceBaseRecord[Annotation], Generic[Annotation]):
+ pass
- Returns
- -------
- A flatten version of metadata dictionary
- """
- if metadata:
- metadata = flatten_dict(metadata, drop_empty=True)
- metadata = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH)
- return metadata
+class ScoreRange(ServiceQueryRange):
+ pass
-Record = TypeVar("Record", bound=BaseRecord)
+_Record = TypeVar("_Record", bound=BaseRecord)
-class ScoreRange(QueryRange):
- pass
+class BulkResponse(BaseModel):
+ dataset: str
+ processed: int
+ failed: int = 0
-__ALL__ = [
- QueryRange,
- SortableField,
- BaseSearchResults,
- BaseSearchResultsAggregations,
- Annotation,
- TaskStatus,
- TaskType,
- EsRecordDataFieldNames,
- BaseAnnotation,
- PredictionStatus,
- BulkResponse,
-]
+class BaseSearchResults(
+ GenericModel, Generic[_Record, ServiceSearchResultsAggregations]
+):
+ total: int = 0
+ records: List[_Record] = Field(default_factory=list)
+ aggregations: ServiceSearchResultsAggregations = None
diff --git a/src/rubrix/server/apis/v0/models/commons/params.py b/src/rubrix/server/apis/v0/models/commons/params.py
index b224d92fa5..7700e8f134 100644
--- a/src/rubrix/server/apis/v0/models/commons/params.py
+++ b/src/rubrix/server/apis/v0/models/commons/params.py
@@ -1,7 +1,46 @@
-from fastapi import Path
+from dataclasses import dataclass
-from rubrix._constants import DATASET_NAME_REGEX_PATTERN
+from fastapi import Header, Path, Query
+
+from rubrix._constants import DATASET_NAME_REGEX_PATTERN, RUBRIX_WORKSPACE_HEADER_NAME
+from rubrix.server.security.model import WORKSPACE_NAME_PATTERN
DATASET_NAME_PATH_PARAM = Path(
..., regex=DATASET_NAME_REGEX_PATTERN, description="The dataset name"
)
+
+
+@dataclass
+class RequestPagination:
+ """Query pagination params"""
+
+ limit: int = Query(50, gte=0, le=1000, description="Response records limit")
+ from_: int = Query(
+ 0, ge=0, le=10000, alias="from", description="Record sequence from"
+ )
+
+
+@dataclass
+class CommonTaskHandlerDependencies:
+ """Common task query dependencies"""
+
+ # TODO(@frascuchon): we could include the request user and parametrize the action scopes
+ # Depends(CommonTaskHandlerDependencies.create(scopes=[...])
+
+ __workspace_header__: str = Header(None, alias=RUBRIX_WORKSPACE_HEADER_NAME)
+ __workspace_param__: str = Query(
+ None,
+ alias="workspace",
+ description="The workspace where dataset belongs to. If not provided default user team will be used",
+ )
+
+ @property
+ def workspace(self) -> str:
+ """Return read workspace. Query param prior to header param"""
+ workspace = self.__workspace_param__ or self.__workspace_header__
+ if workspace:
+ assert WORKSPACE_NAME_PATTERN.match(workspace), (
+ "Wrong workspace format. "
+ f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}"
+ )
+ return workspace
diff --git a/src/rubrix/server/apis/v0/models/commons/workspace.py b/src/rubrix/server/apis/v0/models/commons/workspace.py
deleted file mode 100644
index d2abf609f6..0000000000
--- a/src/rubrix/server/apis/v0/models/commons/workspace.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from dataclasses import dataclass
-
-from fastapi import Header, Query
-
-from rubrix._constants import RUBRIX_WORKSPACE_HEADER_NAME
-from rubrix.server.security.model import WORKSPACE_NAME_PATTERN
-
-
-@dataclass
-class CommonTaskQueryParams:
- """Common task query params"""
-
- __workspace_param__: str = Query(
- None,
- alias="workspace",
- description="The workspace where dataset belongs to. If not provided default user team will be used",
- )
-
- __workspace_header__: str = Header(None, alias=RUBRIX_WORKSPACE_HEADER_NAME)
-
- @property
- def workspace(self) -> str:
- """Return read workspace. Query param prior to header param"""
- workspace = self.__workspace_param__ or self.__workspace_header__
- if workspace:
- assert WORKSPACE_NAME_PATTERN.match(workspace), (
- "Wrong workspace format. "
- f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}"
- )
- return workspace
diff --git a/src/rubrix/server/apis/v0/models/datasets.py b/src/rubrix/server/apis/v0/models/datasets.py
index f9f5586e4d..427210059f 100644
--- a/src/rubrix/server/apis/v0/models/datasets.py
+++ b/src/rubrix/server/apis/v0/models/datasets.py
@@ -17,14 +17,13 @@
Dataset models definition
"""
-from datetime import datetime
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from rubrix._constants import DATASET_NAME_REGEX_PATTERN
-from rubrix.server.apis.v0.models.commons.model import TaskType
-from rubrix.server.services.datasets import DatasetDB as SVCDataset
+from rubrix.server.commons.models import TaskType
+from rubrix.server.services.datasets import ServiceBaseDataset
class UpdateDatasetRequest(BaseModel):
@@ -43,15 +42,15 @@ class UpdateDatasetRequest(BaseModel):
metadata: Dict[str, Any] = Field(default_factory=dict)
-class CreationDatasetRequest(UpdateDatasetRequest):
+class _BaseDatasetRequest(UpdateDatasetRequest):
name: str = Field(regex=DATASET_NAME_REGEX_PATTERN, description="The dataset name")
-class DatasetCreate(CreationDatasetRequest):
+class CreateDatasetRequest(_BaseDatasetRequest):
task: TaskType = Field(description="The dataset task")
-class CopyDatasetRequest(CreationDatasetRequest):
+class CopyDatasetRequest(_BaseDatasetRequest):
"""
Request body for copy dataset operation
"""
@@ -59,7 +58,7 @@ class CopyDatasetRequest(CreationDatasetRequest):
target_workspace: Optional[str] = None
-class BaseDatasetDB(CreationDatasetRequest, SVCDataset):
+class Dataset(_BaseDatasetRequest, ServiceBaseDataset):
"""
Low level dataset data model
@@ -76,13 +75,3 @@ class BaseDatasetDB(CreationDatasetRequest, SVCDataset):
"""
task: TaskType
-
-
-class DatasetDB(BaseDatasetDB):
- pass
-
-
-class Dataset(BaseDatasetDB):
- """Dataset used for response output"""
-
- pass
diff --git a/src/rubrix/server/apis/v0/models/info.py b/src/rubrix/server/apis/v0/models/info.py
deleted file mode 100644
index fb3a7485f9..0000000000
--- a/src/rubrix/server/apis/v0/models/info.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# coding=utf-8
-# Copyright 2021-present, the Recognai S.L. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any, Dict
-
-from pydantic import BaseModel
-
-
-class ApiInfo(BaseModel):
- """Basic api info"""
-
- rubrix_version: str
-
-
-class ApiStatus(ApiInfo):
- """The Rubrix api status model"""
-
- elasticsearch: Dict[str, Any]
- mem_info: Dict[str, Any]
diff --git a/src/rubrix/server/apis/v0/models/metrics/base.py b/src/rubrix/server/apis/v0/models/metrics/base.py
deleted file mode 100644
index 8bae2ab238..0000000000
--- a/src/rubrix/server/apis/v0/models/metrics/base.py
+++ /dev/null
@@ -1,304 +0,0 @@
-from typing import (
- Any,
- ClassVar,
- Dict,
- Generic,
- Iterable,
- List,
- Optional,
- TypeVar,
- Union,
-)
-
-from pydantic import BaseModel, root_validator
-
-from rubrix.server._helpers import unflatten_dict
-from rubrix.server.apis.v0.models.commons.model import BaseRecord
-from rubrix.server.apis.v0.models.datasets import Dataset
-from rubrix.server.daos.records import DatasetRecordsDAO
-from rubrix.server.elasticseach.query_helpers import aggregations
-
-GenericRecord = TypeVar("GenericRecord", bound=BaseRecord)
-
-
-class BaseMetric(BaseModel):
- """
- Base model for rubrix dataset metrics summaries
- """
-
- id: str
- name: str
- description: str = None
-
-
-class PythonMetric(BaseMetric, Generic[GenericRecord]):
- """
- A metric definition which will be calculated using raw queried data
- """
-
- def apply(self, records: Iterable[GenericRecord]) -> Dict[str, Any]:
- """
- Metric calculation method.
-
- Parameters
- ----------
- records:
- The matched records
-
- Returns
- -------
- The metric result
- """
- raise NotImplementedError()
-
-
-class ElasticsearchMetric(BaseMetric):
- """
- A metric summarized by using one or several elasticsearch aggregations
- """
-
- def aggregation_request(
- self, *args, **kwargs
- ) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
- """
- Configures the summary es aggregation definition
- """
- raise NotImplementedError()
-
- def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
- """
- Parse the es aggregation result. Override this method
- for result customization
-
- Parameters
- ----------
- aggregation_result:
- Retrieved es aggregation result
-
- """
- return aggregation_result.get(self.id, aggregation_result)
-
-
-class NestedPathElasticsearchMetric(ElasticsearchMetric):
- """
- A ``ElasticsearchMetric`` which need nested fields for summary calculation.
-
- Aggregations for nested fields need some extra configuration and this class
- encapsulate these common logic.
-
- Attributes:
- -----------
- nested_path:
- The nested
- """
-
- nested_path: str
-
- def inner_aggregation(self, *args, **kwargs) -> Dict[str, Any]:
- """The specific aggregation definition"""
- raise NotImplementedError()
-
- def aggregation_request(self, *args, **kwargs) -> Dict[str, Any]:
- """Implements the common mechanism to define aggregations with nested fields"""
- return {
- self.id: aggregations.nested_aggregation(
- nested_path=self.nested_path,
- inner_aggregation=self.inner_aggregation(*args, **kwargs),
- )
- }
-
- def compound_nested_field(self, inner_field: str) -> str:
- return f"{self.nested_path}.{inner_field}"
-
-
-class BaseTaskMetrics(BaseModel):
- """
- Base class encapsulating related task metrics
-
- Attributes:
- -----------
-
- metrics:
- A list of configured metrics for task
- """
-
- metrics: ClassVar[List[BaseMetric]]
-
- @classmethod
- def configure_es_index(cls):
- """
- If some metrics require specific es field mapping definitions,
- include them here.
-
- """
- pass
-
- @classmethod
- def find_metric(cls, id: str) -> Optional[BaseMetric]:
- """
- Finds a metric by id
-
- Parameters
- ----------
- id:
- The metric id
-
- Returns
- -------
- Found metric if any, ``None`` otherwise
-
- """
- for metric in cls.metrics:
- if metric.id == id:
- return metric
-
- @classmethod
- def record_metrics(cls, record: GenericRecord) -> Dict[str, Any]:
- """
- Use this method is some configured metric requires additional
- records fields.
-
- Generated records will be persisted under ``metrics`` record path.
- For example, if you define a field called ``sentence_length`` like
-
- >>> def record_metrics(cls, record)-> Dict[str, Any]:
- ... return { "sentence_length" : len(record.text) }
-
- The new field will be stored in elasticsearch in ``metrics.sentence_length``
-
- Parameters
- ----------
- record:
- The record used for calculate metrics fields
-
- Returns
- -------
- A dict with calculated metrics fields
- """
- return {}
-
-
-class HistogramAggregation(ElasticsearchMetric):
- """
- Base elasticsearch histogram aggregation metric
-
- Attributes
- ----------
- field:
- The histogram field
- script:
- If provided, it will be used as scripted field
- for aggregation
- fixed_interval:
- If provided, it will used ALWAYS as the histogram
- aggregation interval
- """
-
- field: str
- script: Optional[Union[str, Dict[str, Any]]] = None
- fixed_interval: Optional[float] = None
-
- def aggregation_request(self, interval: Optional[float] = None) -> Dict[str, Any]:
- if self.fixed_interval:
- interval = self.fixed_interval
- return {
- self.id: aggregations.histogram_aggregation(
- field_name=self.field, script=self.script, interval=interval
- )
- }
-
-
-class TermsAggregation(ElasticsearchMetric):
- """
- The base elasticsearch terms aggregation metric
-
- Attributes
- ----------
-
- field:
- The term field
- script:
- If provided, it will be used as scripted field
- for aggregation
- fixed_size:
- If provided, the size will use for terms aggregation
- missing:
- If provided, will use the value for docs results with missing value for field
-
- """
-
- field: str = None
- script: Union[str, Dict[str, Any]] = None
- fixed_size: Optional[int] = None
- missing: Optional[str] = None
-
- def aggregation_request(self, size: int = None) -> Dict[str, Any]:
- if self.fixed_size:
- size = self.fixed_size
- return {
- self.id: aggregations.terms_aggregation(
- self.field, script=self.script, size=size, missing=self.missing
- )
- }
-
-
-class NestedTermsAggregation(NestedPathElasticsearchMetric):
- terms: TermsAggregation
-
- @root_validator
- def normalize_terms_field(cls, values):
- terms = values["terms"]
- nested_path = values["nested_path"]
- terms.field = f"{nested_path}.{terms.field}"
-
- return values
-
- def inner_aggregation(self, size: int) -> Dict[str, Any]:
- return self.terms.aggregation_request(size)
-
-
-class NestedHistogramAggregation(NestedPathElasticsearchMetric):
- histogram: HistogramAggregation
-
- @root_validator
- def normalize_terms_field(cls, values):
- histogram = values["histogram"]
- nested_path = values["nested_path"]
- histogram.field = f"{nested_path}.{histogram.field}"
-
- return values
-
- def inner_aggregation(self, interval: float) -> Dict[str, Any]:
- return self.histogram.aggregation_request(interval)
-
-
-class WordCloudAggregation(ElasticsearchMetric):
- default_field: str
-
- def aggregation_request(
- self, text_field: str = None, size: int = None
- ) -> Dict[str, Any]:
- field = text_field or self.default_field
- return TermsAggregation(
- id=f"{self.id}_{field}" if text_field else self.id,
- name=f"Words cloud for field {field}",
- field=field,
- ).aggregation_request(size=size)
-
-
-class MetadataAggregations(ElasticsearchMetric):
- def aggregation_request(
- self,
- dataset: Dataset,
- dao: DatasetRecordsDAO,
- size: int = None,
- ) -> List[Dict[str, Any]]:
-
- metadata_aggs = aggregations.custom_fields(
- fields_definitions=dao.get_metadata_schema(dataset), size=size
- )
- return [{key: value} for key, value in metadata_aggs.items()]
-
- def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
- data = unflatten_dict(aggregation_result, stop_keys=["metadata"])
- return data.get("metadata", {})
diff --git a/src/rubrix/server/apis/v0/models/metrics/commons.py b/src/rubrix/server/apis/v0/models/metrics/commons.py
deleted file mode 100644
index 214bf5e15c..0000000000
--- a/src/rubrix/server/apis/v0/models/metrics/commons.py
+++ /dev/null
@@ -1,75 +0,0 @@
-from typing import Any, ClassVar, Dict, Generic, List
-
-from rubrix.server.apis.v0.models.commons.model import (
- EsRecordDataFieldNames,
- TaskStatus,
-)
-from rubrix.server.apis.v0.models.metrics.base import (
- BaseMetric,
- BaseTaskMetrics,
- GenericRecord,
- HistogramAggregation,
- MetadataAggregations,
- TermsAggregation,
- WordCloudAggregation,
-)
-
-
-class CommonTasksMetrics(BaseTaskMetrics, Generic[GenericRecord]):
- """Common task metrics"""
-
- @classmethod
- def record_metrics(cls, record: GenericRecord) -> Dict[str, Any]:
- """Record metrics will persist the text_length"""
- return {"text_length": len(record.all_text())}
-
- metrics: ClassVar[List[BaseMetric]] = [
- HistogramAggregation(
- id="text_length",
- name="Text length distribution",
- description="Computes the input text length distribution",
- field="metrics.text_length",
- script="params._source.text.length()",
- fixed_interval=1,
- ),
- TermsAggregation(
- id="error_distribution",
- name="Error distribution",
- description="Computes the dataset error distribution. It's mean, records "
- "with correct predictions vs records with incorrect prediction "
- "vs records with unknown prediction result",
- field=EsRecordDataFieldNames.predicted,
- missing="unknown",
- fixed_size=3,
- ),
- TermsAggregation(
- id="status_distribution",
- name="Record status distribution",
- description="The dataset record status distribution",
- field=EsRecordDataFieldNames.status,
- fixed_size=len(TaskStatus),
- ),
- WordCloudAggregation(
- id="words_cloud",
- name="Inputs words cloud",
- description="The words cloud for dataset inputs",
- default_field="text.wordcloud",
- ),
- MetadataAggregations(id="metadata", name="Metadata fields stats"),
- TermsAggregation(
- id="predicted_by",
- name="Predicted by distribution",
- field="predicted_by",
- ),
- TermsAggregation(
- id="annotated_by",
- name="Annotated by distribution",
- field="annotated_by",
- ),
- HistogramAggregation(
- id="score",
- name="Score record distribution",
- field="score",
- fixed_interval=0.001,
- ),
- ]
diff --git a/src/rubrix/server/apis/v0/models/metrics/token_classification.py b/src/rubrix/server/apis/v0/models/metrics/token_classification.py
deleted file mode 100644
index 80c9abdee6..0000000000
--- a/src/rubrix/server/apis/v0/models/metrics/token_classification.py
+++ /dev/null
@@ -1,643 +0,0 @@
-from typing import Any, ClassVar, Dict, Iterable, List, Optional, Set, Tuple
-
-from pydantic import BaseModel, Field
-
-from rubrix.server.apis.v0.models.metrics.base import (
- BaseMetric,
- ElasticsearchMetric,
- HistogramAggregation,
- NestedHistogramAggregation,
- NestedPathElasticsearchMetric,
- NestedTermsAggregation,
- PythonMetric,
- TermsAggregation,
-)
-from rubrix.server.apis.v0.models.metrics.commons import CommonTasksMetrics
-from rubrix.server.apis.v0.models.token_classification import (
- EntitySpan,
- TokenClassificationRecord,
-)
-from rubrix.server.elasticseach.query_helpers import aggregations
-
-
-class TokensLength(ElasticsearchMetric):
- """
- Summarizes the tokens length metric into an histogram
-
- Attributes:
- -----------
- length_field:
- The elasticsearch field where tokens length is stored
- """
-
- length_field: str
-
- def aggregation_request(self, interval: int) -> Dict[str, Any]:
- return {
- self.id: aggregations.histogram_aggregation(
- self.length_field, interval=interval or 1
- )
- }
-
-
-_DEFAULT_MAX_ENTITY_BUCKET = 1000
-
-
-class EntityLabels(NestedPathElasticsearchMetric):
- """
- Computes the entity labels distribution
-
- Attributes:
- -----------
- labels_field:
- The elasticsearch field where tags are stored
- """
-
- labels_field: str
-
- def inner_aggregation(self, size: int) -> Dict[str, Any]:
- return {
- "labels": aggregations.terms_aggregation(
- self.compound_nested_field(self.labels_field),
- size=size or _DEFAULT_MAX_ENTITY_BUCKET,
- )
- }
-
-
-class EntityDensity(NestedPathElasticsearchMetric):
- """Summarizes the entity density metric into an histogram"""
-
- density_field: str
-
- def inner_aggregation(self, interval: float) -> Dict[str, Any]:
- return {
- "density": aggregations.histogram_aggregation(
- field_name=self.compound_nested_field(self.density_field),
- interval=interval or 0.01,
- )
- }
-
-
-class MentionLength(NestedPathElasticsearchMetric):
- """Summarizes the mention length into an histogram"""
-
- length_field: str
-
- def inner_aggregation(self, interval: int) -> Dict[str, Any]:
- return {
- "mention_length": aggregations.histogram_aggregation(
- self.compound_nested_field(self.length_field), interval=interval or 1
- )
- }
-
-
-class EntityCapitalness(NestedPathElasticsearchMetric):
- """Computes the mention capitalness distribution"""
-
- capitalness_field: str
-
- def inner_aggregation(self) -> Dict[str, Any]:
- return {
- "capitalness": aggregations.terms_aggregation(
- self.compound_nested_field(self.capitalness_field),
- size=4, # The number of capitalness choices
- )
- }
-
-
-class MentionsByEntityDistribution(NestedPathElasticsearchMetric):
- def inner_aggregation(self):
- return {
- self.id: aggregations.bidimentional_terms_aggregations(
- field_name_x=f"{self.nested_path}.label",
- field_name_y=f"{self.nested_path}.value",
- )
- }
-
-
-class EntityConsistency(NestedPathElasticsearchMetric):
- """Computes the entity consistency distribution"""
-
- mention_field: str
- labels_field: str
-
- def inner_aggregation(
- self,
- size: int,
- interval: int = 2,
- entity_size: int = _DEFAULT_MAX_ENTITY_BUCKET,
- ) -> Dict[str, Any]:
- size = size or 50
- interval = int(max(interval or 2, 2))
- return {
- "consistency": {
- **aggregations.terms_aggregation(
- self.compound_nested_field(self.mention_field), size=size
- ),
- "aggs": {
- "entities": aggregations.terms_aggregation(
- self.compound_nested_field(self.labels_field), size=entity_size
- ),
- "count": {
- "cardinality": {
- "field": self.compound_nested_field(self.labels_field)
- }
- },
- "entities_variability_filter": {
- "bucket_selector": {
- "buckets_path": {"numLabels": "count"},
- "script": f"params.numLabels >= {interval}",
- }
- },
- "sortby_entities_count": {
- "bucket_sort": {
- "sort": [{"count": {"order": "desc"}}],
- "size": size,
- }
- },
- },
- }
- }
-
- def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
- """Simplifies the aggregation result sorting by worst mention consistency"""
- result = [
- {
- "mention": mention,
- "entities": [
- {"label": entity, "count": count}
- for entity, count in mention_aggs["entities"].items()
- ],
- }
- for mention, mention_aggs in aggregation_result.items()
- ]
- # TODO: filter by entities threshold
- result.sort(key=lambda m: len(m["entities"]), reverse=True)
- return {"mentions": result}
-
-
-class F1Metric(PythonMetric[TokenClassificationRecord]):
- """The F1 metric based on entity-level.
-
- We follow the convention of `CoNLL 2003 `_, where:
- `"precision is the percentage of named entities found by the learning system that are correct.
- Recall is the percentage of named entities present in the corpus that are found by the system.
- A named entity is correct only if it is an exact match (...).”`
- """
-
- def apply(self, records: Iterable[TokenClassificationRecord]) -> Dict[str, Any]:
- # store entities per label in dicts
- predicted_entities = {}
- annotated_entities = {}
-
- # extract entities per label to dicts
- for rec in records:
- if rec.prediction:
- self._add_entities_to_dict(rec.prediction.entities, predicted_entities)
- if rec.annotation:
- self._add_entities_to_dict(rec.annotation.entities, annotated_entities)
-
- # store precision, recall, and f1 per label
- per_label_metrics = {}
-
- annotated_total, predicted_total, correct_total = 0, 0, 0
- precision_macro, recall_macro = 0, 0
- for label, annotated in annotated_entities.items():
- predicted = predicted_entities.get(label, set())
- correct = len(annotated & predicted)
-
- # safe divides are used to cover the 0/0 cases
- precision = self._safe_divide(correct, len(predicted))
- recall = self._safe_divide(correct, len(annotated))
- per_label_metrics.update(
- {
- f"{label}_precision": precision,
- f"{label}_recall": recall,
- f"{label}_f1": self._safe_divide(
- 2 * precision * recall, precision + recall
- ),
- }
- )
-
- annotated_total += len(annotated)
- predicted_total += len(predicted)
- correct_total += correct
-
- precision_macro += precision / len(annotated_entities)
- recall_macro += recall / len(annotated_entities)
-
- # store macro and micro averaged precision, recall and f1
- averaged_metrics = {
- "precision_macro": precision_macro,
- "recall_macro": recall_macro,
- "f1_macro": self._safe_divide(
- 2 * precision_macro * recall_macro, precision_macro + recall_macro
- ),
- }
-
- precision_micro = self._safe_divide(correct_total, predicted_total)
- recall_micro = self._safe_divide(correct_total, annotated_total)
- averaged_metrics.update(
- {
- "precision_micro": precision_micro,
- "recall_micro": recall_micro,
- "f1_micro": self._safe_divide(
- 2 * precision_micro * recall_micro, precision_micro + recall_micro
- ),
- }
- )
-
- return {**averaged_metrics, **per_label_metrics}
-
- @staticmethod
- def _add_entities_to_dict(
- entities: List[EntitySpan], dictionary: Dict[str, Set[Tuple[int, int]]]
- ):
- """Helper function for the apply method."""
- for ent in entities:
- try:
- dictionary[ent.label].add((ent.start, ent.end))
- except KeyError:
- dictionary[ent.label] = {(ent.start, ent.end)}
-
- @staticmethod
- def _safe_divide(numerator, denominator):
- """Helper function for the apply method."""
- try:
- return numerator / denominator
- except ZeroDivisionError:
- return 0
-
-
-class DatasetLabels(PythonMetric):
- id: str = Field("dataset_labels", const=True)
- name: str = Field("The dataset entity labels", const=True)
- max_processed_records: int = 10000
-
- def apply(self, records: Iterable[TokenClassificationRecord]) -> Dict[str, Any]:
- ds_labels = set()
-
- for _ in range(
- 0, self.max_processed_records
- ): # Only a few of records will be parsed
- record: TokenClassificationRecord = next(records, None)
- if record is None:
- break
-
- if record.annotation:
- ds_labels.update(
- [entity.label for entity in record.annotation.entities]
- )
- if record.prediction:
- ds_labels.update(
- [entity.label for entity in record.prediction.entities]
- )
- return {"labels": ds_labels or []}
-
-
-class MentionMetrics(BaseModel):
- """Mention metrics model"""
-
- value: str
- label: str
- score: float = Field(ge=0.0)
- capitalness: Optional[str] = Field(None)
- density: float = Field(ge=0.0)
- tokens_length: int = Field(g=0)
- chars_length: int = Field(g=0)
-
-
-class TokenTagMetrics(BaseModel):
- value: str
- tag: str
-
-
-class TokenMetrics(BaseModel):
- """
- Token metrics stored in elasticsearch for token classification
-
- Attributes
- idx: The token index in sentence
- value: The token textual value
- char_start: The token character start position in sentence
- char_end: The token character end position in sentence
- score: Token score info
- tag: Token tag info. Deprecated: Use metrics.predicted.tags or metrics.annotated.tags instead
- custom: extra token level info
- """
-
- idx: int
- value: str
- char_start: int
- char_end: int
- length: int
- capitalness: Optional[str] = None
- score: Optional[float] = None
- tag: Optional[str] = None # TODO: remove!
- custom: Dict[str, Any] = None
-
-
-class TokenClassificationMetrics(CommonTasksMetrics[TokenClassificationRecord]):
- """Configured metrics for token classification"""
-
- _PREDICTED_NAMESPACE = "metrics.predicted"
- _ANNOTATED_NAMESPACE = "metrics.annotated"
-
- _PREDICTED_MENTIONS_NAMESPACE = f"{_PREDICTED_NAMESPACE}.mentions"
- _ANNOTATED_MENTIONS_NAMESPACE = f"{_ANNOTATED_NAMESPACE}.mentions"
-
- _PREDICTED_TAGS_NAMESPACE = f"{_PREDICTED_NAMESPACE}.tags"
- _ANNOTATED_TAGS_NAMESPACE = f"{_ANNOTATED_NAMESPACE}.tags"
-
- _TOKENS_NAMESPACE = "metrics.tokens"
-
- @staticmethod
- def density(value: int, sentence_length: int) -> float:
- """Compute the string density over a sentence"""
- return value / sentence_length
-
- @staticmethod
- def capitalness(value: str) -> Optional[str]:
- """Compute capitalness for a string value"""
- value = value.strip()
- if not value:
- return None
- if value.isupper():
- return "UPPER"
- if value.islower():
- return "LOWER"
- if value[0].isupper():
- return "FIRST"
- if any([c.isupper() for c in value[1:]]):
- return "MIDDLE"
- return None
-
- @staticmethod
- def mentions_metrics(
- record: TokenClassificationRecord, mentions: List[Tuple[str, EntitySpan]]
- ):
- def mention_tokens_length(entity: EntitySpan) -> int:
- """Compute mention tokens length"""
- return len(
- set(
- [
- token_idx
- for i in range(entity.start, entity.end)
- for token_idx in [record.char_id2token_id(i)]
- if token_idx is not None
- ]
- )
- )
-
- return [
- MentionMetrics(
- value=mention,
- label=entity.label,
- score=entity.score,
- capitalness=TokenClassificationMetrics.capitalness(mention),
- density=TokenClassificationMetrics.density(
- _tokens_length, sentence_length=len(record.tokens)
- ),
- tokens_length=_tokens_length,
- chars_length=len(mention),
- )
- for mention, entity in mentions
- for _tokens_length in [
- mention_tokens_length(entity),
- ]
- ]
-
- @classmethod
- def build_tokens_metrics(
- cls, record: TokenClassificationRecord, tags: Optional[List[str]] = None
- ) -> List[TokenMetrics]:
-
- return [
- TokenMetrics(
- idx=token_idx,
- value=token_value,
- char_start=char_start,
- char_end=char_end,
- capitalness=cls.capitalness(token_value),
- length=1 + (char_end - char_start),
- tag=tags[token_idx] if tags else None,
- )
- for token_idx, token_value in enumerate(record.tokens)
- for char_start, char_end in [record.token_span(token_idx)]
- ]
-
- @classmethod
- def record_metrics(cls, record: TokenClassificationRecord) -> Dict[str, Any]:
- """Compute metrics at record level"""
- base_metrics = super(TokenClassificationMetrics, cls).record_metrics(record)
-
- annotated_tags = record.annotated_iob_tags() or []
- predicted_tags = record.predicted_iob_tags() or []
-
- tokens_metrics = cls.build_tokens_metrics(
- record, predicted_tags or annotated_tags
- )
- return {
- **base_metrics,
- "tokens": tokens_metrics,
- "tokens_length": len(record.tokens),
- "predicted": {
- "mentions": cls.mentions_metrics(record, record.predicted_mentions()),
- "tags": [
- TokenTagMetrics(tag=tag, value=token)
- for tag, token in zip(predicted_tags, record.tokens)
- ],
- },
- "annotated": {
- "mentions": cls.mentions_metrics(record, record.annotated_mentions()),
- "tags": [
- TokenTagMetrics(tag=tag, value=token)
- for tag, token in zip(annotated_tags, record.tokens)
- ],
- },
- }
-
- _TOKENS_METRICS = [
- TokensLength(
- id="tokens_length",
- name="Tokens length",
- description="Computes the text length distribution measured in number of tokens",
- length_field="metrics.tokens_length",
- ),
- NestedTermsAggregation(
- id="token_frequency",
- name="Tokens frequency distribution",
- nested_path=_TOKENS_NAMESPACE,
- terms=TermsAggregation(
- id="frequency",
- field="value",
- name="",
- ),
- ),
- NestedHistogramAggregation(
- id="token_length",
- name="Token length distribution",
- nested_path=_TOKENS_NAMESPACE,
- description="Computes token length distribution in number of characters",
- histogram=HistogramAggregation(
- id="length",
- field="length",
- name="",
- fixed_interval=1,
- ),
- ),
- NestedTermsAggregation(
- id="token_capitalness",
- name="Token capitalness distribution",
- description="Computes capitalization information of tokens",
- nested_path=_TOKENS_NAMESPACE,
- terms=TermsAggregation(
- id="capitalness",
- field="capitalness",
- name="",
- # missing="OTHER",
- ),
- ),
- ]
- _PREDICTED_METRICS = [
- EntityDensity(
- id="predicted_entity_density",
- name="Mention entity density for predictions",
- description="Computes the ratio between the number of all entity tokens and tokens in the text",
- nested_path=_PREDICTED_MENTIONS_NAMESPACE,
- density_field="density",
- ),
- EntityLabels(
- id="predicted_entity_labels",
- name="Predicted entity labels",
- description="Predicted entity labels distribution",
- nested_path=_PREDICTED_MENTIONS_NAMESPACE,
- labels_field="label",
- ),
- EntityCapitalness(
- id="predicted_entity_capitalness",
- name="Mention entity capitalness for predictions",
- description="Computes capitalization information of predicted entity mentions",
- nested_path=_PREDICTED_MENTIONS_NAMESPACE,
- capitalness_field="capitalness",
- ),
- MentionLength(
- id="predicted_mention_token_length",
- name="Predicted mention tokens length",
- description="Computes the length of the predicted entity mention measured in number of tokens",
- nested_path=_PREDICTED_MENTIONS_NAMESPACE,
- length_field="tokens_length",
- ),
- MentionLength(
- id="predicted_mention_char_length",
- name="Predicted mention characters length",
- description="Computes the length of the predicted entity mention measured in number of tokens",
- nested_path=_PREDICTED_MENTIONS_NAMESPACE,
- length_field="chars_length",
- ),
- MentionsByEntityDistribution(
- id="predicted_mentions_distribution",
- name="Predicted mentions distribution by entity",
- description="Computes predicted mentions distribution against its labels",
- nested_path=_PREDICTED_MENTIONS_NAMESPACE,
- ),
- EntityConsistency(
- id="predicted_entity_consistency",
- name="Entity label consistency for predictions",
- description="Computes entity label variability for top-k predicted entity mentions",
- nested_path=_PREDICTED_MENTIONS_NAMESPACE,
- mention_field="value",
- labels_field="label",
- ),
- EntityConsistency(
- id="predicted_tag_consistency",
- name="Token tag consistency for predictions",
- description="Computes token tag variability for top-k predicted tags",
- nested_path=_PREDICTED_TAGS_NAMESPACE,
- mention_field="value",
- labels_field="tag",
- ),
- ]
-
- _ANNOTATED_METRICS = [
- EntityDensity(
- id="annotated_entity_density",
- name="Mention entity density for annotations",
- description="Computes the ratio between the number of all entity tokens and tokens in the text",
- nested_path=_ANNOTATED_MENTIONS_NAMESPACE,
- density_field="density",
- ),
- EntityLabels(
- id="annotated_entity_labels",
- name="Annotated entity labels",
- description="Annotated Entity labels distribution",
- nested_path=_ANNOTATED_MENTIONS_NAMESPACE,
- labels_field="label",
- ),
- EntityCapitalness(
- id="annotated_entity_capitalness",
- name="Mention entity capitalness for annotations",
- description="Compute capitalization information of annotated entity mentions",
- nested_path=_ANNOTATED_MENTIONS_NAMESPACE,
- capitalness_field="capitalness",
- ),
- MentionLength(
- id="annotated_mention_token_length",
- name="Annotated mention tokens length",
- description="Computes the length of the entity mention measured in number of tokens",
- nested_path=_ANNOTATED_MENTIONS_NAMESPACE,
- length_field="tokens_length",
- ),
- MentionLength(
- id="annotated_mention_char_length",
- name="Annotated mention characters length",
- description="Computes the length of the entity mention measured in number of tokens",
- nested_path=_ANNOTATED_MENTIONS_NAMESPACE,
- length_field="chars_length",
- ),
- MentionsByEntityDistribution(
- id="annotated_mentions_distribution",
- name="Annotated mentions distribution by entity",
- description="Computes annotated mentions distribution against its labels",
- nested_path=_ANNOTATED_MENTIONS_NAMESPACE,
- ),
- EntityConsistency(
- id="annotated_entity_consistency",
- name="Entity label consistency for annotations",
- description="Computes entity label variability for top-k annotated entity mentions",
- nested_path=_ANNOTATED_MENTIONS_NAMESPACE,
- mention_field="value",
- labels_field="label",
- ),
- EntityConsistency(
- id="annotated_tag_consistency",
- name="Token tag consistency for annotations",
- description="Computes token tag variability for top-k annotated tags",
- nested_path=_ANNOTATED_TAGS_NAMESPACE,
- mention_field="value",
- labels_field="tag",
- ),
- ]
-
- metrics: ClassVar[List[BaseMetric]] = CommonTasksMetrics.metrics + [
- TermsAggregation(
- id="predicted_as",
- name="Predicted labels distribution",
- field="predicted_as",
- ),
- TermsAggregation(
- id="annotated_as",
- name="Annotated labels distribution",
- field="annotated_as",
- ),
- *_TOKENS_METRICS,
- *_PREDICTED_METRICS,
- *_ANNOTATED_METRICS,
- DatasetLabels(),
- F1Metric(
- id="F1",
- name="F1 Metric based on entity-level",
- description="F1 metrics based on entity-level (averaged and per label), "
- "where only exact matches count (CoNNL 2003).",
- ),
- ]
diff --git a/src/rubrix/server/apis/v0/models/text2text.py b/src/rubrix/server/apis/v0/models/text2text.py
index 1761e9d1ef..5497c2c4eb 100644
--- a/src/rubrix/server/apis/v0/models/text2text.py
+++ b/src/rubrix/server/apis/v0/models/text2text.py
@@ -14,51 +14,34 @@
# limitations under the License.
from datetime import datetime
-from enum import Enum
-from typing import Any, Dict, List, Optional
+from typing import Dict, List, Optional
from pydantic import BaseModel, Field, validator
from rubrix.server.apis.v0.models.commons.model import (
BaseAnnotation,
BaseRecord,
+ BaseRecordInputs,
BaseSearchResults,
- BaseSearchResultsAggregations,
- EsRecordDataFieldNames,
- PredictionStatus,
ScoreRange,
SortableField,
- TaskType,
)
-from rubrix.server.apis.v0.models.datasets import DatasetDB, UpdateDatasetRequest
-from rubrix.server.apis.v0.models.metrics.commons import CommonTasksMetrics
-from rubrix.server.services.search.model import BaseSearchQuery
-
-
-class ExtendedEsRecordDataFieldNames(str, Enum):
- text_predicted = "text_predicted"
- text_annotated = "text_annotated"
+from rubrix.server.apis.v0.models.datasets import UpdateDatasetRequest
+from rubrix.server.commons.models import PredictionStatus, TaskType
+from rubrix.server.services.metrics.models import CommonTasksMetrics
+from rubrix.server.services.search.model import (
+ ServiceBaseRecordsQuery,
+ ServiceBaseSearchResultsAggregations,
+)
+from rubrix.server.services.tasks.text2text.models import ServiceText2TextDataset
class Text2TextPrediction(BaseModel):
- """Represents a text prediction/annotation and its score"""
-
text: str
score: float = Field(default=1.0, ge=0.0, le=1.0)
class Text2TextAnnotation(BaseAnnotation):
- """
- Annotation class for text2text tasks
-
- Attributes:
- -----------
-
- sentences: str
- List of sentence predictions/annotations
-
- """
-
@validator("sentences")
def sort_sentences_by_score(cls, sentences: List[Text2TextPrediction]):
"""Sort provided sentences by score desc"""
@@ -67,168 +50,25 @@ def sort_sentences_by_score(cls, sentences: List[Text2TextPrediction]):
sentences: List[Text2TextPrediction]
-class CreationText2TextRecord(BaseRecord[Text2TextAnnotation]):
- """
- Text2Text record
-
- Attributes:
- -----------
-
- text:
- The input data text
- """
+class Text2TextRecordInputs(BaseRecordInputs[Text2TextAnnotation]):
text: str
- @classmethod
- def task(cls) -> TaskType:
- """The task type"""
- return TaskType.text2text
-
- def all_text(self) -> str:
- return self.text
-
- @property
- def predicted_as(self) -> Optional[List[str]]:
- return (
- [sentence.text for sentence in self.prediction.sentences]
- if self.prediction
- else None
- )
-
- @property
- def annotated_as(self) -> Optional[List[str]]:
- return (
- [sentence.text for sentence in self.annotation.sentences]
- if self.annotation
- else None
- )
-
- @property
- def scores(self) -> List[float]:
- """Values of prediction scores"""
- if not self.prediction:
- return []
- return [sentence.score for sentence in self.prediction.sentences]
-
- @validator("text")
- def validate_text(cls, text: Dict[str, Any]):
- """Applies validation over input text"""
- assert len(text) > 0, "No text provided"
- return text
-
-
-class Text2TextRecord(CreationText2TextRecord):
- """
- The output text2text task record
-
- Attributes:
- -----------
-
- last_updated: datetime
- Last record update (read only)
- predicted: Optional[PredictionStatus]
- The record prediction status. Optional
- """
-
- last_updated: datetime = None
- _predicted: Optional[PredictionStatus] = Field(alias="predicted")
-
- def extended_fields(self):
- return {}
-
-
-class Text2TextRecordDB(Text2TextRecord):
- """The db text2text task record"""
-
- def extended_fields(self) -> Dict[str, Any]:
- return {
- EsRecordDataFieldNames.annotated_as: self.annotated_as,
- EsRecordDataFieldNames.predicted_as: self.predicted_as,
- EsRecordDataFieldNames.annotated_by: self.annotated_by,
- EsRecordDataFieldNames.predicted_by: self.predicted_by,
- EsRecordDataFieldNames.score: self.scores,
- EsRecordDataFieldNames.words: self.all_text(),
- }
-
-
-class Text2TextBulkData(UpdateDatasetRequest):
- """
- API bulk data for text2text
-
- Attributes:
- -----------
- records: List[CreationText2TextRecord]
- The text2text record list
-
- """
-
- records: List[CreationText2TextRecord]
-
-
-class Text2TextQuery(BaseSearchQuery):
- """
- API Filters for text2text
-
- Attributes:
- -----------
- ids: Optional[List[Union[str, int]]]
- Record ids list
-
- query_text: str
- Text query over input text
-
- annotated_by: List[str]
- List of annotation agents
- predicted_by: List[str]
- List of predicted agents
-
- status: List[TaskStatus]
- List of task status
+class Text2TextRecord(Text2TextRecordInputs, BaseRecord[Text2TextAnnotation]):
+ pass
- metadata: Optional[Dict[str, Union[str, List[str]]]]
- Text query over metadata fields. Default=None
- predicted: Optional[PredictionStatus]
- The task prediction status
+class Text2TextBulkRequest(UpdateDatasetRequest):
+ records: List[Text2TextRecordInputs]
- """
+class Text2TextQuery(ServiceBaseRecordsQuery):
score: Optional[ScoreRange] = Field(default=None)
predicted: Optional[PredictionStatus] = Field(default=None, nullable=True)
-class Text2TextSearchRequest(BaseModel):
- """
- API SearchRequest request
-
- Attributes:
- -----------
-
- query: Text2TextQuery
- The search query configuration
-
- sort:
- The sort order list
- """
-
- query: Text2TextQuery = Field(default_factory=Text2TextQuery)
- sort: List[SortableField] = Field(default_factory=list)
-
-
-class Text2TextSearchAggregations(BaseSearchResultsAggregations):
- """
- Extends base aggregation with predicted and annotated text
-
- Attributes:
- -----------
- predicted_text: Dict[str, int]
- The word cloud aggregations for predicted text
- annotated_text: Dict[str, int]
- The word cloud aggregations for annotated text
- """
-
+class Text2TextSearchAggregations(ServiceBaseSearchResultsAggregations):
predicted_text: Dict[str, int] = Field(default_factory=dict)
annotated_text: Dict[str, int] = Field(default_factory=dict)
@@ -239,14 +79,14 @@ class Text2TextSearchResults(
pass
-class Text2TextDatasetDB(DatasetDB):
- task: TaskType = Field(default=TaskType.text2text, const=True)
+class Text2TextDataset(ServiceText2TextDataset):
pass
class Text2TextMetrics(CommonTasksMetrics[Text2TextRecord]):
- """
- Configured metrics for text2text task
- """
-
pass
+
+
+class Text2TextSearchRequest(BaseModel):
+ query: Text2TextQuery = Field(default_factory=Text2TextQuery)
+ sort: List[SortableField] = Field(default_factory=list)
diff --git a/src/rubrix/server/apis/v0/models/text_classification.py b/src/rubrix/server/apis/v0/models/text_classification.py
index e62d87474f..1c2d388a70 100644
--- a/src/rubrix/server/apis/v0/models/text_classification.py
+++ b/src/rubrix/server/apis/v0/models/text_classification.py
@@ -14,26 +14,39 @@
# limitations under the License.
from datetime import datetime
-from typing import Any, ClassVar, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, root_validator, validator
-from rubrix._constants import MAX_KEYWORD_LENGTH
-from rubrix.server.apis.v0.helpers import flatten_dict
from rubrix.server.apis.v0.models.commons.model import (
- BaseAnnotation,
BaseRecord,
+ BaseRecordInputs,
BaseSearchResults,
- BaseSearchResultsAggregations,
- EsRecordDataFieldNames,
- PredictionStatus,
ScoreRange,
SortableField,
- TaskStatus,
- TaskType,
)
-from rubrix.server.apis.v0.models.datasets import DatasetDB, UpdateDatasetRequest
-from rubrix.server.services.search.model import BaseSearchQuery
+from rubrix.server.apis.v0.models.datasets import UpdateDatasetRequest
+from rubrix.server.commons.models import PredictionStatus
+from rubrix.server.services.search.model import (
+ ServiceBaseRecordsQuery,
+ ServiceBaseSearchResultsAggregations,
+)
+from rubrix.server.services.tasks.text_classification.model import (
+ DatasetLabelingRulesMetricsSummary as _DatasetLabelingRulesMetricsSummary,
+)
+from rubrix.server.services.tasks.text_classification.model import (
+ LabelingRuleMetricsSummary as _LabelingRuleMetricsSummary,
+)
+from rubrix.server.services.tasks.text_classification.model import (
+ ServiceTextClassificationDataset,
+)
+from rubrix.server.services.tasks.text_classification.model import (
+ ServiceTextClassificationQuery as _TextClassificationQuery,
+)
+from rubrix.server.services.tasks.text_classification.model import (
+ TextClassificationAnnotation as _TextClassificationAnnotation,
+)
+from rubrix.server.services.tasks.text_classification.model import TokenAttributions
class UpdateLabelingRule(BaseModel):
@@ -63,22 +76,6 @@ def initialize_labels(cls, values):
class CreateLabelingRule(UpdateLabelingRule):
- """
- Data model for labeling rules creation
-
- Attributes:
- -----------
-
- query:
- The ES query of the rule
-
- label: str
- The label associated with the rule
-
- description:
- A brief description of the rule
-
- """
query: str = Field(description="The es rule query")
@@ -89,387 +86,44 @@ def strip_query(cls, query: str) -> str:
class LabelingRule(CreateLabelingRule):
- """
- Adds read-only attributes to the labeling rule
-
- Attributes:
- -----------
-
- author:
- Who created the rule
-
- created_at:
- When was the rule created
-
- """
-
author: str = Field(description="User who created the rule")
created_at: Optional[datetime] = Field(
default_factory=datetime.utcnow, description="Rule creation timestamp"
)
-class LabelingRuleMetricsSummary(BaseModel):
- """Metrics generated for a labeling rule"""
-
- coverage: Optional[float] = None
- coverage_annotated: Optional[float] = None
- correct: Optional[float] = None
- incorrect: Optional[float] = None
- precision: Optional[float] = None
-
- total_records: int
- annotated_records: int
-
-
-class DatasetLabelingRulesMetricsSummary(BaseModel):
- coverage: Optional[float] = None
- coverage_annotated: Optional[float] = None
-
- total_records: int
- annotated_records: int
-
-
-class TextClassificationDatasetDB(DatasetDB):
- """
- A dataset class specialized for text classification task
-
- Attributes:
- -----------
-
- rules:
- A list of dataset labeling rules
- """
-
- task: TaskType = Field(default=TaskType.text_classification, const=True)
-
- rules: List[LabelingRule] = Field(default_factory=list)
-
-
-class ClassPrediction(BaseModel):
- """
- Single class prediction
-
- Attributes:
- -----------
-
- class_label: Union[str, int]
- the predicted class
-
- score: float
- the predicted class score. For human-supervised annotations,
- this probability should be 1.0
- """
-
- class_label: Union[str, int] = Field(alias="class")
- score: float = Field(default=1.0, ge=0.0, le=1.0)
-
- @validator("class_label")
- def check_label_length(cls, class_label):
- if isinstance(class_label, str):
- assert 1 <= len(class_label) <= MAX_KEYWORD_LENGTH, (
- f"Class name '{class_label}' exceeds max length of {MAX_KEYWORD_LENGTH}"
- if len(class_label) > MAX_KEYWORD_LENGTH
- else f"Class name must not be empty"
- )
- return class_label
-
- # See
- class Config:
- allow_population_by_field_name = True
-
-
-class TextClassificationAnnotation(BaseAnnotation):
- """
- Annotation class for text classification tasks
-
- Attributes:
- -----------
-
- labels: List[LabelPrediction]
- list of annotated labels with score
- """
-
- labels: List[ClassPrediction]
-
- @validator("labels")
- def sort_labels(cls, labels: List[ClassPrediction]):
- """Sort provided labels by score"""
- return sorted(labels, key=lambda x: x.score, reverse=True)
-
-
-class TokenAttributions(BaseModel):
- """
- The token attributions explaining predicted labels
-
- Attributes:
- -----------
-
- token: str
- The input token
- attributions: Dict[str, float]
- A dictionary containing label class-attribution pairs
+class LabelingRuleMetricsSummary(_LabelingRuleMetricsSummary):
+ pass
- """
- token: str
- attributions: Dict[str, float] = Field(default_factory=dict)
+class DatasetLabelingRulesMetricsSummary(_DatasetLabelingRulesMetricsSummary):
+ pass
-class CreationTextClassificationRecord(BaseRecord[TextClassificationAnnotation]):
- """
- Text classification record
+class TextClassificationDataset(ServiceTextClassificationDataset):
+ pass
- Attributes:
- -----------
- inputs: Dict[str, Union[str, List[str]]]
- The input data text
+class TextClassificationAnnotation(_TextClassificationAnnotation):
+ pass
- multi_label: bool
- Enable text classification with multiple predicted/annotated labels.
- Default=False
- explanation: Dict[str, List[TokenAttributions]]
- Token attribution list explaining predicted classes per token input.
- The dictionary key must be aligned with provided record text. Optional
- """
+class TextClassificationRecordInputs(BaseRecordInputs[TextClassificationAnnotation]):
inputs: Dict[str, Union[str, List[str]]]
multi_label: bool = False
explanation: Optional[Dict[str, List[TokenAttributions]]] = None
- _SCORE_DEVIATION_ERROR: ClassVar[float] = 0.001
-
- @root_validator
- def validate_record(cls, values):
- """fastapi validator method"""
- prediction = values.get("prediction", None)
- annotation = values.get("annotation", None)
- status = values.get("status")
- multi_label = values.get("multi_label", False)
-
- cls._check_score_integrity(prediction, multi_label)
- cls._check_annotation_integrity(annotation, multi_label, status)
-
- return values
-
- @classmethod
- def _check_annotation_integrity(
- cls,
- annotation: TextClassificationAnnotation,
- multi_label: bool,
- status: TaskStatus,
- ):
- if status == TaskStatus.validated and not multi_label:
- assert (
- annotation and len(annotation.labels) > 0
- ), "Annotation must include some label for validated records"
-
- if not multi_label and annotation:
- assert (
- len(annotation.labels) == 1
- ), "Single label record must include only one annotation label"
-
- @classmethod
- def _check_score_integrity(
- cls, prediction: TextClassificationAnnotation, multi_label: bool
- ):
- """
- Checks the score value integrity
-
- Parameters
- ----------
- prediction:
- The prediction annotation
- multi_label:
- If multi label
-
- """
- if prediction and not multi_label:
- assert sum([label.score for label in prediction.labels]) <= (
- 1.0 + cls._SCORE_DEVIATION_ERROR
- ), f"Wrong score distributions: {prediction.labels}"
-
- @classmethod
- def task(cls) -> TaskType:
- """The task type"""
- return TaskType.text_classification
-
- @property
- def predicted(self) -> Optional[PredictionStatus]:
- if self.predicted_by and self.annotated_by:
- return (
- PredictionStatus.OK
- if set(self.predicted_as) == set(self.annotated_as)
- else PredictionStatus.KO
- )
- return None
-
- @property
- def predicted_as(self) -> List[str]:
- return self._labels_from_annotation(
- self.prediction, multi_label=self.multi_label
- )
-
- @property
- def annotated_as(self) -> List[str]:
- return self._labels_from_annotation(
- self.annotation, multi_label=self.multi_label
- )
-
- @property
- def scores(self) -> List[float]:
- """Values of prediction scores"""
- if not self.prediction:
- return []
- return (
- [label.score for label in self.prediction.labels]
- if self.multi_label
- else [
- prediction_class.score
- for prediction_class in [
- self._max_class_prediction(
- self.prediction, multi_label=self.multi_label
- )
- ]
- if prediction_class
- ]
- )
-
- def all_text(self) -> str:
- sentences = []
- for v in self.inputs.values():
- if isinstance(v, list):
- sentences.extend(v)
- else:
- sentences.append(v)
- return "\n".join(sentences)
-
- @validator("inputs")
- def validate_inputs(cls, text: Dict[str, Any]):
- """Applies validation over input text"""
- assert len(text) > 0, "No inputs provided"
-
- for t in text.values():
- assert t is not None, "Cannot include None fields"
-
- return text
-
- @validator("inputs")
- def flatten_text(cls, text: Dict[str, Any]):
- """Normalizes input text to dict of strings"""
- flat_dict = flatten_dict(text)
- return flat_dict
-
- @classmethod
- def _labels_from_annotation(
- cls, annotation: TextClassificationAnnotation, multi_label: bool
- ) -> Union[List[str], List[int]]:
- """
- Extracts labels values from annotation
-
- Parameters
- ----------
- annotation:
- The annotation
- multi_label
- Enable/Disable multi label model
-
- Returns
- -------
- Label values for a given annotation
-
- """
- if not annotation:
- return []
-
- if multi_label:
- return [
- label.class_label for label in annotation.labels if label.score > 0.5
- ]
-
- class_prediction = cls._max_class_prediction(
- annotation, multi_label=multi_label
- )
- if class_prediction is None:
- return []
-
- return [class_prediction.class_label]
-
- @staticmethod
- def _max_class_prediction(
- p: TextClassificationAnnotation, multi_label: bool
- ) -> Optional[ClassPrediction]:
- """
- Gets the max class prediction for annotation
-
- Parameters
- ----------
- p:
- The annotation
- multi_label:
- Enable/Disable multi_label mode
-
- Returns
- -------
-
- The max class prediction in terms of prediction score if
- prediction has labels and no multi label is enabled. None, otherwise
- """
- if multi_label or p is None or not p.labels:
- return None
- return p.labels[0]
-
- class Config:
- allow_population_by_field_name = True
-
-
-class TextClassificationRecordDB(CreationTextClassificationRecord):
- """
- The main text classification task record
-
- Attributes:
- -----------
-
- last_updated: datetime
- Last record update (read only)
- predicted: Optional[PredictionStatus]
- The record prediction status. Optional
- """
-
- last_updated: datetime = None
- _predicted: Optional[PredictionStatus] = Field(alias="predicted")
-
- def extended_fields(self) -> Dict[str, Any]:
- words = self.all_text()
- return {
- **super().extended_fields(),
- EsRecordDataFieldNames.words: words,
- # This allow query by text:.... or text.exact:....
- # Once words is remove we can normalize at record level
- "text": words,
- }
+class TextClassificationRecord(
+ TextClassificationRecordInputs, BaseRecord[TextClassificationAnnotation]
+):
+ pass
-class TextClassificationRecord(TextClassificationRecordDB):
- def extended_fields(self) -> Dict[str, Any]:
- return {}
-
-
-class TextClassificationBulkData(UpdateDatasetRequest):
- """
- API bulk data for text classification
-
- Attributes:
- -----------
-
- records: List[CreationTextClassificationRecord]
- The text classification record list
- """
+class TextClassificationBulkRequest(UpdateDatasetRequest):
- records: List[CreationTextClassificationRecord]
+ records: List[TextClassificationRecordInputs]
@validator("records")
def check_multi_label_integrity(cls, records: List[TextClassificationRecord]):
@@ -483,26 +137,7 @@ def check_multi_label_integrity(cls, records: List[TextClassificationRecord]):
return records
-class TextClassificationQuery(BaseSearchQuery):
- """
- API Filters for text classification
-
- Attributes:
- -----------
-
- predicted_as: List[str]
- List of predicted terms
-
- annotated_as: List[str]
- List of annotated terms
-
- predicted: Optional[PredictionStatus]
- The task prediction status
-
- uncovered_by_rules:
- Only return records that are NOT covered by these rules.
-
- """
+class TextClassificationQuery(ServiceBaseRecordsQuery):
predicted_as: List[str] = Field(default_factory=list)
annotated_as: List[str] = Field(default_factory=list)
@@ -515,25 +150,7 @@ class TextClassificationQuery(BaseSearchQuery):
)
-class TextClassificationSearchRequest(BaseModel):
- """
- API SearchRequest request
-
- Attributes:
- -----------
-
- query: TextClassificationQuery
- The search query configuration
-
- sort:
- The sort order list
- """
-
- query: TextClassificationQuery = Field(default_factory=TextClassificationQuery)
- sort: List[SortableField] = Field(default_factory=list)
-
-
-class TextClassificationSearchAggregations(BaseSearchResultsAggregations):
+class TextClassificationSearchAggregations(ServiceBaseSearchResultsAggregations):
pass
@@ -541,3 +158,8 @@ class TextClassificationSearchResults(
BaseSearchResults[TextClassificationRecord, TextClassificationSearchAggregations]
):
pass
+
+
+class TextClassificationSearchRequest(BaseModel):
+ query: TextClassificationQuery = Field(default_factory=TextClassificationQuery)
+ sort: List[SortableField] = Field(default_factory=list)
diff --git a/src/rubrix/server/apis/v0/models/token_classification.py b/src/rubrix/server/apis/v0/models/token_classification.py
index 0dac97527d..948fc559ab 100644
--- a/src/rubrix/server/apis/v0/models/token_classification.py
+++ b/src/rubrix/server/apis/v0/models/token_classification.py
@@ -12,104 +12,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import defaultdict
-from datetime import datetime
-from typing import Any, Dict, List, Optional, Set, Tuple
+from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, root_validator, validator
-from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.apis.v0.models.commons.model import (
- BaseAnnotation,
BaseRecord,
+ BaseRecordInputs,
BaseSearchResults,
- BaseSearchResultsAggregations,
- EsRecordDataFieldNames,
- PredictionStatus,
ScoreRange,
- SortableField,
- TaskType,
)
-from rubrix.server.apis.v0.models.datasets import DatasetDB, UpdateDatasetRequest
-from rubrix.server.services.search.model import BaseSearchQuery
-
-PREDICTED_MENTIONS_ES_FIELD_NAME = "predicted_mentions"
-MENTIONS_ES_FIELD_NAME = "mentions"
-
-
-class EntitySpan(BaseModel):
- """
- The tokens span for a labeled text.
-
- Entity spans will be defined between from start to end - 1
-
- Attributes:
- -----------
-
- start: int
- character start position
- end: int
- character end position, must be higher than the starting character.
- label: str
- the label related to tokens that conforms the entity span
- score:
- A higher score means, the model/annotator is more confident about its predicted/annotated entity.
- """
-
- start: int
- end: int
- label: str = Field(min_length=1, max_length=MAX_KEYWORD_LENGTH)
- score: float = Field(default=1.0, ge=0.0, le=1.0)
-
- @validator("end")
- def check_span_offset(cls, end: int, values):
- """Validates span offset"""
- assert (
- end > values["start"]
- ), "End character cannot be placed before the starting character, it must be at least one character after."
- return end
-
- def __hash__(self):
- return hash(type(self)) + hash(self.__dict__.values())
-
-
-class TokenClassificationAnnotation(BaseAnnotation):
- """Annotation class for the Token classification task.
-
- Attributes:
- -----------
- entities: List[EntitiesSpan]
- a list of detected entities spans in tokenized text, if any.
- score: float
- score related to annotated entities. The higher is score value, the
- more likely is that entities were properly annotated.
- """
-
- entities: List[EntitySpan] = Field(default_factory=list)
- score: Optional[float] = None
-
+from rubrix.server.apis.v0.models.datasets import UpdateDatasetRequest
+from rubrix.server.commons.models import PredictionStatus
+from rubrix.server.daos.backend.search.model import SortableField
+from rubrix.server.services.search.model import (
+ ServiceBaseRecordsQuery,
+ ServiceBaseSearchResultsAggregations,
+)
+from rubrix.server.services.tasks.token_classification.model import (
+ ServiceTokenClassificationAnnotation as _TokenClassificationAnnotation,
+)
+from rubrix.server.services.tasks.token_classification.model import (
+ ServiceTokenClassificationDataset,
+)
-class CreationTokenClassificationRecord(BaseRecord[TokenClassificationAnnotation]):
- """
- Dataset record for token classification task
- Attributes:
- -----------
+class TokenClassificationAnnotation(_TokenClassificationAnnotation):
+ pass
- tokens: List[str]
- The input tokens
- text: str
- Textual representation of token list
- """
+class TokenClassificationRecordInputs(BaseRecordInputs[TokenClassificationAnnotation]):
- tokens: List[str] = Field(min_items=1)
text: str = Field()
+ tokens: List[str] = Field(min_items=1)
+ # TODO(@frascuchon): Delete this field and all related logic
_raw_text: Optional[str] = Field(alias="raw_text")
- __chars2tokens__: Dict[int, int] = None
- __tokens2chars__: Dict[int, Tuple[int, int]] = None
-
@root_validator(pre=True)
def accept_old_fashion_text_field(cls, values):
text, raw_text = values.get("text"), values.get("raw_text")
@@ -118,290 +56,23 @@ def accept_old_fashion_text_field(cls, values):
return values
- def __init__(self, **data):
- super().__init__(**data)
-
- self.__chars2tokens__, self.__tokens2chars__ = self.__build_indices_map__()
-
- self.check_annotation(self.prediction)
- self.check_annotation(self.annotation)
-
- def char_id2token_id(self, char_idx: int) -> Optional[int]:
- return self.__chars2tokens__.get(char_idx)
-
- def token_span(self, token_idx: int) -> Tuple[int, int]:
- if token_idx not in self.__tokens2chars__:
- raise IndexError(f"Token id {token_idx} out of bounds")
- return self.__tokens2chars__[token_idx]
-
@validator("text")
def check_text_content(cls, text: str):
assert text and text.strip(), "No text or empty text provided"
return text
- def __build_indices_map__(
- self,
- ) -> Tuple[Dict[int, int], Dict[int, Tuple[int, int]]]:
- """
- Build the indices mapping between text characters and tokens where belongs to,
- and vice versa.
-
- chars2tokens index contains is the token idx where i char is contained (if any).
-
- Out-of-token characters won't be included in this map,
- so access should be using ``chars2tokens_map.get(i)``
- instead of ``chars2tokens_map[i]``.
-
- """
-
- def chars2tokens_index():
- def is_space_after_token(char, idx: int, chars_map) -> str:
- return char == " " and idx - 1 in chars_map
-
- chars_map = {}
- current_token = 0
- current_token_char_start = 0
- for idx, char in enumerate(self.text):
- if is_space_after_token(char, idx, chars_map):
- continue
- relative_idx = idx - current_token_char_start
- if (
- relative_idx < len(self.tokens[current_token])
- and char == self.tokens[current_token][relative_idx]
- ):
- chars_map[idx] = current_token
- elif (
- current_token + 1 < len(self.tokens)
- and relative_idx >= len(self.tokens[current_token])
- and char == self.tokens[current_token + 1][0]
- ):
- current_token += 1
- current_token_char_start += relative_idx
- chars_map[idx] = current_token
-
- return chars_map
-
- def tokens2chars_index(
- chars2tokens: Dict[int, int]
- ) -> Dict[int, Tuple[int, int]]:
- tokens2chars_map = defaultdict(list)
- for c, t in chars2tokens.items():
- tokens2chars_map[t].append(c)
-
- return {
- token_idx: (min(chars), max(chars))
- for token_idx, chars in tokens2chars_map.items()
- }
-
- chars2tokens_idx = chars2tokens_index()
- return chars2tokens_idx, tokens2chars_index(chars2tokens_idx)
-
- def check_annotation(
- self,
- annotation: Optional[TokenClassificationAnnotation],
- ):
- """Validates entities in terms of offset spans"""
-
- def adjust_span_bounds(start, end):
- if start < 0:
- start = 0
- if entity.end > len(self.text):
- end = len(self.text)
- while start <= len(self.text) and not self.text[start].strip():
- start += 1
- while not self.text[end - 1].strip():
- end -= 1
- return start, end
-
- if annotation:
- for entity in annotation.entities:
- entity.start, entity.end = adjust_span_bounds(entity.start, entity.end)
- mention = self.text[entity.start : entity.end]
- assert len(mention) > 0, f"Empty offset defined for entity {entity}"
-
- token_start = self.char_id2token_id(entity.start)
- token_end = self.char_id2token_id(entity.end - 1)
-
- assert not (
- token_start is None or token_end is None
- ), f"Provided entity span {self.text[entity.start: entity.end]} is not aligned with provided tokens."
- "Some entity chars could be reference characters out of tokens"
-
- span_start, _ = self.token_span(token_start)
- _, span_end = self.token_span(token_end)
-
- assert (
- self.text[span_start : span_end + 1] == mention
- ), f"Defined offset [{self.text[entity.start: entity.end]}] is a misaligned entity mention"
-
- def task(cls) -> TaskType:
- """The record task type"""
- return TaskType.token_classification
-
- @property
- def predicted(self) -> Optional[PredictionStatus]:
- if self.annotation and self.prediction:
- return (
- PredictionStatus.OK
- if self.annotation.entities == self.prediction.entities
- else PredictionStatus.KO
- )
- return None
-
- @property
- def predicted_as(self) -> List[str]:
- return [ent.label for ent in self.predicted_entities()]
-
- @property
- def annotated_as(self) -> List[str]:
- return [ent.label for ent in self.annotated_entities()]
-
- @property
- def scores(self) -> List[float]:
- if not self.prediction:
- return []
- if self.prediction.score is not None:
- return [self.prediction.score]
- return [e.score for e in self.prediction.entities]
-
- def all_text(self) -> str:
- return self.text
-
- def predicted_iob_tags(self) -> Optional[List[str]]:
- if self.prediction is None:
- return None
- return self.spans2iob(self.prediction.entities)
-
- def annotated_iob_tags(self) -> Optional[List[str]]:
- if self.annotation is None:
- return None
- return self.spans2iob(self.annotation.entities)
-
- def spans2iob(self, spans: List[EntitySpan]) -> Optional[List[str]]:
- if spans is None:
- return None
- tags = ["O"] * len(self.tokens)
- for entity in spans:
- token_start = self.char_id2token_id(entity.start)
- token_end = self.char_id2token_id(entity.end - 1)
- tags[token_start] = f"B-{entity.label}"
- for idx in range(token_start + 1, token_end + 1):
- tags[idx] = f"I-{entity.label}"
-
- return tags
-
- def predicted_mentions(self) -> List[Tuple[str, EntitySpan]]:
- return [
- (mention, entity)
- for mention, entity in self.__mentions_from_entities__(
- self.predicted_entities()
- ).items()
- ]
-
- def annotated_mentions(self) -> List[Tuple[str, EntitySpan]]:
- return [
- (mention, entity)
- for mention, entity in self.__mentions_from_entities__(
- self.annotated_entities()
- ).items()
- ]
-
- def annotated_entities(self) -> Set[EntitySpan]:
- """Shortcut for real annotated entities, if provided"""
- if self.annotation is None:
- return set()
- return set(self.annotation.entities)
-
- def predicted_entities(self) -> Set[EntitySpan]:
- """Predicted entities"""
- if self.prediction is None:
- return set()
- return set(self.prediction.entities)
-
- def __mentions_from_entities__(
- self, entities: Set[EntitySpan]
- ) -> Dict[str, EntitySpan]:
- return {
- mention: entity
- for entity in entities
- for mention in [self.text[entity.start : entity.end]]
- }
-
- class Config:
- allow_population_by_field_name = True
- underscore_attrs_are_private = True
-
-
-class TokenClassificationRecordDB(CreationTokenClassificationRecord):
- """
- The main token classification task record
-
- Attributes:
- -----------
-
- last_updated: datetime
- Last record update (read only)
- predicted: Optional[PredictionStatus]
- The record prediction status. Optional
- """
-
- last_updated: datetime = None
- _predicted: Optional[PredictionStatus] = Field(alias="predicted")
-
- def extended_fields(self) -> Dict[str, Any]:
-
- return {
- **super().extended_fields(),
- # See ../service/service.py
- PREDICTED_MENTIONS_ES_FIELD_NAME: [
- {"mention": mention, "entity": entity.label, "score": entity.score}
- for mention, entity in self.predicted_mentions()
- ],
- MENTIONS_ES_FIELD_NAME: [
- {"mention": mention, "entity": entity.label}
- for mention, entity in self.annotated_mentions()
- ],
- EsRecordDataFieldNames.words: self.all_text(),
- }
-
-
-class TokenClassificationRecord(TokenClassificationRecordDB):
- def extended_fields(self) -> Dict[str, Any]:
- return {
- "raw_text": self.text, # Maintain results compatibility
- }
-
-
-class TokenClassificationBulkData(UpdateDatasetRequest):
- """
- API bulk data for text classification
-
- Attributes:
- -----------
-
- records: List[TextClassificationRecord]
- The text classification record list
-
- """
-
- records: List[CreationTokenClassificationRecord]
+class TokenClassificationRecord(
+ TokenClassificationRecordInputs, BaseRecord[TokenClassificationAnnotation]
+):
+ pass
-class TokenClassificationQuery(BaseSearchQuery):
- """
- API Filters for text classification
- Attributes:
- -----------
+class TokenClassificationBulkRequest(UpdateDatasetRequest):
+ records: List[TokenClassificationRecordInputs]
- predicted_as: List[str]
- List of predicted terms
- annotated_as: List[str]
- List of annotated terms
- predicted: Optional[PredictionStatus]
- The task prediction status
- """
+class TokenClassificationQuery(ServiceBaseRecordsQuery):
predicted_as: List[str] = Field(default_factory=list)
annotated_as: List[str] = Field(default_factory=list)
@@ -410,35 +81,11 @@ class TokenClassificationQuery(BaseSearchQuery):
class TokenClassificationSearchRequest(BaseModel):
-
- """
- API SearchRequest request
-
- Attributes:
- -----------
-
- query: TokenClassificationQuery
- The search query configuration
- sort:
- The sort by order in search results
- """
-
query: TokenClassificationQuery = Field(default_factory=TokenClassificationQuery)
sort: List[SortableField] = Field(default_factory=list)
-class TokenClassificationAggregations(BaseSearchResultsAggregations):
- """
- Extends base aggregation with mentions
-
- Attributes:
- -----------
- mentions: Dict[str,Dict[str,int]]
- The annotated entity spans
- predicted_mentions: Dict[str,Dict[str,int]]
- The prediction entity spans
- """
-
+class TokenClassificationAggregations(ServiceBaseSearchResultsAggregations):
predicted_mentions: Dict[str, Dict[str, int]] = Field(default_factory=dict)
mentions: Dict[str, Dict[str, int]] = Field(default_factory=dict)
@@ -449,6 +96,5 @@ class TokenClassificationSearchResults(
pass
-class TokenClassificationDatasetDB(DatasetDB):
- task: TaskType = Field(default=TaskType.token_classification, const=True)
+class TokenClassificationDataset(ServiceTokenClassificationDataset):
pass
diff --git a/src/rubrix/server/apis/v0/validators/text_classification.py b/src/rubrix/server/apis/v0/validators/text_classification.py
index f2bee61bee..0bb4f5db97 100644
--- a/src/rubrix/server/apis/v0/validators/text_classification.py
+++ b/src/rubrix/server/apis/v0/validators/text_classification.py
@@ -2,27 +2,27 @@
from fastapi import Depends
-from rubrix.server.apis.v0.models.commons.model import TaskType
from rubrix.server.apis.v0.models.dataset_settings import TextClassificationSettings
from rubrix.server.apis.v0.models.datasets import Dataset
-from rubrix.server.apis.v0.models.metrics.text_classification import DatasetLabels
-from rubrix.server.apis.v0.models.text_classification import (
- CreationTextClassificationRecord,
- TextClassificationRecord,
-)
+from rubrix.server.commons.models import TaskType
from rubrix.server.errors import BadRequestError, EntityNotFoundError
from rubrix.server.security.model import User
-from rubrix.server.services.datasets import DatasetsService, SVCDatasetSettings
+from rubrix.server.services.datasets import DatasetsService, ServiceBaseDatasetSettings
+from rubrix.server.services.tasks.text_classification.metrics import DatasetLabels
-__svc_settings_class__: Type[SVCDatasetSettings] = type(
+__svc_settings_class__: Type[ServiceBaseDatasetSettings] = type(
f"{TaskType.text_classification}_DatasetSettings",
- (SVCDatasetSettings, TextClassificationSettings),
+ (ServiceBaseDatasetSettings, TextClassificationSettings),
{},
)
from rubrix.server.services.metrics import MetricsService
+from rubrix.server.services.tasks.text_classification.model import (
+ ServiceTextClassificationRecord,
+)
+# TODO(@frascuchon): Move validator and its models to the service layer
class DatasetValidator:
_INSTANCE = None
@@ -48,7 +48,7 @@ async def validate_dataset_settings(
results = self.__metrics__.summarize_metric(
dataset=dataset,
metric=DatasetLabels(),
- record_class=TextClassificationRecord,
+ record_class=ServiceTextClassificationRecord,
query=None,
)
if results:
@@ -67,7 +67,7 @@ async def validate_dataset_records(
self,
user: User,
dataset: Dataset,
- records: Optional[List[CreationTextClassificationRecord]] = None,
+ records: Optional[List[ServiceTextClassificationRecord]] = None,
):
try:
settings: TextClassificationSettings = await self.__datasets__.get_settings(
diff --git a/src/rubrix/server/apis/v0/validators/token_classification.py b/src/rubrix/server/apis/v0/validators/token_classification.py
index 947c2ce1dc..5e49f68f3e 100644
--- a/src/rubrix/server/apis/v0/validators/token_classification.py
+++ b/src/rubrix/server/apis/v0/validators/token_classification.py
@@ -2,28 +2,28 @@
from fastapi import Depends
-from rubrix.server.apis.v0.models.commons.model import TaskType
from rubrix.server.apis.v0.models.dataset_settings import TokenClassificationSettings
from rubrix.server.apis.v0.models.datasets import Dataset
-from rubrix.server.apis.v0.models.metrics.token_classification import DatasetLabels
-from rubrix.server.apis.v0.models.token_classification import (
- CreationTokenClassificationRecord,
- TokenClassificationAnnotation,
- TokenClassificationRecord,
-)
+from rubrix.server.commons.models import TaskType
from rubrix.server.errors import BadRequestError, EntityNotFoundError
from rubrix.server.security.model import User
-from rubrix.server.services.datasets import DatasetsService, SVCDatasetSettings
+from rubrix.server.services.datasets import DatasetsService, ServiceBaseDatasetSettings
+from rubrix.server.services.tasks.token_classification.metrics import DatasetLabels
-__svc_settings_class__: Type[SVCDatasetSettings] = type(
+__svc_settings_class__: Type[ServiceBaseDatasetSettings] = type(
f"{TaskType.token_classification}_DatasetSettings",
- (SVCDatasetSettings, TokenClassificationSettings),
+ (ServiceBaseDatasetSettings, TokenClassificationSettings),
{},
)
from rubrix.server.services.metrics import MetricsService
+from rubrix.server.services.tasks.token_classification.model import (
+ ServiceTokenClassificationAnnotation,
+ ServiceTokenClassificationRecord,
+)
+# TODO(@frascuchon): Move validator and its models to the service layer
class DatasetValidator:
_INSTANCE = None
@@ -48,7 +48,7 @@ async def validate_dataset_settings(
results = self.__metrics__.summarize_metric(
dataset=dataset,
metric=DatasetLabels(),
- record_class=TokenClassificationRecord,
+ record_class=ServiceTokenClassificationRecord,
query=None,
)
if results:
@@ -67,7 +67,7 @@ async def validate_dataset_records(
self,
user: User,
dataset: Dataset,
- records: List[CreationTokenClassificationRecord],
+ records: List[ServiceTokenClassificationRecord],
):
try:
settings: TokenClassificationSettings = (
@@ -90,7 +90,7 @@ async def validate_dataset_records(
@staticmethod
def __check_label_entities__(
- label_schema: Set[str], annotation: TokenClassificationAnnotation
+ label_schema: Set[str], annotation: ServiceTokenClassificationAnnotation
):
if not annotation:
return
diff --git a/src/rubrix/server/apis/v0/config/__init__.py b/src/rubrix/server/commons/__init__.py
similarity index 100%
rename from src/rubrix/server/apis/v0/config/__init__.py
rename to src/rubrix/server/commons/__init__.py
diff --git a/src/rubrix/server/commons/config.py b/src/rubrix/server/commons/config.py
new file mode 100644
index 0000000000..6de562d06b
--- /dev/null
+++ b/src/rubrix/server/commons/config.py
@@ -0,0 +1,96 @@
+from typing import List, Optional, Set, Type
+
+from pydantic import BaseModel
+
+from rubrix.server.commons.models import TaskType
+from rubrix.server.errors import EntityNotFoundError, WrongTaskError
+from rubrix.server.services.datasets import ServiceDataset
+from rubrix.server.services.metrics import ServiceBaseMetric
+from rubrix.server.services.metrics.models import ServiceBaseTaskMetrics
+from rubrix.server.services.search.model import ServiceRecordsQuery
+from rubrix.server.services.tasks.commons import ServiceRecord
+
+
+class TaskConfig(BaseModel):
+ task: TaskType
+ query: Type[ServiceRecordsQuery]
+ dataset: Type[ServiceDataset]
+ record: Type[ServiceRecord]
+ metrics: Optional[Type[ServiceBaseTaskMetrics]]
+
+
+class TasksFactory:
+
+ __REGISTERED_TASKS__ = dict()
+
+ @classmethod
+ def register_task(
+ cls,
+ task_type: TaskType,
+ dataset_class: Type[ServiceDataset],
+ query_request: Type[ServiceRecordsQuery],
+ record_class: Type[ServiceRecord],
+ metrics: Optional[Type[ServiceBaseTaskMetrics]] = None,
+ ):
+
+ cls.__REGISTERED_TASKS__[task_type] = TaskConfig(
+ task=task_type,
+ dataset=dataset_class,
+ query=query_request,
+ record=record_class,
+ metrics=metrics,
+ )
+
+ @classmethod
+ def get_all_configs(cls) -> List[TaskConfig]:
+ return [cfg for cfg in cls.__REGISTERED_TASKS__.values()]
+
+ @classmethod
+ def get_task_by_task_type(cls, task_type: TaskType) -> Optional[TaskConfig]:
+ return cls.__REGISTERED_TASKS__.get(task_type)
+
+ @classmethod
+ def get_task_metrics(cls, task: TaskType) -> Optional[Type[ServiceBaseTaskMetrics]]:
+ config = cls.get_task_by_task_type(task)
+ if config:
+ return config.metrics
+
+ @classmethod
+ def get_task_dataset(cls, task: TaskType) -> Type[ServiceDataset]:
+ config = cls.__get_task_config__(task)
+ return config.dataset
+
+ @classmethod
+ def get_task_record(cls, task: TaskType) -> Type[ServiceRecord]:
+ config = cls.__get_task_config__(task)
+ return config.record
+
+ @classmethod
+ def __get_task_config__(cls, task):
+ config = cls.get_task_by_task_type(task)
+ if not config:
+ raise WrongTaskError(f"No configuration found for task {task}")
+ return config
+
+ @classmethod
+ def find_task_metric(
+ cls, task: TaskType, metric_id: str
+ ) -> Optional[ServiceBaseMetric]:
+ metrics = cls.find_task_metrics(task, {metric_id})
+ if metrics:
+ return metrics[0]
+ raise EntityNotFoundError(name=metric_id, type=ServiceBaseMetric)
+
+ @classmethod
+ def find_task_metrics(
+ cls, task: TaskType, metric_ids: Set[str]
+ ) -> List[ServiceBaseMetric]:
+
+ if not metric_ids:
+ return []
+
+ metrics = []
+ for metric in cls.get_task_metrics(task).metrics:
+ if metric.id in metric_ids:
+ metrics.append(metric)
+ return metrics
diff --git a/src/rubrix/server/commons/models.py b/src/rubrix/server/commons/models.py
new file mode 100644
index 0000000000..31e2444863
--- /dev/null
+++ b/src/rubrix/server/commons/models.py
@@ -0,0 +1,21 @@
+from enum import Enum
+
+
+class TaskStatus(str, Enum):
+ default = "Default"
+ edited = "Edited" # TODO: DEPRECATE
+ discarded = "Discarded"
+ validated = "Validated"
+
+
+class TaskType(str, Enum):
+
+ text_classification = "TextClassification"
+ token_classification = "TokenClassification"
+ text2text = "Text2Text"
+ multi_task_text_token_classification = "MultitaskTextTokenClassification"
+
+
+class PredictionStatus(str, Enum):
+ OK = "ok"
+ KO = "ko"
diff --git a/src/rubrix/server/apis/v0/models/metrics/__init__.py b/src/rubrix/server/daos/backend/__init__.py
similarity index 100%
rename from src/rubrix/server/apis/v0/models/metrics/__init__.py
rename to src/rubrix/server/daos/backend/__init__.py
diff --git a/src/rubrix/server/elasticseach/__init__.py b/src/rubrix/server/daos/backend/mappings/__init__.py
similarity index 100%
rename from src/rubrix/server/elasticseach/__init__.py
rename to src/rubrix/server/daos/backend/mappings/__init__.py
diff --git a/src/rubrix/server/elasticseach/mappings/datasets.py b/src/rubrix/server/daos/backend/mappings/datasets.py
similarity index 87%
rename from src/rubrix/server/elasticseach/mappings/datasets.py
rename to src/rubrix/server/daos/backend/mappings/datasets.py
index 125343f0fe..63fbc0ecda 100644
--- a/src/rubrix/server/elasticseach/mappings/datasets.py
+++ b/src/rubrix/server/daos/backend/mappings/datasets.py
@@ -1,9 +1,9 @@
from rubrix.server.settings import settings
DATASETS_INDEX_NAME = settings.dataset_index_name
-DATASETS_RECORDS_INDEX_NAME = settings.dataset_records_index_name
-
+# TODO(@frascuchon): Define an mapping definition instead and
+# use it when datasets index is created
DATASETS_INDEX_TEMPLATE = {
"index_patterns": [DATASETS_INDEX_NAME],
"settings": {"number_of_shards": 1},
diff --git a/src/rubrix/server/elasticseach/mappings/helpers.py b/src/rubrix/server/daos/backend/mappings/helpers.py
similarity index 97%
rename from src/rubrix/server/elasticseach/mappings/helpers.py
rename to src/rubrix/server/daos/backend/mappings/helpers.py
index 4a0715cab2..0a7f75bc20 100644
--- a/src/rubrix/server/elasticseach/mappings/helpers.py
+++ b/src/rubrix/server/daos/backend/mappings/helpers.py
@@ -164,6 +164,8 @@ def tasks_common_mappings():
"id": mappings.keyword_field(),
"words": mappings.words_text_field(),
"text": mappings.text_field(),
+ # TODO(@frascuchon): Enable prediction and annotation
+ # so we can build extra metrics based on these fields
"prediction": {"type": "object", "enabled": False},
"annotation": {"type": "object", "enabled": False},
"status": mappings.keyword_field(),
diff --git a/src/rubrix/server/elasticseach/mappings/text2text.py b/src/rubrix/server/daos/backend/mappings/text2text.py
similarity index 69%
rename from src/rubrix/server/elasticseach/mappings/text2text.py
rename to src/rubrix/server/daos/backend/mappings/text2text.py
index 8f2ec4ffcf..91bfc6880d 100644
--- a/src/rubrix/server/elasticseach/mappings/text2text.py
+++ b/src/rubrix/server/daos/backend/mappings/text2text.py
@@ -1,4 +1,4 @@
-from rubrix.server.elasticseach.mappings.helpers import mappings
+from rubrix.server.daos.backend.mappings.helpers import mappings
def text2text_mappings():
@@ -16,10 +16,6 @@ def text2text_mappings():
]
),
"properties": {
- # TODO: we will include this breaking changes 2 releases after
- # PR https://github.com/recognai/rubrix/pull/1018
- # "annotated_as": mappings.text_field(),
- # "predicted_as": mappings.text_field(),
"annotated_as": mappings.keyword_field(),
"predicted_as": mappings.keyword_field(),
"text_predicted": mappings.words_text_field(),
diff --git a/src/rubrix/server/elasticseach/mappings/text_classification.py b/src/rubrix/server/daos/backend/mappings/text_classification.py
similarity index 95%
rename from src/rubrix/server/elasticseach/mappings/text_classification.py
rename to src/rubrix/server/daos/backend/mappings/text_classification.py
index 95dc19942e..bcce2023e1 100644
--- a/src/rubrix/server/elasticseach/mappings/text_classification.py
+++ b/src/rubrix/server/daos/backend/mappings/text_classification.py
@@ -1,4 +1,4 @@
-from rubrix.server.elasticseach.mappings.helpers import mappings
+from rubrix.server.daos.backend.mappings.helpers import mappings
def text_classification_mappings():
diff --git a/src/rubrix/server/elasticseach/mappings/token_classification.py b/src/rubrix/server/daos/backend/mappings/token_classification.py
similarity index 68%
rename from src/rubrix/server/elasticseach/mappings/token_classification.py
rename to src/rubrix/server/daos/backend/mappings/token_classification.py
index d72302ec3e..664c317e31 100644
--- a/src/rubrix/server/elasticseach/mappings/token_classification.py
+++ b/src/rubrix/server/daos/backend/mappings/token_classification.py
@@ -1,10 +1,39 @@
-from rubrix.server.apis.v0.models.metrics.token_classification import (
- MentionMetrics,
- TokenMetrics,
- TokenTagMetrics,
-)
-from rubrix.server.elasticseach.mappings.helpers import mappings
-from rubrix.server.elasticseach.query_helpers import nested_mappings_from_base_model
+from typing import Any, Dict, Optional
+
+from pydantic import BaseModel, Field
+
+from rubrix.server.daos.backend.mappings.helpers import mappings
+from rubrix.server.daos.backend.query_helpers import nested_mappings_from_base_model
+
+
+class MentionMetrics(BaseModel):
+ """Mention metrics model"""
+
+ value: str
+ label: str
+ score: float = Field(ge=0.0)
+ capitalness: Optional[str] = Field(None)
+ density: float = Field(ge=0.0)
+ tokens_length: int = Field(g=0)
+ chars_length: int = Field(g=0)
+
+
+class TokenTagMetrics(BaseModel):
+ value: str
+ tag: str
+
+
+class TokenMetrics(BaseModel):
+
+ idx: int
+ value: str
+ char_start: int
+ char_end: int
+ length: int
+ capitalness: Optional[str] = None
+ score: Optional[float] = None
+ tag: Optional[str] = None # TODO: remove!
+ custom: Dict[str, Any] = None
def mentions_mappings():
diff --git a/src/rubrix/server/daos/backend/metrics/__init__.py b/src/rubrix/server/daos/backend/metrics/__init__.py
new file mode 100644
index 0000000000..3b02974875
--- /dev/null
+++ b/src/rubrix/server/daos/backend/metrics/__init__.py
@@ -0,0 +1,5 @@
+from .commons import METRICS as COMMON_METRICS
+from .text_classification import METRICS as TEXT_CLASSIFICATION
+from .token_classification import METRICS as TOKEN_CLASSIFICATION
+
+ALL_METRICS = {**COMMON_METRICS, **TEXT_CLASSIFICATION, **TOKEN_CLASSIFICATION}
diff --git a/src/rubrix/server/daos/backend/metrics/base.py b/src/rubrix/server/daos/backend/metrics/base.py
new file mode 100644
index 0000000000..bb469755ad
--- /dev/null
+++ b/src/rubrix/server/daos/backend/metrics/base.py
@@ -0,0 +1,213 @@
+import dataclasses
+from typing import Any, Dict, List, Optional, Union
+
+from rubrix.server.daos.backend.query_helpers import aggregations
+from rubrix.server.helpers import unflatten_dict
+
+
+@dataclasses.dataclass
+class ElasticsearchMetric:
+ id: str
+
+ @property
+ def metric_arg_names(self):
+ return self.__args__
+
+ def __post_init__(self):
+ self.__args__ = self.get_function_arg_names(self._build_aggregation)
+
+ @staticmethod
+ def get_function_arg_names(func):
+ return func.__code__.co_varnames
+
+ def aggregation_request(
+ self, *args, **kwargs
+ ) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
+ """
+ Configures the summary es aggregation definition
+ """
+ return {self.id: self._build_aggregation(*args, **kwargs)}
+
+ def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Parse the es aggregation result. Override this method
+ for result customization
+
+ Parameters
+ ----------
+ aggregation_result:
+ Retrieved es aggregation result
+
+ """
+ return aggregation_result.get(self.id, aggregation_result)
+
+ def _build_aggregation(self, *args, **kwargs) -> Dict[str, Any]:
+ raise NotImplementedError()
+
+
+@dataclasses.dataclass
+class NestedPathElasticsearchMetric(ElasticsearchMetric):
+ """
+ A ``ElasticsearchMetric`` which need nested fields for summary calculation.
+
+ Aggregations for nested fields need some extra configuration and this class
+ encapsulate these common logic.
+
+ Attributes:
+ -----------
+ nested_path:
+ The nested
+ """
+
+ nested_path: str
+
+ def __post_init__(self):
+ self.__args__ = self.get_function_arg_names(self._inner_aggregation)
+
+ def _inner_aggregation(self, *args, **kwargs) -> Dict[str, Any]:
+ """The specific aggregation definition"""
+ raise NotImplementedError()
+
+ def _build_aggregation(self, *args, **kwargs) -> Dict[str, Any]:
+ """Implements the common mechanism to define aggregations with nested fields"""
+ return aggregations.nested_aggregation(
+ nested_path=self.nested_path,
+ inner_aggregation=self._inner_aggregation(*args, **kwargs),
+ )
+
+ def compound_nested_field(self, inner_field: str) -> str:
+ return f"{self.nested_path}.{inner_field}"
+
+
+@dataclasses.dataclass
+class HistogramAggregation(ElasticsearchMetric):
+ """
+ Base elasticsearch histogram aggregation metric
+
+ Attributes
+ ----------
+ field:
+ The histogram field
+ script:
+ If provided, it will be used as scripted field
+ for aggregation
+ fixed_interval:
+ If provided, it will used ALWAYS as the histogram
+ aggregation interval
+ """
+
+ field: str
+ script: Optional[Union[str, Dict[str, Any]]] = None
+ fixed_interval: Optional[float] = None
+
+ def _build_aggregation(self, interval: Optional[float] = None) -> Dict[str, Any]:
+ if self.fixed_interval:
+ interval = self.fixed_interval
+
+ return aggregations.histogram_aggregation(
+ field_name=self.field, script=self.script, interval=interval
+ )
+
+
+@dataclasses.dataclass
+class TermsAggregation(ElasticsearchMetric):
+
+ field: str = None
+ script: Union[str, Dict[str, Any]] = None
+ fixed_size: Optional[int] = None
+ default_size: Optional[int] = None
+ missing: Optional[str] = None
+
+ def _build_aggregation(self, size: int = None) -> Dict[str, Any]:
+ if self.fixed_size:
+ size = self.fixed_size
+ return aggregations.terms_aggregation(
+ self.field,
+ script=self.script,
+ size=size or self.default_size,
+ missing=self.missing,
+ )
+
+
+@dataclasses.dataclass
+class BidimensionalTermsAggregation(ElasticsearchMetric):
+ field_x: str
+ field_y: str
+
+ def _build_aggregation(self, size: int = None) -> Dict[str, Any]:
+ return aggregations.bidimentional_terms_aggregations(
+ field_name_x=self.field_x,
+ field_name_y=self.field_y,
+ size=size,
+ )
+
+
+@dataclasses.dataclass
+class NestedTermsAggregation(NestedPathElasticsearchMetric):
+ terms: TermsAggregation
+
+ def __post_init__(self):
+ super().__post_init__()
+ self.terms.field = f"{self.nested_path}.{self.terms.field}"
+
+ def _inner_aggregation(self, size: int) -> Dict[str, Any]:
+ return self.terms.aggregation_request(size)
+
+
+@dataclasses.dataclass
+class NestedBidimensionalTermsAggregation(NestedPathElasticsearchMetric):
+ biterms: BidimensionalTermsAggregation
+
+ def __post_init__(self):
+ super().__post_init__()
+ self.biterms.field_x = f"{self.nested_path}.{self.biterms.field_x}"
+ self.biterms.field_y = f"{self.nested_path}.{self.biterms.field_y}"
+
+ def _inner_aggregation(self, size: int = None) -> Dict[str, Any]:
+ return self.biterms.aggregation_request(size)
+
+
+@dataclasses.dataclass
+class NestedHistogramAggregation(NestedPathElasticsearchMetric):
+ histogram: HistogramAggregation
+
+ def __post_init__(self):
+ super().__post_init__()
+ self.histogram.field = f"{self.nested_path}.{self.histogram.field}"
+
+ def _inner_aggregation(self, interval: float) -> Dict[str, Any]:
+ return self.histogram.aggregation_request(interval)
+
+
+@dataclasses.dataclass
+class WordCloudAggregation(ElasticsearchMetric):
+ default_field: str
+
+ def _build_aggregation(
+ self, text_field: str = None, size: int = None
+ ) -> Dict[str, Any]:
+ field = text_field or self.default_field
+ terms_id = f"{self.id}_{field}" if text_field else self.id
+ return TermsAggregation(id=terms_id, field=field,).aggregation_request(
+ size=size
+ )[terms_id]
+
+
+@dataclasses.dataclass
+class MetadataAggregations(ElasticsearchMetric):
+ def __post_init__(self):
+ super().__post_init__()
+ self.__args__ = self.get_function_arg_names(self.aggregation_request)
+
+ def aggregation_request(
+ self,
+ schema: Dict[str, Any],
+ size: int = None,
+ ) -> List[Dict[str, Any]]:
+
+ metadata_aggs = aggregations.custom_fields(fields_definitions=schema, size=size)
+ return [{key: value} for key, value in metadata_aggs.items()]
+
+ def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
+ data = unflatten_dict(aggregation_result, stop_keys=["metadata"])
+ return data.get("metadata", {})
diff --git a/src/rubrix/server/daos/backend/metrics/commons.py b/src/rubrix/server/daos/backend/metrics/commons.py
new file mode 100644
index 0000000000..b91c712a94
--- /dev/null
+++ b/src/rubrix/server/daos/backend/metrics/commons.py
@@ -0,0 +1,47 @@
+from rubrix.server.commons.models import TaskStatus
+from rubrix.server.daos.backend.metrics.base import (
+ HistogramAggregation,
+ MetadataAggregations,
+ TermsAggregation,
+ WordCloudAggregation,
+)
+
+METRICS = {
+ "text_length": HistogramAggregation(
+ id="text_length",
+ field="metrics.text_length",
+ script="params._source.text.length()",
+ fixed_interval=1,
+ ),
+ "error_distribution": TermsAggregation(
+ id="error_distribution",
+ field="predicted",
+ missing="unknown",
+ fixed_size=3,
+ ),
+ "status_distribution": TermsAggregation(
+ id="status_distribution",
+ field="status",
+ fixed_size=len(TaskStatus),
+ ),
+ "words_cloud": WordCloudAggregation(
+ id="words_cloud",
+ default_field="text.wordcloud",
+ ),
+ "metadata": MetadataAggregations(
+ id="metadata",
+ ),
+ "predicted_by": TermsAggregation(
+ id="predicted_by",
+ field="predicted_by",
+ ),
+ "annotated_by": TermsAggregation(
+ id="annotated_by",
+ field="annotated_by",
+ ),
+ "score": HistogramAggregation(
+ id="score",
+ field="score",
+ fixed_interval=0.001,
+ ),
+}
diff --git a/src/rubrix/server/daos/backend/metrics/datasets.py b/src/rubrix/server/daos/backend/metrics/datasets.py
new file mode 100644
index 0000000000..a48bc73d9b
--- /dev/null
+++ b/src/rubrix/server/daos/backend/metrics/datasets.py
@@ -0,0 +1,8 @@
+# All metrics related to the datasets index
+from rubrix.server.daos.backend.metrics.base import TermsAggregation
+
+METRICS = {
+ "all_rubrix_workspaces": TermsAggregation(
+ id="all_rubrix_workspaces", field="owner.keyword"
+ )
+}
diff --git a/src/rubrix/server/daos/backend/metrics/text_classification.py b/src/rubrix/server/daos/backend/metrics/text_classification.py
new file mode 100644
index 0000000000..085400a342
--- /dev/null
+++ b/src/rubrix/server/daos/backend/metrics/text_classification.py
@@ -0,0 +1,130 @@
+import dataclasses
+from typing import Any, Dict, List, Optional
+
+from rubrix.server.daos.backend.metrics.base import (
+ ElasticsearchMetric,
+ TermsAggregation,
+)
+from rubrix.server.daos.backend.query_helpers import aggregations, filters
+from rubrix.server.helpers import unflatten_dict
+
+
+@dataclasses.dataclass
+class DatasetLabelingRulesMetric(ElasticsearchMetric):
+ def _build_aggregation(self, queries: List[str]) -> Dict[str, Any]:
+ rules_filters = [filters.text_query(rule_query) for rule_query in queries]
+ return aggregations.filters_aggregation(
+ filters={
+ "covered_records": filters.boolean_filter(
+ should_filters=rules_filters, minimum_should_match=1
+ ),
+ "annotated_covered_records": filters.boolean_filter(
+ filter_query=filters.exists_field("annotated_as"),
+ should_filters=rules_filters,
+ minimum_should_match=1,
+ ),
+ }
+ )
+
+
+@dataclasses.dataclass
+class LabelingRulesMetric(ElasticsearchMetric):
+ id: str
+
+ def _build_aggregation(
+ self, rule_query: str, labels: Optional[List[str]]
+ ) -> Dict[str, Any]:
+
+ annotated_records_filter = filters.exists_field("annotated_as")
+ rule_query_filter = filters.text_query(rule_query)
+ aggr_filters = {
+ "covered_records": rule_query_filter,
+ "annotated_covered_records": filters.boolean_filter(
+ filter_query=annotated_records_filter,
+ should_filters=[rule_query_filter],
+ ),
+ }
+
+ if labels is not None:
+ for label in labels:
+ rule_label_annotated_filter = filters.term_filter(
+ "annotated_as", value=label
+ )
+ encoded_label = self._encode_label_name(label)
+ aggr_filters.update(
+ {
+ f"{encoded_label}.correct_records": filters.boolean_filter(
+ filter_query=annotated_records_filter,
+ should_filters=[
+ rule_query_filter,
+ rule_label_annotated_filter,
+ ],
+ minimum_should_match=2,
+ ),
+ f"{encoded_label}.incorrect_records": filters.boolean_filter(
+ filter_query=annotated_records_filter,
+ must_query=rule_query_filter,
+ must_not_query=rule_label_annotated_filter,
+ ),
+ }
+ )
+
+ return aggregations.filters_aggregation(aggr_filters)
+
+ @staticmethod
+ def _encode_label_name(label: str) -> str:
+ return label.replace(".", "@@@")
+
+ @staticmethod
+ def _decode_label_name(label: str) -> str:
+ return label.replace("@@@", ".")
+
+ def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
+ if self.id in aggregation_result:
+ aggregation_result = aggregation_result[self.id]
+
+ aggregation_result = unflatten_dict(aggregation_result)
+ results = {
+ "covered_records": aggregation_result.pop("covered_records"),
+ "annotated_covered_records": aggregation_result.pop(
+ "annotated_covered_records"
+ ),
+ }
+
+ all_correct = []
+ all_incorrect = []
+ all_precision = []
+ for label, metrics in aggregation_result.items():
+ correct = metrics.get("correct_records", 0)
+ incorrect = metrics.get("incorrect_records", 0)
+ annotated = correct + incorrect
+ metrics["annotated"] = annotated
+ if annotated > 0:
+ precision = correct / annotated
+ metrics["precision"] = precision
+ all_precision.append(precision)
+
+ all_correct.append(correct)
+ all_incorrect.append(incorrect)
+ results[self._decode_label_name(label)] = metrics
+
+ results["correct_records"] = sum(all_correct)
+ results["incorrect_records"] = sum(all_incorrect)
+ if len(all_precision) > 0:
+ results["precision"] = sum(all_precision) / len(all_precision)
+
+ return results
+
+
+METRICS = {
+ "predicted_as": TermsAggregation(
+ id="predicted_as",
+ field="predicted_as",
+ ),
+ "annotated_as": TermsAggregation(
+ id="annotated_as",
+ field="annotated_as",
+ ),
+ "labeling_rule": LabelingRulesMetric(id="labeling_rule"),
+ "dataset_labeling_rules": DatasetLabelingRulesMetric(id="dataset_labeling_rules"),
+}
diff --git a/src/rubrix/server/daos/backend/metrics/token_classification.py b/src/rubrix/server/daos/backend/metrics/token_classification.py
new file mode 100644
index 0000000000..166c03b969
--- /dev/null
+++ b/src/rubrix/server/daos/backend/metrics/token_classification.py
@@ -0,0 +1,214 @@
+import dataclasses
+from typing import Any, Dict
+
+from rubrix.server.daos.backend.metrics.base import (
+ BidimensionalTermsAggregation,
+ HistogramAggregation,
+ NestedBidimensionalTermsAggregation,
+ NestedHistogramAggregation,
+ NestedPathElasticsearchMetric,
+ NestedTermsAggregation,
+ TermsAggregation,
+)
+from rubrix.server.daos.backend.query_helpers import aggregations
+
+_DEFAULT_MAX_ENTITY_BUCKET = 1000
+
+
+@dataclasses.dataclass
+class EntityConsistency(NestedPathElasticsearchMetric):
+ """Computes the entity consistency distribution"""
+
+ mention_field: str
+ labels_field: str
+
+ def _inner_aggregation(
+ self,
+ size: int,
+ interval: int = 2,
+ entity_size: int = _DEFAULT_MAX_ENTITY_BUCKET,
+ ) -> Dict[str, Any]:
+ size = size or 50
+ interval = int(max(interval or 2, 2))
+ return {
+ "consistency": {
+ **aggregations.terms_aggregation(
+ self.compound_nested_field(self.mention_field), size=size
+ ),
+ "aggs": {
+ "entities": aggregations.terms_aggregation(
+ self.compound_nested_field(self.labels_field), size=entity_size
+ ),
+ "count": {
+ "cardinality": {
+ "field": self.compound_nested_field(self.labels_field)
+ }
+ },
+ "entities_variability_filter": {
+ "bucket_selector": {
+ "buckets_path": {"numLabels": "count"},
+ "script": f"params.numLabels >= {interval}",
+ }
+ },
+ "sortby_entities_count": {
+ "bucket_sort": {
+ "sort": [{"count": {"order": "desc"}}],
+ "size": size,
+ }
+ },
+ },
+ }
+ }
+
+ def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
+ """Simplifies the aggregation result sorting by worst mention consistency"""
+ result = [
+ {
+ "mention": mention,
+ "entities": [
+ {"label": entity, "count": count}
+ for entity, count in mention_aggs["entities"].items()
+ ],
+ }
+ for mention, mention_aggs in aggregation_result.items()
+ ]
+ # TODO: filter by entities threshold
+ result.sort(key=lambda m: len(m["entities"]), reverse=True)
+ return {"mentions": result}
+
+
+METRICS = {
+ "tokens_length": HistogramAggregation(
+ "tokens_length",
+ field="metrics.tokens_length",
+ ),
+ "token_frequency": NestedTermsAggregation(
+ id="token_frequency",
+ nested_path="metrics.tokens",
+ terms=TermsAggregation(
+ id="frequency",
+ field="value",
+ ),
+ ),
+ "token_length": NestedHistogramAggregation(
+ id="token_length",
+ nested_path="metrics.tokens",
+ histogram=HistogramAggregation(
+ id="length",
+ field="length",
+ fixed_interval=1,
+ ),
+ ),
+ "token_capitalness": NestedTermsAggregation(
+ id="token_capitalness",
+ nested_path="metrics.tokens",
+ terms=TermsAggregation(
+ id="capitalness",
+ field="capitalness",
+ ),
+ ),
+ "predicted_mention_char_length": NestedHistogramAggregation(
+ id="predicted_mention_char_length",
+ nested_path="metrics.predicted.mentions",
+ histogram=HistogramAggregation(
+ id="length",
+ field="chars_length",
+ fixed_interval=1,
+ ),
+ ),
+ "annotated_mention_char_length": NestedHistogramAggregation(
+ id="annotated_mention_char_length",
+ nested_path="metrics.annotated.mentions",
+ histogram=HistogramAggregation(
+ id="length",
+ field="chars_length",
+ fixed_interval=1,
+ ),
+ ),
+ "predicted_entity_labels": NestedTermsAggregation(
+ id="predicted_entity_labels",
+ nested_path="metrics.predicted.mentions",
+ terms=TermsAggregation(
+ id="terms",
+ field="label",
+ default_size=_DEFAULT_MAX_ENTITY_BUCKET,
+ ),
+ ),
+ "annotated_entity_labels": NestedTermsAggregation(
+ id="annotated_entity_labels",
+ nested_path="metrics.annotated.mentions",
+ terms=TermsAggregation(
+ id="terms",
+ field="label",
+ default_size=_DEFAULT_MAX_ENTITY_BUCKET,
+ ),
+ ),
+ "predicted_entity_density": NestedHistogramAggregation(
+ id="predicted_entity_density",
+ nested_path="metrics.predicted.mentions",
+ histogram=HistogramAggregation(id="histogram", field="density"),
+ ),
+ "annotated_entity_density": NestedHistogramAggregation(
+ id="annotated_entity_density",
+ nested_path="metrics.annotated.mentions",
+ histogram=HistogramAggregation(id="histogram", field="density"),
+ ),
+ "predicted_mention_token_length": NestedHistogramAggregation(
+ id="predicted_mention_token_length",
+ nested_path="metrics.predicted.mentions",
+ histogram=HistogramAggregation(id="histogram", field="tokens_length"),
+ ),
+ "annotated_mention_token_length": NestedHistogramAggregation(
+ id="annotated_mention_token_length",
+ nested_path="metrics.annotated.mentions",
+ histogram=HistogramAggregation(id="histogram", field="tokens_length"),
+ ),
+ "predicted_entity_capitalness": NestedTermsAggregation(
+ id="predicted_entity_capitalness",
+ nested_path="metrics.predicted.mentions",
+ terms=TermsAggregation(id="terms", field="capitalness"),
+ ),
+ "annotated_entity_capitalness": NestedTermsAggregation(
+ id="annotated_entity_capitalness",
+ nested_path="metrics.annotated.mentions",
+ terms=TermsAggregation(id="terms", field="capitalness"),
+ ),
+ "predicted_mentions_distribution": NestedBidimensionalTermsAggregation(
+ id="predicted_mentions_distribution",
+ nested_path="metrics.predicted.mentions",
+ biterms=BidimensionalTermsAggregation(
+ id="bi-dimensional", field_x="label", field_y="value"
+ ),
+ ),
+ "annotated_mentions_distribution": NestedBidimensionalTermsAggregation(
+ id="predicted_mentions_distribution",
+ nested_path="metrics.annotated.mentions",
+ biterms=BidimensionalTermsAggregation(
+ id="bi-dimensional", field_x="label", field_y="value"
+ ),
+ ),
+ "predicted_entity_consistency": EntityConsistency(
+ id="predicted_entity_consistency",
+ nested_path="metrics.predicted.mentions",
+ mention_field="value",
+ labels_field="label",
+ ),
+ "annotated_entity_consistency": EntityConsistency(
+ id="annotated_entity_consistency",
+ nested_path="metrics.annotated.mentions",
+ mention_field="value",
+ labels_field="label",
+ ),
+ "predicted_tag_consistency": EntityConsistency(
+ id="predicted_tag_consistency",
+ nested_path="metrics.predicted.tags",
+ mention_field="value",
+ labels_field="tag",
+ ),
+ "annotated_tag_consistency": EntityConsistency(
+ id="annotated_tag_consistency",
+ nested_path="metrics.annotated.tags",
+ mention_field="value",
+ labels_field="tag",
+ ),
+}
diff --git a/src/rubrix/server/elasticseach/query_helpers.py b/src/rubrix/server/daos/backend/query_helpers.py
similarity index 90%
rename from src/rubrix/server/elasticseach/query_helpers.py
rename to src/rubrix/server/daos/backend/query_helpers.py
index 0830cf51c5..3386c299bb 100644
--- a/src/rubrix/server/elasticseach/query_helpers.py
+++ b/src/rubrix/server/daos/backend/query_helpers.py
@@ -13,17 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import math
from typing import Any, Dict, List, Optional, Type, Union
from pydantic import BaseModel
-from rubrix.server.apis.v0.models.commons.model import (
- EsRecordDataFieldNames,
- SortableField,
- TaskStatus,
-)
-from rubrix.server.elasticseach.mappings.helpers import mappings
+from rubrix.server.commons.models import TaskStatus
+from rubrix.server.daos.backend.mappings.helpers import mappings
def nested_mappings_from_base_model(model_class: Type[BaseModel]) -> Dict[str, Any]:
@@ -45,21 +40,6 @@ def resolve_mapping(info) -> Dict[str, Any]:
}
-def sort_by2elasticsearch(
- sort: List[SortableField], valid_fields: Optional[List[str]] = None
-) -> List[Dict[str, Any]]:
- valid_fields = valid_fields or []
- result = []
- for sortable_field in sort:
- if valid_fields:
- assert sortable_field.id.split(".")[0] in valid_fields, (
- f"Wrong sort id {sortable_field.id}. Valid values are: "
- f"{[str(v) for v in valid_fields]}"
- )
- result.append({sortable_field.id: {"order": sortable_field.order}})
- return result
-
-
def parse_aggregations(
es_aggregations: Dict[str, Any] = None
) -> Optional[Dict[str, Any]]:
@@ -126,10 +106,6 @@ def parse_buckets(buckets: List[Dict[str, Any]]) -> Dict[str, Any]:
return result
-def decode_field_name(field: EsRecordDataFieldNames) -> str:
- return field.value
-
-
class filters:
"""Group of functions related to elasticsearch filters"""
@@ -172,29 +148,21 @@ def predicted_by(predicted_by: List[str] = None) -> Optional[Dict[str, Any]]:
if not predicted_by:
return None
- return {
- "terms": {
- decode_field_name(EsRecordDataFieldNames.predicted_by): predicted_by
- }
- }
+ return {"terms": {"predicted_by": predicted_by}}
@staticmethod
def annotated_by(annotated_by: List[str] = None) -> Optional[Dict[str, Any]]:
"""Filter records with given predicted by terms"""
if not annotated_by:
return None
- return {
- "terms": {
- decode_field_name(EsRecordDataFieldNames.annotated_by): annotated_by
- }
- }
+ return {"terms": {"annotated_by": annotated_by}}
@staticmethod
def status(status: List[TaskStatus] = None) -> Optional[Dict[str, Any]]:
"""Filter records by status"""
if not status:
return None
- return {"terms": {decode_field_name(EsRecordDataFieldNames.status): status}}
+ return {"terms": {"status": status}}
@staticmethod
def metadata(metadata: Dict[str, Union[str, List[str]]]) -> List[Dict[str, Any]]:
@@ -269,7 +237,9 @@ class aggregations:
MAX_AGGREGATION_SIZE = 5000 # TODO: improve by setting env var
@staticmethod
- def nested_aggregation(nested_path: str, inner_aggregation: Dict[str, Any]):
+ def nested_aggregation(
+ nested_path: str, inner_aggregation: Dict[str, Any]
+ ) -> Dict[str, Any]:
inner_meta = list(inner_aggregation.values())[0].get("meta", {})
return {
"meta": {
@@ -374,6 +344,10 @@ def __resolve_aggregation_for_field_type(
if aggregation
}
+ @staticmethod
+ def filters_aggregation(filters: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
+ return {"filters": {"filters": filters}}
+
def find_nested_field_path(
field_name: str, mapping_definition: Dict[str, Any]
diff --git a/src/rubrix/server/elasticseach/mappings/__init__.py b/src/rubrix/server/daos/backend/search/__init__.py
similarity index 100%
rename from src/rubrix/server/elasticseach/mappings/__init__.py
rename to src/rubrix/server/daos/backend/search/__init__.py
diff --git a/src/rubrix/server/daos/backend/search/model.py b/src/rubrix/server/daos/backend/search/model.py
new file mode 100644
index 0000000000..ff2598a1c0
--- /dev/null
+++ b/src/rubrix/server/daos/backend/search/model.py
@@ -0,0 +1,67 @@
+from enum import Enum
+from typing import Dict, List, Optional, TypeVar, Union
+
+from pydantic import BaseModel, Field
+
+from rubrix.server.commons.models import TaskStatus
+
+
+class SortOrder(str, Enum):
+ asc = "asc"
+ desc = "desc"
+
+
+class QueryRange(BaseModel):
+
+ range_from: float = Field(default=0.0, alias="from")
+ range_to: float = Field(default=None, alias="to")
+
+ class Config:
+ allow_population_by_field_name = True
+
+
+class SortableField(BaseModel):
+ """Sortable field structure"""
+
+ id: str
+ order: SortOrder = SortOrder.asc
+
+
+class SortConfig(BaseModel):
+ shuffle: bool = False
+
+ sort_by: List[SortableField] = Field(default_factory=list)
+ valid_fields: List[str] = Field(default_factory=list)
+
+
+class BaseQuery(BaseModel):
+ pass
+
+
+class BaseDatasetsQuery(BaseQuery):
+ tasks: Optional[List[str]] = None
+ owners: Optional[List[str]] = None
+ include_no_owner: bool = None
+ name: Optional[str] = None
+
+
+class BaseRecordsQuery(BaseQuery):
+
+ query_text: Optional[str] = None
+ advanced_query_dsl: bool = False
+
+ ids: Optional[List[Union[str, int]]]
+
+ annotated_by: List[str] = Field(default_factory=list)
+ predicted_by: List[str] = Field(default_factory=list)
+
+ status: List[TaskStatus] = Field(default_factory=list)
+ metadata: Optional[Dict[str, Union[str, List[str]]]] = None
+
+ has_annotation: Optional[bool] = None
+ has_prediction: Optional[bool] = None
+
+
+BackendQuery = TypeVar("BackendQuery", bound=BaseQuery)
+BackendRecordsQuery = TypeVar("BackendRecordsQuery", bound=BaseRecordsQuery)
+BackendDatasetsQuery = TypeVar("BackendDatasetsQuery", bound=BaseDatasetsQuery)
diff --git a/src/rubrix/server/daos/backend/search/query_builder.py b/src/rubrix/server/daos/backend/search/query_builder.py
new file mode 100644
index 0000000000..d63c2a4a30
--- /dev/null
+++ b/src/rubrix/server/daos/backend/search/query_builder.py
@@ -0,0 +1,227 @@
+import logging
+from enum import Enum
+from typing import Any, Dict, List, Optional
+
+from luqum.elasticsearch import ElasticsearchQueryBuilder, SchemaAnalyzer
+from luqum.parser import parser
+
+from rubrix.server.daos.backend.query_helpers import filters
+from rubrix.server.daos.backend.search.model import (
+ BackendDatasetsQuery,
+ BackendQuery,
+ BackendRecordsQuery,
+ BaseDatasetsQuery,
+ QueryRange,
+ SortableField,
+ SortConfig,
+)
+
+
+class EsQueryBuilder:
+ _INSTANCE: "EsQueryBuilder" = None
+ _LOGGER = logging.getLogger(__name__)
+
+ @classmethod
+ def get_instance(cls):
+ if not cls._INSTANCE:
+ cls._INSTANCE = cls()
+ return cls._INSTANCE
+
+ def _datasets_to_es_query(
+ self, query: Optional[BackendDatasetsQuery] = None
+ ) -> Dict[str, Any]:
+ if not query:
+ return filters.match_all()
+
+ query_filters = []
+ if query.owners:
+ owners_filter = filters.terms_filter("owner.keyword", query.owners)
+ if query.include_no_owner:
+ query_filters.append(
+ filters.boolean_filter(
+ minimum_should_match=1, # OR Condition
+ should_filters=[
+ owners_filter,
+ filters.boolean_filter(
+ must_not_query=filters.exists_field("owner")
+ ),
+ ],
+ )
+ )
+ else:
+ query_filters.append(owners_filter)
+
+ if query.tasks:
+ query_filters.append(
+ filters.terms_filter(field="task.keyword", values=query.tasks)
+ )
+ if query.name:
+ query_filters.append(
+ filters.term_filter(field="name.keyword", value=query.name)
+ )
+ if not query_filters:
+ return filters.match_all()
+ return filters.boolean_filter(
+ should_filters=query_filters, minimum_should_match=len(query_filters)
+ )
+
+ def _search_to_es_query(
+ self,
+ schema: Optional[Dict[str, Any]] = None,
+ query: Optional[BackendRecordsQuery] = None,
+ sort: Optional[SortConfig] = None,
+ ):
+ if not query:
+ return filters.match_all()
+
+ if not query.advanced_query_dsl or not query.query_text:
+ return self._to_es_query(query)
+
+ text_search = query.query_text
+ new_query = query.copy(update={"query_text": None})
+
+ schema = SchemaAnalyzer(schema)
+ es_query_builder = ElasticsearchQueryBuilder(
+ **{
+ **schema.query_builder_options(),
+ "default_field": "text",
+ }
+ )
+
+ query_tree = parser.parse(text_search)
+ query_text = es_query_builder(query_tree)
+
+ return filters.boolean_filter(
+ filter_query=self._to_es_query(new_query), must_query=query_text
+ )
+
+ def map_2_es_query(
+ self,
+ schema: Optional[Dict[str, Any]] = None,
+ query: Optional[BackendQuery] = None,
+ sort: Optional[SortConfig] = None,
+ id_from: Optional[str] = None,
+ ) -> Dict[str, Any]:
+ es_query: Dict[str, Any] = (
+ {"query": self._datasets_to_es_query(query)}
+ if isinstance(query, BaseDatasetsQuery)
+ else {"query": self._search_to_es_query(schema, query)}
+ )
+
+ if id_from:
+ es_query["search_after"] = [id_from]
+ sort = SortConfig() # sort by id as default
+
+ es_sort = self.map_2_es_sort_configuration(schema=schema, sort=sort)
+ if es_sort:
+ es_query["sort"] = es_sort
+
+ return es_query
+
+ def map_2_es_sort_configuration(
+ self, schema: Optional[Dict[str, Any]] = None, sort: Optional[SortConfig] = None
+ ) -> Optional[List[Dict[str, Any]]]:
+
+ if not sort:
+ return None
+
+ # TODO(@frascuchon): compute valid list from the schema
+ valid_fields = sort.valid_fields or [
+ "id",
+ "metadata",
+ "score",
+ "predicted",
+ "predicted_as",
+ "predicted_by",
+ "annotated_as",
+ "annotated_by",
+ "status",
+ "last_updated",
+ "event_timestamp",
+ ]
+
+ id_field = "id"
+ id_keyword_field = "id.keyword"
+ schema = schema or {}
+ mappings = self._clean_mappings(schema.get("mappings", {}))
+ use_id_keyword = "text" == mappings.get("id")
+
+ es_sort = []
+ for sortable_field in sort.sort_by or [SortableField(id="id")]:
+ if valid_fields:
+ if not sortable_field.id.split(".")[0] in valid_fields:
+ raise AssertionError(
+ f"Wrong sort id {sortable_field.id}. Valid values are: "
+ f"{[str(v) for v in valid_fields]}"
+ )
+ field = sortable_field.id
+ if field == id_field and use_id_keyword:
+ field = id_keyword_field
+ es_sort.append({field: {"order": sortable_field.order}})
+
+ return es_sort
+
+ @classmethod
+ def _to_es_query(cls, query: BackendRecordsQuery) -> Dict[str, Any]:
+ if query.ids:
+ return filters.ids_filter(query.ids)
+
+ query_text = filters.text_query(query.query_text)
+ all_filters = filters.metadata(query.metadata)
+ if query.has_annotation:
+ all_filters.append(filters.exists_field("annotated_by"))
+ if query.has_prediction:
+ all_filters.append(filters.exists_field("predicted_by"))
+
+ query_data = query.dict(
+ exclude={
+ "advanced_query_dsl",
+ "query_text",
+ "metadata",
+ "uncovered_by_rules",
+ "has_annotation",
+ "has_prediction",
+ }
+ )
+ for key, value in query_data.items():
+ if value is None:
+ continue
+ key_filter = None
+ if isinstance(value, dict):
+ value = getattr(query, key) # check the original field type
+ if isinstance(value, list):
+ key_filter = filters.terms_filter(key, value)
+ elif isinstance(value, (str, Enum)):
+ key_filter = filters.term_filter(key, value)
+ elif isinstance(value, QueryRange):
+ key_filter = filters.range_filter(
+ field=key, value_from=value.range_from, value_to=value.range_to
+ )
+
+ else:
+ cls._LOGGER.warning(f"Cannot parse query value {value} for key {key}")
+ if key_filter:
+ all_filters.append(key_filter)
+
+ return filters.boolean_filter(
+ must_query=query_text or filters.match_all(),
+ filter_query=filters.boolean_filter(
+ should_filters=all_filters, minimum_should_match=len(all_filters)
+ )
+ if all_filters
+ else None,
+ must_not_query=filters.boolean_filter(
+ should_filters=[filters.text_query(q) for q in query.uncovered_by_rules]
+ )
+ if hasattr(query, "uncovered_by_rules") and query.uncovered_by_rules
+ else None,
+ )
+
+ def _clean_mappings(self, mappings: Dict[str, Any]):
+ if not mappings:
+ return {}
+
+ return {
+ key: definition.get("type") or self._clean_mappings(definition)
+ for key, definition in mappings["properties"].items()
+ }
diff --git a/src/rubrix/server/daos/datasets.py b/src/rubrix/server/daos/datasets.py
index eaab6919cc..d476edd1eb 100644
--- a/src/rubrix/server/daos/datasets.py
+++ b/src/rubrix/server/daos/datasets.py
@@ -12,24 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
from typing import Any, Dict, List, Optional, Type
from fastapi import Depends
-from rubrix.server.daos.models.datasets import BaseDatasetDB, DatasetDB, SettingsDB
-from rubrix.server.daos.records import DatasetRecordsDAO, dataset_records_index
-from rubrix.server.elasticseach import query_helpers
-from rubrix.server.elasticseach.client_wrapper import ElasticsearchWrapper
-from rubrix.server.elasticseach.mappings.datasets import (
- DATASETS_INDEX_NAME,
- DATASETS_INDEX_TEMPLATE,
+from rubrix.server.daos.backend.elasticsearch import ElasticsearchBackend
+from rubrix.server.daos.backend.search.model import BaseDatasetsQuery
+from rubrix.server.daos.models.datasets import (
+ BaseDatasetDB,
+ BaseDatasetSettingsDB,
+ DatasetDB,
+ DatasetSettingsDB,
)
+from rubrix.server.daos.records import DatasetRecordsDAO
from rubrix.server.errors import WrongTaskError
NO_WORKSPACE = ""
-MAX_NUMBER_OF_LISTED_DATASETS = 2500
class DatasetsDAO:
@@ -40,7 +39,7 @@ class DatasetsDAO:
@classmethod
def get_instance(
cls,
- es: ElasticsearchWrapper = Depends(ElasticsearchWrapper.get_instance),
+ es: ElasticsearchBackend = Depends(ElasticsearchBackend.get_instance),
records_dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance),
) -> "DatasetsDAO":
"""
@@ -63,207 +62,91 @@ def get_instance(
cls._INSTANCE = cls(es, records_dao)
return cls._INSTANCE
- def __init__(self, es: ElasticsearchWrapper, records_dao: DatasetRecordsDAO):
+ def __init__(self, es: ElasticsearchBackend, records_dao: DatasetRecordsDAO):
self._es = es
self.__records_dao__ = records_dao
self.init()
def init(self):
"""Initializes dataset dao. Used on app startup"""
- self._es.create_index_template(
- name=DATASETS_INDEX_NAME,
- template=DATASETS_INDEX_TEMPLATE,
- force_recreate=True,
- )
- self._es.create_index(DATASETS_INDEX_NAME)
+ self._es.create_datasets_index(force_recreate=True)
def list_datasets(
self,
owner_list: List[str] = None,
- task2dataset_map: Dict[str, Type[BaseDatasetDB]] = None,
- ) -> List[BaseDatasetDB]:
- filters = []
- if owner_list:
- owners_filter = query_helpers.filters.terms_filter(
- "owner.keyword", owner_list
- )
- if NO_WORKSPACE in owner_list:
- filters.append(
- query_helpers.filters.boolean_filter(
- minimum_should_match=1, # OR Condition
- should_filters=[
- query_helpers.filters.boolean_filter(
- must_not_query=query_helpers.filters.exists_field(
- "owner"
- )
- ),
- owners_filter,
- ],
- )
- )
- else:
- filters.append(owners_filter)
-
- if task2dataset_map:
- filters.append(
- query_helpers.filters.terms_filter(
- field="task.keyword", values=[task for task in task2dataset_map]
- )
- )
-
- docs = self._es.list_documents(
- index=DATASETS_INDEX_NAME,
- fetch_once=True,
- # TODO(@frascuchon): include id as part of the document as keyword to enable sorting by id
- size=MAX_NUMBER_OF_LISTED_DATASETS,
- query={
- "query": query_helpers.filters.boolean_filter(
- should_filters=filters, minimum_should_match=len(filters)
- )
- }
- if filters
- else None,
+ task2dataset_map: Dict[str, Type[DatasetDB]] = None,
+ name: Optional[str] = None,
+ ) -> List[DatasetDB]:
+ owner_list = owner_list or []
+ query = BaseDatasetsQuery(
+ owners=owner_list,
+ include_no_owner=NO_WORKSPACE in owner_list,
+ tasks=[task for task in task2dataset_map] if task2dataset_map else None,
+ name=name,
)
+ docs = self._es.list_datasets(query)
task2dataset_map = task2dataset_map or {}
return [
self._es_doc_to_instance(
- doc, ds_class=task2dataset_map.get(task, DatasetDB)
+ doc, ds_class=task2dataset_map.get(task, BaseDatasetDB)
)
for doc in docs
for task in [self.__get_doc_field__(doc, "task")]
]
- def create_dataset(
- self, dataset: BaseDatasetDB, mappings: Dict[str, Any]
- ) -> BaseDatasetDB:
- """
- Stores a dataset in elasticsearch and creates corresponding dataset records index
-
- Parameters
- ----------
- dataset:
- The dataset
-
- Returns
- -------
- Created dataset
- """
-
- self._es.add_document(
- index=DATASETS_INDEX_NAME,
- doc_id=dataset.id,
- document=self._dataset_to_es_doc(dataset),
+ def create_dataset(self, dataset: DatasetDB) -> DatasetDB:
+ self._es.add_dataset_document(
+ id=dataset.id, document=self._dataset_to_es_doc(dataset)
)
- self.__records_dao__.create_dataset_index(
- dataset, mappings=mappings, force_recreate=True
+ self._es.create_dataset_index(
+ id=dataset.id,
+ task=dataset.task,
+ force_recreate=True,
)
return dataset
def update_dataset(
self,
- dataset: BaseDatasetDB,
- ) -> BaseDatasetDB:
- """
- Updates an stored dataset
-
- Parameters
- ----------
- dataset:
- The dataset
-
- Returns
- -------
- The updated dataset
-
- """
- dataset_id = dataset.id
+ dataset: DatasetDB,
+ ) -> DatasetDB:
- self._es.update_document(
- index=DATASETS_INDEX_NAME,
- doc_id=dataset_id,
- document=self._dataset_to_es_doc(dataset),
- partial_update=True,
+ self._es.update_dataset_document(
+ id=dataset.id, document=self._dataset_to_es_doc(dataset)
)
return dataset
- def delete_dataset(self, dataset: BaseDatasetDB):
- """
- Deletes indices related to provided dataset
-
- Parameters
- ----------
- dataset:
- The dataset
-
- """
- try:
- self._es.delete_index(dataset_records_index(dataset.id))
- finally:
- self._es.delete_document(index=DATASETS_INDEX_NAME, doc_id=dataset.id)
+ def delete_dataset(self, dataset: DatasetDB):
+ self._es.delete(dataset.id)
def find_by_name(
self,
name: str,
owner: Optional[str],
- as_dataset_class: Type[BaseDatasetDB] = DatasetDB,
+ as_dataset_class: Type[DatasetDB] = BaseDatasetDB,
task: Optional[str] = None,
- ) -> Optional[BaseDatasetDB]:
- """
- Finds a dataset by name
+ ) -> Optional[DatasetDB]:
- Parameters
- ----------
- name: The dataset name
- owner: The dataset owner
- as_dataset_class: The dataset class used to return data
- task: The dataset task string definition
-
- Returns
- -------
- The found dataset if any. None otherwise
- """
- dataset_id = DatasetDB.build_dataset_id(
+ dataset_id = BaseDatasetDB.build_dataset_id(
name=name,
owner=owner,
)
- document = self._es.get_document_by_id(
- index=DATASETS_INDEX_NAME, doc_id=dataset_id
- )
- if not document and owner is None:
- # We must search by name since we have no owner
- results = self._es.list_documents(
- index=DATASETS_INDEX_NAME,
- query={"query": {"term": {"name.keyword": name}}},
- fetch_once=True,
- )
- results = list(results)
- if len(results) == 0:
- return None
-
- if len(results) > 1:
- raise ValueError(
- f"Ambiguous dataset info found for name {name}. Please provide a valid owner"
- )
-
- document = results[0]
-
+ document = self._es.find_dataset(id=dataset_id, name=name, owner=owner)
if document is None:
return None
-
base_ds = self._es_doc_to_instance(document)
if task and task != base_ds.task:
raise WrongTaskError(
detail=f"Provided task {task} cannot be applied to dataset"
)
- dataset_type = as_dataset_class or DatasetDB
+ dataset_type = as_dataset_class or BaseDatasetDB
return self._es_doc_to_instance(document, ds_class=dataset_type)
@staticmethod
def _es_doc_to_instance(
- doc: Dict[str, Any], ds_class: Type[BaseDatasetDB] = DatasetDB
- ) -> BaseDatasetDB:
- """Transforms a stored elasticsearch document into a `DatasetDB`"""
+ doc: Dict[str, Any], ds_class: Type[DatasetDB] = BaseDatasetDB
+ ) -> DatasetDB:
+ """Transforms a stored elasticsearch document into a `BaseDatasetDB`"""
def __key_value_list_to_dict__(
key_value_list: List[Dict[str, Any]]
@@ -300,67 +183,47 @@ def __dict_to_key_value_list__(data: Dict[str, Any]) -> List[Dict[str, Any]]:
}
def copy(self, source: DatasetDB, target: DatasetDB):
- source_doc = self._es.get_document_by_id(
- index=DATASETS_INDEX_NAME, doc_id=source.id
- )
- self._es.add_document(
- index=DATASETS_INDEX_NAME,
- doc_id=target.id,
+ source_doc = self._es.find_dataset(id=source.id)
+ self._es.add_dataset_document(
+ id=target.id,
document={
**source_doc["_source"], # we copy extended fields from source document
**self._dataset_to_es_doc(target),
},
)
- index_from = dataset_records_index(source.id)
- index_to = dataset_records_index(target.id)
- self._es.clone_index(index=index_from, clone_to=index_to)
+ self._es.copy(id_from=source.id, id_to=target.id)
def close(self, dataset: DatasetDB):
"""Close a dataset. It's mean that release all related resources, like elasticsearch index"""
- self._es.close_index(dataset_records_index(dataset.id))
+ self._es.close(dataset.id)
def open(self, dataset: DatasetDB):
"""Make available a dataset"""
- self._es.open_index(dataset_records_index(dataset.id))
+ self._es.open(dataset.id)
def get_all_workspaces(self) -> List[str]:
"""Get all datasets (Only for super users)"""
-
- workspaces_dict = self._es.aggregate(
- index=DATASETS_INDEX_NAME,
- aggregation=query_helpers.aggregations.terms_aggregation(
- "owner.keyword",
- missing=NO_WORKSPACE,
- size=500, # TODO: A max number of workspaces env var could be leveraged by this.
- ),
- )
-
- return [k for k in workspaces_dict]
-
- def save_settings(self, dataset: DatasetDB, settings: SettingsDB) -> SettingsDB:
- self._es.update_document(
- index=DATASETS_INDEX_NAME,
- doc_id=dataset.id,
- document={"settings": settings.dict(exclude_none=True)},
- partial_update=True,
+ metric_data = self._es.compute_rubrix_metric(metric_id="all_rubrix_workspaces")
+ return [k for k in metric_data]
+
+ def save_settings(
+ self, dataset: DatasetDB, settings: DatasetSettingsDB
+ ) -> BaseDatasetSettingsDB:
+ self._es.update_dataset_document(
+ id=dataset.id, document={"settings": settings.dict(exclude_none=True)}
)
return settings
def load_settings(
- self, dataset: DatasetDB, as_class: Type[SettingsDB]
- ) -> Optional[SettingsDB]:
- doc = self._es.get_document_by_id(index=DATASETS_INDEX_NAME, doc_id=dataset.id)
+ self, dataset: DatasetDB, as_class: Type[DatasetSettingsDB]
+ ) -> Optional[DatasetSettingsDB]:
+ doc = self._es.find_dataset(id=dataset.id)
if doc:
settings = self.__get_doc_field__(doc, field="settings")
return as_class.parse_obj(settings) if settings else None
def delete_settings(self, dataset: DatasetDB):
- self._es.update_document(
- index=DATASETS_INDEX_NAME,
- doc_id=dataset.id,
- script='ctx._source.remove("settings")',
- partial_update=True,
- )
+ self._es.remove_dataset_field(dataset.id, field="settings")
def __get_doc_field__(self, doc: Dict[str, Any], field: str) -> Optional[Any]:
return doc["_source"].get(field)
diff --git a/src/rubrix/server/daos/models/datasets.py b/src/rubrix/server/daos/models/datasets.py
index 9d4a1e9f8c..9d06f9f8e4 100644
--- a/src/rubrix/server/daos/models/datasets.py
+++ b/src/rubrix/server/daos/models/datasets.py
@@ -3,10 +3,13 @@
from pydantic import BaseModel, Field
+from rubrix._constants import DATASET_NAME_REGEX_PATTERN
+from rubrix.server.commons.models import TaskType
-class DatasetDB(BaseModel):
- name: str
- task: str
+
+class BaseDatasetDB(BaseModel):
+ name: str = Field(regex=DATASET_NAME_REGEX_PATTERN)
+ task: TaskType
owner: Optional[str] = None
tags: Dict[str, str] = Field(default_factory=dict)
metadata: Dict[str, Any] = Field(default_factory=dict)
@@ -29,8 +32,9 @@ def id(self) -> str:
return self.build_dataset_id(self.name, self.owner)
-class SettingsDB(BaseModel):
+class BaseDatasetSettingsDB(BaseModel):
pass
-BaseDatasetDB = TypeVar("BaseDatasetDB", bound=DatasetDB)
+DatasetDB = TypeVar("DatasetDB", bound=BaseDatasetDB)
+DatasetSettingsDB = TypeVar("DatasetSettingsDB", bound=BaseDatasetSettingsDB)
diff --git a/src/rubrix/server/daos/models/records.py b/src/rubrix/server/daos/models/records.py
index b221588b62..8bff581261 100644
--- a/src/rubrix/server/daos/models/records.py
+++ b/src/rubrix/server/daos/models/records.py
@@ -12,56 +12,159 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from datetime import datetime
+from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
+from uuid import uuid4
-from typing import Any, Dict, List, Optional
+from pydantic import BaseModel, Field, validator
+from pydantic.generics import GenericModel
-from pydantic import BaseModel, Field
+from rubrix._constants import MAX_KEYWORD_LENGTH
+from rubrix.server.commons.models import PredictionStatus, TaskStatus, TaskType
+from rubrix.server.daos.backend.search.model import BackendRecordsQuery, SortConfig
+from rubrix.server.helpers import flatten_dict
+from rubrix.utils import limit_value_length
-class RecordSearch(BaseModel):
- """
- Dao search
+class DaoRecordsSearch(BaseModel):
- Attributes:
- -----------
+ query: Optional[BackendRecordsQuery] = None
+ sort: SortConfig = Field(default_factory=SortConfig)
- query:
- The elasticsearch search query portion
- sort:
- The elasticsearch sort order
- aggregations:
- The elasticsearch search aggregations
- """
+class DaoRecordsSearchResults(BaseModel):
+ total: int
+ records: List[Dict[str, Any]]
- query: Optional[Dict[str, Any]] = None
- sort: List[Dict[str, Any]] = Field(default_factory=list)
- aggregations: Optional[Dict[str, Any]]
+class BaseAnnotationDB(BaseModel):
+ agent: str = Field(max_length=64)
-class RecordSearchResults(BaseModel):
- """
- Dao search results
- Attributes:
- -----------
+AnnotationDB = TypeVar("AnnotationDB", bound=BaseAnnotationDB)
- total: int
- The total of query results
- records: List[T]
- List of records retrieved for the pagination configuration
- aggregations: Optional[Dict[str, Dict[str, Any]]]
- The query aggregations grouped by task. Optional
- words: Optional[Dict[str, int]]
- The words cloud aggregations
- metadata: Optional[Dict[str, int]]
- Metadata fields aggregations
- metrics: Optional[List[DatasetMetricResults]]
- Calculated metrics for search
- """
- total: int
- records: List[Dict[str, Any]]
- aggregations: Optional[Dict[str, Dict[str, Any]]] = Field(default_factory=dict)
- words: Optional[Dict[str, int]] = None
- metadata: Optional[Dict[str, int]] = None
+class BaseRecordInDB(GenericModel, Generic[AnnotationDB]):
+ id: Optional[Union[int, str]] = Field(default=None)
+ metadata: Dict[str, Any] = Field(default=None)
+ event_timestamp: Optional[datetime] = None
+ status: Optional[TaskStatus] = None
+ prediction: Optional[AnnotationDB] = None
+ annotation: Optional[AnnotationDB] = None
+
+ @validator("id", always=True, pre=True)
+ def default_id_if_none_provided(cls, id: Optional[str]) -> str:
+ """Validates id info and sets a random uuid if not provided"""
+ if id is None:
+ return str(uuid4())
+ return id
+
+ @validator("status", always=True)
+ def fill_default_value(cls, status: TaskStatus):
+ """Fastapi validator for set default task status"""
+ return TaskStatus.default if status is None else status
+
+ @validator("metadata", pre=True)
+ def flatten_metadata(cls, metadata: Dict[str, Any]):
+ """
+ A fastapi validator for flatten metadata dictionary
+
+ Parameters
+ ----------
+ metadata:
+ The metadata dictionary
+
+ Returns
+ -------
+ A flatten version of metadata dictionary
+
+ """
+ if metadata:
+ metadata = flatten_dict(metadata, drop_empty=True)
+ metadata = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH)
+ return metadata
+
+ @classmethod
+ def task(cls) -> TaskType:
+ """The task type related to this task info"""
+ raise NotImplementedError
+
+ @property
+ def predicted(self) -> Optional[PredictionStatus]:
+ """The task record prediction status (if any)"""
+ return None
+
+ @property
+ def predicted_as(self) -> Optional[List[str]]:
+ """Predictions strings representation"""
+ return None
+
+ @property
+ def annotated_as(self) -> Optional[List[str]]:
+ """Annotations strings representation"""
+ return None
+
+ @property
+ def scores(self) -> Optional[List[float]]:
+ """Prediction scores"""
+ return None
+
+ def all_text(self) -> str:
+ """All textual information related to record"""
+ raise NotImplementedError
+
+ @property
+ def predicted_by(self) -> List[str]:
+ """The prediction agents"""
+ if self.prediction:
+ return [self.prediction.agent]
+ return []
+
+ @property
+ def annotated_by(self) -> List[str]:
+ """The annotation agents"""
+ if self.annotation:
+ return [self.annotation.agent]
+ return []
+
+ def extended_fields(self) -> Dict[str, Any]:
+ """
+ Used for extends fields to store in db. Tasks that would include extra
+ properties than commons (predicted, annotated_as,....) could implement
+ this method.
+ """
+ return {
+ "predicted": self.predicted,
+ "annotated_as": self.annotated_as,
+ "predicted_as": self.predicted_as,
+ "annotated_by": self.annotated_by,
+ "predicted_by": self.predicted_by,
+ "score": self.scores,
+ }
+
+ def dict(self, *args, **kwargs) -> "DictStrAny":
+ """
+ Extends base component dict extending object properties
+ and user defined extended fields
+ """
+ return {
+ **super().dict(*args, **kwargs),
+ **self.extended_fields(),
+ }
+
+
+class BaseRecordDB(BaseRecordInDB, Generic[AnnotationDB]):
+
+ # Read only ones
+ metrics: Dict[str, Any] = Field(default_factory=dict)
+ search_keywords: Optional[List[str]] = None
+ last_updated: datetime = None
+
+ @validator("search_keywords")
+ def remove_duplicated_keywords(cls, value) -> List[str]:
+ """Remove duplicated keywords"""
+ if value:
+ return list(set(value))
+
+
+RecordDB = TypeVar("RecordDB", bound=BaseRecordDB)
diff --git a/src/rubrix/server/daos/records.py b/src/rubrix/server/daos/records.py
index d7ae4d0666..1793f83a76 100644
--- a/src/rubrix/server/daos/records.py
+++ b/src/rubrix/server/daos/records.py
@@ -13,66 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import dataclasses
import datetime
-import re
-from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar
+from typing import Any, Dict, Iterable, List, Optional, Type
-import deprecated
from fastapi import Depends
-from rubrix.server.apis.v0.models.commons.model import BaseRecord, TaskType
-from rubrix.server.daos.models.datasets import BaseDatasetDB
-from rubrix.server.daos.models.records import RecordSearch, RecordSearchResults
-from rubrix.server.elasticseach.client_wrapper import (
+from rubrix.server.daos.backend.elasticsearch import (
ClosedIndexError,
- ElasticsearchWrapper,
+ ElasticsearchBackend,
IndexNotFoundError,
- create_es_wrapper,
)
-from rubrix.server.elasticseach.mappings.datasets import DATASETS_RECORDS_INDEX_NAME
-from rubrix.server.elasticseach.mappings.helpers import (
- mappings,
- tasks_common_mappings,
- tasks_common_settings,
+from rubrix.server.daos.backend.search.model import BackendRecordsQuery
+from rubrix.server.daos.models.datasets import DatasetDB
+from rubrix.server.daos.models.records import (
+ DaoRecordsSearch,
+ DaoRecordsSearchResults,
+ RecordDB,
)
-from rubrix.server.elasticseach.query_helpers import parse_aggregations
from rubrix.server.errors import ClosedDatasetError, MissingDatasetRecordsError
-from rubrix.server.errors.task_errors import MetadataLimitExceededError
-from rubrix.server.settings import settings
-
-DBRecord = TypeVar("DBRecord", bound=BaseRecord)
-
-
-@dataclasses.dataclass
-class _IndexTemplateExtensions:
-
- analyzers: List[Dict[str, Any]] = dataclasses.field(default_factory=list)
- properties: List[Dict[str, Any]] = dataclasses.field(default_factory=list)
- dynamic_templates: List[Dict[str, Any]] = dataclasses.field(default_factory=list)
-
-
-def dataset_records_index(dataset_id: str) -> str:
- """
- Returns dataset records index for a given dataset id
-
- The dataset info is stored in two elasticsearch indices. The main
- index where all datasets definition are stored and
- an specific dataset index for data records.
-
- This function calculates the corresponding dataset records index
- for a given dataset id.
-
- Parameters
- ----------
- dataset_id
-
- Returns
- -------
- The dataset records index name
-
- """
- return DATASETS_RECORDS_INDEX_NAME.format(dataset_id)
class DatasetRecordsDAO:
@@ -80,24 +38,10 @@ class DatasetRecordsDAO:
_INSTANCE = None
- # Keep info about elasticsearch mappings per task
- # This info must be provided by each task using dao.register_task_mappings method
- _MAPPINGS_BY_TASKS = {}
-
- __HIGHLIGHT_PRE_TAG__ = "<@@-rb-key>"
- __HIGHLIGHT_POST_TAG__ = "@@-rb-key>"
- __HIGHLIGHT_VALUES_REGEX__ = re.compile(
- rf"{__HIGHLIGHT_PRE_TAG__}(.+?){__HIGHLIGHT_POST_TAG__}"
- )
-
- __HIGHLIGHT_PHRASE_PRE_PARSER_REGEX__ = re.compile(
- rf"{__HIGHLIGHT_POST_TAG__}\s+{__HIGHLIGHT_PRE_TAG__}"
- )
-
@classmethod
def get_instance(
cls,
- es: ElasticsearchWrapper = Depends(ElasticsearchWrapper.get_instance),
+ es: ElasticsearchBackend = Depends(ElasticsearchBackend.get_instance),
) -> "DatasetRecordsDAO":
"""
Creates a dataset records dao instance
@@ -112,9 +56,8 @@ def get_instance(
cls._INSTANCE = cls(es)
return cls._INSTANCE
- def __init__(self, es: ElasticsearchWrapper):
+ def __init__(self, es: ElasticsearchBackend):
self._es = es
- self.init()
def init(self):
"""Initializes dataset records dao. Used on app startup"""
@@ -122,10 +65,9 @@ def init(self):
def add_records(
self,
- dataset: BaseDatasetDB,
- mappings: Dict[str, Any],
- records: List[DBRecord],
- record_class: Type[DBRecord],
+ dataset: DatasetDB,
+ records: List[RecordDB],
+ record_class: Type[RecordDB],
) -> int:
"""
Add records to dataset
@@ -160,71 +102,60 @@ def add_records(
db_record.dict(exclude_none=False, exclude={"search_keywords"})
)
- index_name = self.create_dataset_index(dataset, mappings=mappings)
- self._configure_metadata_fields(index_name, metadata_values)
+ self._es.create_dataset_index(
+ dataset.id,
+ task=dataset.task,
+ metadata_values=metadata_values,
+ )
+
return self._es.add_documents(
- index=index_name,
+ id=dataset.id,
documents=documents,
- doc_id=lambda _record: _record.get("id"),
)
- def get_metadata_schema(self, dataset: BaseDatasetDB) -> Dict[str, str]:
+ def get_metadata_schema(self, dataset: DatasetDB) -> Dict[str, str]:
"""Get metadata fields schema for provided dataset"""
- records_index = dataset_records_index(dataset.id)
- return self._es.get_field_mapping(index=records_index, field_name="metadata.*")
+
+ return self._es.get_metadata_mappings(id=dataset.id)
+
+ def compute_metric(
+ self,
+ dataset: DatasetDB,
+ metric_id: str,
+ metric_params: Dict[str, Any] = None,
+ query: Optional[BackendRecordsQuery] = None,
+ ):
+
+ return self._es.compute_metric(
+ id=dataset.id,
+ metric_id=metric_id,
+ query=query,
+ params=metric_params,
+ )
def search_records(
self,
- dataset: BaseDatasetDB,
- search: Optional[RecordSearch] = None,
+ dataset: DatasetDB,
+ search: Optional[DaoRecordsSearch] = None,
size: int = 100,
record_from: int = 0,
exclude_fields: List[str] = None,
highligth_results: bool = True,
- ) -> RecordSearchResults:
- """
- SearchRequest records under a dataset given a search parameters.
-
- Parameters
- ----------
- dataset:
- The dataset
- search:
- The search params
- size:
- Number of records to retrieve (for pagination)
- record_from:
- Record from which to retrieve the records (for pagination)
- exclude_fields:
- a list of fields to exclude from the result source. Wildcards are accepted
- Returns
- -------
- The search result
-
- """
- search = search or RecordSearch()
- records_index = dataset_records_index(dataset.id)
- compute_aggregations = record_from == 0
- aggregation_requests = (
- {**(search.aggregations or {})} if compute_aggregations else {}
- )
-
- sort_config = self.__normalize_sort_config__(records_index, sort=search.sort)
-
- es_query = {
- "_source": {"excludes": exclude_fields or []},
- "from": record_from,
- "query": search.query or {"match_all": {}},
- "sort": sort_config,
- "aggs": aggregation_requests,
- }
- if highligth_results:
- es_query["highlight"] = self.__configure_query_highlight__(
- task=dataset.task
- )
+ ) -> DaoRecordsSearchResults:
try:
- results = self._es.search(index=records_index, query=es_query, size=size)
+ search = search or DaoRecordsSearch()
+
+ total, records = self._es.search_records(
+ id=dataset.id,
+ query=search.query,
+ sort=search.sort,
+ record_from=record_from,
+ size=size,
+ exclude_fields=exclude_fields,
+ enable_highlight=highligth_results,
+ )
+ return DaoRecordsSearchResults(total=total, records=records)
except ClosedIndexError:
raise ClosedDatasetError(dataset.name)
except IndexNotFoundError:
@@ -232,43 +163,11 @@ def search_records(
f"No records index found for dataset {dataset.name}"
)
- hits = results["hits"]
- total = hits["total"]
- docs = hits["hits"]
- search_aggregations = results.get("aggregations", {})
-
- result = RecordSearchResults(
- total=total,
- records=list(map(self.__esdoc2record__, docs)),
- )
- if search_aggregations:
- parsed_aggregations = parse_aggregations(search_aggregations)
- result.aggregations = parsed_aggregations
-
- return result
-
- def __normalize_sort_config__(
- self, index: str, sort: Optional[List[Dict[str, Any]]] = None
- ) -> List[Dict[str, Any]]:
- id_field = "id"
- id_keyword_field = "id.keyword"
- sort_config = []
-
- for sort_field in sort or [{id_field: {"order": "asc"}}]:
- for field in sort_field:
- if field == id_field and self._es.get_field_mapping(
- index=index, field_name=id_keyword_field
- ):
- sort_config.append({id_keyword_field: sort_field[field]})
- else:
- sort_config.append(sort_field)
- return sort_config
-
def scan_dataset(
self,
- dataset: BaseDatasetDB,
+ dataset: DatasetDB,
+ search: Optional[DaoRecordsSearch] = None,
limit: int = 1000,
- search: Optional[RecordSearch] = None,
id_from: Optional[str] = None,
) -> Iterable[Dict[str, Any]]:
"""
@@ -289,178 +188,12 @@ def scan_dataset(
-------
An iterable over found dataset records
"""
- index = dataset_records_index(dataset.id)
- search = search or RecordSearch()
-
- sort_cfg = self.__normalize_sort_config__(
- index=index, sort=[{"id": {"order": "asc"}}]
- )
- es_query = {
- "query": search.query or {"match_all": {}},
- "highlight": self.__configure_query_highlight__(task=dataset.task),
- "sort": sort_cfg, # Sort the search so the consistency is maintained in every search
- }
-
- if id_from:
- # Scroll method does not accept read_after, thus, this case is handled as a search
- es_query["search_after"] = [id_from]
- results = self._es.search(index=index, query=es_query, size=limit)
- hits = results["hits"]
- docs = hits["hits"]
-
- else:
- docs = self._es.list_documents(
- index,
- query=es_query,
- sort_cfg=sort_cfg,
- )
- for doc in docs:
- yield self.__esdoc2record__(doc)
-
- def __esdoc2record__(
- self,
- doc: Dict[str, Any],
- query: Optional[str] = None,
- is_phrase_query: bool = True,
- ):
- return {
- **doc["_source"],
- "id": doc["_id"],
- "search_keywords": self.__parse_highlight_results__(
- doc, query=query, is_phrase_query=is_phrase_query
- ),
- }
-
- @classmethod
- def __parse_highlight_results__(
- cls,
- doc: Dict[str, Any],
- query: Optional[str] = None,
- is_phrase_query: bool = False,
- ) -> Optional[List[str]]:
- highlight_info = doc.get("highlight")
- if not highlight_info:
- return None
-
- search_keywords = []
- for content in highlight_info.values():
- if not isinstance(content, list):
- content = [content]
- text = " ".join(content)
-
- if is_phrase_query:
- text = re.sub(cls.__HIGHLIGHT_PHRASE_PRE_PARSER_REGEX__, " ", text)
- search_keywords.extend(re.findall(cls.__HIGHLIGHT_VALUES_REGEX__, text))
- return list(set(search_keywords))
-
- def _configure_metadata_fields(self, index: str, metadata_values: Dict[str, Any]):
- def check_metadata_length(metadata_length: int = 0):
- if metadata_length > settings.metadata_fields_limit:
- raise MetadataLimitExceededError(
- length=metadata_length, limit=settings.metadata_fields_limit
- )
-
- def detect_nested_type(v: Any) -> bool:
- """Returns True if value match as nested value"""
- return isinstance(v, list) and isinstance(v[0], dict)
-
- check_metadata_length(len(metadata_values))
- check_metadata_length(
- len(
- {
- *self._es.get_field_mapping(
- index, "metadata.*", exclude_subfields=True
- ),
- *[f"metadata.{k}" for k in metadata_values.keys()],
- }
- )
- )
- for field, value in metadata_values.items():
- if detect_nested_type(value):
- self._es.create_field_mapping(
- index,
- field_name=f"metadata.{field}",
- mapping=mappings.nested_field(),
- )
-
- def create_dataset_index(
- self,
- dataset: BaseDatasetDB,
- mappings: Dict[str, Any],
- force_recreate: bool = False,
- ) -> str:
- """
- Creates a dataset records elasticsearch index based on dataset task type
-
- Args:
- dataset:
- The dataset
- force_recreate:
- If True, the index will be deleted and recreated
-
- Returns:
- The generated index name.
- """
- _mappings = tasks_common_mappings()
- task_mappings = mappings.copy()
- for k in task_mappings:
- if isinstance(task_mappings[k], list):
- _mappings[k] = [*_mappings.get(k, []), *task_mappings[k]]
- else:
- _mappings[k] = {**_mappings.get(k, {}), **task_mappings[k]}
-
- index_name = dataset_records_index(dataset.id)
- self._es.create_index(
- index=index_name,
- settings=tasks_common_settings(),
- mappings={**tasks_common_mappings(), **_mappings},
- force_recreate=force_recreate,
+ search = search or DaoRecordsSearch()
+ return self._es.scan_records(
+ id=dataset.id, query=search.query, limit=limit, id_from=id_from
)
- return index_name
- def get_dataset_schema(self, dataset: BaseDatasetDB) -> Dict[str, Any]:
+ def get_dataset_schema(self, dataset: DatasetDB) -> Dict[str, Any]:
"""Return inner elasticsearch index configuration"""
- index_name = dataset_records_index(dataset.id)
- response = self._es.__client__.indices.get_mapping(index=index_name)
-
- if index_name in response:
- response = response.get(index_name)
-
- return response
-
- @classmethod
- def __configure_query_highlight__(cls, task: TaskType):
-
- return {
- "pre_tags": [cls.__HIGHLIGHT_PRE_TAG__],
- "post_tags": [cls.__HIGHLIGHT_POST_TAG__],
- "require_field_match": True,
- "fields": {
- "text": {},
- "text.*": {},
- # TODO(@frascuchon): `words` will be removed in version 0.16.0
- **({"inputs.*": {}} if task == TaskType.text_classification else {}),
- },
- }
-
-
-_instance: Optional[DatasetRecordsDAO] = None
-
-
-@deprecated.deprecated(reason="Use `DatasetRecordsDAO.get_instance` instead")
-def dataset_records_dao(
- es: ElasticsearchWrapper = Depends(create_es_wrapper),
-) -> DatasetRecordsDAO:
- """
- Creates a dataset records dao instance
-
- Parameters
- ----------
- es:
- The elasticserach wrapper dependency
-
- """
- global _instance
- if not _instance:
- _instance = DatasetRecordsDAO(es)
- return _instance
+ schema = self._es.get_mappings(id=dataset.id)
+ return schema
diff --git a/src/rubrix/server/elasticseach/client_wrapper.py b/src/rubrix/server/elasticseach/client_wrapper.py
deleted file mode 100644
index fcad32ab29..0000000000
--- a/src/rubrix/server/elasticseach/client_wrapper.py
+++ /dev/null
@@ -1,650 +0,0 @@
-# coding=utf-8
-# Copyright 2021-present, the Recognai S.L. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any, Callable, Dict, Iterable, List, Optional
-
-import deprecated
-from opensearchpy import NotFoundError, OpenSearch, OpenSearchException, RequestError
-from opensearchpy.helpers import bulk as es_bulk
-from opensearchpy.helpers import scan as es_scan
-
-from rubrix.logging import LoggingMixin
-from rubrix.server.elasticseach import query_helpers
-from rubrix.server.errors import InvalidTextSearchError
-
-try:
- import ujson as json
-except ModuleNotFoundError:
- import json
-
-from rubrix.server.settings import settings
-
-
-class ClosedIndexError(Exception):
- pass
-
-
-class IndexNotFoundError(Exception):
- pass
-
-
-class GenericSearchError(Exception):
- def __init__(self, origin_error: Exception):
- self.origin_error = origin_error
-
-
-class ElasticsearchWrapper(LoggingMixin):
- """A simple elasticsearch client wrapper for atomize some repetitive operations"""
-
- _INSTANCE = None
-
- @classmethod
- def get_instance(cls) -> "ElasticsearchWrapper":
- """
- Creates an instance of ElasticsearchWrapper.
-
- This function is used in fastapi for resolve component dependencies.
-
- See
-
- Returns
- -------
-
- """
-
- if cls._INSTANCE is None:
- es_client = OpenSearch(
- hosts=settings.elasticsearch,
- verify_certs=settings.elasticsearch_ssl_verify,
- ca_certs=settings.elasticsearch_ca_path,
- # Extra args to es configuration -> TODO: extensible by settings
- retry_on_timeout=True,
- max_retries=5,
- )
- cls._INSTANCE = cls(es_client)
-
- return cls._INSTANCE
-
- def __init__(self, es_client: OpenSearch):
- self.__client__ = es_client
-
- @property
- def client(self):
- """The elasticsearch client"""
- return self.__client__
-
- def list_documents(
- self,
- index: str,
- query: Dict[str, Any] = None,
- sort_cfg: Optional[List[Dict[str, Any]]] = None,
- size: Optional[int] = None,
- fetch_once: bool = False,
- ) -> Iterable[Dict[str, Any]]:
- """
- List ALL documents of an elasticsearch index
- Parameters
- ----------
- index:
- The index name
- sor_id:
- The sort id configuration
- query:
- The es query for filter results. Default: None
- sort_cfg:
- Customized configuration for sort-by id
- size:
- Amount of samples to retrieve per iteration, 1000 by default
- fetch_once:
- If enabled, will return only the `size` first records found. Default to: ``False``
-
- Returns
- -------
- A sequence of documents resulting from applying the query on the index
-
- """
- size = size or 1000
- query = query.copy() or {}
- if sort_cfg:
- query["sort"] = sort_cfg
- query["track_total_hits"] = False # Speedup pagination
- response = self.__client__.search(index=index, body=query, size=size)
- while response["hits"]["hits"]:
- for hit in response["hits"]["hits"]:
- yield hit
- if fetch_once:
- break
-
- last_id = hit["_id"]
- query["search_after"] = [last_id]
- response = self.__client__.search(index=index, body=query, size=size)
-
- def index_exists(self, index: str) -> bool:
- """
- Checks if provided index exists
-
- Parameters
- ----------
- index:
- The index name
-
- Returns
- -------
- True if index exists. False otherwise
- """
- return self.__client__.indices.exists(index)
-
- def search(
- self,
- index: str,
- routing: str = None,
- size: int = 100,
- query: Dict[str, Any] = None,
- ) -> Dict[str, Any]:
- """
- Apply a search over an index.
- See
-
- Parameters
- ----------
- index:
- The index name
- routing:
- The routing key. Optional
- size:
- Number of results to return. Default=100
- query:
- The elasticsearch query. Optional
-
- Returns
- -------
-
- """
- try:
- return self.__client__.search(
- index=index,
- body=query or {},
- routing=routing,
- track_total_hits=True,
- rest_total_hits_as_int=True,
- size=size,
- )
- except RequestError as rex:
-
- if rex.error == "search_phase_execution_exception":
- detail = rex.info["error"]
- detail = detail.get("root_cause")
- detail = detail[0].get("reason") if detail else rex.info["error"]
-
- raise InvalidTextSearchError(detail)
-
- if rex.error == "index_closed_exception":
- raise ClosedIndexError(index)
- raise GenericSearchError(rex)
- except NotFoundError as nex:
- raise IndexNotFoundError(nex)
- except OpenSearchException as ex:
- raise GenericSearchError(ex)
-
- def create_index(
- self,
- index: str,
- force_recreate: bool = False,
- settings: Dict[str, Any] = None,
- mappings: Dict[str, Any] = None,
- ):
- """
- Applies a index creation with provided mapping configuration.
-
- See
-
- Parameters
- ----------
- index:
- The index name
- force_recreate:
- If True, the index will be recreated (if exists). Default=False
- settings:
- The index settings configuration
- mappings:
- The mapping configuration. Optional.
-
- """
- if force_recreate:
- self.delete_index(index)
- if not self.index_exists(index):
- self.__client__.indices.create(
- index=index,
- body={"settings": settings or {}, "mappings": mappings or {}},
- ignore=400,
- )
-
- def create_index_template(
- self, name: str, template: Dict[str, Any], force_recreate: bool = False
- ):
- """
- Applies a index template creation with provided template definition.
-
- Parameters
- ----------
- name:
- The template index name
- template:
- The template definition
- force_recreate:
- If True, the template will be recreated (if exists). Default=False
-
- """
- if force_recreate or not self.__client__.indices.exists_template(name):
- self.__client__.indices.put_template(name=name, body=template)
-
- def delete_index_template(self, index_template: str):
- """Deletes an index template"""
- if self.__client__.indices.exists_index_template(index_template):
- self.__client__.indices.delete_template(
- name=index_template, ignore=[400, 404]
- )
-
- def delete_index(self, index: str):
- """Deletes an elasticsearch index"""
- if self.index_exists(index):
- self.__client__.indices.delete(index, ignore=[400, 404])
-
- def add_document(self, index: str, doc_id: str, document: Dict[str, Any]):
- """
- Creates/updates a document in an index
-
- See
-
- Parameters
- ----------
- index:
- The index name
- doc_id:
- The document id
- document:
- The document source
-
- """
- self.__client__.index(index=index, body=document, id=doc_id, refresh="wait_for")
-
- def get_document_by_id(self, index: str, doc_id: str) -> Optional[Dict[str, Any]]:
- """
- Get a document by its id
-
- See
-
- Parameters
- ----------
- index:
- The index name
- doc_id:
- The document id
-
- Returns
- -------
- The elasticsearch document if found, None otherwise
- """
-
- try:
- if self.__client__.exists(index=index, id=doc_id):
- return self.__client__.get(index=index, id=doc_id)
- except NotFoundError:
- return None
-
- def delete_document(self, index: str, doc_id: str):
- """
- Deletes a document from an index.
-
- See
-
- Parameters
- ----------
- index:
- The index name
- doc_id:
- The document id
-
- Returns
- -------
-
- """
- if self.__client__.exists(index=index, id=doc_id):
- self.__client__.delete(index=index, id=doc_id, refresh=True)
-
- def add_documents(
- self,
- index: str,
- documents: List[Dict[str, Any]],
- routing: Callable[[Dict[str, Any]], str] = None,
- doc_id: Callable[[Dict[str, Any]], str] = None,
- ) -> int:
- """
- Adds or updated a set of documents to an index. Documents can contains
- partial information of document.
-
- See
-
- Parameters
- ----------
- index:
- The index name
- documents:
- The set of documents
- routing:
- The routing key
- doc_id
-
- Returns
- -------
- The number of failed documents
- """
-
- def map_doc_2_action(doc: Dict[str, Any]) -> Dict[str, Any]:
- """Configures bulk action"""
- data = {
- "_op_type": "index",
- "_index": index,
- "_routing": routing(doc) if routing else None,
- **doc,
- }
-
- _id = doc_id(doc) if doc_id else None
- if _id is not None:
- data["_id"] = _id
-
- return data
-
- success, failed = es_bulk(
- self.__client__,
- index=index,
- actions=map(map_doc_2_action, documents),
- raise_on_error=True,
- refresh="wait_for",
- )
- return len(failed)
-
- def get_mapping(self, index: str) -> Dict[str, Any]:
- """
- Return the configured index mapping
-
- See ``
-
- """
- try:
- response = self.__client__.indices.get_mapping(
- index=index,
- ignore_unavailable=False,
- include_type_name=True,
- )
- return list(response[index]["mappings"].values())[0]["properties"]
- except NotFoundError:
- return {}
-
- def get_field_mapping(
- self,
- index: str,
- field_name: Optional[str] = None,
- exclude_subfields: bool = False,
- ) -> Dict[str, str]:
- """
- Returns the mapping for a given field name (can be as wildcard notation). The result
- consist on a dictionary with full field name as key and its type as value
-
- See
-
- Parameters
- ----------
- index:
- The index name
- field_name:
- The field name pattern
- exclude_subfields:
- If True, exclude extra subfields from mappings definition
-
- Returns
- -------
- A dictionary with full field name as key and its type as value
- """
- try:
- response = self.__client__.indices.get_field_mapping(
- fields=field_name or "*",
- index=index,
- ignore_unavailable=False,
- )
- data = {
- key: list(definition["mapping"].values())[0]["type"]
- for key, definition in response[index]["mappings"].items()
- }
-
- if exclude_subfields:
- # Remove `text`, `exact` and `wordcloud` fields
- def is_subfield(key: str):
- for suffix in ["exact", "text", "wordcloud"]:
- if suffix in key:
- return True
- return False
-
- data = {
- key: value for key, value in data.items() if not is_subfield(key)
- }
-
- return data
- except NotFoundError:
- # No mapping data
- return {}
-
- def update_document(
- self,
- index: str,
- doc_id: str,
- document: Optional[Dict[str, Any]] = None,
- script: Optional[str] = None,
- partial_update: bool = False,
- ):
- """
- Updates a document in a given index
-
- Parameters
- ----------
- index:
- The index name
- doc_id:
- The document id
- document:
- The document data. Could be partial document info
- partial_update:
- If True, document contains partial info, and will be
- merged with stored document. If false, the stored document
- will be overwritten. Default=False
-
- Returns
- -------
-
- """
- # TODO: validate either doc or script are provided
- if partial_update:
- body = {"script": script} if script else {"doc": document}
-
- self.__client__.update(
- index=index,
- id=doc_id,
- body=body,
- refresh=True,
- retry_on_conflict=500, # TODO: configurable
- )
- else:
- self.__client__.index(index=index, id=doc_id, body=document, refresh=True)
-
- def open_index(self, index: str):
- """
- Open an elasticsearch index. If index is already open, this operation will do nothing
-
- See ``_
-
- Parameters
- ----------
- index:
- The index name
- """
- self.__client__.indices.open(
- index=index, wait_for_active_shards=settings.es_records_index_shards
- )
-
- def close_index(self, index: str):
- """
- Closes an elasticsearch index. If index is already closed, this operation will do nothing.
-
- See ``_
-
- Parameters
- ----------
- index:
- The index name
- """
- self.__client__.indices.close(
- index=index,
- ignore_unavailable=True,
- wait_for_active_shards=settings.es_records_index_shards,
- )
-
- def clone_index(self, index: str, clone_to: str, override: bool = True):
- """
- Clone an existing index. During index clone, source must be setup as read-only index. Then, changes can be
- applied
-
- See ``_
-
- Parameters
- ----------
- index:
- The source index name
- clone_to:
- The destination index name
- override:
- If True, destination index will be removed if exists
- """
- index_read_only = self.is_index_read_only(index)
- try:
- if not index_read_only:
- self.index_read_only(index, read_only=True)
- if override:
- self.delete_index(clone_to)
- self.__client__.indices.clone(
- index=index,
- target=clone_to,
- wait_for_active_shards=settings.es_records_index_shards,
- )
- finally:
- self.index_read_only(index, read_only=index_read_only)
- self.index_read_only(clone_to, read_only=index_read_only)
-
- def is_index_read_only(self, index: str) -> bool:
- """
- Fetch info about read-only configuration index
-
- Parameters
- ----------
- index:
- The index name
-
- Returns
- -------
- True if queried index is read-only, False otherwise
-
- """
- response = self.__client__.indices.get_settings(
- index=index,
- name="index.blocks.write",
- allow_no_indices=True,
- flat_settings=True,
- )
- return (
- response[index]["settings"]["index.blocks.write"] == "true"
- if response
- else False
- )
-
- def index_read_only(self, index: str, read_only: bool):
- """
- Enable/disable index read only
-
- Parameters
- ----------
- index:
- The index name
- read_only:
- True for enable read-only, False otherwise
-
- """
- self.__client__.indices.put_settings(
- index=index,
- body={"settings": {"index.blocks.write": read_only}},
- ignore=404,
- )
-
- def create_field_mapping(
- self,
- index: str,
- field_name: str,
- mapping: Dict[str, Any],
- ):
- """Creates or updates an index field mapping configuration"""
- self.__client__.indices.put_mapping(
- index=index,
- body={"properties": {field_name: mapping}},
- )
-
- def get_cluster_info(self) -> Dict[str, Any]:
- """Returns basic about es cluster"""
- try:
- return self.__client__.info()
- except OpenSearchException as ex:
- return {"error": ex}
-
- def aggregate(self, index: str, aggregation: Dict[str, Any]) -> Dict[str, Any]:
- """Apply an aggregation over the index returning ONLY the agg results"""
- aggregation_name = "aggregation"
- results = self.search(
- index=index, size=0, query={"aggs": {aggregation_name: aggregation}}
- )
-
- return query_helpers.parse_aggregations(results["aggregations"]).get(
- aggregation_name
- )
-
-
-_instance = None # The singleton instance
-
-
-@deprecated.deprecated(reason="Use `ElasticsearchWrapper.get_instance` instead")
-def create_es_wrapper() -> ElasticsearchWrapper:
- """
- Creates an instance of ElasticsearchWrapper.
-
- This function is used in fastapi for resolve component dependencies.
-
- See
-
- Returns
- -------
-
- """
-
- global _instance
- if _instance is None:
- _instance = ElasticsearchWrapper.get_instance()
-
- return _instance
diff --git a/src/rubrix/server/apis/v0/helpers.py b/src/rubrix/server/helpers.py
similarity index 68%
rename from src/rubrix/server/apis/v0/helpers.py
rename to src/rubrix/server/helpers.py
index 2a17f0bef0..959b4ab103 100644
--- a/src/rubrix/server/apis/v0/helpers.py
+++ b/src/rubrix/server/helpers.py
@@ -12,31 +12,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+"""
+Common helper functions
+"""
+from typing import Any, Dict, List, Optional
-def takeuntil(iterable, limit: int):
+
+def unflatten_dict(
+ data: Dict[str, Any], sep: str = ".", stop_keys: Optional[List[str]] = None
+) -> Dict[str, Any]:
"""
- Iterate over inner iterable until a count limit
+ Given a flat dictionary keys, build a hierarchical version by grouping keys
Parameters
----------
- iterable:
- The inner iterable
- limit:
- The limit
+ data:
+ The data dictionary
+ sep:
+ The key separator. Default "."
+ stop_keys
+ List of dictionary first level keys where hierarchy will stop
Returns
-------
"""
- count = 0
- for e in iterable:
- if count < limit:
- yield e
- count += 1
- else:
- break
+ resultDict = {}
+ stop_keys = stop_keys or []
+ for key, value in data.items():
+ if key is not None:
+ parts = key.split(sep)
+ if parts[0] in stop_keys:
+ parts = [parts[0], sep.join(parts[1:])]
+ d = resultDict
+ for part in parts[:-1]:
+ if part not in d:
+ d[part] = {}
+ d = d[part]
+ d[parts[-1]] = value
+ return resultDict
def flatten_dict(
@@ -81,3 +96,27 @@ def _flatten_internal_(_data: Dict[str, Any], _parent_key="", _sep="."):
return dict(items)
return _flatten_internal_(data, _sep=sep)
+
+
+def takeuntil(iterable, limit: int):
+ """
+ Iterate over inner iterable until a count limit
+
+ Parameters
+ ----------
+ iterable:
+ The inner iterable
+ limit:
+ The limit
+
+ Returns
+ -------
+
+ """
+ count = 0
+ for e in iterable:
+ if count < limit:
+ yield e
+ count += 1
+ else:
+ break
diff --git a/src/rubrix/server/server.py b/src/rubrix/server/server.py
index c2cfcf3f26..d5f316cfc0 100644
--- a/src/rubrix/server/server.py
+++ b/src/rubrix/server/server.py
@@ -27,9 +27,9 @@
from rubrix import __version__ as rubrix_version
from rubrix.logging import configure_logging
+from rubrix.server.daos.backend.elasticsearch import ElasticsearchBackend
from rubrix.server.daos.datasets import DatasetsDAO
from rubrix.server.daos.records import DatasetRecordsDAO
-from rubrix.server.elasticseach.client_wrapper import create_es_wrapper
from rubrix.server.errors import APIErrorHandler, EntityNotFoundError
from rubrix.server.routes import api_router
from rubrix.server.security import auth
@@ -89,7 +89,7 @@ async def configure_elasticsearch():
import opensearchpy
try:
- es_wrapper = create_es_wrapper()
+ es_wrapper = ElasticsearchBackend.get_instance()
dataset_records: DatasetRecordsDAO = DatasetRecordsDAO(es_wrapper)
datasets: DatasetsDAO = DatasetsDAO.get_instance(
es_wrapper, records_dao=dataset_records
diff --git a/src/rubrix/server/services/datasets.py b/src/rubrix/server/services/datasets.py
index 51b09fd371..5d2d67264e 100644
--- a/src/rubrix/server/services/datasets.py
+++ b/src/rubrix/server/services/datasets.py
@@ -18,8 +18,8 @@
from fastapi import Depends
-from rubrix.server.daos.datasets import BaseDatasetDB, DatasetsDAO, SettingsDB
-from rubrix.server.daos.models.datasets import DatasetDB
+from rubrix.server.daos.datasets import BaseDatasetSettingsDB, DatasetsDAO
+from rubrix.server.daos.models.datasets import BaseDatasetDB
from rubrix.server.errors import (
EntityAlreadyExistsError,
EntityNotFoundError,
@@ -28,13 +28,21 @@
)
from rubrix.server.security.model import User
-Dataset = TypeVar("Dataset", bound=BaseDatasetDB)
+
+class ServiceBaseDataset(BaseDatasetDB):
+ pass
-class SVCDatasetSettings(SettingsDB):
+class ServiceBaseDatasetSettings(BaseDatasetSettingsDB):
pass
+ServiceDataset = TypeVar("ServiceDataset", bound=ServiceBaseDataset)
+ServiceDatasetSettings = TypeVar(
+ "ServiceDatasetSettings", bound=ServiceBaseDatasetSettings
+)
+
+
class DatasetsService:
_INSTANCE: "DatasetsService" = None
@@ -50,9 +58,7 @@ def get_instance(
def __init__(self, dao: DatasetsDAO):
self.__dao__ = dao
- def create_dataset(
- self, user: User, dataset: Dataset, mappings: Dict[str, Any]
- ) -> Dataset:
+ def create_dataset(self, user: User, dataset: ServiceDataset) -> ServiceDataset:
user.check_workspace(dataset.owner)
try:
@@ -60,11 +66,11 @@ def create_dataset(
user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner
)
raise EntityAlreadyExistsError(
- name=dataset.name, type=Dataset, workspace=dataset.owner
+ name=dataset.name, type=ServiceDataset, workspace=dataset.owner
)
except WrongTaskError: # Found a dataset with same name but different task
raise EntityAlreadyExistsError(
- name=dataset.name, type=Dataset, workspace=dataset.owner
+ name=dataset.name, type=ServiceDataset, workspace=dataset.owner
)
except EntityNotFoundError:
# The dataset does not exist -> create it !
@@ -72,16 +78,16 @@ def create_dataset(
dataset.created_by = user.username
dataset.created_at = date_now
dataset.last_updated = date_now
- return self.__dao__.create_dataset(dataset, mappings=mappings)
+ return self.__dao__.create_dataset(dataset)
def find_by_name(
self,
user: User,
name: str,
- as_dataset_class: Type[BaseDatasetDB] = DatasetDB,
+ as_dataset_class: Type[ServiceDataset] = ServiceBaseDataset,
task: Optional[str] = None,
workspace: Optional[str] = None,
- ) -> Dataset:
+ ) -> ServiceDataset:
owner = user.check_workspace(workspace)
if task is None:
@@ -96,20 +102,20 @@ def find_by_name(
)
if found_ds is None:
- raise EntityNotFoundError(name=name, type=Dataset)
+ raise EntityNotFoundError(name=name, type=ServiceDataset)
if found_ds.owner and owner and found_ds.owner != owner:
raise EntityNotFoundError(
- name=name, type=Dataset
+ name=name, type=ServiceDataset
) if user.is_superuser() else ForbiddenOperationError()
- return cast(Dataset, found_ds)
+ return cast(ServiceDataset, found_ds)
def __find_by_name_with_superuser_fallback__(
self,
user: User,
name: str,
owner: Optional[str],
- as_dataset_class: Optional[Type[DatasetDB]],
+ as_dataset_class: Optional[Type[ServiceDataset]],
task: Optional[str] = None,
):
found_ds = self.__dao__.find_by_name(
@@ -125,7 +131,7 @@ def __find_by_name_with_superuser_fallback__(
pass
return found_ds
- def delete(self, user: User, dataset: Dataset):
+ def delete(self, user: User, dataset: ServiceDataset):
user.check_workspace(dataset.owner)
found = self.__find_by_name_with_superuser_fallback__(
user=user,
@@ -140,10 +146,10 @@ def delete(self, user: User, dataset: Dataset):
def update(
self,
user: User,
- dataset: Dataset,
+ dataset: ServiceDataset,
tags: Dict[str, str],
metadata: Dict[str, Any],
- ) -> Dataset:
+ ) -> ServiceDataset:
found = self.find_by_name(
user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner
)
@@ -159,30 +165,30 @@ def list(
self,
user: User,
workspaces: Optional[List[str]],
- task2dataset_map: Dict[str, Type[BaseDatasetDB]] = None,
- ) -> List[Dataset]:
+ task2dataset_map: Dict[str, Type[ServiceDataset]] = None,
+ ) -> List[ServiceDataset]:
owners = user.check_workspaces(workspaces)
return self.__dao__.list_datasets(
owner_list=owners, task2dataset_map=task2dataset_map
)
- def close(self, user: User, dataset: Dataset):
+ def close(self, user: User, dataset: ServiceDataset):
found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.owner)
self.__dao__.close(found)
- def open(self, user: User, dataset: Dataset):
+ def open(self, user: User, dataset: ServiceDataset):
found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.owner)
self.__dao__.open(found)
def copy_dataset(
self,
user: User,
- dataset: Dataset,
+ dataset: ServiceDataset,
copy_name: str,
copy_workspace: Optional[str] = None,
copy_tags: Dict[str, Any] = None,
copy_metadata: Dict[str, Any] = None,
- ) -> Dataset:
+ ) -> ServiceDataset:
dataset_workspace = copy_workspace or dataset.owner
dataset_workspace = user.check_workspace(dataset_workspace)
@@ -221,58 +227,23 @@ def copy_dataset(
return copy_dataset
- def all_workspaces(self) -> List[str]:
- """Retrieve all dataset workspaces"""
-
- workspaces = self.__dao__.get_all_workspaces()
- # include the non-workspace workspace?
- return workspaces
-
async def get_settings(
- self, user: User, dataset: Dataset, class_type: Type[SVCDatasetSettings]
- ) -> SVCDatasetSettings:
- """
- Get the configured settings for dataset
-
- Args:
- user: the connected user
- dataset: the target dataset
- class_type: the settings class
-
- Returns:
- An instance of class_type settings configured for provided dataset
-
- """
+ self,
+ user: User,
+ dataset: ServiceDataset,
+ class_type: Type[ServiceDatasetSettings],
+ ) -> ServiceDatasetSettings:
settings = self.__dao__.load_settings(dataset=dataset, as_class=class_type)
if not settings:
raise EntityNotFoundError(name=dataset.name, type=class_type)
return class_type.parse_obj(settings.dict())
async def save_settings(
- self, user: User, dataset: Dataset, settings: SVCDatasetSettings
- ) -> SVCDatasetSettings:
- """
- Save a set of settings for a dataset
+ self, user: User, dataset: ServiceDataset, settings: ServiceDatasetSettings
+ ) -> ServiceDatasetSettings:
- Args:
- user: The user executing the command
- dataset: The dataset
- settings: The dataset settings
-
- Returns:
- Stored dataset settings
-
- """
self.__dao__.save_settings(dataset=dataset, settings=settings)
return settings
- async def delete_settings(self, user: User, dataset: Dataset) -> None:
- """
- Deletes the dataset settings
-
- Args:
- user: The user executing the command
- dataset: The dataset
-
- """
+ async def delete_settings(self, user: User, dataset: ServiceDataset) -> None:
self.__dao__.delete_settings(dataset=dataset)
diff --git a/src/rubrix/server/services/info.py b/src/rubrix/server/services/info.py
index 2fe96d51ba..716b8d7a4e 100644
--- a/src/rubrix/server/services/info.py
+++ b/src/rubrix/server/services/info.py
@@ -14,15 +14,52 @@
# limitations under the License.
import os
-from typing import Any, Dict, Optional
+from typing import Any, Dict
import psutil
from fastapi import Depends
-from hurry.filesize import size
+from pydantic import BaseModel
from rubrix import __version__ as rubrix_version
-from rubrix.server.apis.v0.models.info import ApiStatus
-from rubrix.server.elasticseach.client_wrapper import ElasticsearchWrapper
+from rubrix.server.daos.backend.elasticsearch import ElasticsearchBackend
+
+
+def size(bytes):
+ system = [
+ (1024 ** 5, "P"),
+ (1024 ** 4, "T"),
+ (1024 ** 3, "G"),
+ (1024 ** 2, "M"),
+ (1024 ** 1, "K"),
+ (1024 ** 0, "B"),
+ ]
+
+ factor, suffix = None, None
+ for factor, suffix in system:
+ if bytes >= factor:
+ break
+
+ amount = int(bytes / factor)
+ if isinstance(suffix, tuple):
+ singular, multiple = suffix
+ if amount == 1:
+ suffix = singular
+ else:
+ suffix = multiple
+ return str(amount) + suffix
+
+
+class ApiInfo(BaseModel):
+ """Basic api info"""
+
+ rubrix_version: str
+
+
+class ApiStatus(ApiInfo):
+ """The Rubrix api status model"""
+
+ elasticsearch: Dict[str, Any]
+ mem_info: Dict[str, Any]
class ApiInfoService:
@@ -30,7 +67,22 @@ class ApiInfoService:
The api info service
"""
- def __init__(self, es: ElasticsearchWrapper):
+ _INSTANCE = None
+
+ @classmethod
+ def get_instance(
+ cls,
+ backend: ElasticsearchBackend = Depends(ElasticsearchBackend.get_instance),
+ ) -> "ApiInfoService":
+ """
+ Creates an api info service
+ """
+
+ if not cls._INSTANCE:
+ cls._INSTANCE = ApiInfoService(backend)
+ return cls._INSTANCE
+
+ def __init__(self, es: ElasticsearchBackend):
self.__es__ = es
def api_status(self) -> ApiStatus:
@@ -50,19 +102,3 @@ def _api_memory_info() -> Dict[str, Any]:
"""Fetch the api process memory usage"""
process = psutil.Process(os.getpid())
return {k: size(v) for k, v in process.memory_info()._asdict().items()}
-
-
-_instance: Optional[ApiInfoService] = None
-
-
-def create_info_service(
- es_wrapper: ElasticsearchWrapper = Depends(ElasticsearchWrapper.get_instance),
-) -> ApiInfoService:
- """
- Creates an api info service
- """
-
- global _instance
- if not _instance:
- _instance = ApiInfoService(es_wrapper)
- return _instance
diff --git a/src/rubrix/server/services/metrics.py b/src/rubrix/server/services/metrics.py
deleted file mode 100644
index 4b69dd60ca..0000000000
--- a/src/rubrix/server/services/metrics.py
+++ /dev/null
@@ -1,204 +0,0 @@
-from typing import Callable, Optional, Type, TypeVar, Union
-
-from fastapi import Depends
-
-from rubrix.server.apis.v0.models.metrics.base import (
- ElasticsearchMetric,
- NestedPathElasticsearchMetric,
- PythonMetric,
-)
-from rubrix.server.apis.v0.models.metrics.commons import *
-from rubrix.server.daos.models.records import RecordSearch
-from rubrix.server.daos.records import DatasetRecordsDAO, dataset_records_dao
-from rubrix.server.errors import WrongInputParamError
-from rubrix.server.services.datasets import Dataset
-from rubrix.server.services.search.query_builder import EsQueryBuilder
-from rubrix.server.services.tasks.commons.record import BaseRecordDB
-
-GenericQuery = TypeVar("GenericQuery")
-
-
-class MetricsService:
- """The dataset metrics service singleton"""
-
- _INSTANCE = None
-
- @classmethod
- def get_instance(
- cls,
- dao: DatasetRecordsDAO = Depends(dataset_records_dao),
- query_builder: EsQueryBuilder = Depends(EsQueryBuilder.get_instance),
- ) -> "MetricsService":
- """
- Creates the service instance.
-
- Parameters
- ----------
- dao:
- The dataset records dao
-
- Returns
- -------
- The metrics service instance
-
- """
- if not cls._INSTANCE:
- cls._INSTANCE = cls(dao, query_builder=query_builder)
- return cls._INSTANCE
-
- def __init__(self, dao: DatasetRecordsDAO, query_builder: EsQueryBuilder):
- """
- Creates a service instance
-
- Parameters
- ----------
- dao:
- The dataset records dao
- """
- self.__dao__ = dao
- self.__query_builder__ = query_builder
-
- def summarize_metric(
- self,
- dataset: Dataset,
- metric: BaseMetric,
- record_class: Optional[Type[BaseRecordDB]] = None,
- query: Optional[GenericQuery] = None,
- **metric_params,
- ) -> Dict[str, Any]:
- """
- Applies a metric summarization.
-
- Parameters
- ----------
- dataset:
- The records dataset
- metric:
- The selected metric
- query:
- An optional query passed for records filtering
- metric_params:
- Related metrics parameters
-
- Returns
- -------
- The metric summarization info
- """
-
- if isinstance(metric, ElasticsearchMetric):
- return self._handle_elasticsearch_metric(
- metric, metric_params, dataset=dataset, query=query
- )
- elif isinstance(metric, PythonMetric):
- records = self.__dao__.scan_dataset(
- dataset,
- search=RecordSearch(query=self.__query_builder__(dataset, query=query)),
- )
- return metric.apply(map(record_class.parse_obj, records))
-
- raise WrongInputParamError(f"Cannot process {metric} of type {type(metric)}")
-
- def _handle_elasticsearch_metric(
- self,
- metric: ElasticsearchMetric,
- metric_params: Dict[str, Any],
- dataset: Dataset,
- query: GenericQuery,
- ) -> Dict[str, Any]:
- """
- Parameters
- ----------
- metric:
- The elasticsearch metric summary
- metric_params:
- The summary params
- dataset:
- The records dataset
- query:
- The filter to apply to dataset
-
- Returns
- -------
- The metric summary result
-
- """
- params = self.__compute_metric_params__(
- dataset=dataset, metric=metric, query=query, provided_params=metric_params
- )
- results = self.__metric_results__(
- dataset=dataset,
- query=query,
- metric_aggregation=metric.aggregation_request(**params),
- )
- return metric.aggregation_result(
- aggregation_result=results.get(metric.id, results)
- )
-
- def __compute_metric_params__(
- self,
- dataset: Dataset,
- metric: ElasticsearchMetric,
- query: Optional[GenericQuery],
- provided_params: Dict[str, Any],
- ) -> Dict[str, Any]:
-
- return self._filter_metric_params(
- metric=metric,
- function=metric.aggregation_request,
- metric_params={
- **provided_params, # in case of param conflict, provided metric params will be preserved
- "dataset": dataset,
- "dao": self.__dao__,
- },
- )
-
- def __metric_results__(
- self,
- dataset: Dataset,
- query: Optional[GenericQuery],
- metric_aggregation: Union[Dict[str, Any], List[Dict[str, Any]]],
- ) -> Dict[str, Any]:
-
- if not metric_aggregation:
- return {}
-
- if not isinstance(metric_aggregation, list):
- metric_aggregation = [metric_aggregation]
-
- results = {}
- for agg in metric_aggregation:
- results_ = self.__dao__.search_records(
- dataset,
- size=0, # No records at all
- search=RecordSearch(
- query=self.__query_builder__(dataset, query=query),
- aggregations=agg,
- ),
- )
- results.update(results_.aggregations)
- return results
-
- @staticmethod
- def _filter_metric_params(
- metric: ElasticsearchMetric, function: Callable, metric_params: Dict[str, Any]
- ):
- """
- Select from provided metric parameter those who can be applied to given metric
-
- Parameters
- ----------
- metric:
- The target metric
- metric_params:
- A dict of metric parameters
-
- """
-
- if isinstance(metric, NestedPathElasticsearchMetric):
- function = metric.inner_aggregation
-
- return {
- argument: metric_params[argument]
- for argument in function.__code__.co_varnames
- if argument in metric_params
- }
diff --git a/src/rubrix/server/services/metrics/__init__.py b/src/rubrix/server/services/metrics/__init__.py
new file mode 100644
index 0000000000..341f097d55
--- /dev/null
+++ b/src/rubrix/server/services/metrics/__init__.py
@@ -0,0 +1,2 @@
+from .models import ServiceBaseMetric, ServicePythonMetric
+from .service import MetricsService
diff --git a/src/rubrix/server/services/metrics/models.py b/src/rubrix/server/services/metrics/models.py
new file mode 100644
index 0000000000..e282cb5919
--- /dev/null
+++ b/src/rubrix/server/services/metrics/models.py
@@ -0,0 +1,156 @@
+from typing import (
+ Any,
+ ClassVar,
+ Dict,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ TypeVar,
+ Union,
+)
+
+from pydantic import BaseModel
+
+from rubrix.server.services.tasks.commons import ServiceRecord
+
+
+class ServiceBaseMetric(BaseModel):
+ """
+ Base model for rubrix dataset metrics summaries
+ """
+
+ id: str
+ name: str
+ description: str = None
+
+
+class ServicePythonMetric(ServiceBaseMetric, Generic[ServiceRecord]):
+ """
+ A metric definition which will be calculated using raw queried data
+ """
+
+ def apply(self, records: Iterable[ServiceRecord]) -> Dict[str, Any]:
+ """
+ ServiceBaseMetric calculation method.
+
+ Parameters
+ ----------
+ records:
+ The matched records
+
+ Returns
+ -------
+ The metric result
+ """
+ raise NotImplementedError()
+
+
+ServiceMetric = TypeVar("ServiceMetric", bound=ServiceBaseMetric)
+
+
+class ServiceBaseTaskMetrics(BaseModel):
+ """
+ Base class encapsulating related task metrics
+
+ Attributes:
+ -----------
+
+ metrics:
+ A list of configured metrics for task
+ """
+
+ metrics: ClassVar[List[Union[ServicePythonMetric, str]]]
+
+ @classmethod
+ def find_metric(cls, id: str) -> Optional[Union[ServicePythonMetric, str]]:
+ """
+ Finds a metric by id
+
+ Parameters
+ ----------
+ id:
+ The metric id
+
+ Returns
+ -------
+ Found metric if any, ``None`` otherwise
+
+ """
+ for metric in cls.metrics:
+ if isinstance(metric, str) and metric == id:
+ return metric
+ if metric.id == id:
+ return metric
+
+ @classmethod
+ def record_metrics(cls, record: ServiceRecord) -> Dict[str, Any]:
+ """
+ Use this method is some configured metric requires additional
+ records fields.
+
+ Generated records will be persisted under ``metrics`` record path.
+ For example, if you define a field called ``sentence_length`` like
+
+ >>> def record_metrics(cls, record)-> Dict[str, Any]:
+ ... return { "sentence_length" : len(record.text) }
+
+ The new field will be stored in elasticsearch in ``metrics.sentence_length``
+
+ Parameters
+ ----------
+ record:
+ The record used for calculate metrics fields
+
+ Returns
+ -------
+ A dict with calculated metrics fields
+ """
+ return {}
+
+
+class CommonTasksMetrics(ServiceBaseTaskMetrics, Generic[ServiceRecord]):
+ """Common task metrics"""
+
+ @classmethod
+ def record_metrics(cls, record: ServiceRecord) -> Dict[str, Any]:
+ """Record metrics will persist the text_length"""
+ return {"text_length": len(record.all_text())}
+
+ metrics: ClassVar[List[ServiceBaseMetric]] = [
+ ServiceBaseMetric(
+ id="text_length",
+ name="Text length distribution",
+ description="Computes the input text length distribution",
+ ),
+ ServiceBaseMetric(
+ id="error_distribution",
+ name="Error distribution",
+ description="Computes the dataset error distribution. It's mean, records "
+ "with correct predictions vs records with incorrect prediction "
+ "vs records with unknown prediction result",
+ ),
+ ServiceBaseMetric(
+ id="status_distribution",
+ name="Record status distribution",
+ description="The dataset record status distribution",
+ ),
+ ServiceBaseMetric(
+ id="words_cloud",
+ name="Inputs words cloud",
+ description="The words cloud for dataset inputs",
+ ),
+ ServiceBaseMetric(id="metadata", name="Metadata fields stats"),
+ ServiceBaseMetric(
+ id="predicted_by",
+ name="Predicted by distribution",
+ ),
+ ServiceBaseMetric(
+ id="annotated_by",
+ name="Annotated by distribution",
+ ),
+ ServiceBaseMetric(
+ id="score",
+ name="Score record distribution",
+ ),
+ ]
diff --git a/src/rubrix/server/services/metrics/service.py b/src/rubrix/server/services/metrics/service.py
new file mode 100644
index 0000000000..23b544dc4d
--- /dev/null
+++ b/src/rubrix/server/services/metrics/service.py
@@ -0,0 +1,91 @@
+from typing import Any, Dict, Optional, Type
+
+from fastapi import Depends
+
+from rubrix.server.daos.models.records import DaoRecordsSearch
+from rubrix.server.daos.records import DatasetRecordsDAO
+from rubrix.server.services.datasets import ServiceDataset
+from rubrix.server.services.metrics.models import ServiceMetric, ServicePythonMetric
+from rubrix.server.services.search.model import ServiceRecordsQuery
+from rubrix.server.services.tasks.commons import ServiceRecord
+
+
+class MetricsService:
+ """The dataset metrics service singleton"""
+
+ _INSTANCE = None
+
+ @classmethod
+ def get_instance(
+ cls,
+ dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance),
+ ) -> "MetricsService":
+ """
+ Creates the service instance.
+
+ Parameters
+ ----------
+ dao:
+ The dataset records dao
+
+ Returns
+ -------
+ The metrics service instance
+
+ """
+ if not cls._INSTANCE:
+ cls._INSTANCE = cls(dao)
+ return cls._INSTANCE
+
+ def __init__(self, dao: DatasetRecordsDAO):
+ """
+ Creates a service instance
+
+ Parameters
+ ----------
+ dao:
+ The dataset records dao
+ """
+ self.__dao__ = dao
+
+ def summarize_metric(
+ self,
+ dataset: ServiceDataset,
+ metric: ServiceMetric,
+ record_class: Optional[Type[ServiceRecord]] = None,
+ query: Optional[ServiceRecordsQuery] = None,
+ **metric_params,
+ ) -> Dict[str, Any]:
+ """
+ Applies a metric summarization.
+
+ Parameters
+ ----------
+ dataset:
+ The records dataset
+ metric:
+ The selected metric
+ record_class:
+ The record class type for python metrics computation
+ query:
+ An optional query passed for records filtering
+ metric_params:
+ Related metrics parameters
+
+ Returns
+ -------
+ The metric summarization info
+ """
+
+ if isinstance(metric, ServicePythonMetric):
+ records = self.__dao__.scan_dataset(
+ dataset, search=DaoRecordsSearch(query=query)
+ )
+ return metric.apply(map(record_class.parse_obj, records))
+
+ return self.__dao__.compute_metric(
+ metric_id=metric.id,
+ metric_params=metric_params,
+ dataset=dataset,
+ query=query,
+ )
diff --git a/src/rubrix/server/services/search/model.py b/src/rubrix/server/services/search/model.py
index f206809849..3878d437d1 100644
--- a/src/rubrix/server/services/search/model.py
+++ b/src/rubrix/server/services/search/model.py
@@ -1,55 +1,39 @@
-from enum import Enum
-from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
+from typing import Any, Dict, List, TypeVar
from pydantic import BaseModel, Field
-from pydantic.generics import GenericModel
-from rubrix.server.services.tasks.commons.record import Record, TaskStatus
+from rubrix.server.daos.backend.search.model import (
+ BaseRecordsQuery,
+ QueryRange,
+ SortableField,
+ SortConfig,
+)
+from rubrix.server.services.tasks.commons import ServiceRecord
-class SortOrder(str, Enum):
- asc = "asc"
- desc = "desc"
+class ServiceBaseRecordsQuery(BaseRecordsQuery):
+ pass
-class SortableField(BaseModel):
- """Sortable field structure"""
-
- id: str
- order: SortOrder = SortOrder.asc
-
-
-class BaseSearchQuery(BaseModel):
-
- query_text: Optional[str] = None
- advanced_query_dsl: bool = False
-
- ids: Optional[List[Union[str, int]]]
-
- annotated_by: List[str] = Field(default_factory=list)
- predicted_by: List[str] = Field(default_factory=list)
-
- status: List[TaskStatus] = Field(default_factory=list)
- metadata: Optional[Dict[str, Union[str, List[str]]]] = None
+class ServiceSortConfig(SortConfig):
+ pass
-class QueryRange(BaseModel):
+class ServiceSortableField(SortableField):
+ """Sortable field structure"""
- range_from: float = Field(default=0.0, alias="from")
- range_to: float = Field(default=None, alias="to")
+ pass
- class Config:
- allow_population_by_field_name = True
+class ServiceQueryRange(QueryRange):
+ pass
-class SortConfig(BaseModel):
- shuffle: bool = False
- sort_by: List[SortableField] = Field(default_factory=list)
- valid_fields: List[str] = Field(default_factory=list)
+class ServiceScoreRange(ServiceQueryRange):
+ pass
-class BaseSearchResultsAggregations(BaseModel):
+class ServiceBaseSearchResultsAggregations(BaseModel):
predicted_as: Dict[str, int] = Field(default_factory=dict)
annotated_as: Dict[str, int] = Field(default_factory=dict)
@@ -62,30 +46,15 @@ class BaseSearchResultsAggregations(BaseModel):
metadata: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
-Aggregations = TypeVar("Aggregations", bound=BaseSearchResultsAggregations)
-
-
-class BaseSearchResults(GenericModel, Generic[Record, Aggregations]):
- """
- API search results
-
- Attributes:
- -----------
+ServiceSearchResultsAggregations = TypeVar(
+ "ServiceSearchResultsAggregations", bound=ServiceBaseSearchResultsAggregations
+)
- total:
- The total number of records
- records:
- The selected records to return
- aggregations:
- Requested aggregations
- """
- total: int = 0
- records: List[Record] = Field(default_factory=list)
- aggregations: Aggregations = None
-
-
-class SearchResults(BaseModel):
+class ServiceSearchResults(BaseModel):
total: int
- records: List[Record]
+ records: List[ServiceRecord]
metrics: Dict[str, Any] = Field(default_factory=dict)
+
+
+ServiceRecordsQuery = TypeVar("ServiceRecordsQuery", bound=ServiceBaseRecordsQuery)
diff --git a/src/rubrix/server/services/search/query_builder.py b/src/rubrix/server/services/search/query_builder.py
deleted file mode 100644
index 2f1a982e2c..0000000000
--- a/src/rubrix/server/services/search/query_builder.py
+++ /dev/null
@@ -1,108 +0,0 @@
-import logging
-from enum import Enum
-from typing import Any, Dict, Optional, TypeVar
-
-from fastapi import Depends
-from luqum.elasticsearch import ElasticsearchQueryBuilder, SchemaAnalyzer
-from luqum.parser import parser
-
-from rubrix.server.daos.models.datasets import BaseDatasetDB
-from rubrix.server.daos.records import DatasetRecordsDAO
-from rubrix.server.elasticseach.query_helpers import filters
-from rubrix.server.services.search.model import BaseSearchQuery, QueryRange
-
-SearchQuery = TypeVar("SearchQuery", bound=BaseSearchQuery)
-
-
-class EsQueryBuilder:
- _INSTANCE: "EsQueryBuilder" = None
- _LOGGER = logging.getLogger(__name__)
-
- @classmethod
- def get_instance(
- cls, dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance)
- ):
- if not cls._INSTANCE:
- cls._INSTANCE = cls(dao=dao)
- return cls._INSTANCE
-
- def __init__(self, dao: DatasetRecordsDAO):
- self.__dao__ = dao
-
- def __call__(
- self, dataset: BaseDatasetDB, query: Optional[SearchQuery] = None
- ) -> Dict[str, Any]:
-
- if not query:
- return filters.match_all()
-
- if not query.advanced_query_dsl or not query.query_text:
- return self.to_es_query(query)
-
- text_search = query.query_text
- new_query = query.copy(update={"query_text": None})
-
- schema = self.__dao__.get_dataset_schema(dataset)
- schema = SchemaAnalyzer(schema)
- es_query_builder = ElasticsearchQueryBuilder(
- **{
- **schema.query_builder_options(),
- "default_field": "text",
- } # TODO: This will change
- )
-
- query_tree = parser.parse(text_search)
- query_text = es_query_builder(query_tree)
-
- return filters.boolean_filter(
- filter_query=self.to_es_query(new_query), must_query=query_text
- )
-
- @classmethod
- def to_es_query(cls, query: BaseSearchQuery) -> Dict[str, Any]:
- if query.ids:
- return filters.ids_filter(query.ids)
-
- query_text = filters.text_query(query.query_text)
- all_filters = filters.metadata(query.metadata)
- query_data = query.dict(
- exclude={
- "advanced_query_dsl",
- "query_text",
- "metadata",
- "uncovered_by_rules",
- }
- )
- for key, value in query_data.items():
- if value is None:
- continue
- key_filter = None
- if isinstance(value, dict):
- value = getattr(query, key) # check the original field type
- if isinstance(value, list):
- key_filter = filters.terms_filter(key, value)
- elif isinstance(value, (str, Enum)):
- key_filter = filters.term_filter(key, value)
- elif isinstance(value, QueryRange):
- key_filter = filters.range_filter(
- field=key, value_from=value.range_from, value_to=value.range_to
- )
- else:
- cls._LOGGER.warning(f"Cannot parse query value {value} for key {key}")
-
- if key_filter:
- all_filters.append(key_filter)
-
- return filters.boolean_filter(
- must_query=query_text or filters.match_all(),
- filter_query=filters.boolean_filter(
- should_filters=all_filters, minimum_should_match=len(all_filters)
- )
- if all_filters
- else None,
- must_not_query=filters.boolean_filter(
- should_filters=[filters.text_query(q) for q in query.uncovered_by_rules]
- )
- if hasattr(query, "uncovered_by_rules") and query.uncovered_by_rules
- else None,
- )
diff --git a/src/rubrix/server/services/search/service.py b/src/rubrix/server/services/search/service.py
index 891f2d4bfd..862150ed89 100644
--- a/src/rubrix/server/services/search/service.py
+++ b/src/rubrix/server/services/search/service.py
@@ -3,22 +3,16 @@
from fastapi import Depends
-from rubrix.server.apis.v0.models.metrics.base import BaseMetric
-from rubrix.server.daos.models.records import RecordSearch
+from rubrix.server.daos.models.records import DaoRecordsSearch
from rubrix.server.daos.records import DatasetRecordsDAO
-from rubrix.server.elasticseach.query_helpers import sort_by2elasticsearch
-from rubrix.server.services.datasets import Dataset
+from rubrix.server.services.datasets import ServiceDataset
from rubrix.server.services.metrics import MetricsService
+from rubrix.server.services.metrics.models import ServiceMetric
from rubrix.server.services.search.model import (
- BaseSearchQuery,
- Record,
- SearchResults,
- SortConfig,
-)
-from rubrix.server.services.search.query_builder import EsQueryBuilder
-from rubrix.server.services.tasks.commons.record import (
- BaseRecordDB,
- EsRecordDataFieldNames,
+ ServiceRecord,
+ ServiceRecordsQuery,
+ ServiceSearchResults,
+ ServiceSortConfig,
)
@@ -34,59 +28,39 @@ def get_instance(
cls,
dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance),
metrics: MetricsService = Depends(MetricsService.get_instance),
- query_builder: EsQueryBuilder = Depends(EsQueryBuilder.get_instance),
):
if not cls._INSTANCE:
- cls._INSTANCE = cls(dao=dao, metrics=metrics, query_builder=query_builder)
+ cls._INSTANCE = cls(dao=dao, metrics=metrics)
return cls._INSTANCE
def __init__(
self,
dao: DatasetRecordsDAO,
metrics: MetricsService,
- query_builder: EsQueryBuilder,
):
self.__dao__ = dao
self.__metrics__ = metrics
- self.__query_builder__ = query_builder
def search(
self,
- dataset: Dataset,
- record_type: Type[BaseRecordDB],
- query: Optional[BaseSearchQuery] = None,
- sort_config: Optional[SortConfig] = None,
+ dataset: ServiceDataset,
+ record_type: Type[ServiceRecord],
+ query: Optional[ServiceRecordsQuery] = None,
+ sort_config: Optional[ServiceSortConfig] = None,
record_from: int = 0,
size: int = 100,
exclude_metrics: bool = True,
- metrics: Optional[List[BaseMetric]] = None,
- ) -> SearchResults:
+ metrics: Optional[List[ServiceMetric]] = None,
+ ) -> ServiceSearchResults:
if record_from > 0:
metrics = None
- sort_config = sort_config or SortConfig()
+ sort_config = sort_config or ServiceSortConfig()
exclude_fields = ["metrics.*"] if exclude_metrics else None
results = self.__dao__.search_records(
dataset,
- search=RecordSearch(
- query=self.__query_builder__(dataset, query),
- sort=sort_by2elasticsearch(
- sort_config.sort_by,
- valid_fields=[
- "metadata",
- EsRecordDataFieldNames.last_updated,
- EsRecordDataFieldNames.score,
- EsRecordDataFieldNames.predicted,
- EsRecordDataFieldNames.predicted_as,
- EsRecordDataFieldNames.predicted_by,
- EsRecordDataFieldNames.annotated_as,
- EsRecordDataFieldNames.annotated_by,
- EsRecordDataFieldNames.status,
- EsRecordDataFieldNames.event_timestamp,
- ],
- ),
- ),
+ search=DaoRecordsSearch(query=query, sort=sort_config),
size=size,
record_from=record_from,
exclude_fields=exclude_fields,
@@ -110,7 +84,7 @@ def search(
)
metrics_results[metric.id] = {}
- return SearchResults(
+ return ServiceSearchResults(
total=results.total,
records=[record_type.parse_obj(r) for r in results.records],
metrics=metrics_results if metrics_results else {},
@@ -118,14 +92,15 @@ def search(
def scan_records(
self,
- dataset: Dataset,
- record_type: Type[BaseRecordDB],
- query: Optional[BaseSearchQuery] = None,
+ dataset: ServiceDataset,
+ record_type: Type[ServiceRecord],
+ query: Optional[ServiceRecordsQuery] = None,
id_from: Optional[str] = None,
limit: int = 1000
- ) -> Iterable[Record]:
+ ) -> Iterable[ServiceRecord]:
"""Scan records for a queried"""
+ search = DaoRecordsSearch(query=query)
for doc in self.__dao__.scan_dataset(
- dataset, search=RecordSearch(query=self.__query_builder__(dataset, query)), id_from=id_from, limit=limit
+ dataset, id_from=id_from, limit=limit, search=search
):
yield record_type.parse_obj(doc)
diff --git a/src/rubrix/server/services/storage/service.py b/src/rubrix/server/services/storage/service.py
index 1d89d20d1c..72007fcc56 100644
--- a/src/rubrix/server/services/storage/service.py
+++ b/src/rubrix/server/services/storage/service.py
@@ -1,11 +1,11 @@
-from typing import Any, Dict, List, Optional, Type
+from typing import List, Type
from fastapi import Depends
-from rubrix.server.apis.v0.models.commons.model import Record
-from rubrix.server.apis.v0.models.datasets import BaseDatasetDB
-from rubrix.server.apis.v0.models.metrics.base import BaseTaskMetrics
+from rubrix.server.commons.config import TasksFactory
from rubrix.server.daos.records import DatasetRecordsDAO
+from rubrix.server.services.datasets import ServiceDataset
+from rubrix.server.services.tasks.commons import ServiceRecord
class RecordsStorageService:
@@ -26,20 +26,18 @@ def __init__(self, dao: DatasetRecordsDAO):
def store_records(
self,
- dataset: BaseDatasetDB,
- mappings: Dict[str, Any],
- records: List[Record],
- record_type: Type[Record],
- metrics: Optional[Type[BaseTaskMetrics]] = None,
+ dataset: ServiceDataset,
+ records: List[ServiceRecord],
+ record_type: Type[ServiceRecord],
) -> int:
"""Store a set of records"""
+ metrics = TasksFactory.get_task_metrics(dataset.task)
if metrics:
for record in records:
record.metrics = metrics.record_metrics(record)
return self.__dao__.add_records(
dataset=dataset,
- mappings=mappings,
records=records,
record_class=record_type,
)
diff --git a/src/rubrix/server/services/tasks/commons/__init__.py b/src/rubrix/server/services/tasks/commons/__init__.py
index 4958b62db9..aed4fa323c 100644
--- a/src/rubrix/server/services/tasks/commons/__init__.py
+++ b/src/rubrix/server/services/tasks/commons/__init__.py
@@ -1,2 +1 @@
-from .logging import *
-from .record import *
+from .models import *
diff --git a/src/rubrix/server/services/tasks/commons/logging.py b/src/rubrix/server/services/tasks/commons/logging.py
deleted file mode 100644
index 71bbc8b19c..0000000000
--- a/src/rubrix/server/services/tasks/commons/logging.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from pydantic import BaseModel
-
-
-class BulkResponse(BaseModel):
- """
- Data info for bulk results
-
- Attributes
- ----------
-
- dataset:
- The dataset name
- processed:
- Number of records in bulk
- failed:
- Number of failed records
- """
-
- dataset: str
- processed: int
- failed: int = 0
diff --git a/src/rubrix/server/services/tasks/commons/models.py b/src/rubrix/server/services/tasks/commons/models.py
new file mode 100644
index 0000000000..da691433c6
--- /dev/null
+++ b/src/rubrix/server/services/tasks/commons/models.py
@@ -0,0 +1,35 @@
+from typing import Generic, TypeVar
+
+from pydantic import BaseModel
+
+from rubrix.server.daos.models.records import (
+ BaseAnnotationDB,
+ BaseRecordDB,
+ BaseRecordInDB,
+)
+
+
+class ServiceBaseAnnotation(BaseAnnotationDB):
+ pass
+
+
+class BulkResponse(BaseModel):
+ dataset: str
+ processed: int
+ failed: int = 0
+
+
+ServiceAnnotation = TypeVar("ServiceAnnotation", bound=ServiceBaseAnnotation)
+
+
+class ServiceBaseRecordInputs(
+ BaseRecordInDB[ServiceAnnotation], Generic[ServiceAnnotation]
+):
+ pass
+
+
+class ServiceBaseRecord(BaseRecordDB[ServiceAnnotation], Generic[ServiceAnnotation]):
+ pass
+
+
+ServiceRecord = TypeVar("ServiceRecord", bound=ServiceBaseRecord)
diff --git a/src/rubrix/server/services/tasks/commons/record.py b/src/rubrix/server/services/tasks/commons/record.py
deleted file mode 100644
index 04d2d91108..0000000000
--- a/src/rubrix/server/services/tasks/commons/record.py
+++ /dev/null
@@ -1,152 +0,0 @@
-from datetime import datetime
-from enum import Enum
-from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
-from uuid import uuid4
-
-from pydantic import BaseModel, Field, validator
-from pydantic.generics import GenericModel
-
-
-class EsRecordDataFieldNames(str, Enum):
-
- predicted_as = "predicted_as"
- annotated_as = "annotated_as"
- annotated_by = "annotated_by"
- predicted_by = "predicted_by"
- status = "status"
- predicted = "predicted"
- score = "score"
- words = "words"
- event_timestamp = "event_timestamp"
- last_updated = "last_updated"
-
- def __str__(self):
- return self.value
-
-
-class BaseAnnotation(BaseModel):
- agent: str = Field(max_length=64)
-
-
-class TaskType(str, Enum):
-
- text_classification = "TextClassification"
- token_classification = "TokenClassification"
- text2text = "Text2Text"
- multi_task_text_token_classification = "MultitaskTextTokenClassification"
-
-
-class TaskStatus(str, Enum):
- default = "Default"
- edited = "Edited" # TODO: DEPRECATE
- discarded = "Discarded"
- validated = "Validated"
-
-
-class PredictionStatus(str, Enum):
- OK = "ok"
- KO = "ko"
-
-
-Annotation = TypeVar("Annotation", bound=BaseAnnotation)
-
-
-class BaseRecordDB(GenericModel, Generic[Annotation]):
-
- id: Optional[Union[int, str]] = Field(default=None)
- metadata: Dict[str, Any] = Field(default=None)
- event_timestamp: Optional[datetime] = None
- status: Optional[TaskStatus] = None
- prediction: Optional[Annotation] = None
- annotation: Optional[Annotation] = None
- metrics: Dict[str, Any] = Field(default_factory=dict)
- search_keywords: Optional[List[str]] = None
-
- @validator("id", always=True, pre=True)
- def default_id_if_none_provided(cls, id: Optional[str]) -> str:
- """Validates id info and sets a random uuid if not provided"""
- if id is None:
- return str(uuid4())
- return id
-
- @validator("status", always=True)
- def fill_default_value(cls, status: TaskStatus):
- """Fastapi validator for set default task status"""
- return TaskStatus.default if status is None else status
-
- @validator("search_keywords")
- def remove_duplicated_keywords(cls, value) -> List[str]:
- """Remove duplicated keywords"""
- if value:
- return list(set(value))
-
- @classmethod
- def task(cls) -> TaskType:
- """The task type related to this task info"""
- raise NotImplementedError
-
- @property
- def predicted(self) -> Optional[PredictionStatus]:
- """The task record prediction status (if any)"""
- return None
-
- @property
- def predicted_as(self) -> Optional[List[str]]:
- """Predictions strings representation"""
- return None
-
- @property
- def annotated_as(self) -> Optional[List[str]]:
- """Annotations strings representation"""
- return None
-
- @property
- def scores(self) -> Optional[List[float]]:
- """Prediction scores"""
- return None
-
- def all_text(self) -> str:
- """All textual information related to record"""
- raise NotImplementedError
-
- @property
- def predicted_by(self) -> List[str]:
- """The prediction agents"""
- if self.prediction:
- return [self.prediction.agent]
- return []
-
- @property
- def annotated_by(self) -> List[str]:
- """The annotation agents"""
- if self.annotation:
- return [self.annotation.agent]
- return []
-
- def extended_fields(self) -> Dict[str, Any]:
- """
- Used for extends fields to store in db. Tasks that would include extra
- properties than commons (predicted, annotated_as,....) could implement
- this method.
- """
- return {
- EsRecordDataFieldNames.predicted: self.predicted,
- EsRecordDataFieldNames.annotated_as: self.annotated_as,
- EsRecordDataFieldNames.predicted_as: self.predicted_as,
- EsRecordDataFieldNames.annotated_by: self.annotated_by,
- EsRecordDataFieldNames.predicted_by: self.predicted_by,
- EsRecordDataFieldNames.score: self.scores,
- }
-
- def dict(self, *args, **kwargs) -> "DictStrAny":
- """
- Extends base component dict extending object properties
- and user defined extended fields
- """
- return {
- **super().dict(*args, **kwargs),
- **self.extended_fields(),
- }
-
-
-Record = TypeVar("Record", bound=BaseRecordDB)
diff --git a/src/rubrix/server/services/tasks/text2text/__init__.py b/src/rubrix/server/services/tasks/text2text/__init__.py
new file mode 100644
index 0000000000..79de1ed5a3
--- /dev/null
+++ b/src/rubrix/server/services/tasks/text2text/__init__.py
@@ -0,0 +1 @@
+from .service import Text2TextService
diff --git a/src/rubrix/server/services/tasks/text2text/models.py b/src/rubrix/server/services/tasks/text2text/models.py
new file mode 100644
index 0000000000..8b21787008
--- /dev/null
+++ b/src/rubrix/server/services/tasks/text2text/models.py
@@ -0,0 +1,81 @@
+from datetime import datetime
+from typing import Any, Dict, List, Optional
+
+from pydantic import BaseModel, Field
+
+from rubrix.server.commons.models import PredictionStatus, TaskType
+from rubrix.server.services.datasets import ServiceBaseDataset
+from rubrix.server.services.search.model import (
+ ServiceBaseRecordsQuery,
+ ServiceBaseSearchResultsAggregations,
+ ServiceScoreRange,
+ ServiceSearchResults,
+)
+from rubrix.server.services.tasks.commons import (
+ ServiceBaseAnnotation,
+ ServiceBaseRecord,
+)
+
+
+class ServiceText2TextPrediction(BaseModel):
+ text: str
+ score: float
+
+
+class ServiceText2TextAnnotation(ServiceBaseAnnotation):
+ sentences: List[ServiceText2TextPrediction]
+
+
+class ServiceText2TextRecord(ServiceBaseRecord[ServiceText2TextAnnotation]):
+ text: str
+
+ @classmethod
+ def task(cls) -> TaskType:
+ """The task type"""
+ return TaskType.text2text
+
+ def all_text(self) -> str:
+ return self.text
+
+ @property
+ def predicted_as(self) -> Optional[List[str]]:
+ return (
+ [sentence.text for sentence in self.prediction.sentences]
+ if self.prediction
+ else None
+ )
+
+ @property
+ def annotated_as(self) -> Optional[List[str]]:
+ return (
+ [sentence.text for sentence in self.annotation.sentences]
+ if self.annotation
+ else None
+ )
+
+ @property
+ def scores(self) -> List[float]:
+ """Values of prediction scores"""
+ if not self.prediction:
+ return []
+ return [sentence.score for sentence in self.prediction.sentences]
+
+ def extended_fields(self) -> Dict[str, Any]:
+ return {
+ "annotated_as": self.annotated_as,
+ "predicted_as": self.predicted_as,
+ "annotated_by": self.annotated_by,
+ "predicted_by": self.predicted_by,
+ "score": self.scores,
+ "words": self.all_text(),
+ }
+
+
+class ServiceText2TextQuery(ServiceBaseRecordsQuery):
+ score: Optional[ServiceScoreRange] = Field(default=None)
+ predicted: Optional[PredictionStatus] = Field(default=None, nullable=True)
+
+
+class ServiceText2TextDataset(ServiceBaseDataset):
+ task: TaskType = Field(default=TaskType.text2text, const=True)
+ pass
diff --git a/src/rubrix/server/services/text2text.py b/src/rubrix/server/services/tasks/text2text/service.py
similarity index 52%
rename from src/rubrix/server/services/text2text.py
rename to src/rubrix/server/services/tasks/text2text/service.py
index f19e3e926e..0c8f552367 100644
--- a/src/rubrix/server/services/text2text.py
+++ b/src/rubrix/server/services/tasks/text2text/service.py
@@ -13,28 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Iterable, List, Optional, Type
+from typing import Iterable, List, Optional, Type
from fastapi import Depends
-from rubrix.server.apis.v0.models.commons.model import (
- BulkResponse,
- EsRecordDataFieldNames,
- SortableField,
+from rubrix.server.commons.config import TasksFactory
+from rubrix.server.services.metrics.models import ServiceBaseTaskMetrics
+from rubrix.server.services.search.model import (
+ ServiceSearchResults,
+ ServiceSortableField,
+ ServiceSortConfig,
)
-from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics
-from rubrix.server.apis.v0.models.text2text import (
- CreationText2TextRecord,
- Text2TextDatasetDB,
- Text2TextQuery,
- Text2TextRecord,
- Text2TextRecordDB,
- Text2TextSearchAggregations,
- Text2TextSearchResults,
-)
-from rubrix.server.services.search.model import SortConfig
from rubrix.server.services.search.service import SearchRecordsService
from rubrix.server.services.storage.service import RecordsStorageService
+from rubrix.server.services.tasks.commons import BulkResponse
+from rubrix.server.services.tasks.text2text.models import (
+ ServiceText2TextDataset,
+ ServiceText2TextQuery,
+ ServiceText2TextRecord,
+)
class Text2TextService:
@@ -65,70 +62,46 @@ def __init__(
def add_records(
self,
- dataset: Text2TextDatasetDB,
- mappings: Dict[str, Any],
- records: List[CreationText2TextRecord],
- metrics: Type[BaseTaskMetrics],
+ dataset: ServiceText2TextDataset,
+ records: List[ServiceText2TextRecord],
):
failed = self.__storage__.store_records(
dataset=dataset,
- mappings=mappings,
records=records,
- record_type=Text2TextRecordDB,
- metrics=metrics,
+ record_type=ServiceText2TextRecord,
)
return BulkResponse(dataset=dataset.name, processed=len(records), failed=failed)
def search(
self,
- dataset: Text2TextDatasetDB,
- query: Text2TextQuery,
- sort_by: List[SortableField],
+ dataset: ServiceText2TextDataset,
+ query: ServiceText2TextQuery,
+ sort_by: List[ServiceSortableField],
record_from: int = 0,
size: int = 100,
exclude_metrics: bool = True,
- metrics: Optional[List[BaseMetric]] = None,
- ) -> Text2TextSearchResults:
- """
- Run a search in a dataset
-
- Parameters
- ----------
- dataset:
- The records dataset
- query:
- The search parameters
- sort_by:
- The sort by list
- record_from:
- The record from return results
- size:
- The max number of records to return
-
- Returns
- -------
- The matched records with aggregation info for specified task_meta.py
-
- """
+ ) -> ServiceSearchResults:
+
+ metrics = TasksFactory.find_task_metrics(
+ dataset.task,
+ metric_ids={
+ "words_cloud",
+ "predicted_by",
+ "annotated_by",
+ "status_distribution",
+ "metadata",
+ "score",
+ },
+ )
results = self.__search__.search(
dataset,
query=query,
size=size,
record_from=record_from,
- record_type=Text2TextRecord,
- sort_config=SortConfig(
+ record_type=ServiceText2TextRecord,
+ sort_config=ServiceSortConfig(
sort_by=sort_by,
- valid_fields=[
- "metadata",
- EsRecordDataFieldNames.predicted_as,
- EsRecordDataFieldNames.annotated_as,
- EsRecordDataFieldNames.predicted_by,
- EsRecordDataFieldNames.annotated_by,
- EsRecordDataFieldNames.status,
- EsRecordDataFieldNames.last_updated,
- EsRecordDataFieldNames.event_timestamp,
- ],
),
exclude_metrics=exclude_metrics,
metrics=metrics,
@@ -138,21 +111,15 @@ def search(
results.metrics["words"] = results.metrics["words_cloud"]
results.metrics["status"] = results.metrics["status_distribution"]
- return Text2TextSearchResults(
- total=results.total,
- records=results.records,
- aggregations=Text2TextSearchAggregations.parse_obj(results.metrics)
- if results.metrics
- else None,
- )
+ return results
def read_dataset(
self,
- dataset: Text2TextDatasetDB,
- query: Optional[Text2TextQuery] = None,
+ dataset: ServiceText2TextDataset,
+ query: Optional[ServiceText2TextQuery] = None,
id_from: Optional[str] = None,
limit: int = 1000
- ) -> Iterable[Text2TextRecord]:
+ ) -> Iterable[ServiceText2TextRecord]:
"""
Scan a dataset records
@@ -170,8 +137,5 @@ def read_dataset(
"""
yield from self.__search__.scan_records(
- dataset, query=query, record_type=Text2TextRecord, id_from=id_from, limit=limit,
+ dataset, query=query, record_type=ServiceText2TextRecord,id_from=id_from, limit=limit
)
-
-
-text2text_service = Text2TextService.get_instance
diff --git a/src/rubrix/server/services/tasks/text_classification/__init__.py b/src/rubrix/server/services/tasks/text_classification/__init__.py
new file mode 100644
index 0000000000..dbc9d953c5
--- /dev/null
+++ b/src/rubrix/server/services/tasks/text_classification/__init__.py
@@ -0,0 +1,2 @@
+from .labeling_rules_service import LabelingService
+from .service import TextClassificationService
diff --git a/src/rubrix/server/services/tasks/text_classification/labeling_rules_service.py b/src/rubrix/server/services/tasks/text_classification/labeling_rules_service.py
new file mode 100644
index 0000000000..0f5cc3a0ee
--- /dev/null
+++ b/src/rubrix/server/services/tasks/text_classification/labeling_rules_service.py
@@ -0,0 +1,138 @@
+from typing import List, Optional, Tuple
+
+from fastapi import Depends
+from pydantic import BaseModel, Field
+
+from rubrix.server.daos.datasets import DatasetsDAO
+from rubrix.server.daos.models.records import DaoRecordsSearch
+from rubrix.server.daos.records import DatasetRecordsDAO
+from rubrix.server.errors import EntityAlreadyExistsError, EntityNotFoundError
+from rubrix.server.services.search.model import ServiceBaseRecordsQuery
+from rubrix.server.services.tasks.text_classification.model import (
+ ServiceLabelingRule,
+ ServiceTextClassificationDataset,
+)
+
+
+class DatasetLabelingRulesSummary(BaseModel):
+ covered_records: int
+ annotated_covered_records: int
+
+
+class LabelingRuleSummary(BaseModel):
+ covered_records: int
+ annotated_covered_records: int
+ correct_records: int = Field(default=0)
+ incorrect_records: int = Field(default=0)
+ precision: Optional[float] = None
+
+
+class LabelingService:
+
+ _INSTANCE = None
+
+ @classmethod
+ def get_instance(
+ cls,
+ datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance),
+ records: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance),
+ ):
+ if cls._INSTANCE is None:
+ cls._INSTANCE = cls(datasets, records)
+ return cls._INSTANCE
+
+ def __init__(self, datasets: DatasetsDAO, records: DatasetRecordsDAO):
+ self.__datasets__ = datasets
+ self.__records__ = records
+
+ # TODO(@frascuchon): Move all rules management methods to the common datasets service like settings
+ def list_rules(
+ self, dataset: ServiceTextClassificationDataset
+ ) -> List[ServiceLabelingRule]:
+ """List a set of rules for a given dataset"""
+ return dataset.rules
+
+ def delete_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str):
+ """Delete a rule from a dataset by its defined query string"""
+ new_rules_set = [r for r in dataset.rules if r.query != rule_query]
+ if len(dataset.rules) != new_rules_set:
+ dataset.rules = new_rules_set
+ self.__datasets__.update_dataset(dataset)
+
+ def add_rule(
+ self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule
+ ) -> ServiceLabelingRule:
+ """Adds a rule to a dataset"""
+ for r in dataset.rules:
+ if r.query == rule.query:
+ raise EntityAlreadyExistsError(rule.query, type=ServiceLabelingRule)
+ dataset.rules.append(rule)
+ self.__datasets__.update_dataset(dataset)
+ return rule
+
+ def compute_rule_metrics(
+ self,
+ dataset: ServiceTextClassificationDataset,
+ rule_query: str,
+ labels: Optional[List[str]] = None,
+ ) -> Tuple[int, int, LabelingRuleSummary]:
+ """Computes metrics for given rule query and optional label against a set of rules"""
+
+ annotated_records = self._count_annotated_records(dataset)
+ dataset_records = self.__records__.search_records(dataset, size=0).total
+ metric_data = self.__records__.compute_metric(
+ dataset=dataset,
+ metric_id="labeling_rule",
+ metric_params=dict(rule_query=rule_query, labels=labels),
+ )
+
+ return (
+ dataset_records,
+ annotated_records,
+ LabelingRuleSummary.parse_obj(metric_data),
+ )
+
+ def _count_annotated_records(
+ self, dataset: ServiceTextClassificationDataset
+ ) -> int:
+ results = self.__records__.search_records(
+ dataset,
+ size=0,
+ search=DaoRecordsSearch(query=ServiceBaseRecordsQuery(has_annotation=True)),
+ )
+ return results.total
+
+ def all_rules_metrics(
+ self, dataset: ServiceTextClassificationDataset
+ ) -> Tuple[int, int, DatasetLabelingRulesSummary]:
+ annotated_records = self._count_annotated_records(dataset)
+ dataset_records = self.__records__.search_records(dataset, size=0).total
+ metric_data = self.__records__.compute_metric(
+ dataset=dataset,
+ metric_id="dataset_labeling_rules",
+ metric_params=dict(queries=[r.query for r in dataset.rules]),
+ )
+
+ return (
+ dataset_records,
+ annotated_records,
+ DatasetLabelingRulesSummary.parse_obj(metric_data),
+ )
+
+ def find_rule_by_query(
+ self, dataset: ServiceTextClassificationDataset, rule_query: str
+ ) -> ServiceLabelingRule:
+ rule_query = rule_query.strip()
+ for rule in dataset.rules:
+ if rule.query == rule_query:
+ return rule
+ raise EntityNotFoundError(rule_query, type=ServiceLabelingRule)
+
+ def replace_rule(
+ self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule
+ ):
+ for idx, r in enumerate(dataset.rules):
+ if r.query == rule.query:
+ dataset.rules[idx] = rule
+ break
+ self.__datasets__.update_dataset(dataset)
diff --git a/src/rubrix/server/apis/v0/models/metrics/text_classification.py b/src/rubrix/server/services/tasks/text_classification/metrics.py
similarity index 69%
rename from src/rubrix/server/apis/v0/models/metrics/text_classification.py
rename to src/rubrix/server/services/tasks/text_classification/metrics.py
index 266fae6ce5..5153cda301 100644
--- a/src/rubrix/server/apis/v0/models/metrics/text_classification.py
+++ b/src/rubrix/server/services/tasks/text_classification/metrics.py
@@ -1,19 +1,17 @@
-from typing import Any, ClassVar, Dict, Iterable, List, Optional, Set
+from typing import Any, ClassVar, Dict, Iterable, List
from pydantic import Field
from sklearn.metrics import precision_recall_fscore_support
from sklearn.preprocessing import MultiLabelBinarizer
-from rubrix.server.apis.v0.models.metrics.base import (
- BaseMetric,
- PythonMetric,
- TermsAggregation,
+from rubrix.server.services.metrics import ServiceBaseMetric, ServicePythonMetric
+from rubrix.server.services.metrics.models import CommonTasksMetrics
+from rubrix.server.services.tasks.text_classification.model import (
+ ServiceTextClassificationRecord,
)
-from rubrix.server.apis.v0.models.metrics.commons import CommonTasksMetrics
-from rubrix.server.apis.v0.models.text_classification import TextClassificationRecord
-class F1Metric(PythonMetric):
+class F1Metric(ServicePythonMetric):
"""
A basic f1 computation for text classification
@@ -25,7 +23,7 @@ class F1Metric(PythonMetric):
multi_label: bool = False
- def apply(self, records: Iterable[TextClassificationRecord]) -> Any:
+ def apply(self, records: Iterable[ServiceTextClassificationRecord]) -> Any:
filtered_records = list(filter(lambda r: r.predicted is not None, records))
# TODO: This must be precomputed with using a global dataset metric
ds_labels = {
@@ -92,12 +90,14 @@ def apply(self, records: Iterable[TextClassificationRecord]) -> Any:
}
-class DatasetLabels(PythonMetric):
+class DatasetLabels(ServicePythonMetric):
id: str = Field("dataset_labels", const=True)
name: str = Field("The dataset labels", const=True)
max_processed_records: int = 10000
- def apply(self, records: Iterable[TextClassificationRecord]) -> Dict[str, Any]:
+ def apply(
+ self, records: Iterable[ServiceTextClassificationRecord]
+ ) -> Dict[str, Any]:
ds_labels = set()
for _ in range(
0, self.max_processed_records
@@ -117,30 +117,33 @@ def apply(self, records: Iterable[TextClassificationRecord]) -> Dict[str, Any]:
return {"labels": ds_labels or []}
-class TextClassificationMetrics(CommonTasksMetrics[TextClassificationRecord]):
+class TextClassificationMetrics(CommonTasksMetrics[ServiceTextClassificationRecord]):
"""Configured metrics for text classification task"""
- metrics: ClassVar[List[BaseMetric]] = CommonTasksMetrics.metrics + [
- TermsAggregation(
- id="predicted_as",
- name="Predicted labels distribution",
- field="predicted_as",
- ),
- TermsAggregation(
- id="annotated_as",
- name="Annotated labels distribution",
- field="annotated_as",
- ),
- F1Metric(
- id="F1",
- name="F1 Metrics for single-label",
- description="F1 Metrics for single-label (averaged and per label)",
- ),
- F1Metric(
- id="MultiLabelF1",
- name="F1 Metrics for multi-label",
- description="F1 Metrics for multi-label (averaged and per label)",
- multi_label=True,
- ),
- DatasetLabels(),
- ]
+ metrics: ClassVar[List[ServiceBaseMetric]] = (
+ CommonTasksMetrics.metrics
+ + [
+ F1Metric(
+ id="F1",
+ name="F1 Metrics for single-label",
+ description="F1 Metrics for single-label (averaged and per label)",
+ ),
+ F1Metric(
+ id="MultiLabelF1",
+ name="F1 Metrics for multi-label",
+ description="F1 Metrics for multi-label (averaged and per label)",
+ multi_label=True,
+ ),
+ DatasetLabels(),
+ ]
+ + [
+ ServiceBaseMetric(
+ id="predicted_as",
+ name="Predicted labels distribution",
+ ),
+ ServiceBaseMetric(
+ id="annotated_as",
+ name="Annotated labels distribution",
+ ),
+ ]
+ )
diff --git a/src/rubrix/server/services/tasks/text_classification/model.py b/src/rubrix/server/services/tasks/text_classification/model.py
new file mode 100644
index 0000000000..3bd8ff3927
--- /dev/null
+++ b/src/rubrix/server/services/tasks/text_classification/model.py
@@ -0,0 +1,344 @@
+# coding=utf-8
+# Copyright 2021-present, the Recognai S.L. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from datetime import datetime
+from typing import Any, ClassVar, Dict, List, Optional, Union
+
+from pydantic import BaseModel, Field, root_validator, validator
+
+from rubrix._constants import MAX_KEYWORD_LENGTH
+from rubrix.server.commons.models import PredictionStatus, TaskStatus, TaskType
+from rubrix.server.helpers import flatten_dict
+from rubrix.server.services.datasets import ServiceBaseDataset
+from rubrix.server.services.search.model import (
+ ServiceBaseRecordsQuery,
+ ServiceScoreRange,
+)
+from rubrix.server.services.tasks.commons import (
+ ServiceBaseAnnotation,
+ ServiceBaseRecord,
+)
+
+
+class ServiceLabelingRule(BaseModel):
+ query: str = Field(description="The es rule query")
+
+ author: str = Field(description="User who created the rule")
+ created_at: Optional[datetime] = Field(
+ default_factory=datetime.utcnow, description="Rule creation timestamp"
+ )
+
+ label: Optional[str] = Field(
+ default=None, description="@Deprecated::The label associated with the rule."
+ )
+ labels: List[str] = Field(
+ default_factory=list,
+ description="For multi label problems, a list of labels. "
+ "It will replace the `label` field",
+ )
+ description: Optional[str] = Field(
+ None, description="A brief description of the rule"
+ )
+
+ @root_validator
+ def initialize_labels(cls, values):
+ label = values.get("label", None)
+ labels = values.get("labels", [])
+
+ if label:
+ labels.append(label)
+ values["labels"] = list(set(labels))
+
+ assert len(labels) >= 1, f"No labels was provided in rule {values}"
+ return values
+
+ @validator("query")
+ def strip_query(cls, query: str) -> str:
+ """Remove blank spaces for query"""
+ return query.strip()
+
+
+class ServiceTextClassificationDataset(ServiceBaseDataset):
+
+ task: TaskType = Field(default=TaskType.text_classification, const=True)
+ rules: List[ServiceLabelingRule] = Field(default_factory=list)
+
+
+class ClassPrediction(BaseModel):
+ """
+ Single class prediction
+
+ Attributes:
+ -----------
+
+ class_label: Union[str, int]
+ the predicted class
+
+ score: float
+ the predicted class score. For human-supervised annotations,
+ this probability should be 1.0
+ """
+
+ class_label: Union[str, int] = Field(alias="class")
+ score: float = Field(default=1.0, ge=0.0, le=1.0)
+
+ @validator("class_label")
+ def check_label_length(cls, class_label):
+ if isinstance(class_label, str):
+ assert 1 <= len(class_label) <= MAX_KEYWORD_LENGTH, (
+ f"Class name '{class_label}' exceeds max length of {MAX_KEYWORD_LENGTH}"
+ if len(class_label) > MAX_KEYWORD_LENGTH
+ else f"Class name must not be empty"
+ )
+ return class_label
+
+ # See
+ class Config:
+ allow_population_by_field_name = True
+
+
+class LabelingRuleMetricsSummary(BaseModel):
+ """Metrics generated for a labeling rule"""
+
+ coverage: Optional[float] = None
+ coverage_annotated: Optional[float] = None
+ correct: Optional[float] = None
+ incorrect: Optional[float] = None
+ precision: Optional[float] = None
+
+ total_records: int
+ annotated_records: int
+
+
+class DatasetLabelingRulesMetricsSummary(BaseModel):
+ coverage: Optional[float] = None
+ coverage_annotated: Optional[float] = None
+
+ total_records: int
+ annotated_records: int
+
+
+class TextClassificationAnnotation(ServiceBaseAnnotation):
+ """
+ Annotation class for text classification tasks
+
+ Attributes:
+ -----------
+
+ labels: List[LabelPrediction]
+ list of annotated labels with score
+ """
+
+ # TODO(@frascuchon): labels must be a dict (to avoid repeat labels)
+ labels: List[ClassPrediction]
+
+ @validator("labels")
+ def sort_labels(cls, labels: List[ClassPrediction]):
+ """Sort provided labels by score"""
+ return sorted(labels, key=lambda x: x.score, reverse=True)
+
+
+class TokenAttributions(BaseModel):
+ """
+ The token attributions explaining predicted labels
+
+ Attributes:
+ -----------
+
+ token: str
+ The input token
+ attributions: Dict[str, float]
+ A dictionary containing label class-attribution pairs
+
+ """
+
+ token: str
+ attributions: Dict[str, float] = Field(default_factory=dict)
+
+
+class ServiceTextClassificationRecord(ServiceBaseRecord[TextClassificationAnnotation]):
+ inputs: Dict[str, Union[str, List[str]]]
+ multi_label: bool = False
+ explanation: Optional[Dict[str, List[TokenAttributions]]] = None
+
+ class Config:
+ allow_population_by_field_name = True
+
+ _SCORE_DEVIATION_ERROR: ClassVar[float] = 0.001
+
+ @root_validator
+ def validate_record(cls, values):
+ """fastapi validator method"""
+ prediction = values.get("prediction", None)
+ annotation = values.get("annotation", None)
+ status = values.get("status")
+ multi_label = values.get("multi_label", False)
+
+ cls._check_score_integrity(prediction, multi_label)
+ cls._check_annotation_integrity(annotation, multi_label, status)
+
+ return values
+
+ @classmethod
+ def _check_annotation_integrity(
+ cls,
+ annotation: TextClassificationAnnotation,
+ multi_label: bool,
+ status: TaskStatus,
+ ):
+ if status == TaskStatus.validated and not multi_label:
+ assert (
+ annotation and len(annotation.labels) > 0
+ ), "Annotation must include some label for validated records"
+
+ if not multi_label and annotation:
+ assert (
+ len(annotation.labels) == 1
+ ), "Single label record must include only one annotation label"
+
+ @classmethod
+ def _check_score_integrity(
+ cls, prediction: TextClassificationAnnotation, multi_label: bool
+ ):
+ """
+ Checks the score value integrity
+
+ Parameters
+ ----------
+ prediction:
+ The prediction annotation
+ multi_label:
+ If multi label
+
+ """
+ if prediction and not multi_label:
+ assert sum([label.score for label in prediction.labels]) <= (
+ 1.0 + cls._SCORE_DEVIATION_ERROR
+ ), f"Wrong score distributions: {prediction.labels}"
+
+ @classmethod
+ def task(cls) -> TaskType:
+ """The task type"""
+ return TaskType.text_classification
+
+ @property
+ def predicted(self) -> Optional[PredictionStatus]:
+ if self.predicted_by and self.annotated_by:
+ return (
+ PredictionStatus.OK
+ if set(self.predicted_as) == set(self.annotated_as)
+ else PredictionStatus.KO
+ )
+ return None
+
+ @property
+ def predicted_as(self) -> List[str]:
+ return self._labels_from_annotation(
+ self.prediction, multi_label=self.multi_label
+ )
+
+ @property
+ def annotated_as(self) -> List[str]:
+ return self._labels_from_annotation(
+ self.annotation, multi_label=self.multi_label
+ )
+
+ @property
+ def scores(self) -> List[float]:
+ if not self.prediction:
+ return []
+ return (
+ [label.score for label in self.prediction.labels]
+ if self.multi_label
+ else [
+ prediction_class.score
+ for prediction_class in [
+ self._max_class_prediction(
+ self.prediction, multi_label=self.multi_label
+ )
+ ]
+ if prediction_class
+ ]
+ )
+
+ def all_text(self) -> str:
+ sentences = []
+ for v in self.inputs.values():
+ if isinstance(v, list):
+ sentences.extend(v)
+ else:
+ sentences.append(v)
+ return "\n".join(sentences)
+
+ @validator("inputs")
+ def validate_inputs(cls, text: Dict[str, Any]):
+ assert len(text) > 0, "No inputs provided"
+
+ for t in text.values():
+ assert t is not None, "Cannot include None fields"
+
+ return text
+
+ @validator("inputs")
+ def flatten_text(cls, text: Dict[str, Any]):
+ flat_dict = flatten_dict(text)
+ return flat_dict
+
+ @classmethod
+ def _labels_from_annotation(
+ cls, annotation: TextClassificationAnnotation, multi_label: bool
+ ) -> Union[List[str], List[int]]:
+
+ if not annotation:
+ return []
+
+ if multi_label:
+ return [
+ label.class_label for label in annotation.labels if label.score > 0.5
+ ]
+
+ class_prediction = cls._max_class_prediction(
+ annotation, multi_label=multi_label
+ )
+ if class_prediction is None:
+ return []
+
+ return [class_prediction.class_label]
+
+ @staticmethod
+ def _max_class_prediction(
+ p: TextClassificationAnnotation, multi_label: bool
+ ) -> Optional[ClassPrediction]:
+ if multi_label or p is None or not p.labels:
+ return None
+ return p.labels[0]
+
+ def extended_fields(self) -> Dict[str, Any]:
+ words = self.all_text()
+ return {
+ **super().extended_fields(),
+ "words": words,
+ "text": words,
+ }
+
+
+class ServiceTextClassificationQuery(ServiceBaseRecordsQuery):
+
+ predicted_as: List[str] = Field(default_factory=list)
+ annotated_as: List[str] = Field(default_factory=list)
+ score: Optional[ServiceScoreRange] = Field(default=None)
+ predicted: Optional[PredictionStatus] = Field(default=None, nullable=True)
+
+ uncovered_by_rules: List[str] = Field(default_factory=list)
diff --git a/src/rubrix/server/services/text_classification.py b/src/rubrix/server/services/tasks/text_classification/service.py
similarity index 67%
rename from src/rubrix/server/services/text_classification.py
rename to src/rubrix/server/services/tasks/text_classification/service.py
index 1c45ff244e..3a372ae443 100644
--- a/src/rubrix/server/services/text_classification.py
+++ b/src/rubrix/server/services/tasks/text_classification/service.py
@@ -13,33 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Iterable, List, Optional, Type
+from typing import Iterable, List, Optional
from fastapi import Depends
-from rubrix.server.apis.v0.models.commons.model import (
- BulkResponse,
- EsRecordDataFieldNames,
- SortableField,
+from rubrix.server.commons.config import TasksFactory
+from rubrix.server.errors.base_errors import MissingDatasetRecordsError
+from rubrix.server.services.search.model import (
+ ServiceSearchResults,
+ ServiceSortableField,
+ ServiceSortConfig,
)
-from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics
-from rubrix.server.apis.v0.models.text_classification import (
- CreationTextClassificationRecord,
+from rubrix.server.services.search.service import SearchRecordsService
+from rubrix.server.services.storage.service import RecordsStorageService
+from rubrix.server.services.tasks.commons import BulkResponse
+from rubrix.server.services.tasks.text_classification import LabelingService
+from rubrix.server.services.tasks.text_classification.model import (
DatasetLabelingRulesMetricsSummary,
- LabelingRule,
LabelingRuleMetricsSummary,
- TextClassificationDatasetDB,
- TextClassificationQuery,
- TextClassificationRecord,
- TextClassificationRecordDB,
- TextClassificationSearchAggregations,
- TextClassificationSearchResults,
+ ServiceLabelingRule,
+ ServiceTextClassificationDataset,
+ ServiceTextClassificationQuery,
+ ServiceTextClassificationRecord,
)
-from rubrix.server.errors.base_errors import MissingDatasetRecordsError
-from rubrix.server.services.search.model import SortConfig
-from rubrix.server.services.search.service import SearchRecordsService
-from rubrix.server.services.storage.service import RecordsStorageService
-from rubrix.server.services.text_classification_labelling_rules import LabelingService
class TextClassificationService:
@@ -73,33 +69,28 @@ def __init__(
def add_records(
self,
- dataset: TextClassificationDatasetDB,
- mappings: Dict[str, Any],
- records: List[CreationTextClassificationRecord],
- metrics: Type[BaseTaskMetrics],
+ dataset: ServiceTextClassificationDataset,
+ records: List[ServiceTextClassificationRecord],
):
# TODO(@frascuchon): This will moved to dataset settings validation once DatasetSettings join the game!
self._check_multi_label_integrity(dataset, records)
failed = self.__storage__.store_records(
dataset=dataset,
- mappings=mappings,
records=records,
- record_type=TextClassificationRecordDB,
- metrics=metrics,
+ record_type=ServiceTextClassificationRecord,
)
return BulkResponse(dataset=dataset.name, processed=len(records), failed=failed)
def search(
self,
- dataset: TextClassificationDatasetDB,
- query: TextClassificationQuery,
- sort_by: List[SortableField],
+ dataset: ServiceTextClassificationDataset,
+ query: ServiceTextClassificationQuery,
+ sort_by: List[ServiceSortableField],
record_from: int = 0,
size: int = 100,
exclude_metrics: bool = True,
- metrics: Optional[List[BaseMetric]] = None,
- ) -> TextClassificationSearchResults:
+ ) -> ServiceSearchResults:
"""
Run a search in a dataset
@@ -121,28 +112,31 @@ def search(
"""
+ metrics = TasksFactory.find_task_metrics(
+ dataset.task,
+ metric_ids={
+ "words_cloud",
+ "predicted_by",
+ "predicted_as",
+ "annotated_by",
+ "annotated_as",
+ "error_distribution",
+ "status_distribution",
+ "metadata",
+ "score",
+ },
+ )
+
results = self.__search__.search(
dataset,
query=query,
- record_type=TextClassificationRecord,
+ record_type=ServiceTextClassificationRecord,
record_from=record_from,
size=size,
exclude_metrics=exclude_metrics,
metrics=metrics,
- sort_config=SortConfig(
+ sort_config=ServiceSortConfig(
sort_by=sort_by,
- valid_fields=[
- "metadata",
- EsRecordDataFieldNames.last_updated,
- EsRecordDataFieldNames.score,
- EsRecordDataFieldNames.predicted,
- EsRecordDataFieldNames.predicted_as,
- EsRecordDataFieldNames.predicted_by,
- EsRecordDataFieldNames.annotated_as,
- EsRecordDataFieldNames.annotated_by,
- EsRecordDataFieldNames.status,
- EsRecordDataFieldNames.event_timestamp,
- ],
),
)
@@ -152,21 +146,15 @@ def search(
results.metrics["predicted"] = results.metrics["error_distribution"]
results.metrics["predicted"].pop("unknown", None)
- return TextClassificationSearchResults(
- total=results.total,
- records=results.records,
- aggregations=TextClassificationSearchAggregations.parse_obj(results.metrics)
- if results.metrics
- else None,
- )
+ return results
def read_dataset(
self,
- dataset: TextClassificationDatasetDB,
- query: Optional[TextClassificationQuery] = None,
+ dataset: ServiceTextClassificationDataset,
+ query: Optional[ServiceTextClassificationQuery] = None,
id_from: Optional[str] = None,
limit: int = 1000
- ) -> Iterable[TextClassificationRecord]:
+ ) -> Iterable[ServiceTextClassificationRecord]:
"""
Scan a dataset records
@@ -184,13 +172,13 @@ def read_dataset(
"""
yield from self.__search__.scan_records(
- dataset, query=query, record_type=TextClassificationRecord, id_from=id_from, limit=limit
+ dataset, query=query, record_type=ServiceTextClassificationRecord, id_from=id_from, limit=limit
)
def _check_multi_label_integrity(
self,
- dataset: TextClassificationDatasetDB,
- records: List[CreationTextClassificationRecord],
+ dataset: ServiceTextClassificationDataset,
+ records: List[ServiceTextClassificationRecord],
):
is_multi_label_dataset = self._is_dataset_multi_label(dataset)
if is_multi_label_dataset is not None:
@@ -203,12 +191,12 @@ def _check_multi_label_integrity(
)
def _is_dataset_multi_label(
- self, dataset: TextClassificationDatasetDB
+ self, dataset: ServiceTextClassificationDataset
) -> Optional[bool]:
try:
results = self.__search__.search(
dataset,
- record_type=TextClassificationRecord,
+ record_type=ServiceTextClassificationRecord,
size=1,
)
except MissingDatasetRecordsError: # No records index yet
@@ -217,25 +205,13 @@ def _is_dataset_multi_label(
return results.records[0].multi_label
def get_labeling_rules(
- self, dataset: TextClassificationDatasetDB
- ) -> Iterable[LabelingRule]:
- """
- Gets rules for a given dataset
+ self, dataset: ServiceTextClassificationDataset
+ ) -> Iterable[ServiceLabelingRule]:
- Parameters
- ----------
- dataset:
- The dataset
-
- Returns
- -------
- A list of labeling rules for a given dataset
-
- """
return self.__labeling__.list_rules(dataset)
def add_labeling_rule(
- self, dataset: TextClassificationDatasetDB, rule: LabelingRule
+ self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule
) -> None:
"""
Adds a labeling rule
@@ -253,24 +229,11 @@ def add_labeling_rule(
def update_labeling_rule(
self,
- dataset: TextClassificationDatasetDB,
+ dataset: ServiceTextClassificationDataset,
rule_query: str,
labels: List[str],
description: Optional[str] = None,
- ) -> LabelingRule:
- """
- Update a labeling rule. Updatable fields are label and/or description
-
- Args:
- dataset: The dataset
- rule_query: The labeling rule
- label: The new rule label
- description: If provided, the new rule description
-
- Returns:
- Updated labeling rule
-
- """
+ ) -> ServiceLabelingRule:
found_rule = self.__labeling__.find_rule_by_query(dataset, rule_query)
found_rule.labels = labels
@@ -282,44 +245,20 @@ def update_labeling_rule(
self.__labeling__.replace_rule(dataset, found_rule)
return found_rule
- def find_labeling_rule(self, dataset: TextClassificationDatasetDB, rule_query: str):
- """
- Find a labeling rule given a rule query string
-
- Args:
- dataset: The dataset
- rule_query: The query string
-
- Returns:
- Found labeling rule.
- If rule was not found EntityNotFoundError is raised
- """
+ def find_labeling_rule(
+ self, dataset: ServiceTextClassificationDataset, rule_query: str
+ ) -> ServiceLabelingRule:
return self.__labeling__.find_rule_by_query(dataset, rule_query=rule_query)
def delete_labeling_rule(
- self, dataset: TextClassificationDatasetDB, rule_query: str
+ self, dataset: ServiceTextClassificationDataset, rule_query: str
):
- """
- Deletes a rule from a dataset.
-
- Nothing happens if the rule does not exist in dataset.
-
- Parameters
- ----------
-
- dataset:
- The dataset
-
- rule_query:
- The rule query
-
- """
if rule_query.strip():
return self.__labeling__.delete_rule(dataset, rule_query)
def compute_rule_metrics(
self,
- dataset: TextClassificationDatasetDB,
+ dataset: ServiceTextClassificationDataset,
rule_query: str,
labels: Optional[List[str]] = None,
) -> LabelingRuleMetricsSummary:
@@ -366,6 +305,7 @@ def compute_rule_metrics(
coverage_annotated = (
metrics.annotated_covered_records / annotated if annotated > 0 else None
)
+
return LabelingRuleMetricsSummary(
total_records=total,
annotated_records=annotated,
@@ -376,7 +316,7 @@ def compute_rule_metrics(
precision=metrics.precision if annotated > 0 else None,
)
- def compute_overall_rules_metrics(self, dataset: TextClassificationDatasetDB):
+ def compute_overall_rules_metrics(self, dataset: ServiceTextClassificationDataset):
total, annotated, metrics = self.__labeling__.all_rules_metrics(dataset)
coverage = metrics.covered_records / total if total else None
coverage_annotated = (
@@ -390,7 +330,7 @@ def compute_overall_rules_metrics(self, dataset: TextClassificationDatasetDB):
)
@staticmethod
- def __normalized_rule__(rule: LabelingRule) -> LabelingRule:
+ def __normalized_rule__(rule: ServiceLabelingRule) -> ServiceLabelingRule:
if rule.labels and len(rule.labels) == 1:
rule.label = rule.labels[0]
elif rule.label and not rule.labels:
diff --git a/src/rubrix/server/services/tasks/token_classification/__init__.py b/src/rubrix/server/services/tasks/token_classification/__init__.py
new file mode 100644
index 0000000000..db4b4a5510
--- /dev/null
+++ b/src/rubrix/server/services/tasks/token_classification/__init__.py
@@ -0,0 +1 @@
+from .service import TokenClassificationService
diff --git a/src/rubrix/server/services/tasks/token_classification/metrics.py b/src/rubrix/server/services/tasks/token_classification/metrics.py
new file mode 100644
index 0000000000..f574cde4c5
--- /dev/null
+++ b/src/rubrix/server/services/tasks/token_classification/metrics.py
@@ -0,0 +1,407 @@
+from typing import Any, ClassVar, Dict, Iterable, List, Optional, Set, Tuple
+
+from pydantic import BaseModel, Field
+
+from rubrix.server.services.metrics import ServiceBaseMetric, ServicePythonMetric
+from rubrix.server.services.metrics.models import CommonTasksMetrics
+from rubrix.server.services.tasks.token_classification.model import (
+ EntitySpan,
+ ServiceTokenClassificationRecord,
+)
+
+
+class F1Metric(ServicePythonMetric[ServiceTokenClassificationRecord]):
+ """The F1 metric based on entity-level.
+
+ We follow the convention of `CoNLL 2003 `_, where:
+ `"precision is the percentage of named entities found by the learning system that are correct.
+ Recall is the percentage of named entities present in the corpus that are found by the system.
+ A named entity is correct only if it is an exact match (...).”`
+ """
+
+ def apply(
+ self, records: Iterable[ServiceTokenClassificationRecord]
+ ) -> Dict[str, Any]:
+ # store entities per label in dicts
+ predicted_entities = {}
+ annotated_entities = {}
+
+ # extract entities per label to dicts
+ for rec in records:
+ if rec.prediction:
+ self._add_entities_to_dict(rec.prediction.entities, predicted_entities)
+ if rec.annotation:
+ self._add_entities_to_dict(rec.annotation.entities, annotated_entities)
+
+ # store precision, recall, and f1 per label
+ per_label_metrics = {}
+
+ annotated_total, predicted_total, correct_total = 0, 0, 0
+ precision_macro, recall_macro = 0, 0
+ for label, annotated in annotated_entities.items():
+ predicted = predicted_entities.get(label, set())
+ correct = len(annotated & predicted)
+
+ # safe divides are used to cover the 0/0 cases
+ precision = self._safe_divide(correct, len(predicted))
+ recall = self._safe_divide(correct, len(annotated))
+ per_label_metrics.update(
+ {
+ f"{label}_precision": precision,
+ f"{label}_recall": recall,
+ f"{label}_f1": self._safe_divide(
+ 2 * precision * recall, precision + recall
+ ),
+ }
+ )
+
+ annotated_total += len(annotated)
+ predicted_total += len(predicted)
+ correct_total += correct
+
+ precision_macro += precision / len(annotated_entities)
+ recall_macro += recall / len(annotated_entities)
+
+ # store macro and micro averaged precision, recall and f1
+ averaged_metrics = {
+ "precision_macro": precision_macro,
+ "recall_macro": recall_macro,
+ "f1_macro": self._safe_divide(
+ 2 * precision_macro * recall_macro, precision_macro + recall_macro
+ ),
+ }
+
+ precision_micro = self._safe_divide(correct_total, predicted_total)
+ recall_micro = self._safe_divide(correct_total, annotated_total)
+ averaged_metrics.update(
+ {
+ "precision_micro": precision_micro,
+ "recall_micro": recall_micro,
+ "f1_micro": self._safe_divide(
+ 2 * precision_micro * recall_micro, precision_micro + recall_micro
+ ),
+ }
+ )
+
+ return {**averaged_metrics, **per_label_metrics}
+
+ @staticmethod
+ def _add_entities_to_dict(
+ entities: List[EntitySpan], dictionary: Dict[str, Set[Tuple[int, int]]]
+ ):
+ """Helper function for the apply method."""
+ for ent in entities:
+ try:
+ dictionary[ent.label].add((ent.start, ent.end))
+ except KeyError:
+ dictionary[ent.label] = {(ent.start, ent.end)}
+
+ @staticmethod
+ def _safe_divide(numerator, denominator):
+ """Helper function for the apply method."""
+ try:
+ return numerator / denominator
+ except ZeroDivisionError:
+ return 0
+
+
+class DatasetLabels(ServicePythonMetric):
+ id: str = Field("dataset_labels", const=True)
+ name: str = Field("The dataset entity labels", const=True)
+ max_processed_records: int = 10000
+
+ def apply(
+ self, records: Iterable[ServiceTokenClassificationRecord]
+ ) -> Dict[str, Any]:
+ ds_labels = set()
+
+ for _ in range(
+ 0, self.max_processed_records
+ ): # Only a few of records will be parsed
+ record: ServiceTokenClassificationRecord = next(records, None)
+ if record is None:
+ break
+
+ if record.annotation:
+ ds_labels.update(
+ [entity.label for entity in record.annotation.entities]
+ )
+ if record.prediction:
+ ds_labels.update(
+ [entity.label for entity in record.prediction.entities]
+ )
+ return {"labels": ds_labels or []}
+
+
+class MentionMetrics(BaseModel):
+ """Mention metrics model"""
+
+ value: str
+ label: str
+ score: float = Field(ge=0.0)
+ capitalness: Optional[str] = Field(None)
+ density: float = Field(ge=0.0)
+ tokens_length: int = Field(g=0)
+ chars_length: int = Field(g=0)
+
+
+class TokenTagMetrics(BaseModel):
+ value: str
+ tag: str
+
+
+class TokenMetrics(BaseModel):
+ """
+ Token metrics stored in elasticsearch for token classification
+
+ Attributes
+ idx: The token index in sentence
+ value: The token textual value
+ char_start: The token character start position in sentence
+ char_end: The token character end position in sentence
+ score: Token score info
+ tag: Token tag info. Deprecated: Use metrics.predicted.tags or metrics.annotated.tags instead
+ custom: extra token level info
+ """
+
+ idx: int
+ value: str
+ char_start: int
+ char_end: int
+ length: int
+ capitalness: Optional[str] = None
+ score: Optional[float] = None
+ tag: Optional[str] = None # TODO: remove!
+ custom: Dict[str, Any] = None
+
+
+class TokenClassificationMetrics(CommonTasksMetrics[ServiceTokenClassificationRecord]):
+ """Configured metrics for token classification"""
+
+ @staticmethod
+ def density(value: int, sentence_length: int) -> float:
+ """Compute the string density over a sentence"""
+ return value / sentence_length
+
+ @staticmethod
+ def capitalness(value: str) -> Optional[str]:
+ """Compute capitalness for a string value"""
+ value = value.strip()
+ if not value:
+ return None
+ if value.isupper():
+ return "UPPER"
+ if value.islower():
+ return "LOWER"
+ if value[0].isupper():
+ return "FIRST"
+ if any([c.isupper() for c in value[1:]]):
+ return "MIDDLE"
+ return None
+
+ @staticmethod
+ def mentions_metrics(
+ record: ServiceTokenClassificationRecord, mentions: List[Tuple[str, EntitySpan]]
+ ):
+ def mention_tokens_length(entity: EntitySpan) -> int:
+ """Compute mention tokens length"""
+ return len(
+ set(
+ [
+ token_idx
+ for i in range(entity.start, entity.end)
+ for token_idx in [record.char_id2token_id(i)]
+ if token_idx is not None
+ ]
+ )
+ )
+
+ return [
+ MentionMetrics(
+ value=mention,
+ label=entity.label,
+ score=entity.score,
+ capitalness=TokenClassificationMetrics.capitalness(mention),
+ density=TokenClassificationMetrics.density(
+ _tokens_length, sentence_length=len(record.tokens)
+ ),
+ tokens_length=_tokens_length,
+ chars_length=len(mention),
+ )
+ for mention, entity in mentions
+ for _tokens_length in [
+ mention_tokens_length(entity),
+ ]
+ ]
+
+ @classmethod
+ def build_tokens_metrics(
+ cls, record: ServiceTokenClassificationRecord, tags: Optional[List[str]] = None
+ ) -> List[TokenMetrics]:
+
+ return [
+ TokenMetrics(
+ idx=token_idx,
+ value=token_value,
+ char_start=char_start,
+ char_end=char_end,
+ capitalness=cls.capitalness(token_value),
+ length=1 + (char_end - char_start),
+ tag=tags[token_idx] if tags else None,
+ )
+ for token_idx, token_value in enumerate(record.tokens)
+ for char_start, char_end in [record.token_span(token_idx)]
+ ]
+
+ @classmethod
+ def record_metrics(cls, record: ServiceTokenClassificationRecord) -> Dict[str, Any]:
+ """Compute metrics at record level"""
+ base_metrics = super(TokenClassificationMetrics, cls).record_metrics(record)
+
+ annotated_tags = record.annotated_iob_tags() or []
+ predicted_tags = record.predicted_iob_tags() or []
+
+ tokens_metrics = cls.build_tokens_metrics(
+ record, predicted_tags or annotated_tags
+ )
+ return {
+ **base_metrics,
+ "tokens": tokens_metrics,
+ "tokens_length": len(record.tokens),
+ "predicted": {
+ "mentions": cls.mentions_metrics(record, record.predicted_mentions()),
+ "tags": [
+ TokenTagMetrics(tag=tag, value=token)
+ for tag, token in zip(predicted_tags, record.tokens)
+ ],
+ },
+ "annotated": {
+ "mentions": cls.mentions_metrics(record, record.annotated_mentions()),
+ "tags": [
+ TokenTagMetrics(tag=tag, value=token)
+ for tag, token in zip(annotated_tags, record.tokens)
+ ],
+ },
+ }
+
+ metrics: ClassVar[List[ServiceBaseMetric]] = (
+ CommonTasksMetrics.metrics
+ + [
+ DatasetLabels(),
+ F1Metric(
+ id="F1",
+ name="F1 ServiceBaseMetric based on entity-level",
+ description="F1 metrics based on entity-level (averaged and per label), "
+ "where only exact matches count (CoNNL 2003).",
+ ),
+ ]
+ + [
+ ServiceBaseMetric(
+ id="predicted_as",
+ name="Predicted labels distribution",
+ ),
+ ServiceBaseMetric(
+ id="annotated_as",
+ name="Annotated labels distribution",
+ ),
+ ServiceBaseMetric(
+ id="tokens_length",
+ name="Tokens length",
+ description="Computes the text length distribution measured in number of tokens",
+ ),
+ ServiceBaseMetric(
+ id="token_frequency",
+ name="Tokens frequency distribution",
+ ),
+ ServiceBaseMetric(
+ id="token_length",
+ name="Token length distribution",
+ description="Computes token length distribution in number of characters",
+ ),
+ ServiceBaseMetric(
+ id="token_capitalness",
+ name="Token capitalness distribution",
+ description="Computes capitalization information of tokens",
+ ),
+ ServiceBaseMetric(
+ id="predicted_entity_density",
+ name="Mention entity density for predictions",
+ description="Computes the ratio between the number of all entity tokens and tokens in the text",
+ ),
+ ServiceBaseMetric(
+ id="predicted_entity_labels",
+ name="Predicted entity labels",
+ description="Predicted entity labels distribution",
+ ),
+ ServiceBaseMetric(
+ id="predicted_entity_capitalness",
+ name="Mention entity capitalness for predictions",
+ description="Computes capitalization information of predicted entity mentions",
+ ),
+ ServiceBaseMetric(
+ id="predicted_mention_token_length",
+ name="Predicted mention tokens length",
+ description="Computes the length of the predicted entity mention measured in number of tokens",
+ ),
+ ServiceBaseMetric(
+ id="predicted_mention_char_length",
+ name="Predicted mention characters length",
+ description="Computes the length of the predicted entity mention measured in number of tokens",
+ ),
+ ServiceBaseMetric(
+ id="predicted_mentions_distribution",
+ name="Predicted mentions distribution by entity",
+ description="Computes predicted mentions distribution against its labels",
+ ),
+ ServiceBaseMetric(
+ id="predicted_entity_consistency",
+ name="Entity label consistency for predictions",
+ description="Computes entity label variability for top-k predicted entity mentions",
+ ),
+ ServiceBaseMetric(
+ id="predicted_tag_consistency",
+ name="Token tag consistency for predictions",
+ description="Computes token tag variability for top-k predicted tags",
+ ),
+ ServiceBaseMetric(
+ id="annotated_entity_density",
+ name="Mention entity density for annotations",
+ description="Computes the ratio between the number of all entity tokens and tokens in the text",
+ ),
+ ServiceBaseMetric(
+ id="annotated_entity_labels",
+ name="Annotated entity labels",
+ description="Annotated Entity labels distribution",
+ ),
+ ServiceBaseMetric(
+ id="annotated_entity_capitalness",
+ name="Mention entity capitalness for annotations",
+ description="Compute capitalization information of annotated entity mentions",
+ ),
+ ServiceBaseMetric(
+ id="annotated_mention_token_length",
+ name="Annotated mention tokens length",
+ description="Computes the length of the entity mention measured in number of tokens",
+ ),
+ ServiceBaseMetric(
+ id="annotated_mention_char_length",
+ name="Annotated mention characters length",
+ description="Computes the length of the entity mention measured in number of tokens",
+ ),
+ ServiceBaseMetric(
+ id="annotated_mentions_distribution",
+ name="Annotated mentions distribution by entity",
+ description="Computes annotated mentions distribution against its labels",
+ ),
+ ServiceBaseMetric(
+ id="annotated_entity_consistency",
+ name="Entity label consistency for annotations",
+ description="Computes entity label variability for top-k annotated entity mentions",
+ ),
+ ServiceBaseMetric(
+ id="annotated_tag_consistency",
+ name="Token tag consistency for annotations",
+ description="Computes token tag variability for top-k annotated tags",
+ ),
+ ]
+ )
diff --git a/src/rubrix/server/services/tasks/token_classification/model.py b/src/rubrix/server/services/tasks/token_classification/model.py
new file mode 100644
index 0000000000..2f307f1348
--- /dev/null
+++ b/src/rubrix/server/services/tasks/token_classification/model.py
@@ -0,0 +1,327 @@
+# coding=utf-8
+# Copyright 2021-present, the Recognai S.L. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections import defaultdict
+from datetime import datetime
+from typing import Any, Dict, List, Optional, Set, Tuple
+
+from pydantic import BaseModel, Field, validator
+
+from rubrix._constants import MAX_KEYWORD_LENGTH
+from rubrix.server.commons.models import PredictionStatus, TaskType
+from rubrix.server.services.datasets import ServiceBaseDataset
+from rubrix.server.services.search.model import (
+ ServiceBaseRecordsQuery,
+ ServiceScoreRange,
+)
+from rubrix.server.services.tasks.commons import (
+ ServiceBaseAnnotation,
+ ServiceBaseRecord,
+)
+
+PREDICTED_MENTIONS_ES_FIELD_NAME = "predicted_mentions"
+MENTIONS_ES_FIELD_NAME = "mentions"
+
+
+class EntitySpan(BaseModel):
+ """
+ The tokens span for a labeled text.
+
+ Entity spans will be defined between from start to end - 1
+
+ Attributes:
+ -----------
+
+ start: int
+ character start position
+ end: int
+ character end position, must be higher than the starting character.
+ label: str
+ the label related to tokens that conforms the entity span
+ score:
+ A higher score means, the model/annotator is more confident about its predicted/annotated entity.
+ """
+
+ start: int
+ end: int
+ label: str = Field(min_length=1, max_length=MAX_KEYWORD_LENGTH)
+ score: float = Field(default=1.0, ge=0.0, le=1.0)
+
+ @validator("end")
+ def check_span_offset(cls, end: int, values):
+ """Validates span offset"""
+ assert (
+ end > values["start"]
+ ), "End character cannot be placed before the starting character, it must be at least one character after."
+ return end
+
+ def __hash__(self):
+ return hash(type(self)) + hash(self.__dict__.values())
+
+
+class ServiceTokenClassificationAnnotation(ServiceBaseAnnotation):
+ entities: List[EntitySpan] = Field(default_factory=list)
+ score: Optional[float] = None
+
+
+class ServiceTokenClassificationRecord(
+ ServiceBaseRecord[ServiceTokenClassificationAnnotation]
+):
+
+ tokens: List[str] = Field(min_items=1)
+ text: str = Field()
+ _raw_text: Optional[str] = Field(alias="raw_text")
+
+ __chars2tokens__: Dict[int, int] = None
+ __tokens2chars__: Dict[int, Tuple[int, int]] = None
+
+ # TODO: review this.
+ _predicted: Optional[PredictionStatus] = Field(alias="predicted")
+
+ def extended_fields(self) -> Dict[str, Any]:
+
+ return {
+ **super().extended_fields(),
+ # See ../service/service.py
+ PREDICTED_MENTIONS_ES_FIELD_NAME: [
+ {"mention": mention, "entity": entity.label, "score": entity.score}
+ for mention, entity in self.predicted_mentions()
+ ],
+ MENTIONS_ES_FIELD_NAME: [
+ {"mention": mention, "entity": entity.label}
+ for mention, entity in self.annotated_mentions()
+ ],
+ "words": self.all_text(),
+ }
+
+ def __init__(self, **data):
+ super().__init__(**data)
+
+ self.__chars2tokens__, self.__tokens2chars__ = self.__build_indices_map__()
+
+ self.check_annotation(self.prediction)
+ self.check_annotation(self.annotation)
+
+ def char_id2token_id(self, char_idx: int) -> Optional[int]:
+ return self.__chars2tokens__.get(char_idx)
+
+ def token_span(self, token_idx: int) -> Tuple[int, int]:
+ if token_idx not in self.__tokens2chars__:
+ raise IndexError(f"Token id {token_idx} out of bounds")
+ return self.__tokens2chars__[token_idx]
+
+ def __build_indices_map__(
+ self,
+ ) -> Tuple[Dict[int, int], Dict[int, Tuple[int, int]]]:
+ """
+ Build the indices mapping between text characters and tokens where belongs to,
+ and vice versa.
+
+ chars2tokens index contains is the token idx where i char is contained (if any).
+
+ Out-of-token characters won't be included in this map,
+ so access should be using ``chars2tokens_map.get(i)``
+ instead of ``chars2tokens_map[i]``.
+
+ """
+
+ def chars2tokens_index():
+ def is_space_after_token(char, idx: int, chars_map) -> str:
+ return char == " " and idx - 1 in chars_map
+
+ chars_map = {}
+ current_token = 0
+ current_token_char_start = 0
+ for idx, char in enumerate(self.text):
+ if is_space_after_token(char, idx, chars_map):
+ continue
+ relative_idx = idx - current_token_char_start
+ if (
+ relative_idx < len(self.tokens[current_token])
+ and char == self.tokens[current_token][relative_idx]
+ ):
+ chars_map[idx] = current_token
+ elif (
+ current_token + 1 < len(self.tokens)
+ and relative_idx >= len(self.tokens[current_token])
+ and char == self.tokens[current_token + 1][0]
+ ):
+ current_token += 1
+ current_token_char_start += relative_idx
+ chars_map[idx] = current_token
+
+ return chars_map
+
+ def tokens2chars_index(
+ chars2tokens: Dict[int, int]
+ ) -> Dict[int, Tuple[int, int]]:
+ tokens2chars_map = defaultdict(list)
+ for c, t in chars2tokens.items():
+ tokens2chars_map[t].append(c)
+
+ return {
+ token_idx: (min(chars), max(chars))
+ for token_idx, chars in tokens2chars_map.items()
+ }
+
+ chars2tokens_idx = chars2tokens_index()
+ return chars2tokens_idx, tokens2chars_index(chars2tokens_idx)
+
+ def check_annotation(
+ self,
+ annotation: Optional[ServiceTokenClassificationAnnotation],
+ ):
+ """Validates entities in terms of offset spans"""
+
+ def adjust_span_bounds(start, end):
+ if start < 0:
+ start = 0
+ if entity.end > len(self.text):
+ end = len(self.text)
+ while start <= len(self.text) and not self.text[start].strip():
+ start += 1
+ while not self.text[end - 1].strip():
+ end -= 1
+ return start, end
+
+ if annotation:
+ for entity in annotation.entities:
+ entity.start, entity.end = adjust_span_bounds(entity.start, entity.end)
+ mention = self.text[entity.start : entity.end]
+ assert len(mention) > 0, f"Empty offset defined for entity {entity}"
+
+ token_start = self.char_id2token_id(entity.start)
+ token_end = self.char_id2token_id(entity.end - 1)
+
+ assert not (
+ token_start is None or token_end is None
+ ), f"Provided entity span {self.text[entity.start: entity.end]} is not aligned with provided tokens."
+ "Some entity chars could be reference characters out of tokens"
+
+ span_start, _ = self.token_span(token_start)
+ _, span_end = self.token_span(token_end)
+
+ assert (
+ self.text[span_start : span_end + 1] == mention
+ ), f"Defined offset [{self.text[entity.start: entity.end]}] is a misaligned entity mention"
+
+ def task(cls) -> TaskType:
+ """The record task type"""
+ return TaskType.token_classification
+
+ @property
+ def predicted(self) -> Optional[PredictionStatus]:
+ if self.annotation and self.prediction:
+ return (
+ PredictionStatus.OK
+ if self.annotation.entities == self.prediction.entities
+ else PredictionStatus.KO
+ )
+ return None
+
+ @property
+ def predicted_as(self) -> List[str]:
+ return [ent.label for ent in self.predicted_entities()]
+
+ @property
+ def annotated_as(self) -> List[str]:
+ return [ent.label for ent in self.annotated_entities()]
+
+ @property
+ def scores(self) -> List[float]:
+ if not self.prediction:
+ return []
+ if self.prediction.score is not None:
+ return [self.prediction.score]
+ return [e.score for e in self.prediction.entities]
+
+ def all_text(self) -> str:
+ return self.text
+
+ def predicted_iob_tags(self) -> Optional[List[str]]:
+ if self.prediction is None:
+ return None
+ return self.spans2iob(self.prediction.entities)
+
+ def annotated_iob_tags(self) -> Optional[List[str]]:
+ if self.annotation is None:
+ return None
+ return self.spans2iob(self.annotation.entities)
+
+ def spans2iob(self, spans: List[EntitySpan]) -> Optional[List[str]]:
+ if spans is None:
+ return None
+ tags = ["O"] * len(self.tokens)
+ for entity in spans:
+ token_start = self.char_id2token_id(entity.start)
+ token_end = self.char_id2token_id(entity.end - 1)
+ tags[token_start] = f"B-{entity.label}"
+ for idx in range(token_start + 1, token_end + 1):
+ tags[idx] = f"I-{entity.label}"
+
+ return tags
+
+ def predicted_mentions(self) -> List[Tuple[str, EntitySpan]]:
+ return [
+ (mention, entity)
+ for mention, entity in self.__mentions_from_entities__(
+ self.predicted_entities()
+ ).items()
+ ]
+
+ def annotated_mentions(self) -> List[Tuple[str, EntitySpan]]:
+ return [
+ (mention, entity)
+ for mention, entity in self.__mentions_from_entities__(
+ self.annotated_entities()
+ ).items()
+ ]
+
+ def annotated_entities(self) -> Set[EntitySpan]:
+ """Shortcut for real annotated entities, if provided"""
+ if self.annotation is None:
+ return set()
+ return set(self.annotation.entities)
+
+ def predicted_entities(self) -> Set[EntitySpan]:
+ """Predicted entities"""
+ if self.prediction is None:
+ return set()
+ return set(self.prediction.entities)
+
+ def __mentions_from_entities__(
+ self, entities: Set[EntitySpan]
+ ) -> Dict[str, EntitySpan]:
+ return {
+ mention: entity
+ for entity in entities
+ for mention in [self.text[entity.start : entity.end]]
+ }
+
+ class Config:
+ allow_population_by_field_name = True
+ underscore_attrs_are_private = True
+
+
+class ServiceTokenClassificationQuery(ServiceBaseRecordsQuery):
+
+ predicted_as: List[str] = Field(default_factory=list)
+ annotated_as: List[str] = Field(default_factory=list)
+ score: Optional[ServiceScoreRange] = Field(default=None)
+ predicted: Optional[PredictionStatus] = Field(default=None, nullable=True)
+
+
+class ServiceTokenClassificationDataset(ServiceBaseDataset):
+ task: TaskType = Field(default=TaskType.token_classification, const=True)
+ pass
diff --git a/src/rubrix/server/services/token_classification.py b/src/rubrix/server/services/tasks/token_classification/service.py
similarity index 53%
rename from src/rubrix/server/services/token_classification.py
rename to src/rubrix/server/services/tasks/token_classification/service.py
index 9c40e852a0..2bdd63afdf 100644
--- a/src/rubrix/server/services/token_classification.py
+++ b/src/rubrix/server/services/tasks/token_classification/service.py
@@ -13,28 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Iterable, List, Optional, Type
+from typing import Iterable, List, Optional, Type
from fastapi import Depends
-from rubrix.server.apis.v0.models.commons.model import (
- BulkResponse,
- EsRecordDataFieldNames,
- SortableField,
-)
-from rubrix.server.apis.v0.models.metrics.base import BaseMetric, BaseTaskMetrics
-from rubrix.server.apis.v0.models.token_classification import (
- CreationTokenClassificationRecord,
- TokenClassificationAggregations,
- TokenClassificationDatasetDB,
- TokenClassificationQuery,
- TokenClassificationRecord,
- TokenClassificationRecordDB,
- TokenClassificationSearchResults,
-)
-from rubrix.server.services.search.model import SortConfig
+from rubrix.server.commons.config import TasksFactory
+from rubrix.server.daos.backend.search.model import SortableField
+from rubrix.server.services.metrics.models import ServiceBaseTaskMetrics
+from rubrix.server.services.search.model import ServiceSearchResults, ServiceSortConfig
from rubrix.server.services.search.service import SearchRecordsService
from rubrix.server.services.storage.service import RecordsStorageService
+from rubrix.server.services.tasks.commons import BulkResponse
+from rubrix.server.services.tasks.token_classification.model import (
+ ServiceTokenClassificationDataset,
+ ServiceTokenClassificationQuery,
+ ServiceTokenClassificationRecord,
+)
class TokenClassificationService:
@@ -65,74 +59,73 @@ def __init__(
def add_records(
self,
- dataset: TokenClassificationDatasetDB,
- mappings: Dict[str, Any],
- records: List[CreationTokenClassificationRecord],
- metrics: Type[BaseTaskMetrics],
+ dataset: ServiceTokenClassificationDataset,
+ records: List[ServiceTokenClassificationRecord],
):
failed = self.__storage__.store_records(
dataset=dataset,
- mappings=mappings,
records=records,
- record_type=TokenClassificationRecordDB,
- metrics=metrics,
+ record_type=ServiceTokenClassificationRecord,
)
return BulkResponse(dataset=dataset.name, processed=len(records), failed=failed)
def search(
self,
- dataset: TokenClassificationDatasetDB,
- query: TokenClassificationQuery,
+ dataset: ServiceTokenClassificationDataset,
+ query: ServiceTokenClassificationQuery,
sort_by: List[SortableField],
record_from: int = 0,
size: int = 100,
exclude_metrics: bool = True,
- metrics: Optional[List[BaseMetric]] = None,
- ) -> TokenClassificationSearchResults:
- """
- Run a search in a dataset
-
- Parameters
- ----------
- dataset:
- The records dataset
- query:
- The search parameters
- sort_by:
- The sort by order list
- record_from:
- The record from return results
- size:
- The max number of records to return
-
- Returns
- -------
- The matched records with aggregation info for specified task_meta.py
+ ) -> ServiceSearchResults:
"""
+ Run a search in a dataset
+
+ Parameters
+ ----------
+ dataset:
+ The records dataset
+ query:
+ The search parameters
+ sort_by:
+ The sort by order list
+ record_from:
+ The record from return results
+ size:
+ The max number of records to return
+
+ Returns
+ -------
+ The matched records with aggregation info for specified task_meta.py
+
+ """
+ metrics = TasksFactory.find_task_metrics(
+ dataset.task,
+ metric_ids={
+ "words_cloud",
+ "predicted_by",
+ "predicted_as",
+ "annotated_by",
+ "annotated_as",
+ "error_distribution",
+ "predicted_mentions_distribution",
+ "annotated_mentions_distribution",
+ "status_distribution",
+ "metadata",
+ "score",
+ },
+ )
+
results = self.__search__.search(
dataset,
query=query,
- record_type=TokenClassificationRecord,
+ record_type=ServiceTokenClassificationRecord,
size=size,
+ metrics=metrics,
record_from=record_from,
exclude_metrics=exclude_metrics,
- metrics=metrics,
- sort_config=SortConfig(
- sort_by=sort_by,
- valid_fields=[
- "metadata",
- EsRecordDataFieldNames.last_updated,
- EsRecordDataFieldNames.score,
- EsRecordDataFieldNames.predicted,
- EsRecordDataFieldNames.predicted_as,
- EsRecordDataFieldNames.predicted_by,
- EsRecordDataFieldNames.annotated_as,
- EsRecordDataFieldNames.annotated_by,
- EsRecordDataFieldNames.status,
- EsRecordDataFieldNames.event_timestamp,
- ],
- ),
+ sort_config=ServiceSortConfig(sort_by=sort_by),
)
if results.metrics:
@@ -147,21 +140,15 @@ def search(
"predicted_mentions_distribution"
]
- return TokenClassificationSearchResults(
- total=results.total,
- records=results.records,
- aggregations=TokenClassificationAggregations.parse_obj(results.metrics)
- if results.metrics
- else None,
- )
+ return results
def read_dataset(
self,
- dataset: TokenClassificationDatasetDB,
- query: TokenClassificationQuery,
+ dataset: ServiceTokenClassificationDataset,
+ query: ServiceTokenClassificationQuery,
id_from: Optional[str] = None,
limit: int = 1000
- ) -> Iterable[TokenClassificationRecord]:
+ ) -> Iterable[ServiceTokenClassificationRecord]:
"""
Scan a dataset records
@@ -181,8 +168,5 @@ def read_dataset(
"""
yield from self.__search__.scan_records(
- dataset, query=query, record_type=TokenClassificationRecord, id_from=id_from, limit=limit
+ dataset, query=query, record_type=ServiceTokenClassificationRecord,id_from=id_from, limit=limit
)
-
-
-token_classification_service = TokenClassificationService.get_instance
diff --git a/src/rubrix/server/services/text_classification_labelling_rules.py b/src/rubrix/server/services/text_classification_labelling_rules.py
deleted file mode 100644
index 1418def5bb..0000000000
--- a/src/rubrix/server/services/text_classification_labelling_rules.py
+++ /dev/null
@@ -1,271 +0,0 @@
-from typing import Any, Dict, List, Optional, Tuple
-
-from fastapi import Depends
-from pydantic import BaseModel, Field
-
-from rubrix.server._helpers import unflatten_dict
-from rubrix.server.apis.v0.models.commons.model import EsRecordDataFieldNames
-from rubrix.server.apis.v0.models.metrics.base import ElasticsearchMetric
-from rubrix.server.apis.v0.models.text_classification import (
- LabelingRule,
- TextClassificationDatasetDB,
-)
-from rubrix.server.daos.datasets import DatasetsDAO
-from rubrix.server.daos.models.records import RecordSearch
-from rubrix.server.daos.records import DatasetRecordsDAO
-from rubrix.server.elasticseach.query_helpers import filters
-from rubrix.server.errors import EntityAlreadyExistsError, EntityNotFoundError
-
-
-class DatasetLabelingRulesMetric(ElasticsearchMetric):
- id: str = Field("dataset_labeling_rules", const=True)
- name: str = Field(
- "Computes overall metrics for defined rules in dataset", const=True
- )
-
- def aggregation_request(self, all_rules: List[LabelingRule]) -> Dict[str, Any]:
- rules_filters = [filters.text_query(rule.query) for rule in all_rules]
- return {
- self.id: {
- "filters": {
- "filters": {
- "covered_records": filters.boolean_filter(
- should_filters=rules_filters, minimum_should_match=1
- ),
- "annotated_covered_records": filters.boolean_filter(
- filter_query=filters.exists_field(
- EsRecordDataFieldNames.annotated_as
- ),
- should_filters=rules_filters,
- minimum_should_match=1,
- ),
- }
- }
- }
- }
-
-
-class LabelingRulesMetric(ElasticsearchMetric):
- id: str = Field("labeling_rule", const=True)
- name: str = Field("Computes metrics for a labeling rule", const=True)
-
- def aggregation_request(
- self,
- rule_query: str,
- labels: Optional[List[str]],
- ) -> Dict[str, Any]:
-
- annotated_records_filter = filters.exists_field(
- EsRecordDataFieldNames.annotated_as
- )
- rule_query_filter = filters.text_query(rule_query)
- aggr_filters = {
- "covered_records": rule_query_filter,
- "annotated_covered_records": filters.boolean_filter(
- filter_query=annotated_records_filter,
- should_filters=[rule_query_filter],
- ),
- }
-
- if labels is not None:
- for label in labels:
- rule_label_annotated_filter = filters.term_filter(
- "annotated_as", value=label
- )
- encoded_label = self._encode_label_name(label)
- aggr_filters.update(
- {
- f"{encoded_label}.correct_records": filters.boolean_filter(
- filter_query=annotated_records_filter,
- should_filters=[
- rule_query_filter,
- rule_label_annotated_filter,
- ],
- minimum_should_match=2,
- ),
- f"{encoded_label}.incorrect_records": filters.boolean_filter(
- filter_query=annotated_records_filter,
- must_query=rule_query_filter,
- must_not_query=rule_label_annotated_filter,
- ),
- }
- )
-
- return {self.id: {"filters": {"filters": aggr_filters}}}
-
- @staticmethod
- def _encode_label_name(label: str) -> str:
- return label.replace(".", "@@@")
-
- @staticmethod
- def _decode_label_name(label: str) -> str:
- return label.replace("@@@", ".")
-
- def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
- if self.id in aggregation_result:
- aggregation_result = aggregation_result[self.id]
-
- aggregation_result = unflatten_dict(aggregation_result)
- results = {
- "covered_records": aggregation_result.pop("covered_records"),
- "annotated_covered_records": aggregation_result.pop(
- "annotated_covered_records"
- ),
- }
-
- all_correct = []
- all_incorrect = []
- all_precision = []
- for label, metrics in aggregation_result.items():
- correct = metrics.get("correct_records", 0)
- incorrect = metrics.get("incorrect_records", 0)
- annotated = correct + incorrect
- metrics["annotated"] = annotated
- if annotated > 0:
- precision = correct / annotated
- metrics["precision"] = precision
- all_precision.append(precision)
-
- all_correct.append(correct)
- all_incorrect.append(incorrect)
- results[self._decode_label_name(label)] = metrics
-
- results["correct_records"] = sum(all_correct)
- results["incorrect_records"] = sum(all_incorrect)
- if len(all_precision) > 0:
- results["precision"] = sum(all_precision) / len(all_precision)
-
- return results
-
-
-class DatasetLabelingRulesSummary(BaseModel):
- covered_records: int
- annotated_covered_records: int
-
-
-class LabelingRuleSummary(BaseModel):
- covered_records: int
- annotated_covered_records: int
- correct_records: int = Field(default=0)
- incorrect_records: int = Field(default=0)
- precision: Optional[float] = None
-
-
-class LabelingService:
-
- _INSTANCE = None
-
- __rule_metrics__ = LabelingRulesMetric()
- __dataset_rules_metrics__ = DatasetLabelingRulesMetric()
-
- @classmethod
- def get_instance(
- cls,
- datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance),
- records: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance),
- ):
- if cls._INSTANCE is None:
- cls._INSTANCE = cls(datasets, records)
- return cls._INSTANCE
-
- def __init__(self, datasets: DatasetsDAO, records: DatasetRecordsDAO):
- self.__datasets__ = datasets
- self.__records__ = records
-
- def list_rules(self, dataset: TextClassificationDatasetDB) -> List[LabelingRule]:
- """List a set of rules for a given dataset"""
- return dataset.rules
-
- def delete_rule(self, dataset: TextClassificationDatasetDB, rule_query: str):
- """Delete a rule from a dataset by its defined query string"""
- new_rules_set = [r for r in dataset.rules if r.query != rule_query]
- if len(dataset.rules) != new_rules_set:
- dataset.rules = new_rules_set
- self.__datasets__.update_dataset(dataset)
-
- def add_rule(
- self, dataset: TextClassificationDatasetDB, rule: LabelingRule
- ) -> LabelingRule:
- """Adds a rule to a dataset"""
- for r in dataset.rules:
- if r.query == rule.query:
- raise EntityAlreadyExistsError(rule.query, type=LabelingRule)
- dataset.rules.append(rule)
- self.__datasets__.update_dataset(dataset)
- return rule
-
- def compute_rule_metrics(
- self,
- dataset: TextClassificationDatasetDB,
- rule_query: str,
- labels: Optional[List[str]] = None,
- ) -> Tuple[int, int, LabelingRuleSummary]:
- """Computes metrics for given rule query and optional label against a set of rules"""
-
- annotated_records = self._count_annotated_records(dataset)
- results = self.__records__.search_records(
- dataset,
- size=0,
- search=RecordSearch(
- aggregations=self.__rule_metrics__.aggregation_request(
- rule_query=rule_query, labels=labels
- ),
- ),
- )
-
- rule_metrics_summary = self.__rule_metrics__.aggregation_result(
- results.aggregations
- )
-
- metrics = LabelingRuleSummary.parse_obj(rule_metrics_summary)
- return results.total, annotated_records, metrics
-
- def _count_annotated_records(self, dataset: TextClassificationDatasetDB) -> int:
- results = self.__records__.search_records(
- dataset,
- size=0,
- search=RecordSearch(
- query=filters.exists_field(EsRecordDataFieldNames.annotated_as),
- ),
- )
- return results.total
-
- def all_rules_metrics(
- self, dataset: TextClassificationDatasetDB
- ) -> Tuple[int, int, DatasetLabelingRulesSummary]:
- annotated_records = self._count_annotated_records(dataset)
- results = self.__records__.search_records(
- dataset,
- size=0,
- search=RecordSearch(
- aggregations=self.__dataset_rules_metrics__.aggregation_request(
- all_rules=dataset.rules
- ),
- ),
- )
-
- rule_metrics_summary = self.__dataset_rules_metrics__.aggregation_result(
- results.aggregations
- )
-
- return (
- results.total,
- annotated_records,
- DatasetLabelingRulesSummary.parse_obj(rule_metrics_summary),
- )
-
- def find_rule_by_query(
- self, dataset: TextClassificationDatasetDB, rule_query: str
- ) -> LabelingRule:
- rule_query = rule_query.strip()
- for rule in dataset.rules:
- if rule.query == rule_query:
- return rule
- raise EntityNotFoundError(rule_query, type=LabelingRule)
-
- def replace_rule(self, dataset: TextClassificationDatasetDB, rule: LabelingRule):
- for idx, r in enumerate(dataset.rules):
- if r.query == rule.query:
- dataset.rules[idx] = rule
- break
- self.__datasets__.update_dataset(dataset)
diff --git a/tests/client/sdk/conftest.py b/tests/client/sdk/conftest.py
index a62e324337..8874094dcf 100644
--- a/tests/client/sdk/conftest.py
+++ b/tests/client/sdk/conftest.py
@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from datetime import datetime
+from typing import Any, Dict, List
import pytest
@@ -32,6 +34,8 @@
TokenClassificationBulkData,
)
+LOGGER = logging.getLogger(__name__)
+
class Helpers:
def remove_key(self, schema: dict, key: str):
@@ -50,6 +54,55 @@ def remove_description(self, schema: dict):
def remove_pattern(self, schema: dict):
return self.remove_key(schema, key="pattern")
+ def are_compatible_api_schemas(self, client_schema: dict, server_schema: dict):
+ def check_schema_props(client_props, server_props):
+ different_props = []
+ for name, definition in client_props.items():
+ if name not in server_props:
+ LOGGER.warning(
+ f"Client property {name} not found in server properties. "
+ "Make sure your API compatibility"
+ )
+ different_props.append(name)
+ continue
+ elif definition != server_props[name]:
+ if not check_schema_props(definition, server_props[name]):
+ return False
+ return len(different_props) < len(client_props) / 2
+
+ client_props = self._expands_schema(
+ client_schema["properties"], client_schema["definitions"]
+ )
+ server_props = self._expands_schema(
+ server_schema["properties"], server_schema["definitions"]
+ )
+
+ if client_props == server_props:
+ return True
+ return check_schema_props(client_props, server_props)
+
+ def _expands_schema(
+ self, props: Dict[str, Any], definitions: List[Dict[str, Any]]
+ ) -> Dict[str, Any]:
+ new_schema = {}
+ for name, definition in props.items():
+ if "$ref" in definition:
+ ref = definition["$ref"]
+ ref_def = definitions[ref.replace("#/definitions/", "")]
+ field_props = ref_def.get("properties", ref_def)
+ expanded_props = self._expands_schema(field_props, definitions)
+ new_schema[name] = expanded_props.get("properties", expanded_props)
+ elif "items" in definition and "$ref" in definition["items"]:
+ ref = definition["items"]["$ref"]
+ ref_def = definitions[ref.replace("#/definitions/", "")]
+ field_props = ref_def.get("properties", ref_def)
+ expanded_props = self._expands_schema(field_props, definitions)
+ definition["items"] = expanded_props.get("properties", expanded_props)
+ new_schema[name] = definition
+ else:
+ new_schema[name] = definition
+ return new_schema
+
@pytest.fixture(scope="session")
def helpers():
diff --git a/tests/client/sdk/text2text/test_models.py b/tests/client/sdk/text2text/test_models.py
index 236c025c98..b9a9fd4ac1 100644
--- a/tests/client/sdk/text2text/test_models.py
+++ b/tests/client/sdk/text2text/test_models.py
@@ -27,7 +27,7 @@
)
from rubrix.client.sdk.text2text.models import Text2TextRecord as SdkText2TextRecord
from rubrix.server.apis.v0.models.text2text import (
- Text2TextBulkData as ServerText2TextBulkData,
+ Text2TextBulkRequest as ServerText2TextBulkData,
)
from rubrix.server.apis.v0.models.text2text import (
Text2TextQuery as ServerText2TextQuery,
@@ -37,19 +37,14 @@
def test_bulk_data_schema(helpers):
client_schema = Text2TextBulkData.schema()
server_schema = ServerText2TextBulkData.schema()
-
- assert helpers.remove_description(client_schema) == helpers.remove_description(
- server_schema
- )
+ assert helpers.are_compatible_api_schemas(client_schema, server_schema)
def test_query_schema(helpers):
client_schema = Text2TextQuery.schema()
server_schema = ServerText2TextQuery.schema()
- assert helpers.remove_description(client_schema) == helpers.remove_description(
- server_schema
- )
+ assert helpers.are_compatible_api_schemas(client_schema, server_schema)
@pytest.mark.parametrize(
diff --git a/tests/client/sdk/text_classification/test_models.py b/tests/client/sdk/text_classification/test_models.py
index 16429a811e..d0d72260c1 100644
--- a/tests/client/sdk/text_classification/test_models.py
+++ b/tests/client/sdk/text_classification/test_models.py
@@ -37,7 +37,7 @@
LabelingRuleMetricsSummary as ServerLabelingRuleMetricsSummary,
)
from rubrix.server.apis.v0.models.text_classification import (
- TextClassificationBulkData as ServerTextClassificationBulkData,
+ TextClassificationBulkRequest as ServerTextClassificationBulkData,
)
from rubrix.server.apis.v0.models.text_classification import (
TextClassificationQuery as ServerTextClassificationQuery,
@@ -48,18 +48,14 @@ def test_bulk_data_schema(helpers):
client_schema = TextClassificationBulkData.schema()
server_schema = ServerTextClassificationBulkData.schema()
- assert helpers.remove_description(client_schema) == helpers.remove_description(
- server_schema
- )
+ assert helpers.are_compatible_api_schemas(client_schema, server_schema)
def test_query_schema(helpers):
client_schema = TextClassificationQuery.schema()
server_schema = ServerTextClassificationQuery.schema()
- assert helpers.remove_description(client_schema) == helpers.remove_description(
- server_schema
- )
+ assert helpers.are_compatible_api_schemas(client_schema, server_schema)
def test_labeling_rule_schema(helpers):
diff --git a/tests/client/sdk/token_classification/test_models.py b/tests/client/sdk/token_classification/test_models.py
index 9dfe5e848c..6cd5e05e88 100644
--- a/tests/client/sdk/token_classification/test_models.py
+++ b/tests/client/sdk/token_classification/test_models.py
@@ -29,7 +29,7 @@
TokenClassificationRecord as SdkTokenClassificationRecord,
)
from rubrix.server.apis.v0.models.token_classification import (
- TokenClassificationBulkData as ServerTokenClassificationBulkData,
+ TokenClassificationBulkRequest as ServerTokenClassificationBulkData,
)
from rubrix.server.apis.v0.models.token_classification import (
TokenClassificationQuery as ServerTokenClassificationQuery,
@@ -40,18 +40,14 @@ def test_bulk_data_schema(helpers):
client_schema = TokenClassificationBulkData.schema()
server_schema = ServerTokenClassificationBulkData.schema()
- assert helpers.remove_description(client_schema) == helpers.remove_description(
- server_schema
- )
+ assert helpers.are_compatible_api_schemas(client_schema, server_schema)
def test_query_schema(helpers):
client_schema = TokenClassificationQuery.schema()
server_schema = ServerTokenClassificationQuery.schema()
- assert helpers.remove_description(client_schema) == helpers.remove_description(
- server_schema
- )
+ assert helpers.are_compatible_api_schemas(client_schema, server_schema)
@pytest.mark.parametrize(
diff --git a/tests/functional_tests/search/test_search_service.py b/tests/functional_tests/search/test_search_service.py
index 7c3d3b8853..5d9872bf8a 100644
--- a/tests/functional_tests/search/test_search_service.py
+++ b/tests/functional_tests/search/test_search_service.py
@@ -1,69 +1,62 @@
import pytest
import rubrix
-from rubrix.server.apis.v0.models.commons.model import ScoreRange, TaskType
+from rubrix.server.apis.v0.models.commons.model import ScoreRange
from rubrix.server.apis.v0.models.datasets import Dataset
-from rubrix.server.apis.v0.models.metrics.base import BaseMetric
from rubrix.server.apis.v0.models.text_classification import (
TextClassificationQuery,
TextClassificationRecord,
)
from rubrix.server.apis.v0.models.token_classification import TokenClassificationQuery
+from rubrix.server.commons.models import TaskType
+from rubrix.server.daos.backend.elasticsearch import ElasticsearchBackend
from rubrix.server.daos.records import DatasetRecordsDAO
-from rubrix.server.elasticseach.client_wrapper import ElasticsearchWrapper
-from rubrix.server.services.metrics import MetricsService
-from rubrix.server.services.search.model import SortConfig
-from rubrix.server.services.search.query_builder import EsQueryBuilder
+from rubrix.server.services.metrics import MetricsService, ServicePythonMetric
+from rubrix.server.services.search.model import ServiceSortConfig
from rubrix.server.services.search.service import SearchRecordsService
@pytest.fixture
-def es_wrapper():
- return ElasticsearchWrapper.get_instance()
+def backend():
+ return ElasticsearchBackend.get_instance()
@pytest.fixture
-def dao(es_wrapper: ElasticsearchWrapper):
- return DatasetRecordsDAO.get_instance(es=es_wrapper)
+def dao(backend: ElasticsearchBackend):
+ return DatasetRecordsDAO.get_instance(es=backend)
@pytest.fixture
-def query_builder(dao: DatasetRecordsDAO):
- return EsQueryBuilder.get_instance(dao=dao)
+def metrics(dao: DatasetRecordsDAO):
+ return MetricsService.get_instance(dao=dao)
@pytest.fixture
-def metrics(dao: DatasetRecordsDAO, query_builder: EsQueryBuilder):
- return MetricsService.get_instance(dao=dao, query_builder=query_builder)
+def service(dao: DatasetRecordsDAO, metrics: MetricsService):
+ return SearchRecordsService.get_instance(dao=dao, metrics=metrics)
-@pytest.fixture
-def service(
- dao: DatasetRecordsDAO, metrics: MetricsService, query_builder: EsQueryBuilder
-):
- return SearchRecordsService.get_instance(
- dao=dao, metrics=metrics, query_builder=query_builder
- )
-
-
-def test_query_builder_with_query_range(query_builder):
- es_query = query_builder(
- "ds", query=TextClassificationQuery(score=ScoreRange(range_from=10))
+def test_query_builder_with_query_range(backend: ElasticsearchBackend):
+ es_query = backend.query_builder.map_2_es_query(
+ schema=None,
+ query=TextClassificationQuery(score=ScoreRange(range_from=10)),
)
assert es_query == {
- "bool": {
- "filter": {
- "bool": {
- "minimum_should_match": 1,
- "should": [{"range": {"score": {"gte": 10.0}}}],
- }
- },
- "must": {"match_all": {}},
+ "query": {
+ "bool": {
+ "filter": {
+ "bool": {
+ "minimum_should_match": 1,
+ "should": [{"range": {"score": {"gte": 10.0}}}],
+ }
+ },
+ "must": {"match_all": {}},
+ }
}
}
-def test_query_builder_with_nested(query_builder, mocked_client):
+def test_query_builder_with_nested(mocked_client, dao, backend: ElasticsearchBackend):
dataset = Dataset(
name="test_query_builder_with_nested",
owner=rubrix.get_workspace(),
@@ -79,8 +72,8 @@ def test_query_builder_with_nested(query_builder, mocked_client):
),
)
- es_query = query_builder(
- dataset=dataset,
+ es_query = backend.query_builder.map_2_es_query(
+ schema=dao.get_dataset_schema(dataset),
query=TokenClassificationQuery(
advanced_query_dsl=True,
query_text="metrics.predicted.mentions:(label:NAME AND score:[* TO 0.1])",
@@ -88,33 +81,35 @@ def test_query_builder_with_nested(query_builder, mocked_client):
)
assert es_query == {
- "bool": {
- "filter": {"bool": {"must": {"match_all": {}}}},
- "must": {
- "nested": {
- "path": "metrics.predicted.mentions",
- "query": {
- "bool": {
- "must": [
- {
- "term": {
- "metrics.predicted.mentions.label": {
- "value": "NAME"
+ "query": {
+ "bool": {
+ "filter": {"bool": {"must": {"match_all": {}}}},
+ "must": {
+ "nested": {
+ "path": "metrics.predicted.mentions",
+ "query": {
+ "bool": {
+ "must": [
+ {
+ "term": {
+ "metrics.predicted.mentions.label": {
+ "value": "NAME"
+ }
}
- }
- },
- {
- "range": {
- "metrics.predicted.mentions.score": {
- "lte": "0.1"
+ },
+ {
+ "range": {
+ "metrics.predicted.mentions.score": {
+ "lte": "0.1"
+ }
}
- }
- },
- ]
- }
- },
- }
- },
+ },
+ ]
+ }
+ },
+ }
+ },
+ }
}
}
@@ -135,8 +130,8 @@ def test_failing_metrics(service, mocked_client):
results = service.search(
dataset=dataset,
query=TextClassificationQuery(),
- sort_config=SortConfig(),
- metrics=[BaseMetric(id="missing-metric", name="Missing metric")],
+ sort_config=ServiceSortConfig(),
+ metrics=[ServicePythonMetric(id="missing-metric", name="Missing metric")],
size=0,
record_type=TextClassificationRecord,
)
diff --git a/tests/labeling/text_classification/test_rule.py b/tests/labeling/text_classification/test_rule.py
index da933b93bd..9c8c86c4cf 100644
--- a/tests/labeling/text_classification/test_rule.py
+++ b/tests/labeling/text_classification/test_rule.py
@@ -237,7 +237,7 @@ def test_rule_metrics_without_annotated(
)
metrics = rule.metrics(log_dataset_without_annotations)
- assert metrics == expected_metrics
+ assert expected_metrics == metrics
def delete_rule_silently(client, dataset: str, rule: Rule):
diff --git a/tests/metrics/test_text_classification.py b/tests/metrics/test_text_classification.py
index 9ad3b91cee..f523707410 100644
--- a/tests/metrics/test_text_classification.py
+++ b/tests/metrics/test_text_classification.py
@@ -38,8 +38,7 @@ def test_metrics_for_text_classification(mocked_client):
)
results = f1(dataset)
- assert results
- assert results.data == {
+ assert results and results.data == {
"f1_macro": 1.0,
"f1_micro": 1.0,
"ham_f1": 1.0,
@@ -58,8 +57,7 @@ def test_metrics_for_text_classification(mocked_client):
results.visualize()
results = f1_multilabel(dataset)
- assert results
- assert results.data == {
+ assert results and results.data == {
"f1_macro": 1.0,
"f1_micro": 1.0,
"ham_f1": 1.0,
diff --git a/tests/server/backend/__init__.py b/tests/server/backend/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/server/backend/test_query_builder.py b/tests/server/backend/test_query_builder.py
new file mode 100644
index 0000000000..cfbe0f1a0c
--- /dev/null
+++ b/tests/server/backend/test_query_builder.py
@@ -0,0 +1,66 @@
+import pytest
+
+from rubrix.server.daos.backend.search.model import SortableField, SortConfig, SortOrder
+from rubrix.server.daos.backend.search.query_builder import EsQueryBuilder
+
+
+@pytest.mark.parametrize(
+ ["index_schema", "sort_cfg", "expected_sort"],
+ [
+ (
+ {
+ "mappings": {
+ "properties": {
+ "id": {"type": "text"},
+ }
+ }
+ },
+ [SortableField(id="id")],
+ [{"id.keyword": {"order": SortOrder.asc}}],
+ ),
+ (
+ {
+ "mappings": {
+ "properties": {
+ "id": {"type": "keyword"},
+ }
+ }
+ },
+ [SortableField(id="id")],
+ [{"id": {"order": SortOrder.asc}}],
+ ),
+ (
+ {
+ "mappings": {
+ "properties": {
+ "id": {"type": "keyword"},
+ }
+ }
+ },
+ [SortableField(id="metadata.black", order=SortOrder.desc)],
+ [{"metadata.black": {"order": SortOrder.desc}}],
+ ),
+ ],
+)
+def test_build_sort_configuration(index_schema, sort_cfg, expected_sort):
+
+ builder = EsQueryBuilder()
+
+ es_sort = builder.map_2_es_sort_configuration(
+ sort=SortConfig(sort_by=sort_cfg), schema=index_schema
+ )
+ assert es_sort == expected_sort
+
+
+def test_build_sort_with_wrong_field_name():
+ builder = EsQueryBuilder()
+
+ with pytest.raises(Exception):
+ builder.map_2_es_sort_configuration(
+ sort=SortConfig(sort_by=[SortableField(id="wat?!")])
+ )
+
+
+def test_build_sort_without_sort_config():
+ builder = EsQueryBuilder()
+ assert builder.map_2_es_sort_configuration() is None
diff --git a/tests/server/commons/test_records_dao.py b/tests/server/commons/test_records_dao.py
index db8dc4b104..92de61cb8f 100644
--- a/tests/server/commons/test_records_dao.py
+++ b/tests/server/commons/test_records_dao.py
@@ -1,15 +1,17 @@
import pytest
-from rubrix.server.apis.v0.models.commons.model import TaskType
-from rubrix.server.apis.v0.models.datasets import DatasetDB
+from rubrix.server.commons.models import TaskType
+from rubrix.server.daos.backend.elasticsearch import ElasticsearchBackend
+from rubrix.server.daos.models.datasets import BaseDatasetDB
from rubrix.server.daos.records import DatasetRecordsDAO
-from rubrix.server.elasticseach.client_wrapper import ElasticsearchWrapper
from rubrix.server.errors import MissingDatasetRecordsError
def test_raise_proper_error():
- dao = DatasetRecordsDAO.get_instance(ElasticsearchWrapper.get_instance())
+ dao = DatasetRecordsDAO.get_instance(ElasticsearchBackend.get_instance())
with pytest.raises(MissingDatasetRecordsError):
dao.search_records(
- dataset=DatasetDB(name="mock-notfound", task=TaskType.text_classification)
+ dataset=BaseDatasetDB(
+ name="mock-notfound", task=TaskType.text_classification
+ )
)
diff --git a/tests/server/datasets/test_api.py b/tests/server/datasets/test_api.py
index 37a04cdbd8..fd99debe8b 100644
--- a/tests/server/datasets/test_api.py
+++ b/tests/server/datasets/test_api.py
@@ -14,9 +14,11 @@
# limitations under the License.
from typing import Optional
-from rubrix.server.apis.v0.models.commons.model import TaskType
from rubrix.server.apis.v0.models.datasets import Dataset
-from rubrix.server.apis.v0.models.text_classification import TextClassificationBulkData
+from rubrix.server.apis.v0.models.text_classification import (
+ TextClassificationBulkRequest,
+)
+from rubrix.server.commons.models import TaskType
from tests.helpers import SecuredClient
@@ -31,7 +33,7 @@ def test_delete_dataset(mocked_client):
assert response.json() == {
"detail": {
"code": "rubrix.api.errors::EntityNotFoundError",
- "params": {"name": "test_delete_dataset", "type": "Dataset"},
+ "params": {"name": "test_delete_dataset", "type": "ServiceDataset"},
}
}
@@ -108,7 +110,7 @@ def test_fetch_dataset_using_workspaces(mocked_client: SecuredClient):
def test_dataset_naming_validation(mocked_client):
- request = TextClassificationBulkData(records=[])
+ request = TextClassificationBulkRequest(records=[])
dataset = "Wrong dataset name"
response = mocked_client.post(
@@ -128,7 +130,7 @@ def test_dataset_naming_validation(mocked_client):
"type": "value_error.str.regex",
}
],
- "model": "TextClassificationDatasetDB",
+ "model": "TextClassificationDataset",
},
}
}
@@ -150,7 +152,7 @@ def test_dataset_naming_validation(mocked_client):
"type": "value_error.str.regex",
}
],
- "model": "TokenClassificationDatasetDB",
+ "model": "TokenClassificationDataset",
},
}
}
@@ -221,7 +223,7 @@ def delete_dataset(client, dataset, workspace: Optional[str] = None):
def create_mock_dataset(client, dataset):
client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- json=TextClassificationBulkData(
+ json=TextClassificationBulkRequest(
tags={"env": "test", "class": "text classification"},
metadata={"config": {"the": "config"}},
records=[],
diff --git a/tests/server/datasets/test_dao.py b/tests/server/datasets/test_dao.py
index 19c200e13f..5364eae1bd 100644
--- a/tests/server/datasets/test_dao.py
+++ b/tests/server/datasets/test_dao.py
@@ -15,27 +15,22 @@
import pytest
-from rubrix.server.apis.v0.config.tasks_factory import TaskFactory
-from rubrix.server.apis.v0.models.commons.model import TaskType
-from rubrix.server.apis.v0.models.datasets import DatasetDB
+from rubrix.server.commons.models import TaskType
+from rubrix.server.daos.backend.elasticsearch import ElasticsearchBackend
from rubrix.server.daos.datasets import DatasetsDAO
-from rubrix.server.daos.records import dataset_records_dao
-from rubrix.server.elasticseach.client_wrapper import create_es_wrapper
-from rubrix.server.elasticseach.mappings.text_classification import (
- text_classification_mappings,
-)
+from rubrix.server.daos.models.datasets import BaseDatasetDB
+from rubrix.server.daos.records import DatasetRecordsDAO
from rubrix.server.errors import ClosedDatasetError
-es_wrapper = create_es_wrapper()
-records = dataset_records_dao(es_wrapper)
+es_wrapper = ElasticsearchBackend.get_instance()
+records = DatasetRecordsDAO.get_instance(es_wrapper)
dao = DatasetsDAO.get_instance(es_wrapper, records)
def test_retrieve_ownered_dataset_for_no_owner_user():
dataset = "test_retrieve_owned_dataset_for_no_owner_user"
created = dao.create_dataset(
- DatasetDB(name=dataset, owner="other", task=TaskType.text_classification),
- mappings=TaskFactory.get_task_mappings(TaskType.text_classification),
+ BaseDatasetDB(name=dataset, owner="other", task=TaskType.text_classification),
)
assert dao.find_by_name(created.name, owner=created.owner) == created
assert dao.find_by_name(created.name, owner=None) == created
@@ -46,8 +41,7 @@ def test_close_dataset():
dataset = "test_close_dataset"
created = dao.create_dataset(
- DatasetDB(name=dataset, owner="other", task=TaskType.text_classification),
- mappings=TaskFactory.get_task_mappings(TaskType.text_classification),
+ BaseDatasetDB(name=dataset, owner="other", task=TaskType.text_classification),
)
dao.close(created)
diff --git a/tests/server/datasets/test_model.py b/tests/server/datasets/test_model.py
index 2e2234a60c..ffa558001e 100644
--- a/tests/server/datasets/test_model.py
+++ b/tests/server/datasets/test_model.py
@@ -16,7 +16,8 @@
import pytest
from pydantic import ValidationError
-from rubrix.server.apis.v0.models.datasets import CreationDatasetRequest
+from rubrix.server.apis.v0.models.datasets import CreateDatasetRequest
+from rubrix.server.commons.models import TaskType
@pytest.mark.parametrize(
@@ -24,7 +25,7 @@
["fine", "fine33", "fine_33", "fine-3-3"],
)
def test_dataset_naming_ok(name):
- request = CreationDatasetRequest(name=name)
+ request = CreateDatasetRequest(name=name, task=TaskType.token_classification)
assert request.name == name
@@ -42,4 +43,4 @@ def test_dataset_naming_ok(name):
)
def test_dataset_naming_ko(name):
with pytest.raises(ValidationError, match="string does not match regex"):
- CreationDatasetRequest(name=name)
+ CreateDatasetRequest(name=name, task=TaskType.token_classification)
diff --git a/tests/server/info/test_api.py b/tests/server/info/test_api.py
index da099c457f..f21c40baed 100644
--- a/tests/server/info/test_api.py
+++ b/tests/server/info/test_api.py
@@ -14,7 +14,7 @@
# limitations under the License.
from rubrix import __version__ as rubrix_version
-from rubrix.server.apis.v0.models.info import ApiInfo, ApiStatus
+from rubrix.server.services.info import ApiInfo, ApiStatus
def test_api_info(mocked_client):
diff --git a/tests/server/metrics/test_api.py b/tests/server/metrics/test_api.py
index 14637defb5..c944c16757 100644
--- a/tests/server/metrics/test_api.py
+++ b/tests/server/metrics/test_api.py
@@ -13,19 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from rubrix.server.apis.v0.models.metrics.commons import CommonTasksMetrics
-from rubrix.server.apis.v0.models.metrics.token_classification import (
- TokenClassificationMetrics,
-)
-from rubrix.server.apis.v0.models.text2text import Text2TextBulkData, Text2TextRecord
+from rubrix.server.apis.v0.models.text2text import Text2TextBulkRequest, Text2TextRecord
from rubrix.server.apis.v0.models.text_classification import (
- TextClassificationBulkData,
+ TextClassificationBulkRequest,
TextClassificationRecord,
)
from rubrix.server.apis.v0.models.token_classification import (
- TokenClassificationBulkData,
+ TokenClassificationBulkRequest,
TokenClassificationRecord,
)
+from rubrix.server.services.metrics.models import CommonTasksMetrics
+from rubrix.server.services.tasks.token_classification.metrics import (
+ TokenClassificationMetrics,
+)
COMMON_METRICS_LENGTH = len(CommonTasksMetrics.metrics)
@@ -41,7 +41,7 @@ def test_wrong_dataset_metrics(mocked_client):
{"text": text},
]
]
- request = Text2TextBulkData(records=records)
+ request = Text2TextBulkRequest(records=records)
dataset = "test_wrong_dataset_metrics"
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
@@ -91,7 +91,7 @@ def test_dataset_for_text2text(mocked_client):
{"text": text},
]
]
- request = Text2TextBulkData(records=records)
+ request = Text2TextBulkRequest(records=records)
dataset = "test_dataset_for_text2text"
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
@@ -121,7 +121,7 @@ def test_dataset_for_token_classification(mocked_client):
]
]
- request = TokenClassificationBulkData(records=records)
+ request = TokenClassificationBulkRequest(records=records)
dataset = "test_dataset_for_token_classification"
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
@@ -146,7 +146,7 @@ def test_dataset_for_token_classification(mocked_client):
json={},
)
- assert response.status_code == 200, response.json()
+ assert response.status_code == 200, f"{metric} :: {response.json()}"
summary = response.json()
if not ("predicted" in metric_id or "annotated" in metric_id):
@@ -171,7 +171,7 @@ def test_dataset_metrics(mocked_client):
},
]
]
- request = TextClassificationBulkData(records=records)
+ request = TextClassificationBulkRequest(records=records)
dataset = "test_get_dataset_metrics"
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
@@ -199,7 +199,7 @@ def test_dataset_metrics(mocked_client):
assert response.json() == {
"detail": {
"code": "rubrix.api.errors::EntityNotFoundError",
- "params": {"name": "missing_metric", "type": "BaseMetric"},
+ "params": {"name": "missing_metric", "type": "ServiceBaseMetric"},
}
}
@@ -242,7 +242,7 @@ def test_dataset_labels_for_text_classification(mocked_client):
},
]
]
- request = TextClassificationBulkData(records=records)
+ request = TextClassificationBulkRequest(records=records)
dataset = "test_dataset_labels_for_text_classification"
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
diff --git a/tests/server/test_api.py b/tests/server/test_api.py
index 5cbd00c0b7..972e3f7134 100644
--- a/tests/server/test_api.py
+++ b/tests/server/test_api.py
@@ -15,12 +15,11 @@
import os
-from rubrix.server.apis.v0.models.commons.model import TaskStatus
from rubrix.server.apis.v0.models.text_classification import (
- TaskType,
- TextClassificationBulkData,
+ TextClassificationBulkRequest,
TextClassificationRecord,
)
+from rubrix.server.commons.models import TaskStatus, TaskType
def create_some_data_for_text_classification(client, name: str, n: int):
@@ -65,7 +64,7 @@ def create_some_data_for_text_classification(client, name: str, n: int):
]
client.post(
f"/api/datasets/{name}/{TaskType.text_classification}:bulk",
- json=TextClassificationBulkData(
+ json=TextClassificationBulkRequest(
tags={"env": "test", "class": "text classification"},
metadata={"config": {"the": "config"}},
records=records,
diff --git a/tests/server/text2text/test_api.py b/tests/server/text2text/test_api.py
index a02b44966b..40667143fd 100644
--- a/tests/server/text2text/test_api.py
+++ b/tests/server/text2text/test_api.py
@@ -1,7 +1,7 @@
from rubrix.client.sdk.text2text.models import Text2TextBulkData
from rubrix.server.apis.v0.models.commons.model import BulkResponse
from rubrix.server.apis.v0.models.text2text import (
- CreationText2TextRecord,
+ Text2TextRecordInputs,
Text2TextSearchResults,
)
@@ -11,7 +11,7 @@ def test_search_records(mocked_client):
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
records = [
- CreationText2TextRecord.parse_obj(data)
+ Text2TextRecordInputs.parse_obj(data)
for data in [
{
"id": 0,
diff --git a/tests/server/text2text/test_model.py b/tests/server/text2text/test_model.py
index 4e9ee2f106..d8e0f56498 100644
--- a/tests/server/text2text/test_model.py
+++ b/tests/server/text2text/test_model.py
@@ -4,7 +4,7 @@
Text2TextQuery,
Text2TextRecord,
)
-from rubrix.server.services.search.query_builder import EsQueryBuilder
+from rubrix.server.daos.backend.search.query_builder import EsQueryBuilder
def test_sentences_sorted_by_score():
@@ -57,4 +57,4 @@ def test_model_dict():
def test_query_as_elasticsearch():
query = Text2TextQuery(ids=[1, 2, 3])
- assert EsQueryBuilder.to_es_query(query) == {"ids": {"values": query.ids}}
+ assert EsQueryBuilder._to_es_query(query) == {"ids": {"values": query.ids}}
diff --git a/tests/server/text_classification/test_api.py b/tests/server/text_classification/test_api.py
index dc3381bf90..1e8a12248b 100644
--- a/tests/server/text_classification/test_api.py
+++ b/tests/server/text_classification/test_api.py
@@ -15,16 +15,17 @@
from datetime import datetime
-from rubrix.server.apis.v0.models.commons.model import BulkResponse, PredictionStatus
+from rubrix.server.apis.v0.models.commons.model import BulkResponse
from rubrix.server.apis.v0.models.datasets import Dataset
from rubrix.server.apis.v0.models.text_classification import (
TextClassificationAnnotation,
- TextClassificationBulkData,
+ TextClassificationBulkRequest,
TextClassificationQuery,
TextClassificationRecord,
TextClassificationSearchRequest,
TextClassificationSearchResults,
)
+from rubrix.server.commons.models import PredictionStatus
def test_create_records_for_text_classification_with_multi_label(mocked_client):
@@ -70,7 +71,7 @@ def test_create_records_for_text_classification_with_multi_label(mocked_client):
]
response = mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- json=TextClassificationBulkData(
+ json=TextClassificationBulkRequest(
tags={"env": "test", "class": "text classification"},
metadata={"config": {"the": "config"}},
records=records,
@@ -85,7 +86,7 @@ def test_create_records_for_text_classification_with_multi_label(mocked_client):
response = mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- json=TextClassificationBulkData(
+ json=TextClassificationBulkRequest(
tags={"new": "tag"},
metadata={"new": {"metadata": "value"}},
records=records,
@@ -123,7 +124,7 @@ def test_create_records_for_text_classification(mocked_client):
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
tags = {"env": "test", "class": "text classification"}
metadata = {"config": {"the": "config"}}
- classification_bulk = TextClassificationBulkData(
+ classification_bulk = TextClassificationBulkRequest(
tags=tags,
metadata=metadata,
records=[
@@ -197,7 +198,7 @@ def test_partial_record_update(mocked_client):
}
)
- bulk = TextClassificationBulkData(
+ bulk = TextClassificationBulkRequest(
records=[record],
)
@@ -268,7 +269,7 @@ def test_sort_by_last_updated(mocked_client):
for i in range(0, 10):
mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- json=TextClassificationBulkData(
+ json=TextClassificationBulkRequest(
records=[
TextClassificationRecord(
**{
@@ -294,7 +295,7 @@ def test_sort_by_id_as_default(mocked_client):
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
response = mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- json=TextClassificationBulkData(
+ json=TextClassificationBulkRequest(
records=[
TextClassificationRecord(
**{
@@ -335,7 +336,7 @@ def test_some_sort_by(mocked_client):
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- json=TextClassificationBulkData(
+ json=TextClassificationBulkRequest(
records=[
TextClassificationRecord(
**{
@@ -366,10 +367,10 @@ def test_some_sort_by(mocked_client):
"code": "rubrix.api.errors::BadRequestError",
"params": {
"message": "Wrong sort id wrong_field. Valid values "
- "are: ['metadata', 'last_updated', 'score', "
+ "are: ['id', 'metadata', 'score', "
"'predicted', 'predicted_as', "
"'predicted_by', 'annotated_as', "
- "'annotated_by', 'status', "
+ "'annotated_by', 'status', 'last_updated', "
"'event_timestamp']"
},
}
@@ -407,7 +408,7 @@ def test_disable_aggregations_when_scroll(mocked_client):
response = mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- json=TextClassificationBulkData(
+ json=TextClassificationBulkRequest(
tags={"env": "test", "class": "text classification"},
metadata={"config": {"the": "config"}},
records=[
@@ -447,7 +448,7 @@ def test_include_event_timestamp(mocked_client):
response = mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- data=TextClassificationBulkData(
+ data=TextClassificationBulkRequest(
tags={"env": "test", "class": "text classification"},
metadata={"config": {"the": "config"}},
records=[
@@ -488,7 +489,7 @@ def test_words_cloud(mocked_client):
response = mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- data=TextClassificationBulkData(
+ data=TextClassificationBulkRequest(
records=[
TextClassificationRecord(
**{
@@ -528,7 +529,7 @@ def test_metadata_with_point_in_field_name(mocked_client):
response = mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- data=TextClassificationBulkData(
+ data=TextClassificationBulkRequest(
records=[
TextClassificationRecord(
**{
@@ -565,7 +566,7 @@ def test_wrong_text_query(mocked_client):
mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- data=TextClassificationBulkData(
+ data=TextClassificationBulkRequest(
records=[
TextClassificationRecord(
**{
@@ -599,7 +600,7 @@ def test_search_using_text(mocked_client):
mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- data=TextClassificationBulkData(
+ data=TextClassificationBulkRequest(
records=[
TextClassificationRecord(
**{
diff --git a/tests/server/text_classification/test_api_rules.py b/tests/server/text_classification/test_api_rules.py
index 5a24a1770e..ce16623f15 100644
--- a/tests/server/text_classification/test_api_rules.py
+++ b/tests/server/text_classification/test_api_rules.py
@@ -4,7 +4,7 @@
CreateLabelingRule,
LabelingRule,
LabelingRuleMetricsSummary,
- TextClassificationBulkData,
+ TextClassificationBulkRequest,
TextClassificationRecord,
)
@@ -31,7 +31,7 @@ def log_some_records(
response = client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
- data=TextClassificationBulkData(
+ data=TextClassificationBulkRequest(
records=[
TextClassificationRecord(**record),
],
@@ -202,7 +202,7 @@ def test_duplicated_dataset_rules(mocked_client):
assert response.json() == {
"detail": {
"code": "rubrix.api.errors::EntityAlreadyExistsError",
- "params": {"name": "a query", "type": "LabelingRule"},
+ "params": {"name": "a query", "type": "ServiceLabelingRule"},
}
}
diff --git a/tests/server/text_classification/test_api_settings.py b/tests/server/text_classification/test_api_settings.py
index 585183f0d1..4dd170896d 100644
--- a/tests/server/text_classification/test_api_settings.py
+++ b/tests/server/text_classification/test_api_settings.py
@@ -1,5 +1,5 @@
import rubrix as rb
-from rubrix.server.apis.v0.models.commons.model import TaskType
+from rubrix.server.commons.models import TaskType
def create_dataset(client, name: str):
diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py
index c41bf0fbf4..a4f4edec73 100644
--- a/tests/server/text_classification/test_model.py
+++ b/tests/server/text_classification/test_model.py
@@ -16,15 +16,17 @@
from pydantic import ValidationError
from rubrix._constants import MAX_KEYWORD_LENGTH
-from rubrix.server.apis.v0.models.commons.model import TaskStatus
from rubrix.server.apis.v0.models.text_classification import (
- ClassPrediction,
- PredictionStatus,
TextClassificationAnnotation,
TextClassificationQuery,
TextClassificationRecord,
)
-from rubrix.server.services.search.query_builder import EsQueryBuilder
+from rubrix.server.commons.models import PredictionStatus, TaskStatus
+from rubrix.server.daos.backend.search.query_builder import EsQueryBuilder
+from rubrix.server.services.tasks.text_classification.model import (
+ ClassPrediction,
+ ServiceTextClassificationRecord,
+)
def test_flatten_metadata():
@@ -34,7 +36,7 @@ def test_flatten_metadata():
"mail": {"subject": "The mail subject", "body": "This is a large text body"}
},
}
- record = TextClassificationRecord.parse_obj(data)
+ record = ServiceTextClassificationRecord.parse_obj(data)
assert list(record.metadata.keys()) == ["mail.subject", "mail.body"]
@@ -48,7 +50,7 @@ def test_metadata_with_object_list():
]
},
}
- record = TextClassificationRecord.parse_obj(data)
+ record = ServiceTextClassificationRecord.parse_obj(data)
assert list(record.metadata.keys()) == ["mails"]
@@ -87,7 +89,7 @@ def test_single_label_with_multiple_annotation():
ValidationError,
match="Single label record must include only one annotation label",
):
- TextClassificationRecord.parse_obj(
+ ServiceTextClassificationRecord.parse_obj(
{
"inputs": {"text": "This is a text"},
"annotation": {
@@ -100,7 +102,7 @@ def test_single_label_with_multiple_annotation():
def test_too_long_metadata():
- record = TextClassificationRecord.parse_obj(
+ record = ServiceTextClassificationRecord.parse_obj(
{
"inputs": {"text": "bogh"},
"metadata": {"too_long": "a" * 1000},
@@ -112,7 +114,7 @@ def test_too_long_metadata():
def test_too_long_label():
with pytest.raises(ValidationError, match="exceeds max length"):
- TextClassificationRecord.parse_obj(
+ ServiceTextClassificationRecord.parse_obj(
{
"inputs": {"text": "bogh"},
"prediction": {
@@ -137,26 +139,26 @@ def test_score_integrity():
}
try:
- TextClassificationRecord.parse_obj(data)
+ ServiceTextClassificationRecord.parse_obj(data)
except ValidationError as e:
assert "Wrong score distributions" in e.json()
data["multi_label"] = True
- record = TextClassificationRecord.parse_obj(data)
+ record = ServiceTextClassificationRecord.parse_obj(data)
assert record is not None
data["multi_label"] = False
data["prediction"]["labels"] = [
{"class": "B", "score": 0.9},
]
- record = TextClassificationRecord.parse_obj(data)
+ record = ServiceTextClassificationRecord.parse_obj(data)
assert record is not None
data["prediction"]["labels"] = [
{"class": "B", "score": 0.10000000012},
{"class": "B", "score": 0.90000000002},
]
- record = TextClassificationRecord.parse_obj(data)
+ record = ServiceTextClassificationRecord.parse_obj(data)
assert record is not None
@@ -174,7 +176,7 @@ def test_prediction_ok_cases():
},
}
- record = TextClassificationRecord(**data)
+ record = ServiceTextClassificationRecord(**data)
assert record.predicted is None
record.annotation = TextClassificationAnnotation(
**{
@@ -215,7 +217,7 @@ def test_predicted_as_with_no_labels():
"inputs": {"text": "The input text"},
"prediction": {"agent": "test", "labels": []},
}
- record = TextClassificationRecord(**data)
+ record = ServiceTextClassificationRecord(**data)
assert record.predicted_as == []
@@ -224,12 +226,12 @@ def test_created_record_with_default_status():
"inputs": {"data": "My cool data"},
}
- record = TextClassificationRecord.parse_obj(data)
+ record = ServiceTextClassificationRecord.parse_obj(data)
assert record.status == TaskStatus.default
def test_predicted_ok_for_multilabel_unordered():
- record = TextClassificationRecord(
+ record = ServiceTextClassificationRecord(
inputs={"text": "The text"},
prediction=TextClassificationAnnotation(
agent="test",
@@ -264,7 +266,7 @@ def test_validate_without_labels_for_single_label(annotation):
ValidationError,
match="Annotation must include some label for validated records",
):
- TextClassificationRecord(
+ ServiceTextClassificationRecord(
inputs={"text": "The text"},
status=TaskStatus.validated,
prediction=TextClassificationAnnotation(
@@ -281,7 +283,7 @@ def test_query_with_uncovered_by_rules():
query = TextClassificationQuery(uncovered_by_rules=["query", "other*"])
- assert EsQueryBuilder.to_es_query(query) == {
+ assert EsQueryBuilder._to_es_query(query) == {
"bool": {
"must": {"match_all": {}},
"must_not": {
@@ -322,12 +324,12 @@ def test_empty_labels_for_no_multilabel():
ValidationError,
match="Single label record must include only one annotation label",
):
- TextClassificationRecord(
+ ServiceTextClassificationRecord(
inputs={"text": "The input text"},
annotation=TextClassificationAnnotation(agent="ann.", labels=[]),
)
- record = TextClassificationRecord(
+ record = ServiceTextClassificationRecord(
inputs={"text": "The input text"},
prediction=TextClassificationAnnotation(agent="ann.", labels=[]),
annotation=TextClassificationAnnotation(
@@ -338,7 +340,7 @@ def test_empty_labels_for_no_multilabel():
def test_annotated_without_labels_for_multilabel():
- record = TextClassificationRecord(
+ record = ServiceTextClassificationRecord(
inputs={"text": "The input text"},
multi_label=True,
prediction=TextClassificationAnnotation(agent="pred.", labels=[]),
diff --git a/tests/server/token_classification/test_api.py b/tests/server/token_classification/test_api.py
index c5637848d5..20b75b6108 100644
--- a/tests/server/token_classification/test_api.py
+++ b/tests/server/token_classification/test_api.py
@@ -18,7 +18,7 @@
from rubrix.server.apis.v0.models.commons.model import BulkResponse, SortableField
from rubrix.server.apis.v0.models.token_classification import (
- TokenClassificationBulkData,
+ TokenClassificationBulkRequest,
TokenClassificationQuery,
TokenClassificationRecord,
TokenClassificationSearchRequest,
@@ -41,7 +41,7 @@ def test_load_as_different_task(mocked_client):
]
mocked_client.post(
f"/api/datasets/{dataset}/TokenClassification:bulk",
- json=TokenClassificationBulkData(
+ json=TokenClassificationBulkRequest(
tags={"env": "test", "class": "text classification"},
metadata={"config": {"the": "config"}},
records=records,
@@ -80,7 +80,7 @@ def test_search_special_characters(mocked_client):
]
mocked_client.post(
f"/api/datasets/{dataset}/TokenClassification:bulk",
- json=TokenClassificationBulkData(
+ json=TokenClassificationBulkRequest(
tags={"env": "test", "class": "text classification"},
metadata={"config": {"the": "config"}},
records=records,
@@ -123,7 +123,7 @@ def test_some_sort(mocked_client):
]
mocked_client.post(
f"/api/datasets/{dataset}/TokenClassification:bulk",
- json=TokenClassificationBulkData(
+ json=TokenClassificationBulkRequest(
tags={"env": "test", "class": "text classification"},
metadata={"config": {"the": "config"}},
records=records,
@@ -143,11 +143,10 @@ def test_some_sort(mocked_client):
"code": "rubrix.api.errors::BadRequestError",
"params": {
"message": "Wrong sort id babba. Valid values are: "
- "['metadata', 'last_updated', 'score', "
- "'predicted', 'predicted_as', "
- "'predicted_by', 'annotated_as', "
- "'annotated_by', 'status', "
- "'event_timestamp']"
+ "['id', 'metadata', 'score', 'predicted', "
+ "'predicted_as', 'predicted_by', "
+ "'annotated_as', 'annotated_by', 'status', "
+ "'last_updated', 'event_timestamp']"
},
}
}
@@ -185,7 +184,7 @@ def test_create_records_for_token_classification(
response = mocked_client.post(
f"/api/datasets/{dataset}/TokenClassification:bulk",
- json=TokenClassificationBulkData(
+ json=TokenClassificationBulkRequest(
tags={"env": "test", "class": "text classification"},
metadata={"config": {"the": "config"}},
records=records,
@@ -267,7 +266,7 @@ def test_multiple_mentions_in_same_record(mocked_client):
]
response = mocked_client.post(
f"/api/datasets/{dataset}/TokenClassification:bulk",
- json=TokenClassificationBulkData(
+ json=TokenClassificationBulkRequest(
tags={"env": "test", "class": "text classification"},
metadata={"config": {"the": "config"}},
records=records,
@@ -295,7 +294,7 @@ def test_show_not_aggregable_metadata_fields(mocked_client):
response = mocked_client.post(
f"/api/datasets/{dataset}/TokenClassification:bulk",
- json=TokenClassificationBulkData(
+ json=TokenClassificationBulkRequest(
records=[
TokenClassificationRecord.parse_obj(
{
diff --git a/tests/server/token_classification/test_api_settings.py b/tests/server/token_classification/test_api_settings.py
index 9e05ada666..f517a45e76 100644
--- a/tests/server/token_classification/test_api_settings.py
+++ b/tests/server/token_classification/test_api_settings.py
@@ -1,5 +1,5 @@
import rubrix as rb
-from rubrix.server.apis.v0.models.commons.model import TaskType
+from rubrix.server.commons.models import TaskType
def create_dataset(client, name: str):
diff --git a/tests/server/token_classification/test_model.py b/tests/server/token_classification/test_model.py
index f40fba6d7e..3e8309fd3e 100644
--- a/tests/server/token_classification/test_model.py
+++ b/tests/server/token_classification/test_model.py
@@ -18,14 +18,16 @@
from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.apis.v0.models.token_classification import (
- CreationTokenClassificationRecord,
- EntitySpan,
- PredictionStatus,
TokenClassificationAnnotation,
TokenClassificationQuery,
TokenClassificationRecord,
)
-from rubrix.server.services.search.query_builder import EsQueryBuilder
+from rubrix.server.commons.models import PredictionStatus
+from rubrix.server.daos.backend.search.query_builder import EsQueryBuilder
+from rubrix.server.services.tasks.token_classification.model import (
+ EntitySpan,
+ ServiceTokenClassificationRecord,
+)
def test_char_position():
@@ -38,7 +40,7 @@ def test_char_position():
EntitySpan(start=1, end=1, label="label")
text = "I am Maxi"
- TokenClassificationRecord(
+ ServiceTokenClassificationRecord(
text=text,
tokens=text.split(),
prediction=TokenClassificationAnnotation(
@@ -53,7 +55,7 @@ def test_char_position():
def test_fix_substrings():
text = "On one ones o no"
- TokenClassificationRecord(
+ ServiceTokenClassificationRecord(
text=text,
tokens=text.split(),
prediction=TokenClassificationAnnotation(
@@ -68,7 +70,7 @@ def test_fix_substrings():
def test_entities_with_spaces():
text = "This is a great space"
- TokenClassificationRecord(
+ ServiceTokenClassificationRecord(
text=text,
tokens=["This", "is", " ", "a", " ", "great", " ", "space"],
prediction=TokenClassificationAnnotation(
@@ -102,7 +104,6 @@ def test_model_dict():
"agent": "test",
"entities": [{"end": 24, "label": "test", "score": 1.0, "start": 9}],
},
- "raw_text": text,
"text": text,
"tokens": tokens,
"status": "Default",
@@ -111,7 +112,7 @@ def test_model_dict():
def test_too_long_metadata():
text = "On one ones o no"
- record = TokenClassificationRecord.parse_obj(
+ record = ServiceTokenClassificationRecord.parse_obj(
{
"text": text,
"tokens": text.split(),
@@ -127,7 +128,7 @@ def test_entity_label_too_long():
with pytest.raises(
ValidationError, match="ensure this value has at most 128 character"
):
- TokenClassificationRecord(
+ ServiceTokenClassificationRecord(
text=text,
tokens=text.split(),
prediction=TokenClassificationAnnotation(
@@ -145,11 +146,11 @@ def test_entity_label_too_long():
def test_to_es_query():
query = TokenClassificationQuery(ids=[1, 2, 3])
- assert EsQueryBuilder.to_es_query(query) == {"ids": {"values": query.ids}}
+ assert EsQueryBuilder._to_es_query(query) == {"ids": {"values": query.ids}}
def test_misaligned_entity_mentions_with_spaces_left():
- assert TokenClassificationRecord(
+ assert ServiceTokenClassificationRecord(
text="according to analysts.\n Dart Group Corp was not",
tokens=[
"according",
@@ -172,7 +173,7 @@ def test_misaligned_entity_mentions_with_spaces_left():
def test_misaligned_entity_mentions_with_spaces_right():
- assert TokenClassificationRecord(
+ assert ServiceTokenClassificationRecord(
text="\nvs 9.91 billion\n Note\n REUTER\n",
tokens=["\n", "vs", "9.91", "billion", "\n ", "Note", "\n ", "REUTER", "\n"],
annotation=TokenClassificationAnnotation(
@@ -184,7 +185,7 @@ def test_misaligned_entity_mentions_with_spaces_right():
def test_custom_tokens_splitting():
- TokenClassificationRecord(
+ ServiceTokenClassificationRecord(
text="ThisisMr.Bean, a character playedby actor RowanAtkinson",
tokens=[
"This",
@@ -211,7 +212,7 @@ def test_custom_tokens_splitting():
def test_record_scores():
- record = TokenClassificationRecord(
+ record = ServiceTokenClassificationRecord(
text="\nvs 9.91 billion\n Note\n REUTER\n",
tokens=["\n", "vs", "9.91", "billion", "\n ", "Note", "\n ", "REUTER", "\n"],
prediction=TokenClassificationAnnotation(
@@ -228,7 +229,7 @@ def test_record_scores():
def test_annotated_without_entities():
text = "The text that i wrote"
- record = TokenClassificationRecord(
+ record = ServiceTokenClassificationRecord(
text=text,
tokens=text.split(),
prediction=TokenClassificationAnnotation(
@@ -245,8 +246,7 @@ def test_annotated_without_entities():
def test_adjust_spans():
text = "A text with some empty spaces that could bring not cleany annotated spans"
-
- record = TokenClassificationRecord(
+ record = ServiceTokenClassificationRecord(
text=text,
tokens=text.split(),
prediction=TokenClassificationAnnotation(
@@ -275,6 +275,7 @@ def test_adjust_spans():
EntitySpan(start=70, end=85, label="DET"),
]
+
def test_whitespace_in_tokens():
from spacy import load
@@ -291,7 +292,6 @@ def test_whitespace_in_tokens():
},
}
- record = CreationTokenClassificationRecord.parse_obj(record)
+ record = ServiceTokenClassificationRecord.parse_obj(record)
assert record
assert record.tokens == ["every", "four", "(", "4", ")", " "]
-