diff --git a/.gitignore b/.gitignore index 6917232047e..2c0ee5edca6 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,7 @@ __pycache__ .coverage .pytest_cache/ .benchmarks +htmlcov/ # documentation build artifacts diff --git a/Makefile b/Makefile index c6c27887cda..fe28b8d4463 100644 --- a/Makefile +++ b/Makefile @@ -66,6 +66,13 @@ test-with-cov : --cov=$(SRC) \ --cov-report=xml +.PHONY : test-with-cov-html +test-with-cov-html : + pytest --color=yes -rf --durations=40 \ + --cov-config=.coveragerc \ + --cov=$(SRC) \ + --cov-report=html + .PHONY : gpu-test gpu-test : check-for-cuda pytest --color=yes -v -rf -m gpu diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py new file mode 100644 index 00000000000..7c98995af50 --- /dev/null +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -0,0 +1,337 @@ +from allennlp.data import DatasetReader, Token, Field, Tokenizer +from allennlp.data.fields import TextField, LabelField, ListField, TensorField +from allennlp.data.instance import Instance +from datasets import load_dataset, DatasetDict, list_datasets +from datasets.features import ( + ClassLabel, + Sequence, + Translation, + TranslationVariableLanguages, + Value, + FeatureType, +) + +import torch +from typing import Iterable, Optional, Dict, List, Union + + +@DatasetReader.register("huggingface-datasets") +class HuggingfaceDatasetReader(DatasetReader): + """ + Reads instances from the given huggingface supported dataset + + This reader implementation wraps the huggingface datasets package + + Registered as a `DatasetReader` with name `huggingface-datasets` + + # Parameters + dataset_name : `str` + Name of the dataset from huggingface datasets the reader will be used for. + config_name : `str`, optional (default=`None`) + Configuration(mandatory for some datasets) of the dataset. + tokenizer : `Tokenizer`, optional (default=`None`) + If specified is used for tokenization of string and text fields from the dataset. + This is useful since text in allennlp is dealt with as a series of tokens. + """ + + def __init__( + self, + dataset_name: str = None, + config_name: Optional[str] = None, + tokenizer: Optional[Tokenizer] = None, + **kwargs, + ) -> None: + super().__init__( + manual_distributed_sharding=True, + manual_multiprocess_sharding=True, + **kwargs, + ) + + # It would be cleaner to create a separate reader object for each different dataset + if dataset_name not in list_datasets(): + raise ValueError(f"Dataset {dataset_name} not available in huggingface datasets") + self.dataset: DatasetDict = DatasetDict() + self.dataset_name = dataset_name + self.config_name = config_name + self.tokenizer = tokenizer + + self.features = None + + def load_dataset_split(self, split: str): + if self.config_name is not None: + self.dataset[split] = load_dataset(self.dataset_name, self.config_name, split=split) + else: + self.dataset[split] = load_dataset(self.dataset_name, split=split) + + def _read(self, file_path: str) -> Iterable[Instance]: + """ + Reads the dataset and converts the entry to AllenNLP friendly instance + """ + if file_path is None: + raise ValueError("parameter split cannot be None") + + # If split is not loaded, load the specific split + if file_path not in self.dataset: + self.load_dataset_split(file_path) + if self.features is None: + self.features = self.dataset[file_path].features + + # TODO see if use of Dataset.select() is better + dataset_split = self.dataset[file_path] + for index in self.shard_iterable(range(len(dataset_split))): + yield self.text_to_instance(file_path, dataset_split[index]) + + def raise_feature_not_supported_value_error(feature_name, feature_type): + raise ValueError( + f"Datasets feature {feature_name} type {feature_type} is not supported yet." + ) + + def text_to_instance(self, split: str, entry) -> Instance: # type: ignore + """ + Takes care of converting dataset entry into AllenNLP friendly instance + + Currently this is how datasets.features types are mapped to AllenNLP Fields + + dataset.feature type allennlp.data.fields + `ClassLabel` `LabelField` in feature name namespace + `Value.string` `TextField` with value as Token + `Value.*` `LabelField` with value being label in feature name namespace + `Translation` `ListField` of 2 ListField (ClassLabel and TextField) + `TranslationVariableLanguages` `ListField` of 2 ListField (ClassLabel and TextField) + `Sequence` `ListField` of sub-types + """ + + # features indicate the different information available in each entry from dataset + # feature types decide what type of information they are + # e.g. In a Sentiment dataset an entry could have one feature (of type text/string) indicating the text + # and another indicate the sentiment (of type int32/ClassLabel) + + features: Dict[str, FeatureType] = self.dataset[split].features + fields: Dict[str, Field] = dict() + + # TODO we need to support all different datasets features described + # in https://huggingface.co/docs/datasets/features.html + for feature_name in features: + item_field: Field + field_list: list + feature_type = features[feature_name] + + fields_to_be_added = _map_Feature( + feature_name, entry[feature_name], feature_type, self.tokenizer + ) + for field_key in fields_to_be_added: + fields[field_key] = fields_to_be_added[field_key] + + return Instance(fields) + + +# Feature Mappers - These functions map a FeatureType into Fields +def _map_Feature( + feature_name: str, value, feature_type, tokenizer: Optional[Tokenizer] +) -> Dict[str, Field]: + fields_to_be_added: Dict[str, Field] = dict() + if isinstance(feature_type, ClassLabel): + fields_to_be_added[feature_name] = _map_ClassLabel(feature_name, value) + # datasets Value can be of different types + elif isinstance(feature_type, Value): + fields_to_be_added[feature_name] = _map_Value(feature_name, value, feature_type, tokenizer) + + elif isinstance(feature_type, Sequence): + if type(feature_type.feature) == dict: + fields_to_be_added = _map_Dict(feature_type.feature, value, tokenizer, feature_name) + else: + fields_to_be_added[feature_name] = _map_Sequence( + feature_name, value, feature_type.feature, tokenizer + ) + + elif isinstance(feature_type, Translation): + fields_to_be_added = _map_Translation(feature_name, value, feature_type, tokenizer) + + elif isinstance(feature_type, TranslationVariableLanguages): + fields_to_be_added = _map_TranslationVariableLanguages( + feature_name, value, feature_type, tokenizer + ) + + elif isinstance(feature_type, dict): + fields_to_be_added = _map_Dict(feature_type, value, tokenizer) + else: + raise ValueError(f"Datasets feature type {type(feature_type)} is not supported yet.") + return fields_to_be_added + + +def _map_ClassLabel(feature_name: str, value: ClassLabel) -> Field: + field: Field = _map_to_Label(feature_name, value, skip_indexing=True) + return field + + +def _map_Value( + feature_name: str, value: Value, feature_type, tokenizer: Optional[Tokenizer] +) -> Union[TextField, LabelField, TensorField]: + field: Union[TextField, LabelField, TensorField] + if feature_type.dtype == "string": + # datasets.Value[string] maps to TextField + # If tokenizer is provided we will use it to split it to tokens + # Else put whole text as a single token + field = _map_String(value, tokenizer) + + elif feature_type.dtype == "float32" or feature_type.dtype == "float64": + field = _map_Float(value) + + else: + field = LabelField(value, label_namespace=feature_name, skip_indexing=True) + return field + + +def _map_Sequence( + feature_name, value: Sequence, item_feature_type, tokenizer: Optional[Tokenizer] +) -> ListField: + field_list: List[Field] = list() + field: ListField = list() + item_field: Field + # In HF Sequence and list are considered interchangeable, but there are some distinctions such as + if isinstance(item_feature_type, Value): + for item in value: + # If tokenizer is provided we will use it to split it to tokens + # Else put whole text as a single token + item_field = _map_Value(feature_name, item, item_feature_type, tokenizer) + field_list.append(item_field) + if len(field_list) > 0: + field = ListField(field_list) + + # datasets Sequence of strings to ListField of LabelField + elif isinstance(item_feature_type, str): + for item in value: + # If tokenizer is provided we will use it to split it to tokens + # Else put whole text as a single token + item_field = _map_Value(feature_name, item, item_feature_type, tokenizer) + field_list.append(item_field) + if len(field_list) > 0: + field = ListField(field_list) + + elif isinstance(item_feature_type, ClassLabel): + for item in value: + item_field = _map_to_Label(feature_name, item, skip_indexing=True) + field_list.append(item_field) + + if len(field_list) > 0: + field = ListField(field_list) + + elif isinstance(item_feature_type, Sequence): + for item in value: + item_field = _map_Sequence(value.feature, item, item_feature_type.feature, tokenizer) + field_list.append(item_field) + + if len(field_list) > 0: + field = ListField(field_list) + + # # WIP for dropx` + # elif isinstance(item_feature_type, dict): + # for item in value: + # item_field = _map_Dict(item_feature_type, value[item], tokenizer) + # field_list.append(item_field) + # if len(field_list) > 0: + # field = ListField(field_list) + + else: + HuggingfaceDatasetReader.raise_feature_not_supported_value_error( + feature_name, item_feature_type + ) + + return field + + +def _map_Translation( + feature_name: str, value: Translation, feature_type, tokenizer: Optional[Tokenizer] +) -> Dict[str, Field]: + fields: Dict[str, Field] = dict() + if feature_type.dtype == "dict": + input_dict = value + langs = list(input_dict.keys()) + texts = list() + for lang in langs: + if tokenizer is not None: + tokens = tokenizer.tokenize(input_dict[lang]) + + else: + tokens = [Token(input_dict[lang])] + texts.append(TextField(tokens)) + + fields[feature_name + "-languages"] = ListField( + [ + _map_to_Label(feature_name + "-languages", lang, skip_indexing=False) + for lang in langs + ] + ) + fields[feature_name + "-texts"] = ListField(texts) + + else: + raise ValueError(f"Datasets feature type {type(feature_type)} is not supported yet.") + + return fields + + +def _map_TranslationVariableLanguages( + feature_name: str, + value: TranslationVariableLanguages, + feature_type, + tokenizer: Optional[Tokenizer], +) -> Dict[str, Field]: + fields: Dict[str, Field] = dict() + if feature_type.dtype == "dict": + input_dict = value + fields[feature_name + "-language"] = ListField( + [ + _map_to_Label(feature_name + "-languages", lang, skip_indexing=False) + for lang in input_dict["language"] + ] + ) + + if tokenizer is not None: + fields[feature_name + "-translation"] = ListField( + [TextField(tokenizer.tokenize(text)) for text in input_dict["translation"]] + ) + else: + fields[feature_name + "-translation"] = ListField( + [TextField([Token(text)]) for text in input_dict["translation"]] + ) + + else: + raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + + return fields + + +# value mapper - Maps a single text value to TextField +def _map_String(text: str, tokenizer: Optional[Tokenizer]) -> TextField: + field: TextField + if tokenizer is not None: + field = TextField(tokenizer.tokenize(text)) + else: + field = TextField([Token(text)]) + return field + + +def _map_Float(value: float) -> TensorField: + return TensorField(torch.tensor(value)) + + +# value mapper - Maps a single value to a LabelField +def _map_to_Label(namespace, item, skip_indexing=True) -> LabelField: + return LabelField(label=item, label_namespace=namespace, skip_indexing=skip_indexing) + + +def _map_Dict( + feature_definition: dict, + values: dict, + tokenizer: Optional[Tokenizer] = None, + feature_name: Optional[str] = None, +) -> Dict[str, Field]: + # TODO abhishek-p expand this to more generic based on metadata checks + # Map it as a Dictionary of List + fields: Dict[str, Field] = dict() + for key in values: + key_name: str = key + if feature_name is not None: + key_name = feature_name + "-" + key + fields[key_name] = _map_Sequence(key, values[key], feature_definition[key], tokenizer) + return fields \ No newline at end of file diff --git a/setup.py b/setup.py index 8acce160aa0..589b4e210dd 100644 --- a/setup.py +++ b/setup.py @@ -76,6 +76,7 @@ "wandb>=0.10.0,<0.11.0", "huggingface_hub>=0.0.8", "google-cloud-storage>=1.38.0,<1.39.0", + "datasets>=1.5.0,<1.6.0", ], entry_points={"console_scripts": ["allennlp=allennlp.__main__:run"]}, include_package_data=True, diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py new file mode 100644 index 00000000000..dba15d14f0d --- /dev/null +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -0,0 +1,217 @@ +import pytest +from allennlp.data import Tokenizer + +from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetReader +from allennlp.data.tokenizers import WhitespaceTokenizer + + +# TODO Add test where we compare huggingface wrapped reader with an explicitly coded dataset +class HuggingfaceDatasetReaderTest: + """ + Test read for some lightweight datasets + """ + + @pytest.mark.parametrize( + "dataset, config, split", + (("glue", "cola", "train"), ("glue", "cola", "test")), + ) + def test_read(self, dataset, config, split): + huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset, config_name=config) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + + entry = huggingface_reader.dataset[split][0] + instance = instances[0] + + # Confirm all features were mapped + assert len(instance.fields) == len(entry) + + def test_read_with_tokenizer(self): + dataset = "glue" + config = "cola" + split = "train" + tokenizer: Tokenizer = WhitespaceTokenizer() + huggingface_reader = HuggingfaceDatasetReader( + dataset_name=dataset, config_name=config, tokenizer=tokenizer + ) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + + entry = huggingface_reader.dataset[split][0] + instance = instances[0] + + # Confirm all features were mapped + assert len(instance.fields) == len(entry) + + # Confirm it was tokenized + assert len(instance["sentence"]) > 1 + + def test_read_without_config(self): + dataset = "urdu_fake_news" + split = "train" + huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + + entry = huggingface_reader.dataset[split][0] + instance = instances[0] + + # Confirm all features were mapped + assert len(instance.fields) == len(entry) + + """ + Test mapping of the datasets.feature.Translation and datasets.feature.TranslationVariableLanguages + """ + + def test_read_xnli_all_languages(self): + dataset = "xnli" + config = "all_languages" + split = "validation" + huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset, config_name=config) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + instance = instances[0] + # We are splitting datasets.features.Translation and + # datasets.features.TranslationVariableLanguages into two fields each + # For XNLI that means 3 fields become 5 + assert len(instance.fields) == 5 + + def test_non_available_dataset(self): + with pytest.raises(ValueError): + HuggingfaceDatasetReader(dataset_name="surely-such-a-dataset-does-not-exist") + + @pytest.mark.parametrize("split", (None, "surely-such-a-split-does-not-exist")) + def test_read_with_invalid_split(self, split): + with pytest.raises(ValueError): + next(HuggingfaceDatasetReader(dataset_name="glue", config_name="cola").read(split)) + + """ + Test to help validate for the known supported datasets + Skipped by default, enable when required + """ + + # TODO abhishek-p skip these once MR is ready to check-in + @pytest.mark.parametrize( + "dataset, config, split", + ( + ("xnli", "ar", "train"), + ("xnli", "en", "train"), + ("xnli", "de", "train"), + ("glue", "mrpc", "train"), + ("glue", "sst2", "train"), + ("glue", "qqp", "train"), + ("glue", "mnli", "train"), + ("glue", "mnli_matched", "validation"), + ("universal_dependencies", "en_lines", "train"), + ("universal_dependencies", "ko_kaist", "train"), + ("universal_dependencies", "af_afribooms", "train"), + ), + ) + def test_read_known_supported_datasets_with_config(self, dataset, config, split): + huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset, config_name=config) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + + entry = huggingface_reader.dataset[split][0] + instance = instances[0] + + # Confirm all features were mapped + assert len(instance.fields) == len(entry) + + """ + Test to help validate for the known supported datasets without config + Skipped by default, enable when required + """ + + # TODO abhishek-p skip these once MR is ready to check-in + @pytest.mark.parametrize("dataset", (("swahili"), ("dbpedia_14"), ("trec"), ("emotion"))) + def test_read_known_supported_datasets_without_config(self, dataset): + split = "train" + huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + + entry = huggingface_reader.dataset[split][0] + instance = instances[0] + + # Confirm all features were mapped + assert len(instance.fields) == len(entry) + + # def test_conll2003(self): + # instances = list(HuggingfaceDatasetReader("conll2003").read("test")) + # print(instances[0]) + + def test_squad(self): + tokenizer: Tokenizer = WhitespaceTokenizer() + instance_gen = HuggingfaceDatasetReader("squad", tokenizer=tokenizer).read("train") + print(next(instance_gen)) + + @pytest.mark.parametrize("config", (("default"), ("ptb"))) + def test_sst(self, config): + instances = list(HuggingfaceDatasetReader("sst", config).read("test")) + print(instances[0]) + + def test_open_web_text(self): + instances = list(HuggingfaceDatasetReader("openwebtext").read("plain_text")) + print(instances[0]) + + # @pytest.mark.skip("Requires mapping of dict type") + def test_mocha(self): + instances = list(HuggingfaceDatasetReader("mocha").read("test")) + print(instances[0]) + + @pytest.mark.skip("Requires implementation of Dict") + def test_commonsense_qa(self): + instances = list(HuggingfaceDatasetReader("commonsense_qa").read("test")) + print(instances[0]) + + def test_piqa(self): + instances = list(HuggingfaceDatasetReader("piqa").read("test")) + print(instances[0]) + + def test_swag(self): + instances = list(HuggingfaceDatasetReader("swag").read("test")) + print(instances[0]) + + def test_snli(self): + instances = list(HuggingfaceDatasetReader("snli").read("test")) + print(instances[0]) + + def test_multi_nli(self): + instances = list(HuggingfaceDatasetReader("multi_nli").read("test")) + print(instances[0]) + + def test_super_glue(self): + instances = list(HuggingfaceDatasetReader("super_glue").read("test")) + print(instances[0]) + + @pytest.mark.parametrize( + "config", + ( + ("cola"), + ("mnli"), + ("ax"), + ("mnli_matched"), + ("mnli_mismatched"), + ("mrpc"), + ("qnli"), + ("qqp"), + ("rte"), + ("sst2"), + ("stsb"), + ("wnli"), + ), + ) + def test_glue(self, config): + instances = list(HuggingfaceDatasetReader("glue", config).read("test")) + print(instances[0]) + + def test_drop(self): + instances = list(HuggingfaceDatasetReader("drop").read("test")) + print(instances[0])