Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Dictionary works with SQUAD
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-P committed Aug 11, 2021
1 parent e32c5b0 commit 5f702ef
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 40 deletions.
47 changes: 26 additions & 21 deletions allennlp/data/dataset_readers/huggingface_datasets_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def _read(self, file_path: str) -> Iterable[Instance]:
yield self.text_to_instance(file_path, dataset_split[index])

def raise_feature_not_supported_value_error(feature_name, feature_type):
raise ValueError(f"Datasets feature {feature_name} type {feature_type} is not supported yet.")
raise ValueError(
f"Datasets feature {feature_name} type {feature_type} is not supported yet."
)

def text_to_instance(self, split: str, entry) -> Instance: # type: ignore
"""
Expand Down Expand Up @@ -114,7 +116,9 @@ def text_to_instance(self, split: str, entry) -> Instance: # type: ignore
field_list: list
feature_type = features[feature_name]

fields_to_be_added = _map_Feature(feature_name, entry[feature_name], feature_type, self.tokenizer)
fields_to_be_added = _map_Feature(
feature_name, entry[feature_name], feature_type, self.tokenizer
)
for field_key in fields_to_be_added:
fields[field_key] = fields_to_be_added[field_key]

Expand All @@ -130,22 +134,18 @@ def _map_Feature(
fields_to_be_added[feature_name] = _map_ClassLabel(feature_name, value)
# datasets Value can be of different types
elif isinstance(feature_type, Value):
fields_to_be_added[feature_name] = _map_Value(
feature_name, value, feature_type, tokenizer
)
fields_to_be_added[feature_name] = _map_Value(feature_name, value, feature_type, tokenizer)

elif isinstance(feature_type, Sequence):
if type(value) == dict:
fields_to_be_added = _map_Dict(feature_type, value, tokenizer)
if type(feature_type.feature) == dict:
fields_to_be_added[feature_name] = _map_Dict(feature_type.feature, value, tokenizer)
else:
fields_to_be_added[feature_name] = _map_Sequence(
feature_name, value, feature_type.feature, tokenizer
)

elif isinstance(feature_type, Translation):
fields_to_be_added = _map_Translation(
feature_name, value, feature_type, tokenizer
)
fields_to_be_added = _map_Translation(feature_name, value, feature_type, tokenizer)

elif isinstance(feature_type, TranslationVariableLanguages):
fields_to_be_added = _map_TranslationVariableLanguages(
Expand All @@ -166,8 +166,8 @@ def _map_ClassLabel(feature_name: str, value: ClassLabel) -> Field:

def _map_Value(
feature_name: str, value: Value, feature_type, tokenizer: Optional[Tokenizer]
) -> Union[TextField, LabelField]:
field: Union[TextField, LabelField]
) -> Union[TextField, LabelField, TensorField]:
field: Union[TextField, LabelField, TensorField]
if feature_type.dtype == "string":
# datasets.Value[string] maps to TextField
# If tokenizer is provided we will use it to split it to tokens
Expand All @@ -176,16 +176,17 @@ def _map_Value(

elif feature_type.dtype == "float32" or feature_type.dtype == "float64":
field = _map_Float(value)

else:
field = LabelField(value, label_namespace=feature_name, skip_indexing=True)
return field


def _map_Sequence(
feature_name, value: Sequence, item_feature_type, tokenizer: Optional[Tokenizer]
) -> Union[ListField]:
) -> ListField:
field_list: List[Field] = list()
field: ListField = None
field: ListField
item_field: Field
# In HF Sequence and list are considered interchangeable, but there are some distinctions such as
if isinstance(item_feature_type, Value):
Expand Down Expand Up @@ -223,7 +224,7 @@ def _map_Sequence(
if len(field_list) > 0:
field = ListField(field_list)

# WIP for drop
# WIP for dropx`
elif isinstance(item_feature_type, dict):
for item in value:
item_field = _map_Dict(item_feature_type, value[item], tokenizer)
Expand All @@ -232,7 +233,9 @@ def _map_Sequence(
field = ListField(field_list)

else:
HuggingfaceDatasetReader.raise_feature_not_supported_value_error(feature_name, item_feature_type)
HuggingfaceDatasetReader.raise_feature_not_supported_value_error(
feature_name, item_feature_type
)

return field

Expand Down Expand Up @@ -307,6 +310,7 @@ def _map_String(text: str, tokenizer: Optional[Tokenizer]) -> TextField:
field = TextField([Token(text)])
return field


def _map_Float(value: float) -> TensorField:
return TensorField(torch.tensor(value))

Expand All @@ -315,11 +319,12 @@ def _map_Float(value: float) -> TensorField:
def _map_to_Label(namespace, item, skip_indexing=True) -> LabelField:
return LabelField(label=item, label_namespace=namespace, skip_indexing=skip_indexing)

def _map_Dict(feature_definition: dict, values: dict, tokenizer: Tokenizer) -> Dict[str, Field]:

def _map_Dict(
feature_definition: dict, values: dict, tokenizer: Optional[Tokenizer]
) -> Dict[str, Field]:
# Map it as a Dictionary of List
fields: Dict[str, Field] = dict()
for key in values:
fields[key] = _map_Feature(key, values[key], feature_definition[key], tokenizer)
fields[key] = _map_Sequence(key, values[key], feature_definition[key], tokenizer)
return fields



41 changes: 22 additions & 19 deletions tests/data/dataset_readers/huggingface_datasets_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# TODO Add test where we compare huggingface wrapped reader with an explicitly coded dataset
class HuggingfaceDatasetReaderTest:

"""
Test read for some lightweight datasets
"""
Expand Down Expand Up @@ -101,6 +100,7 @@ def test_read_with_invalid_split(self, split):
Test to help validate for the known supported datasets
Skipped by default, enable when required
"""

