Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fix division by zero when there are zero-length spans in MismatchedEm…
Browse files Browse the repository at this point in the history
…bedder. (#4615)

* Implment MattG's fix for NaN gradients in MismatchedEmbedder.

Fix `clamp_min` on embeddings.

Implment MattG's fix for NaN gradients in MismatchedEmbedder.

* Fix NaN gradients caused by weird tokens in MismatchedEmbedder.

Fixed division by zero error when there are zero-length spans in the input to a
mismatched embedder.

* Add changelog message.

* Re-run `black` to get code formatting right.

* combine fixed sections after merging with master

Co-authored-by: Matt Gardner <mattg@allenai.org>
  • Loading branch information
David Wadden and matt-gardner committed Sep 1, 2020
1 parent be97943 commit 711afaa
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 2 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed handling of some edge cases when constructing classes with `FromParams` where the class
accepts `**kwargs`.
- Fixed division by zero error when there are zero-length spans in the input to a
`PretrainedTransformerMismatchedIndexer`.

### Added

- `Predictor.capture_model_internals()` now accepts a regex specifying
which modules to capture


## [v1.1.0rc4](https://github.com/allenai/allennlp/releases/tag/v1.1.0rc4) - 2020-08-20

### Added
Expand Down Expand Up @@ -63,7 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added the option to specify `requires_grad: false` within an optimizer's parameter groups.
- Added the `file-friendly-logging` flag back to the `train` command. Also added this flag to the `predict`, `evaluate`, and `find-learning-rate` commands.
- Added an `EpochCallback` to track current epoch as a model class member.
- Added an `EpochCallback` to track current epoch as a model class member.
- Added the option to enable or disable gradient checkpointing for transformer token embedders via boolean parameter `gradient_checkpointing`.

### Removed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def forward(
span_embeddings_sum = span_embeddings.sum(2)
span_embeddings_len = span_mask.sum(2)
# Shape: (batch_size, num_orig_tokens, embedding_size)
orig_embeddings = span_embeddings_sum / span_embeddings_len
orig_embeddings = span_embeddings_sum / torch.clamp_min(span_embeddings_len, 1)

# All the places where the span length is zero, write in zeros.
orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import PretrainedTransformerMismatchedIndexer
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders import PretrainedTransformerMismatchedEmbedder
from allennlp.common.testing import AllenNlpTestCase


Expand Down Expand Up @@ -143,3 +144,36 @@ def test_token_without_wordpieces(self):
assert not torch.isnan(bert_vectors).any()
assert all(bert_vectors[0, 1] == 0)
assert all(bert_vectors[1, 1] == 0)

def test_exotic_tokens_no_nan_grads(self):
token_indexer = PretrainedTransformerMismatchedIndexer("bert-base-uncased")

sentence1 = ["A", "", "AllenNLP", "sentence", "."]
sentence2 = ["A", "\uf732\uf730\uf730\uf733", "AllenNLP", "sentence", "."]

tokens1 = [Token(word) for word in sentence1]
tokens2 = [Token(word) for word in sentence2]
vocab = Vocabulary()

token_embedder = BasicTextFieldEmbedder(
{"bert": PretrainedTransformerMismatchedEmbedder("bert-base-uncased")}
)

instance1 = Instance({"tokens": TextField(tokens1, {"bert": token_indexer})})
instance2 = Instance({"tokens": TextField(tokens2, {"bert": token_indexer})})

batch = Batch([instance1, instance2])
batch.index_instances(vocab)

padding_lengths = batch.get_padding_lengths()
tensor_dict = batch.as_tensor_dict(padding_lengths)
tokens = tensor_dict["tokens"]

bert_vectors = token_embedder(tokens)
test_loss = bert_vectors.mean()

test_loss.backward()

for name, param in token_embedder.named_parameters():
grad = param.grad
assert (grad is None) or (not torch.any(torch.isnan(grad)).item())

0 comments on commit 711afaa

Please sign in to comment.