Skip to content

Commit

Permalink
fix(NER): create record annotation from tags (also in DatasetForToken…
Browse files Browse the repository at this point in the history
…Classification.from_datasets) (#1283)

* fix(ner): build record annotation from tags

* fix(ner): parse tags in from_datasets method
  • Loading branch information
frascuchon committed Mar 21, 2022
1 parent cc34c32 commit 65da06f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
27 changes: 20 additions & 7 deletions src/rubrix/client/datasets.py
Expand Up @@ -66,6 +66,14 @@ class DatasetBase:

_RECORD_TYPE = None

@classmethod
def _record_init_args(cls) -> List[str]:
"""
Helper the returns the field list available for creation of inner records.
The ``_RECORD_TYPE.__fields__`` will be returned as default
"""
return [field for field in cls._RECORD_TYPE.__fields__]

def __init__(self, records: Optional[List[Record]] = None):
if self._RECORD_TYPE is None:
raise NotImplementedError(
Expand Down Expand Up @@ -185,9 +193,7 @@ def from_datasets(
)

not_supported_columns = [
col
for col in dataset.column_names
if col not in cls._RECORD_TYPE.__fields__
col for col in dataset.column_names if col not in cls._record_init_args()
]
if not_supported_columns:
_LOGGER.warning(
Expand Down Expand Up @@ -251,11 +257,12 @@ def from_pandas(cls, dataframe: pd.DataFrame) -> "Dataset":
The imported records in a Rubrix Dataset.
"""
not_supported_columns = [
col for col in dataframe.columns if col not in cls._RECORD_TYPE.__fields__
col for col in dataframe.columns if col not in cls._record_init_args()
]
if not_supported_columns:
_LOGGER.warning(
f"Following columns are not supported by the {cls._RECORD_TYPE.__name__} model and are ignored: {not_supported_columns}"
f"Following columns are not supported by the {cls._RECORD_TYPE.__name__} model "
f"and are ignored: {not_supported_columns}"
)
dataframe = dataframe.drop(columns=not_supported_columns)

Expand Down Expand Up @@ -638,6 +645,12 @@ class DatasetForTokenClassification(DatasetBase):

_RECORD_TYPE = TokenClassificationRecord

@classmethod
def _record_init_args(cls) -> List[str]:
"""Adds the `tags` argument to default record init arguments"""
parent_fields = super(DatasetForTokenClassification, cls)._record_init_args()
return parent_fields + ["tags"] # compute annotation from tags

def __init__(self, records: Optional[List[TokenClassificationRecord]] = None):
# we implement this to have more specific type hints
super().__init__(records=records)
Expand Down Expand Up @@ -871,8 +884,8 @@ def _parse_tags_field(
import datasets

labels = dataset.features[field]
if isinstance(labels, list):
labels = labels[0]
if isinstance(labels, datasets.Sequence):
labels = labels.feature
int2str = (
labels.int2str if isinstance(labels, datasets.ClassLabel) else lambda x: x
)
Expand Down
19 changes: 10 additions & 9 deletions src/rubrix/client/models.py
Expand Up @@ -331,15 +331,16 @@ def __tags2entities__(self, tags: List[str]) -> List[Tuple[str, int, int]]:
entities = []
while idx < len(tags):
tag = tags[idx]
prefix, entity = tag.split("-")
if tag == "B":
char_start, char_end = self.token_span(token_idx=idx)
entities.append(
{"entity": entity, "start": char_start, "end": char_end}
)
elif prefix in ["I", "L"]:
_, char_end = self.token_span(token_idx=idx)
entities[-1]["end"] = char_end
if tag != "O":
prefix, entity = tag.split("-")
if prefix == "B":
char_start, char_end = self.token_span(token_idx=idx)
entities.append(
{"entity": entity, "start": char_start, "end": char_end}
)
elif prefix in ["I", "L"]:
_, char_end = self.token_span(token_idx=idx)
entities[-1]["end"] = char_end
idx += 1
return [(value["entity"], value["start"], value["end"]) for value in entities]

Expand Down
15 changes: 15 additions & 0 deletions tests/client/test_models.py
Expand Up @@ -114,6 +114,21 @@ def test_token_classification_record(annotation, status, expected_status, expect
assert record.spans2iob(record.annotation) == expected_iob


@pytest.mark.parametrize(
("tokens", "tags", "annotation"),
[
(["Una", "casa"], ["O", "B-OBJ"], [("OBJ", 4, 7)]),
(["Matias", "Aguado"], ["B-PER", "I-PER"], [("PER", 0, 12)]),
(["Todo", "Todo", "Todo"], ["B-T", "I-T", "L-T"], [("T", 0, 13)]),
(["Una", "casa"], ["O", "U-OBJ"], []),
],
)
def test_token_classification_with_tokens_and_tags(tokens, tags, annotation):
record = TokenClassificationRecord(tokens=tokens, tags=tags)
assert record.annotation is not None
assert record.annotation == annotation


def test_token_classification_validations():
with pytest.raises(
AssertionError,
Expand Down

0 comments on commit 65da06f

Please sign in to comment.