-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Fix division by zero when there are zero-length spans in MismatchedEmbedder. #4615
Conversation
Fix `clamp_min` on embeddings. Implment MattG's fix for NaN gradients in MismatchedEmbedder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this, this looks great! There are a couple of minor things to clean up in the test, and can you add a simple note to the changelog saying something like "Fixed division by zero error when there are zero-length spans in the input to a mismatched embedder."
params = Params( | ||
{ | ||
"token_embedders": { | ||
"bert": { | ||
"type": "pretrained_transformer_mismatched", | ||
"model_name": "bert-base-uncased", | ||
} | ||
} | ||
} | ||
) | ||
token_embedder = BasicTextFieldEmbedder.from_params(vocab=vocab, params=params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just make this:
params = Params( | |
{ | |
"token_embedders": { | |
"bert": { | |
"type": "pretrained_transformer_mismatched", | |
"model_name": "bert-base-uncased", | |
} | |
} | |
} | |
) | |
token_embedder = BasicTextFieldEmbedder.from_params(vocab=vocab, params=params) | |
token_embedder = BasicTextFieldEmbedder({"bert": PretrainedTransformerMismatchedEmbedder("bert-base-uncased")}) |
No need to use Params
here. (be sure to run black on that line, it might be too long, and this might require adding some imports above)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
tests/modules/token_embedders/pretrained_transformer_mismatched_embedder_test.py
Show resolved
Hide resolved
Fixed division by zero error when there are zero-length spans in the input to a mismatched embedder.
Changes have been made. There are two transformer parameters that have |
tests/modules/token_embedders/pretrained_transformer_mismatched_embedder_test.py
Outdated
Show resolved
Hide resolved
Thanks @dwadden! I tried pushing a couple of small final fixes to get this to pass CI, but for some reason it didn't let me (maybe because this is from your master branch). I was able to do one of them from the web UI, but I can't add the changelog statement that I mentioned in my comment above. Can you do that? Then I'll merge this. |
OK, I added a message to the changelog. I also got confused and did |
Thanks again! |
Fixes #4612 and adds a unit test to confirm that missing or exotic tokens to not lead to
nan
gradients in BERT-style embedder.