diff --git a/src/rubrix/client/datasets.py b/src/rubrix/client/datasets.py index bed1913c32..c892dd7311 100644 --- a/src/rubrix/client/datasets.py +++ b/src/rubrix/client/datasets.py @@ -66,6 +66,14 @@ class DatasetBase: _RECORD_TYPE = None + @classmethod + def _record_init_args(cls) -> List[str]: + """ + Helper the returns the field list available for creation of inner records. + The ``_RECORD_TYPE.__fields__`` will be returned as default + """ + return [field for field in cls._RECORD_TYPE.__fields__] + def __init__(self, records: Optional[List[Record]] = None): if self._RECORD_TYPE is None: raise NotImplementedError( @@ -185,9 +193,7 @@ def from_datasets( ) not_supported_columns = [ - col - for col in dataset.column_names - if col not in cls._RECORD_TYPE.__fields__ + col for col in dataset.column_names if col not in cls._record_init_args() ] if not_supported_columns: _LOGGER.warning( @@ -251,11 +257,12 @@ def from_pandas(cls, dataframe: pd.DataFrame) -> "Dataset": The imported records in a Rubrix Dataset. """ not_supported_columns = [ - col for col in dataframe.columns if col not in cls._RECORD_TYPE.__fields__ + col for col in dataframe.columns if col not in cls._record_init_args() ] if not_supported_columns: _LOGGER.warning( - f"Following columns are not supported by the {cls._RECORD_TYPE.__name__} model and are ignored: {not_supported_columns}" + f"Following columns are not supported by the {cls._RECORD_TYPE.__name__} model " + f"and are ignored: {not_supported_columns}" ) dataframe = dataframe.drop(columns=not_supported_columns) @@ -638,6 +645,12 @@ class DatasetForTokenClassification(DatasetBase): _RECORD_TYPE = TokenClassificationRecord + @classmethod + def _record_init_args(cls) -> List[str]: + """Adds the `tags` argument to default record init arguments""" + parent_fields = super(DatasetForTokenClassification, cls)._record_init_args() + return parent_fields + ["tags"] # compute annotation from tags + def __init__(self, records: Optional[List[TokenClassificationRecord]] = None): # we implement this to have more specific type hints super().__init__(records=records) @@ -871,8 +884,8 @@ def _parse_tags_field( import datasets labels = dataset.features[field] - if isinstance(labels, list): - labels = labels[0] + if isinstance(labels, datasets.Sequence): + labels = labels.feature int2str = ( labels.int2str if isinstance(labels, datasets.ClassLabel) else lambda x: x ) diff --git a/src/rubrix/client/models.py b/src/rubrix/client/models.py index 77e8a3c7e6..0e3e749c3c 100644 --- a/src/rubrix/client/models.py +++ b/src/rubrix/client/models.py @@ -331,15 +331,16 @@ def __tags2entities__(self, tags: List[str]) -> List[Tuple[str, int, int]]: entities = [] while idx < len(tags): tag = tags[idx] - prefix, entity = tag.split("-") - if tag == "B": - char_start, char_end = self.token_span(token_idx=idx) - entities.append( - {"entity": entity, "start": char_start, "end": char_end} - ) - elif prefix in ["I", "L"]: - _, char_end = self.token_span(token_idx=idx) - entities[-1]["end"] = char_end + if tag != "O": + prefix, entity = tag.split("-") + if prefix == "B": + char_start, char_end = self.token_span(token_idx=idx) + entities.append( + {"entity": entity, "start": char_start, "end": char_end} + ) + elif prefix in ["I", "L"]: + _, char_end = self.token_span(token_idx=idx) + entities[-1]["end"] = char_end idx += 1 return [(value["entity"], value["start"], value["end"]) for value in entities] diff --git a/tests/client/test_models.py b/tests/client/test_models.py index 3f09fee969..2363742c6a 100644 --- a/tests/client/test_models.py +++ b/tests/client/test_models.py @@ -114,6 +114,21 @@ def test_token_classification_record(annotation, status, expected_status, expect assert record.spans2iob(record.annotation) == expected_iob +@pytest.mark.parametrize( + ("tokens", "tags", "annotation"), + [ + (["Una", "casa"], ["O", "B-OBJ"], [("OBJ", 4, 7)]), + (["Matias", "Aguado"], ["B-PER", "I-PER"], [("PER", 0, 12)]), + (["Todo", "Todo", "Todo"], ["B-T", "I-T", "L-T"], [("T", 0, 13)]), + (["Una", "casa"], ["O", "U-OBJ"], []), + ], +) +def test_token_classification_with_tokens_and_tags(tokens, tags, annotation): + record = TokenClassificationRecord(tokens=tokens, tags=tags) + assert record.annotation is not None + assert record.annotation == annotation + + def test_token_classification_validations(): with pytest.raises( AssertionError,