Skip to content

Commit

Permalink
refactor(server): moving common base models to services layer (#1598)
Browse files Browse the repository at this point in the history
Also update module imports
  • Loading branch information
frascuchon committed Jul 4, 2022
1 parent ec03b43 commit ec6104d
Show file tree
Hide file tree
Showing 10 changed files with 281 additions and 329 deletions.
315 changes: 34 additions & 281 deletions src/rubrix/server/apis/v0/models/commons/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,70 +18,33 @@
"""

from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from uuid import uuid4
from typing import Any, Dict, Generic, TypeVar

from fastapi import Query
from pydantic import BaseModel, Field, validator
from pydantic import validator
from pydantic.generics import GenericModel

from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.apis.v0.helpers import flatten_dict
from rubrix.server.services.search.model import (
BaseSearchResults,
BaseSearchResultsAggregations,
QueryRange,
SortableField,
)
from rubrix.server.services.tasks.commons import (
Annotation,
BaseAnnotation,
BaseRecordDB,
BulkResponse,
EsRecordDataFieldNames,
PredictionStatus,
TaskStatus,
TaskType,
)
from rubrix.utils import limit_value_length


class EsRecordDataFieldNames(str, Enum):
"""Common elasticsearch field names"""

predicted_as = "predicted_as"
annotated_as = "annotated_as"
annotated_by = "annotated_by"
predicted_by = "predicted_by"
status = "status"
predicted = "predicted"
score = "score"
words = "words"
event_timestamp = "event_timestamp"
last_updated = "last_updated"

def __str__(self):
return self.value


class SortOrder(str, Enum):
asc = "asc"
desc = "desc"


class SortableField(BaseModel):
"""Sortable field structure"""

id: str
order: SortOrder = SortOrder.asc


class BulkResponse(BaseModel):
"""
Data info for bulk results
Attributes
----------
dataset:
The dataset name
processed:
Number of records in bulk
failed:
Number of failed records
"""

dataset: str
processed: int
failed: int = 0


@dataclass
class PaginationParams:
"""Query pagination params"""
Expand All @@ -92,70 +55,7 @@ class PaginationParams:
)


class BaseAnnotation(BaseModel):
"""
Annotation class base
Attributes:
-----------
agent:
Which agent or component makes the annotation. We should find model annotations, user annotations,
or some other human-supervised automatic process.
"""

agent: str = Field(max_length=64)


class TaskType(str, Enum):
"""
The available task types:
**text_classification**, for text classification tasks
**token_classification**, for token classification tasks
"""

text_classification = "TextClassification"
token_classification = "TokenClassification"
text2text = "Text2Text"
multi_task_text_token_classification = "MultitaskTextTokenClassification"


class TaskStatus(str, Enum):
"""
Task data status:
**Default**, default status, for no provided status records.
**Edited**, normally used when original annotation was modified but not yet validated (confirmed).
**Discarded**, for records that will be excluded for analysis.
**Validated**, when annotation was confirmed as ok.
"""

default = "Default"
edited = "Edited" # TODO: DEPRECATE
discarded = "Discarded"
validated = "Validated"


class PredictionStatus(str, Enum):
"""
The prediction status:
**OK**, for record containing a success prediction
**KO**, for record containing a wrong prediction
"""

OK = "ok"
KO = "ko"


Annotation = TypeVar("Annotation", bound=BaseAnnotation)


class BaseRecord(GenericModel, Generic[Annotation]):
class BaseRecord(BaseRecordDB, GenericModel, Generic[Annotation]):
"""
Minimal dataset record information
Expand All @@ -171,28 +71,6 @@ class BaseRecord(GenericModel, Generic[Annotation]):
"""

id: Optional[Union[int, str]] = Field(None)
metadata: Dict[str, Any] = Field(default=None)
event_timestamp: Optional[datetime] = None
status: Optional[TaskStatus] = None
prediction: Optional[Annotation] = None
annotation: Optional[Annotation] = None
metrics: Dict[str, Any] = Field(default_factory=dict)
search_keywords: Optional[List[str]] = None

@validator("search_keywords")
def remove_duplicated_keywords(cls, value) -> List[str]:
"""Remove duplicated keywords"""
if value:
return list(set(value))

@validator("id", always=True)
def default_id_if_none_provided(cls, id: Optional[str]) -> str:
"""Validates id info and sets a random uuid if not provided"""
if id is None:
return str(uuid4())
return id

@validator("metadata", pre=True)
def flatten_metadata(cls, metadata: Dict[str, Any]):
"""
Expand All @@ -213,149 +91,24 @@ def flatten_metadata(cls, metadata: Dict[str, Any]):
metadata = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH)
return metadata

