From adcf1b14306c226806187de9506a5d7cc276b072 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 25 Mar 2022 13:19:34 +0100 Subject: [PATCH] fix(NER): create record annotation from tags (also in from_datasets) (#1283) * fix(ner): build record annotation from tags * fix(ner): parse tags in from_datasets method (cherry picked from commit 65da06fd6d1e4d1b2f4da28436068ef0e32e38fa) --- src/rubrix/client/datasets.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/rubrix/client/datasets.py b/src/rubrix/client/datasets.py index bed1913c32..c892dd7311 100644 --- a/src/rubrix/client/datasets.py +++ b/src/rubrix/client/datasets.py @@ -66,6 +66,14 @@ class DatasetBase: _RECORD_TYPE = None + @classmethod + def _record_init_args(cls) -> List[str]: + """ + Helper the returns the field list available for creation of inner records. + The ``_RECORD_TYPE.__fields__`` will be returned as default + """ + return [field for field in cls._RECORD_TYPE.__fields__] + def __init__(self, records: Optional[List[Record]] = None): if self._RECORD_TYPE is None: raise NotImplementedError( @@ -185,9 +193,7 @@ def from_datasets( ) not_supported_columns = [ - col - for col in dataset.column_names - if col not in cls._RECORD_TYPE.__fields__ + col for col in dataset.column_names if col not in cls._record_init_args() ] if not_supported_columns: _LOGGER.warning( @@ -251,11 +257,12 @@ def from_pandas(cls, dataframe: pd.DataFrame) -> "Dataset": The imported records in a Rubrix Dataset. """ not_supported_columns = [ - col for col in dataframe.columns if col not in cls._RECORD_TYPE.__fields__ + col for col in dataframe.columns if col not in cls._record_init_args() ] if not_supported_columns: _LOGGER.warning( - f"Following columns are not supported by the {cls._RECORD_TYPE.__name__} model and are ignored: {not_supported_columns}" + f"Following columns are not supported by the {cls._RECORD_TYPE.__name__} model " + f"and are ignored: {not_supported_columns}" ) dataframe = dataframe.drop(columns=not_supported_columns) @@ -638,6 +645,12 @@ class DatasetForTokenClassification(DatasetBase): _RECORD_TYPE = TokenClassificationRecord + @classmethod + def _record_init_args(cls) -> List[str]: + """Adds the `tags` argument to default record init arguments""" + parent_fields = super(DatasetForTokenClassification, cls)._record_init_args() + return parent_fields + ["tags"] # compute annotation from tags + def __init__(self, records: Optional[List[TokenClassificationRecord]] = None): # we implement this to have more specific type hints super().__init__(records=records) @@ -871,8 +884,8 @@ def _parse_tags_field( import datasets labels = dataset.features[field] - if isinstance(labels, list): - labels = labels[0] + if isinstance(labels, datasets.Sequence): + labels = labels.feature int2str = ( labels.int2str if isinstance(labels, datasets.ClassLabel) else lambda x: x )