# TODO pab-vmware skip these once MR is ready to check-in
@pytest.mark.parametrize(
"dataset, config, split",
Expand Down Expand Up @@ -136,9 +136,7 @@ def test_read_known_supported_datasets_with_config(self, dataset, config, split)
"""

# TODO pab-vmware skip these once MR is ready to check-in
@pytest.mark.parametrize(
"dataset", (("swahili"), ("dbpedia_14"), ("trec"), ("emotion"))
)
@pytest.mark.parametrize("dataset", (("swahili"), ("dbpedia_14"), ("trec"), ("emotion")))
def test_read_known_supported_datasets_without_config(self, dataset):
split = "train"
huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset)
Expand All @@ -152,15 +150,14 @@ def test_read_known_supported_datasets_without_config(self, dataset):
# Confirm all features were mapped
assert len(instance.fields) == len(entry)


# def test_conll2003(self):
# instances = list(HuggingfaceDatasetReader("conll2003").read("test"))
# print(instances[0])


# @pytest.mark.skip("Requires implementation of Dict")
def test_squad(self):
instances = list(HuggingfaceDatasetReader("squad").read("train"))
tokenizer: Tokenizer = WhitespaceTokenizer()
instances = list(HuggingfaceDatasetReader("squad", tokenizer=tokenizer).read("train"))
print(instances[0])

@pytest.mark.parametrize("config", (("default"), ("ptb")))
Expand All @@ -174,7 +171,7 @@ def test_open_web_text(self):

# @pytest.mark.skip("Requires mapping of dict type")
def test_mocha(self):
reader = HuggingfaceDatasetReader("mocha").read("test")
instances = list(HuggingfaceDatasetReader("mocha").read("test"))
print(instances[0])

@pytest.mark.skip("Requires implementation of Dict")
Expand Down Expand Up @@ -202,21 +199,27 @@ def test_super_glue(self):
instances = list(HuggingfaceDatasetReader("super_glue").read("test"))
print(instances[0])

@pytest.mark.parametrize("config", (("cola"), ("mnli"), ("ax"), ("mnli_matched"), ("mnli_mismatched"), ("mrpc"), ("qnli"),\
("qqp"), ("rte"), ("sst2"), ("stsb"), ("wnli")))
@pytest.mark.parametrize(
"config",
(
("cola"),
("mnli"),
("ax"),
("mnli_matched"),
("mnli_mismatched"),
("mrpc"),
("qnli"),
("qqp"),
("rte"),
("sst2"),
("stsb"),
("wnli"),
),
)
def test_glue(self, config):
instances = list(HuggingfaceDatasetReader("glue", config).read("test"))
print(instances[0])

def test_drop(self):
instances = list(HuggingfaceDatasetReader("drop").read("test"))
print(instances[0])









0 comments on commit 5f702ef

Please sign in to comment.