@validator("status", always=True)
def fill_default_value(cls, status: TaskStatus):
"""Fastapi validator for set default task status"""
return TaskStatus.default if status is None else status

@classmethod
def task(cls) -> TaskType:
"""The task type related to this task info"""
raise NotImplementedError

@property
def predicted(self) -> Optional[PredictionStatus]:
"""The task record prediction status (if any)"""
return None

@property
def predicted_as(self) -> Optional[List[str]]:
"""Predictions strings representation"""
return None

@property
def annotated_as(self) -> Optional[List[str]]:
"""Annotations strings representation"""
return None

@property
def scores(self) -> Optional[List[float]]:
"""Prediction scores"""
return None

def all_text(self) -> str:
"""All textual information related to record"""
raise NotImplementedError

@property
def predicted_by(self) -> List[str]:
"""The prediction agents"""
if self.prediction:
return [self.prediction.agent]
return []

@property
def annotated_by(self) -> List[str]:
"""The annotation agents"""
if self.annotation:
return [self.annotation.agent]
return []

def extended_fields(self) -> Dict[str, Any]:
"""
Used for extends fields to store in db. Tasks that would include extra
properties than commons (predicted, annotated_as,....) could implement
this method.
"""
return {
EsRecordDataFieldNames.predicted: self.predicted,
EsRecordDataFieldNames.annotated_as: self.annotated_as,
EsRecordDataFieldNames.predicted_as: self.predicted_as,
EsRecordDataFieldNames.annotated_by: self.annotated_by,
EsRecordDataFieldNames.predicted_by: self.predicted_by,
EsRecordDataFieldNames.score: self.scores,
}

def dict(self, *args, **kwargs) -> "DictStrAny":
"""
Extends base component dict extending object properties
and user defined extended fields
"""
return {
**super().dict(*args, **kwargs),
**self.extended_fields(),
}


class BaseSearchResultsAggregations(BaseModel):

"""
API for result aggregations
Attributes:
-----------
predicted_as: Dict[str, int]
Occurrence info about more relevant predicted terms
annotated_as: Dict[str, int]
Occurrence info about more relevant annotated terms
annotated_by: Dict[str, int]
Occurrence info about more relevant annotation agent terms
predicted_by: Dict[str, int]
Occurrence info about more relevant prediction agent terms
status: Dict[str, int]
Occurrence info about task status
predicted: Dict[str, int]
Occurrence info about task prediction status
words: Dict[str, int]
The word cloud aggregations
metadata: Dict[str, Dict[str, Any]]
The metadata fields aggregations
"""

predicted_as: Dict[str, int] = Field(default_factory=dict)
annotated_as: Dict[str, int] = Field(default_factory=dict)
annotated_by: Dict[str, int] = Field(default_factory=dict)
predicted_by: Dict[str, int] = Field(default_factory=dict)
status: Dict[str, int] = Field(default_factory=dict)
predicted: Dict[str, int] = Field(default_factory=dict)
score: Dict[str, int] = Field(default_factory=dict)
words: Dict[str, int] = Field(default_factory=dict)
metadata: Dict[str, Dict[str, Any]] = Field(default_factory=dict)


Record = TypeVar("Record", bound=BaseRecord)
Aggregations = TypeVar("Aggregations", bound=BaseSearchResultsAggregations)


class BaseSearchResults(GenericModel, Generic[Record, Aggregations]):
"""
API search results
Attributes:
-----------
total:
The total number of records
records:
The selected records to return
aggregations:
Requested aggregations
"""

total: int = 0
records: List[Record] = Field(default_factory=list)
aggregations: Aggregations = None


class QueryRange(BaseModel):
"""Score range filter"""

range_from: float = Field(default=0.0, alias="from")
range_to: float = Field(default=None, alias="to")

class Config:
allow_population_by_field_name = True


class ScoreRange(QueryRange):
pass


__ALL__ = [
QueryRange,
SortableField,
BaseSearchResults,
BaseSearchResultsAggregations,
Annotation,
TaskStatus,
TaskType,
EsRecordDataFieldNames,
BaseAnnotation,
PredictionStatus,
BulkResponse,
]
2 changes: 1 addition & 1 deletion src/rubrix/server/services/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from rubrix.server.security.model import User

Dataset = TypeVar("Dataset", bound=DatasetDB)
Dataset = TypeVar("Dataset", bound=BaseDatasetDB)


class SVCDatasetSettings(SettingsDB):
Expand Down

0 comments on commit ec6104d

Please sign in to comment.