Skip to content

Commit

Permalink
feat(datasets): simplify load flow from hf datasets with no rb format (
Browse files Browse the repository at this point in the history
…#1234)

* fix: optional search_keywords

* feat(datasets): simplify load flow from hf datasetswith no rb format

* feat(token-class): allow create record with tags list

* feat: mapping shortcut

* chore: adjust datasets mappings

* chore: better messages

* feat: parse shorcut for text2text

* test: skip dataset

* refactor: build text from tokens if possible for NER records

* test: fix tests
  • Loading branch information
frascuchon committed Mar 14, 2022
1 parent e217d31 commit a64476b
Show file tree
Hide file tree
Showing 4 changed files with 435 additions and 20 deletions.
229 changes: 216 additions & 13 deletions src/rubrix/client/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,44 @@ def _to_datasets_dict(self) -> Dict:
raise NotImplementedError

@classmethod
def from_datasets(cls, dataset: "datasets.Dataset") -> "Dataset":
def from_datasets(
cls,
dataset: "datasets.Dataset",
id: Optional[str] = None,
text: Optional[str] = None,
annotation: Optional[str] = None,
metadata: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> "Dataset":
"""Imports records from a `datasets.Dataset`.
Columns that are not supported are ignored.
Args:
dataset: A datasets Dataset from which to import the records.
id: The field name used as record id. Default: `None`
text: The field name used as record text. Default: `None`
annotation: The field name used as record annotation. Default: `None`
metadata: The field name used as record metadata. Default: `None`
Returns:
The imported records in a Rubrix Dataset.
"""
import datasets

assert not isinstance(dataset, datasets.DatasetDict), (
"ERROR: `datasets.DatasetDict` are not supported. "
"Please, select the dataset split before"
)

dataset = cls._prepare_hf_dataset(
dataset,
id=id,
text=text,
annotation=annotation,
metadata=metadata,
**kwargs,
)

not_supported_columns = [
col
Expand All @@ -164,12 +191,31 @@ def from_datasets(cls, dataset: "datasets.Dataset") -> "Dataset":
]
if not_supported_columns:
_LOGGER.warning(
f"Following columns are not supported by the {cls._RECORD_TYPE.__name__} model and are ignored: {not_supported_columns}"
f"Following columns are not supported by the {cls._RECORD_TYPE.__name__}"
f" model and are ignored: {not_supported_columns}"
)
dataset = dataset.remove_columns(not_supported_columns)

return cls._from_datasets(dataset)

@classmethod
def _prepare_hf_dataset(
cls,
dataset: "dataset.Dataset",
id: Optional[str] = None,
text: Optional[str] = None,
annotation: Optional[str] = None,
metadata: Optional[Union[str, List[str]]] = None,
) -> "dataclasses.Dataset":
for field, parser in [
(id, cls._parse_id_field),
(text, cls._parse_text_field),
(metadata, cls._parse_metadata_field),
(annotation, cls._parse_annotation_field),
]:
if field:
dataset = parser(dataset, field)
return dataset

@classmethod
def _from_datasets(cls, dataset: "datasets.Dataset") -> "Dataset":
"""Helper method to create a Rubrix Dataset from a datasets Dataset.
Expand Down Expand Up @@ -241,6 +287,37 @@ def prepare_for_training(self, **kwargs) -> "datasets.Dataset":
"""
raise NotImplementedError

@classmethod
def _parse_id_field(
cls, dataset: "datasets.Dataset", field: str
) -> "datasets.Dataset":
return dataset.rename_column(field, "id")

@classmethod
def _parse_text_field(
cls, dataset: "datasets.Dataset", field: str
) -> "datasets.Dataset":
return dataset.rename_column(field, "text")

@classmethod
def _parse_metadata_field(
cls, dataset: "datasets.Dataset", fields: Union[str, List[str]]
) -> "datasets.Dataset":

if isinstance(fields, str):
fields = [fields]

def parse_metadata_from_dataset(example):
return {"metadata": {k: example[k] for k in fields}}

return dataset.map(parse_metadata_from_dataset).remove_columns(fields)

@classmethod
def _parse_annotation_field(
cls, dataset: "datasets.Dataset", field: str
) -> "datasets.Dataset":
return dataset.rename_column(field, "annotation")


def _prepend_docstring(record_type: Type[Record]):
docstring = f"""This Dataset contains {record_type.__name__} records.
Expand Down Expand Up @@ -301,13 +378,21 @@ def from_datasets(
# we implement this to have more specific type hints
cls,
dataset: "datasets.Dataset",
id: Optional[str] = None,
inputs: Optional[Union[str, List[str]]] = None,
annotation: Optional[str] = None,
metadata: Optional[Union[str, List[str]]] = None,
) -> "DatasetForTextClassification":
"""Imports records from a `datasets.Dataset`.
Columns that are not supported are ignored.
Args:
dataset: A datasets Dataset from which to import the records.
id: The field name used as record id. Default: `None`
inputs: A list of field names used for record inputs. Default: `None`
annotation: The field name used as record annotation. Default: `None`
metadata: The field name used as record metadata. Default: `None`
Returns:
The imported records in a Rubrix Dataset.
Expand All @@ -322,7 +407,10 @@ def from_datasets(
... })
>>> DatasetForTextClassification.from_datasets(ds)
"""
return super().from_datasets(dataset)

return super().from_datasets(
dataset, id=id, annotation=annotation, metadata=metadata, inputs=inputs
)

@classmethod
def from_pandas(
Expand Down Expand Up @@ -364,10 +452,41 @@ def _to_datasets_dict(self) -> Dict:

return ds_dict

@classmethod
def _parse_annotation_field(
cls, dataset: "datasets.Dataset", field: str
) -> "datasets.Dataset":
import datasets

labels = dataset.features[field]
if isinstance(labels, datasets.Sequence):
labels = labels.feature
int2str = (
labels.int2str if isinstance(labels, datasets.ClassLabel) else lambda x: x
)

def parse_annotation(example):
return {"annotation": int2str(example["annotation"])}

return dataset.rename_column(field, "annotation").map(parse_annotation)

@classmethod
def _prepare_hf_dataset(
cls,
dataset: "dataset.Dataset",
inputs: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> "dataclasses.Dataset":
dataset = super()._prepare_hf_dataset(dataset, **kwargs)
if inputs:
dataset = cls._parse_inputs_field(dataset, fields=inputs)
return dataset

@classmethod
def _from_datasets(
cls, dataset: "datasets.Dataset"
) -> "DatasetForTextClassification":

records = []
for row in dataset:
if row.get("inputs") and isinstance(row["inputs"], dict):
Expand Down Expand Up @@ -399,10 +518,25 @@ def _from_datasets(
else None
)

records.append(TextClassificationRecord(**row))

records.append(TextClassificationRecord.parse_obj(row))
return cls(records)

@classmethod
def _parse_inputs_field(
cls, dataset: "datasets.Dataset", fields: Optional[Union[str, List[str]]]
) -> "datasets.Dataset":
if isinstance(fields, str):
fields = [fields]

def parse_inputs_from_dataset(example):
return {
"inputs": example[fields[0]]
if len(fields) == 1
else {k: example[k] for k in fields}
}

return dataset.map(parse_inputs_from_dataset)

@classmethod
def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTextClassification":
return cls(
Expand Down Expand Up @@ -504,14 +638,23 @@ def __init__(self, records: Optional[List[TokenClassificationRecord]] = None):

@classmethod
def from_datasets(
cls, dataset: "datasets.Dataset"
cls,
dataset: "datasets.Dataset",
text: Optional[str] = None,
tokens: Optional[str] = None,
tags: Optional[str] = None,
metadata: Optional[Union[str, List[str]]] = None,
) -> "DatasetForTokenClassification":
"""Imports records from a `datasets.Dataset`.
Columns that are not supported are ignored.
Args:
dataset: A datasets Dataset from which to import the records.
text: The field name used as record text. Default: `None`
tokens: The field name used as record tokens. Default: `None`
tags: The field name used as record tags. Default: `None`
metadata: The field name used as record metadata. Default: `None`
Returns:
The imported records in a Rubrix Dataset.
Expand All @@ -528,7 +671,9 @@ def from_datasets(
>>> DatasetForTokenClassification.from_datasets(ds)
"""
# we implement this to have more specific type hints
return super().from_datasets(dataset)
return super().from_datasets(
dataset, text=text, tokens=tokens, tags=tags, metadata=metadata
)

@classmethod
def from_pandas(
Expand Down Expand Up @@ -669,9 +814,25 @@ def __entities_to_tuple__(
for ent in entities
]

@classmethod
def _prepare_hf_dataset(
cls,
dataset: "dataset.Dataset",
tokens: Optional[str] = None,
tags: Optional[str] = None,
**kwargs,
) -> "dataclasses.Dataset":
dataset = super()._prepare_hf_dataset(dataset, **kwargs)
if tokens:
dataset = cls._parse_tokens_field(dataset, field=tokens)
if tags:
dataset = cls._parse_tags_field(dataset, field=tags)
return dataset

@classmethod
def _from_datasets(
cls, dataset: "datasets.Dataset"
cls,
dataset: "datasets.Dataset",
) -> "DatasetForTokenClassification":

records = []
Expand All @@ -680,10 +841,40 @@ def _from_datasets(
row["prediction"] = cls.__entities_to_tuple__(row["prediction"])
if row.get("annotation"):
row["annotation"] = cls.__entities_to_tuple__(row["annotation"])
records.append(TokenClassificationRecord.parse_obj(row))
return cls(records)

records.append(TokenClassificationRecord(**row))
@classmethod
def _parse_tokens_field(
cls, dataset: "datasets.Dataset", field: str
) -> "datasets.Dataset":
def parse_tokens_from_example(example):
tokens: List[str] = example[field]
data = {"tokens": tokens}

return cls(records)
if "text" not in example:
data["text"] = " ".join(tokens)
return data

return dataset.map(parse_tokens_from_example)

@classmethod
def _parse_tags_field(
cls, dataset: "datasets.Dataset", field: str = str
) -> "datasets.Dataset":
import datasets

labels = dataset.features[field]
if isinstance(labels, list):
labels = labels[0]
int2str = (
labels.int2str if isinstance(labels, datasets.ClassLabel) else lambda x: x
)

def parse_tags_from_example(example):
return {"tags": [int2str(t) for t in example[field] or []]}

return dataset.map(parse_tags_from_example)

@classmethod
def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTokenClassification":
Expand Down Expand Up @@ -726,13 +917,22 @@ def __init__(self, records: Optional[List[Text2TextRecord]] = None):
super().__init__(records=records)

@classmethod
def from_datasets(cls, dataset: "datasets.Dataset") -> "DatasetForText2Text":
def from_datasets(
cls,
dataset: "datasets.Dataset",
text: Optional[str] = None,
annotation: Optional[str] = None,
metadata: Optional[Union[str, List[str]]] = None,
) -> "DatasetForText2Text":
"""Imports records from a `datasets.Dataset`.
Columns that are not supported are ignored.
Args:
dataset: A datasets Dataset from which to import the records.
text: The field name used as record text. Default: `None`
annotation: The field name used as record annotation. Default: `None`
metadata: The field name used as record metadata. Default: `None`
Returns:
The imported records in a Rubrix Dataset.
Expand All @@ -750,8 +950,11 @@ def from_datasets(cls, dataset: "datasets.Dataset") -> "DatasetForText2Text":
... })
>>> DatasetForText2Text.from_datasets(ds)
"""

# we implement this to have more specific type hints
return super().from_datasets(dataset)
return super().from_datasets(
dataset, text=text, annotation=annotation, metadata=metadata
)

@classmethod
def from_pandas(
Expand Down

0 comments on commit a64476b

Please sign in to comment.