Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Improve dict missing key code (#3071)
Browse files Browse the repository at this point in the history
* Improve dict missing key code

* remove trailing whitespace
  • Loading branch information
hawkeoni authored and joelgrus committed Jul 17, 2019
1 parent 9166c18 commit 014fe31
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 9 deletions.
4 changes: 2 additions & 2 deletions allennlp/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_padding_lengths(self) -> Dict[str, Dict[str, int]]:
all_field_lengths[field_name].append(instance_field_lengths)
for field_name, field_lengths in all_field_lengths.items():
for padding_key in field_lengths[0].keys():
max_value = max(x[padding_key] if padding_key in x else 0 for x in field_lengths)
max_value = max(x.get(padding_key, 0) for x in field_lengths)
padding_lengths[field_name][padding_key] = max_value
return {**padding_lengths}

Expand Down Expand Up @@ -124,7 +124,7 @@ def as_tensor_dict(self,
lengths_to_use: Dict[str, Dict[str, int]] = defaultdict(dict)
for field_name, instance_field_lengths in instance_padding_lengths.items():
for padding_key in instance_field_lengths.keys():
if padding_lengths[field_name].get(padding_key) is not None:
if padding_key in padding_lengths[field_name]:
lengths_to_use[field_name][padding_key] = padding_lengths[field_name][padding_key]
else:
lengths_to_use[field_name][padding_key] = instance_field_lengths[padding_key]
Expand Down
5 changes: 1 addition & 4 deletions allennlp/data/dataset_readers/penn_tree_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,7 @@ def text_to_instance(self, # type: ignore
spans.append(SpanField(start, end, text_field))

if gold_spans is not None:
if (start, end) in gold_spans.keys():
gold_labels.append(gold_spans[(start, end)])
else:
gold_labels.append("NO-LABEL")
gold_labels.append(gold_spans.get((start, end), "NO-LABEL"))

metadata = {"tokens": tokens}
if gold_tree:
Expand Down
4 changes: 2 additions & 2 deletions allennlp/data/fields/knowledge_graph_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,13 @@ def get_padding_lengths(self) -> Dict[str, int]:
# Iterate over the keys in the first element of the list. This is fine as for a given
# indexer, all entities will return the same keys, so we can just use the first one.
for key in entity_lengths[0].keys():
indexer_lengths[key] = max(x[key] if key in x else 0 for x in entity_lengths)
indexer_lengths[key] = max(x.get(key, 0) for x in entity_lengths)
lengths.append(indexer_lengths)

# Get all the keys which have been used for padding.
padding_keys = {key for d in lengths for key in d.keys()}
for padding_key in padding_keys:
padding_lengths[padding_key] = max(x[padding_key] if padding_key in x else 0 for x in lengths)
padding_lengths[padding_key] = max(x.get(padding_key, 0) for x in lengths)
return padding_lengths

@overrides
Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/fields/text_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_padding_lengths(self) -> Dict[str, int]:
# Get all keys which have been used for padding for each indexer and take the max if there are duplicates.
padding_keys = {key for d in lengths for key in d.keys()}
for padding_key in padding_keys:
padding_lengths[padding_key] = max(x[padding_key] if padding_key in x else 0 for x in lengths)
padding_lengths[padding_key] = max(x.get(padding_key, 0) for x in lengths)
return padding_lengths

@overrides
Expand Down

0 comments on commit 014fe31

Please sign in to comment.