Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Add HuggingfaceDatasetReader for using Huggingface datasets #5095

Open
wants to merge 67 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
af9661e
Add `HuggingfaceDatasetReader` for using Huggingface `datasets`
Abhishek-P Mar 31, 2021
8370803
Move mapping to funcs, remove preload support
Abhishek-P May 9, 2021
d5b8f3f
Support for Sequence Nesting
Abhishek-P May 9, 2021
49fa0bc
Misc Fixes
Abhishek-P May 9, 2021
17cd4ac
Misc check
Abhishek-P May 9, 2021
5159f69
Comments
Abhishek-P May 10, 2021
7155a32
map funcs _ prefix
Abhishek-P May 11, 2021
eb4b573
Parameters rename and cleanup
Abhishek-P May 12, 2021
a9ef475
Apply suggestions from code review by Dirk - comment text
May 12, 2021
92d95f5
Merge branch 'main' into datasets_feature
dirkgr May 13, 2021
0e441da
Merge branch 'main' into datasets_feature
dirkgr May 19, 2021
2610df8
Formatting
dirkgr May 20, 2021
a0d1408
Comments addressed
Abhishek-P May 20, 2021
57b6f9e
Formatting
Abhishek-P May 23, 2021
e841b6e
removed invalid conll test
Abhishek-P May 23, 2021
2497b24
Regression Fix
Abhishek-P May 25, 2021
a6718f4
Merge branch 'allenai:main' into datasets_feature
Jun 15, 2021
74931dc
Add float mapping to TensorField
Abhishek-P Jun 24, 2021
10dd3e6
Verification tests
Abhishek-P Jun 29, 2021
f3e54dd
Attempt to Support Dict
Abhishek-P Jun 30, 2021
d0f31c1
Quick changes
Abhishek-P Aug 4, 2021
b277534
Dictionary works with SQUAD
Abhishek-P Aug 11, 2021
a1d9bca
Bias Mitigation and Direction Methods (#5130)
ArjunSubramonian May 11, 2021
5dce9f5
Bias Metrics (#5139)
ArjunSubramonian May 13, 2021
dfed580
Update transformers requirement from <4.6,>=4.1 to >=4.1,<4.7 (#5199)
dependabot[bot] May 14, 2021
f1a1adc
Rename sanity_checks to confidence_checks (#5201)
AkshitaB May 14, 2021
047ae34
Changes and improvements to how we initialize transformer modules fro…
epwalsh May 17, 2021
0ea9225
Add a `min_steps` parameter to `BeamSearch` (#5207)
danieldeutsch May 17, 2021
9de5b4e
Implementing abstraction to score final sequences in `BeamSearch` (#5…
danieldeutsch May 18, 2021
5660670
added shuffle disable option in BucketBatchSampler (#5212)
ArjunSubramonian May 19, 2021
73e570b
save meta data with model archives (#5209)
epwalsh May 19, 2021
f3aeeeb
Formatting
dirkgr May 20, 2021
d6c7769
Comments addressed
Abhishek-P May 20, 2021
79f58a8
Formatting
Abhishek-P May 23, 2021
a55a7ba
removed invalid conll test
Abhishek-P May 23, 2021
81d0409
Regression Fix
Abhishek-P May 25, 2021
5b9e0c2
Bump black from 20.8b1 to 21.5b1 (#5195)
dependabot[bot] May 25, 2021
66f226b
Update nr-interface requirement from <0.0.4 to <0.0.6 (#5213)
dependabot[bot] May 25, 2021
3295bd5
Fix W&B callback for distributed training (#5223)
epwalsh May 26, 2021
19d2a87
cancel redundant GH Actions workflows (#5226)
epwalsh May 26, 2021
51a01fe
fix race condition when extracting files with cached_path (#5227)
epwalsh May 27, 2021
7727af5
Bump checklist from 0.0.10 to 0.0.11 (#5222)
dependabot[bot] May 27, 2021
0d5b88f
Added `DataCollator` for dynamic operations for each batch. (#5221)
wlhgtc May 27, 2021
b75c60c
Roll backbone (#5229)
jacob-morrison May 28, 2021
fd0981c
Fixes Checkpointing (#5220)
dirkgr May 29, 2021
804fd59
Emergency fix. I forgot to take this out.
dirkgr May 29, 2021
deeec84
Add constraints to beam search (#5216)
danieldeutsch Jun 1, 2021
0bdee9d
Make BeamSearch Registrable (#5231)
JohnGiorgi Jun 1, 2021
8e10f69
tick version for nightly release
epwalsh Jun 2, 2021
7b8e9e9
Generalize T5 modules (#5166)
AkshitaB Jun 2, 2021
3916cf3
Fix tqdm logging into multiple files with allennlp-optuna (#5235)
MagiaSN Jun 2, 2021
4753906
Checklist fixes (#5239)
AkshitaB Jun 2, 2021
b7a62fa
Contextualized bias mitigation (#5176)
ArjunSubramonian Jun 2, 2021
1159432
Prepare for release v2.5.0
epwalsh Jun 3, 2021
5f76b59
tick version for nightly release
epwalsh Jun 3, 2021
044e0ff
Bump black from 21.5b1 to 21.5b2 (#5236)
dependabot[bot] Jun 4, 2021
b7fd842
[Docs] Fixes broken link in Fairness_Metrics (#5245)
bhadreshpsavani Jun 7, 2021
38c930b
Ensure all relevant allennlp submodules are imported with `import_plu…
epwalsh Jun 7, 2021
0e3a225
added `on_backward` trainer callback (#5249)
ArjunSubramonian Jun 11, 2021
69d05ff
Add float mapping to TensorField
Abhishek-P Jun 24, 2021
356b383
Verification tests
Abhishek-P Jun 29, 2021
3192d70
Attempt to Support Dict
Abhishek-P Jun 30, 2021
e32c5b0
Quick changes
Abhishek-P Aug 4, 2021
5f702ef
Dictionary works with SQUAD
Abhishek-P Aug 11, 2021
fd95128
Merge branch 'datasets_feature' of github.com:Abhishek-P/allennlp int…
Abhishek-P Aug 11, 2021
af029b3
Fix typing issues
Abhishek-P Aug 11, 2021
41b7034
Works for Mocha, although may need to add specific handling for SQUAD…
Abhishek-P Aug 11, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ __pycache__
.coverage
.pytest_cache/
.benchmarks
htmlcov/

# documentation build artifacts

Expand Down
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ test-with-cov :
--cov=$(SRC) \
--cov-report=xml

.PHONY : test-with-cov-html
test-with-cov-html :
pytest --color=yes -rf --durations=40 \
--cov-config=.coveragerc \
--cov=$(SRC) \
--cov-report=html

.PHONY : gpu-test
gpu-test : check-for-cuda
pytest --color=yes -v -rf -m gpu
Expand Down
337 changes: 337 additions & 0 deletions allennlp/data/dataset_readers/huggingface_datasets_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,337 @@
from allennlp.data import DatasetReader, Token, Field, Tokenizer
from allennlp.data.fields import TextField, LabelField, ListField, TensorField
from allennlp.data.instance import Instance
from datasets import load_dataset, DatasetDict, list_datasets
from datasets.features import (
ClassLabel,
Sequence,
Translation,
TranslationVariableLanguages,
Value,
FeatureType,
)

import torch
from typing import Iterable, Optional, Dict, List, Union


@DatasetReader.register("huggingface-datasets")
class HuggingfaceDatasetReader(DatasetReader):
"""
Reads instances from the given huggingface supported dataset

This reader implementation wraps the huggingface datasets package

Registered as a `DatasetReader` with name `huggingface-datasets`

# Parameters
dataset_name : `str`
Name of the dataset from huggingface datasets the reader will be used for.
config_name : `str`, optional (default=`None`)
Configuration(mandatory for some datasets) of the dataset.
tokenizer : `Tokenizer`, optional (default=`None`)
If specified is used for tokenization of string and text fields from the dataset.
This is useful since text in allennlp is dealt with as a series of tokens.
"""

def __init__(
self,
dataset_name: str = None,
config_name: Optional[str] = None,
tokenizer: Optional[Tokenizer] = None,
**kwargs,
) -> None:
super().__init__(
manual_distributed_sharding=True,
manual_multiprocess_sharding=True,
**kwargs,
)

# It would be cleaner to create a separate reader object for each different dataset
if dataset_name not in list_datasets():
raise ValueError(f"Dataset {dataset_name} not available in huggingface datasets")
self.dataset: DatasetDict = DatasetDict()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically a cache, so if you load the same dataset twice it doesn't load it twice?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused by this, because Huggingface already has their own cache. So we're caching it twice. Calling datasets.load_dataset("squad", split="train") takes about 200ms on my machine once all the files are downloaded. That's not a lot of time to save with a cache.

Copy link
Author

@ghost ghost May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a programming perspective having this dict is cleaner. Even datasets lib gives you a DatasetDict
And this is not a cache, since the reference is still to the same dataset object given by the datasets lib.
When a split is loaded it is a dataset, for the reader to maintain the organization of splits, I am using this DatasetDict.

self.dataset_name = dataset_name
self.config_name = config_name
self.tokenizer = tokenizer

self.features = None

def load_dataset_split(self, split: str):
if self.config_name is not None:
self.dataset[split] = load_dataset(self.dataset_name, self.config_name, split=split)
else:
self.dataset[split] = load_dataset(self.dataset_name, split=split)

def _read(self, file_path: str) -> Iterable[Instance]:
"""
Reads the dataset and converts the entry to AllenNLP friendly instance
"""
if file_path is None:
raise ValueError("parameter split cannot be None")
This conversation was marked as resolved.
Show resolved Hide resolved

# If split is not loaded, load the specific split
if file_path not in self.dataset:
self.load_dataset_split(file_path)
if self.features is None:
self.features = self.dataset[file_path].features

# TODO see if use of Dataset.select() is better
This conversation was marked as resolved.
Show resolved Hide resolved
dataset_split = self.dataset[file_path]
for index in self.shard_iterable(range(len(dataset_split))):
yield self.text_to_instance(file_path, dataset_split[index])

def raise_feature_not_supported_value_error(feature_name, feature_type):
raise ValueError(
f"Datasets feature {feature_name} type {feature_type} is not supported yet."
)

def text_to_instance(self, split: str, entry) -> Instance: # type: ignore
"""
Takes care of converting dataset entry into AllenNLP friendly instance

Currently this is how datasets.features types are mapped to AllenNLP Fields

dataset.feature type allennlp.data.fields
`ClassLabel` `LabelField` in feature name namespace
`Value.string` `TextField` with value as Token
`Value.*` `LabelField` with value being label in feature name namespace
`Translation` `ListField` of 2 ListField (ClassLabel and TextField)
`TranslationVariableLanguages` `ListField` of 2 ListField (ClassLabel and TextField)
`Sequence` `ListField` of sub-types
Comment on lines +95 to +101
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a proper Markdown table syntax: https://www.markdownguide.org/extended-syntax/

@epwalsh, do we support that syntax when the docs are built?

"""

# features indicate the different information available in each entry from dataset
# feature types decide what type of information they are
# e.g. In a Sentiment dataset an entry could have one feature (of type text/string) indicating the text
# and another indicate the sentiment (of type int32/ClassLabel)

features: Dict[str, FeatureType] = self.dataset[split].features
This conversation was marked as resolved.
Show resolved Hide resolved
fields: Dict[str, Field] = dict()

# TODO we need to support all different datasets features described
# in https://huggingface.co/docs/datasets/features.html
This conversation was marked as resolved.
Show resolved Hide resolved
This conversation was marked as resolved.
Show resolved Hide resolved
for feature_name in features:
item_field: Field
field_list: list
Comment on lines +115 to +116
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are unused?

feature_type = features[feature_name]

fields_to_be_added = _map_Feature(
feature_name, entry[feature_name], feature_type, self.tokenizer
)
for field_key in fields_to_be_added:
fields[field_key] = fields_to_be_added[field_key]

return Instance(fields)


# Feature Mappers - These functions map a FeatureType into Fields
def _map_Feature(
feature_name: str, value, feature_type, tokenizer: Optional[Tokenizer]
) -> Dict[str, Field]:
fields_to_be_added: Dict[str, Field] = dict()
if isinstance(feature_type, ClassLabel):
fields_to_be_added[feature_name] = _map_ClassLabel(feature_name, value)
# datasets Value can be of different types
elif isinstance(feature_type, Value):
fields_to_be_added[feature_name] = _map_Value(feature_name, value, feature_type, tokenizer)

elif isinstance(feature_type, Sequence):
if type(feature_type.feature) == dict:
fields_to_be_added = _map_Dict(feature_type.feature, value, tokenizer, feature_name)
else:
fields_to_be_added[feature_name] = _map_Sequence(
feature_name, value, feature_type.feature, tokenizer
)

elif isinstance(feature_type, Translation):
fields_to_be_added = _map_Translation(feature_name, value, feature_type, tokenizer)

elif isinstance(feature_type, TranslationVariableLanguages):
fields_to_be_added = _map_TranslationVariableLanguages(
feature_name, value, feature_type, tokenizer
)

elif isinstance(feature_type, dict):
fields_to_be_added = _map_Dict(feature_type, value, tokenizer)
else:
raise ValueError(f"Datasets feature type {type(feature_type)} is not supported yet.")
return fields_to_be_added


def _map_ClassLabel(feature_name: str, value: ClassLabel) -> Field:
field: Field = _map_to_Label(feature_name, value, skip_indexing=True)
return field


def _map_Value(
feature_name: str, value: Value, feature_type, tokenizer: Optional[Tokenizer]
) -> Union[TextField, LabelField, TensorField]:
field: Union[TextField, LabelField, TensorField]
if feature_type.dtype == "string":
# datasets.Value[string] maps to TextField
# If tokenizer is provided we will use it to split it to tokens
# Else put whole text as a single token
field = _map_String(value, tokenizer)

elif feature_type.dtype == "float32" or feature_type.dtype == "float64":
field = _map_Float(value)

else:
field = LabelField(value, label_namespace=feature_name, skip_indexing=True)
return field


def _map_Sequence(
feature_name, value: Sequence, item_feature_type, tokenizer: Optional[Tokenizer]
) -> ListField:
field_list: List[Field] = list()
field: ListField = list()
item_field: Field
# In HF Sequence and list are considered interchangeable, but there are some distinctions such as
if isinstance(item_feature_type, Value):
for item in value:
# If tokenizer is provided we will use it to split it to tokens
# Else put whole text as a single token
item_field = _map_Value(feature_name, item, item_feature_type, tokenizer)
field_list.append(item_field)
if len(field_list) > 0:
field = ListField(field_list)

# datasets Sequence of strings to ListField of LabelField
elif isinstance(item_feature_type, str):
for item in value:
# If tokenizer is provided we will use it to split it to tokens
# Else put whole text as a single token
item_field = _map_Value(feature_name, item, item_feature_type, tokenizer)
field_list.append(item_field)
if len(field_list) > 0:
field = ListField(field_list)

elif isinstance(item_feature_type, ClassLabel):
for item in value:
item_field = _map_to_Label(feature_name, item, skip_indexing=True)
field_list.append(item_field)

if len(field_list) > 0:
field = ListField(field_list)

elif isinstance(item_feature_type, Sequence):
for item in value:
item_field = _map_Sequence(value.feature, item, item_feature_type.feature, tokenizer)
field_list.append(item_field)

if len(field_list) > 0:
field = ListField(field_list)

# # WIP for dropx`
# elif isinstance(item_feature_type, dict):
# for item in value:
# item_field = _map_Dict(item_feature_type, value[item], tokenizer)
# field_list.append(item_field)
# if len(field_list) > 0:
# field = ListField(field_list)

else:
HuggingfaceDatasetReader.raise_feature_not_supported_value_error(
feature_name, item_feature_type
)

return field


def _map_Translation(
feature_name: str, value: Translation, feature_type, tokenizer: Optional[Tokenizer]
) -> Dict[str, Field]:
fields: Dict[str, Field] = dict()
if feature_type.dtype == "dict":
input_dict = value
langs = list(input_dict.keys())
texts = list()
for lang in langs:
if tokenizer is not None:
tokens = tokenizer.tokenize(input_dict[lang])

else:
tokens = [Token(input_dict[lang])]
texts.append(TextField(tokens))

fields[feature_name + "-languages"] = ListField(
[
_map_to_Label(feature_name + "-languages", lang, skip_indexing=False)
for lang in langs
]
)
fields[feature_name + "-texts"] = ListField(texts)

else:
raise ValueError(f"Datasets feature type {type(feature_type)} is not supported yet.")

return fields


def _map_TranslationVariableLanguages(
feature_name: str,
value: TranslationVariableLanguages,
feature_type,
tokenizer: Optional[Tokenizer],
) -> Dict[str, Field]:
fields: Dict[str, Field] = dict()
if feature_type.dtype == "dict":
input_dict = value
fields[feature_name + "-language"] = ListField(
[
_map_to_Label(feature_name + "-languages", lang, skip_indexing=False)
for lang in input_dict["language"]
]
)

if tokenizer is not None:
fields[feature_name + "-translation"] = ListField(
[TextField(tokenizer.tokenize(text)) for text in input_dict["translation"]]
)
else:
fields[feature_name + "-translation"] = ListField(
[TextField([Token(text)]) for text in input_dict["translation"]]
)

else:
raise ValueError(f"Datasets feature type {type(value)} is not supported yet.")

return fields


# value mapper - Maps a single text value to TextField
def _map_String(text: str, tokenizer: Optional[Tokenizer]) -> TextField:
field: TextField
if tokenizer is not None:
field = TextField(tokenizer.tokenize(text))
else:
field = TextField([Token(text)])
return field


def _map_Float(value: float) -> TensorField:
return TensorField(torch.tensor(value))


# value mapper - Maps a single value to a LabelField
def _map_to_Label(namespace, item, skip_indexing=True) -> LabelField:
return LabelField(label=item, label_namespace=namespace, skip_indexing=skip_indexing)


def _map_Dict(
feature_definition: dict,
values: dict,
tokenizer: Optional[Tokenizer] = None,
feature_name: Optional[str] = None,
) -> Dict[str, Field]:
# TODO abhishek-p expand this to more generic based on metadata checks
# Map it as a Dictionary of List
fields: Dict[str, Field] = dict()
for key in values:
key_name: str = key
if feature_name is not None:
key_name = feature_name + "-" + key
fields[key_name] = _map_Sequence(key, values[key], feature_definition[key], tokenizer)
return fields
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"wandb>=0.10.0,<0.11.0",
"huggingface_hub>=0.0.8",
"google-cloud-storage>=1.38.0,<1.39.0",
"datasets>=1.5.0,<1.6.0",
],
entry_points={"console_scripts": ["allennlp=allennlp.__main__:run"]},
include_package_data=True,
Expand Down
Loading