/
token_classification.py
102 lines (77 loc) · 3.39 KB
/
token_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# coding=utf-8
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, root_validator, validator
from rubrix.server.apis.v0.models.commons.model import (
BaseRecord,
BaseRecordInputs,
BaseSearchResults,
ScoreRange,
)
from rubrix.server.apis.v0.models.datasets import UpdateDatasetRequest
from rubrix.server.commons.models import PredictionStatus
from rubrix.server.daos.backend.search.model import SortableField
from rubrix.server.services.search.model import (
ServiceBaseRecordsQuery,
ServiceBaseSearchResultsAggregations,
)
from rubrix.server.services.tasks.token_classification.model import (
ServiceTokenClassificationAnnotation as _TokenClassificationAnnotation,
)
from rubrix.server.services.tasks.token_classification.model import (
ServiceTokenClassificationDataset,
)
from rubrix.utils import SpanUtils
class TokenClassificationAnnotation(_TokenClassificationAnnotation):
pass
class TokenClassificationRecordInputs(BaseRecordInputs[TokenClassificationAnnotation]):
text: str = Field()
tokens: List[str] = Field(min_items=1)
# TODO(@frascuchon): Delete this field and all related logic
_raw_text: Optional[str] = Field(alias="raw_text")
@root_validator(pre=True)
def accept_old_fashion_text_field(cls, values):
text, raw_text = values.get("text"), values.get("raw_text")
text = text or raw_text
values["text"] = cls.check_text_content(text)
return values
@validator("text")
def check_text_content(cls, text: str):
assert text and text.strip(), "No text or empty text provided"
return text
class TokenClassificationRecord(
TokenClassificationRecordInputs, BaseRecord[TokenClassificationAnnotation]
):
pass
class TokenClassificationBulkRequest(UpdateDatasetRequest):
records: List[TokenClassificationRecordInputs]
class TokenClassificationQuery(ServiceBaseRecordsQuery):
predicted_as: List[str] = Field(default_factory=list)
annotated_as: List[str] = Field(default_factory=list)
score: Optional[ScoreRange] = Field(default=None)
predicted: Optional[PredictionStatus] = Field(default=None, nullable=True)
class TokenClassificationSearchRequest(BaseModel):
query: TokenClassificationQuery = Field(default_factory=TokenClassificationQuery)
sort: List[SortableField] = Field(default_factory=list)
class TokenClassificationAggregations(ServiceBaseSearchResultsAggregations):
predicted_mentions: Dict[str, Dict[str, int]] = Field(default_factory=dict)
mentions: Dict[str, Dict[str, int]] = Field(default_factory=dict)
class TokenClassificationSearchResults(
BaseSearchResults[TokenClassificationRecord, TokenClassificationAggregations]
):
pass
class TokenClassificationDataset(ServiceTokenClassificationDataset):
pass