This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
add BERT token embedder #2067
Merged
Merged
add BERT token embedder #2067
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
025cf1b
bert wip
joelgrus e90345b
update pylintrc
joelgrus a0cadbe
bert
joelgrus e6ff247
debugging test
joelgrus 72ed878
fix bug in token indexer
joelgrus 405795e
get end to end test working
joelgrus 97fabda
remove print statements
joelgrus ccc3784
add back missing line
joelgrus 93396a9
use pip installed bert version
joelgrus 78724eb
clean up class hierarchy
joelgrus a0a59cc
keep working on BERT
joelgrus 1ff2c7e
Merge branch 'master' into bert
joelgrus 212bda1
fix bert tests
joelgrus ef4f02b
fix mypy + pylint
joelgrus c7df632
Merge branch 'master' into bert
joelgrus 3a4c0bc
add comments
joelgrus 69a32c1
address PR feedback
joelgrus 6eb5cca
address PR feedback
joelgrus c23b67c
Merge branch 'master' into bert
joelgrus File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# pylint: disable=no-self-use | ||
from typing import Dict, List, Callable | ||
import logging | ||
|
||
from overrides import overrides | ||
|
||
from pytorch_pretrained_bert.tokenization import BertTokenizer | ||
|
||
from allennlp.common.util import pad_sequence_to_length | ||
from allennlp.data.vocabulary import Vocabulary | ||
from allennlp.data.tokenizers.token import Token | ||
from allennlp.data.token_indexers.token_indexer import TokenIndexer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# TODO(joelgrus): Figure out how to generate token_type_ids out of this token indexer. | ||
|
||
class WordpieceIndexer(TokenIndexer[int]): | ||
""" | ||
A token indexer that does the wordpiece-tokenization (e.g. for BERT embeddings). | ||
If you are using one of the pretrained BERT models, you'll want to use the ``PretrainedBertIndexer`` | ||
subclass rather than this base class. | ||
|
||
Parameters | ||
---------- | ||
vocab : ``Dict[str, int]`` | ||
The mapping {wordpiece -> id}. Note this is not an AllenNLP ``Vocabulary``. | ||
wordpiece_tokenizer : ``Callable[[str], List[str]]`` | ||
A function that does the actual tokenization. | ||
namespace : str, optional (default: "wordpiece") | ||
The namespace in the AllenNLP ``Vocabulary`` into which the wordpieces | ||
will be loaded. | ||
use_starting_offsets : bool, optional (default: False) | ||
By default, the "offsets" created by the token indexer correspond to the | ||
last wordpiece in each word. If ``use_starting_offsets`` is specified, | ||
they will instead correspond to the first wordpiece in each word. | ||
max_pieces : int, optional (default: 512) | ||
The BERT embedder uses positional embeddings and so has a corresponding | ||
maximum length for its input ids. Currently any inputs longer than this | ||
will be truncated. If this behavior is undesirable to you, you should | ||
consider filtering them out in your dataset reader. | ||
""" | ||
def __init__(self, | ||
vocab: Dict[str, int], | ||
wordpiece_tokenizer: Callable[[str], List[str]], | ||
namespace: str = "wordpiece", | ||
use_starting_offsets: bool = False, | ||
max_pieces: int = 512) -> None: | ||
self.vocab = vocab | ||
|
||
# The BERT code itself does a two-step tokenization: | ||
# sentence -> [words], and then word -> [wordpieces] | ||
# In AllenNLP, the first step is implemented as the ``BertSimpleWordSplitter``, | ||
# and this token indexer handles the second. | ||
self.wordpiece_tokenizer = wordpiece_tokenizer | ||
|
||
self._namespace = namespace | ||
self._added_to_vocabulary = False | ||
self.max_pieces = max_pieces | ||
self.use_starting_offsets = use_starting_offsets | ||
|
||
@overrides | ||
def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]): | ||
# If we only use pretrained models, we don't need to do anything here. | ||
pass | ||
|
||
def _add_encoding_to_vocabulary(self, vocabulary: Vocabulary) -> None: | ||
# pylint: disable=protected-access | ||
for word, idx in self.vocab.items(): | ||
vocabulary._token_to_index[self._namespace][word] = idx | ||
vocabulary._index_to_token[self._namespace][idx] = word | ||
|
||
@overrides | ||
def tokens_to_indices(self, | ||
tokens: List[Token], | ||
vocabulary: Vocabulary, | ||
index_name: str) -> Dict[str, List[int]]: | ||
if not self._added_to_vocabulary: | ||
self._add_encoding_to_vocabulary(vocabulary) | ||
self._added_to_vocabulary = True | ||
|
||
text_tokens: List[int] = [] | ||
offsets = [] | ||
# For initial offsets, start at 0; otherwise, start at -1 | ||
offset = 0 if self.use_starting_offsets else -1 | ||
|
||
for token in tokens: | ||
wordpieces = self.wordpiece_tokenizer(token.text) | ||
wordpiece_ids = [self.vocab[token] for token in wordpieces] | ||
|
||
# truncate and pray | ||
if len(text_tokens) + len(wordpiece_ids) > self.max_pieces: | ||
# TODO(joelgrus): figure out a better way to handle this | ||
logger.warning(f"Too many wordpieces, truncating: {[token.text for token in tokens]}") | ||
break | ||
|
||
# For initial offsets, the current value of ``offset`` is the start of | ||
# the current wordpiece, so add it to ``offsets`` and then increment it. | ||
if self.use_starting_offsets: | ||
offsets.append(offset) | ||
offset += len(wordpiece_ids) | ||
# For final offsets, the current value of ``offset`` is the end of | ||
# the previous wordpiece, so increment it and then add it to ``offsets``. | ||
else: | ||
offset += len(wordpiece_ids) | ||
offsets.append(offset) | ||
text_tokens.extend(wordpiece_ids) | ||
|
||
# add mask according to the original tokens, | ||
# because calling util.get_text_field_mask on the | ||
# "byte pair" tokens will produce the wrong shape | ||
mask = [1 for _ in offsets] | ||
|
||
return { | ||
index_name: text_tokens, | ||
f"{index_name}-offsets": offsets, | ||
"mask": mask | ||
} | ||
|
||
@overrides | ||
def get_padding_token(self) -> int: | ||
return 0 | ||
|
||
@overrides | ||
def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument | ||
return {} | ||
|
||
@overrides | ||
def pad_token_sequence(self, | ||
tokens: Dict[str, List[int]], | ||
desired_num_tokens: Dict[str, int], | ||
padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: # pylint: disable=unused-argument | ||
return {key: pad_sequence_to_length(val, desired_num_tokens[key]) | ||
for key, val in tokens.items()} | ||
|
||
|
||
@TokenIndexer.register("bert-pretrained") | ||
class PretrainedBertIndexer(WordpieceIndexer): | ||
# pylint: disable=line-too-long | ||
""" | ||
A ``TokenIndexer`` corresponding to a pretrained BERT model. | ||
|
||
Parameters | ||
---------- | ||
pretrained_model: ``str``, optional (default = None) | ||
Either the name of the pretrained model to use (e.g. 'bert-base-uncased'), | ||
or the path to the .txt file with its vocabulary. | ||
|
||
If the name is a key in the list of pretrained models at | ||
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py#L33 | ||
the corresponding path will be used; otherwise it will be interpreted as a path or URL. | ||
use_starting_offsets: bool, optional (default: False) | ||
By default, the "offsets" created by the token indexer correspond to the | ||
last wordpiece in each word. If ``use_starting_offsets`` is specified, | ||
they will instead correspond to the first wordpiece in each word. | ||
do_lowercase: ``bool``, optional (default = True) | ||
Whether to lowercase the tokens before converting to wordpiece ids. | ||
max_pieces: int, optional (default: 512) | ||
The BERT embedder uses positional embeddings and so has a corresponding | ||
maximum length for its input ids. Currently any inputs longer than this | ||
will be truncated. If this behavior is undesirable to you, you should | ||
consider filtering them out in your dataset reader. | ||
""" | ||
def __init__(self, | ||
pretrained_model: str, | ||
use_starting_offsets: bool = False, | ||
do_lowercase: bool = True, | ||
max_pieces: int = 512) -> None: | ||
bert_tokenizer = BertTokenizer.from_pretrained(pretrained_model, do_lowercase) | ||
super().__init__(vocab=bert_tokenizer.vocab, | ||
wordpiece_tokenizer=bert_tokenizer.wordpiece_tokenizer.tokenize, | ||
namespace="bert", | ||
use_starting_offsets=use_starting_offsets, | ||
max_pieces=max_pieces) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
127 changes: 127 additions & 0 deletions
127
allennlp/modules/token_embedders/bert_token_embedder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
""" | ||
A ``TokenEmbedder`` which uses one of the BERT models | ||
(https://github.com/google-research/bert) | ||
to produce embeddings. | ||
|
||
At its core it uses Hugging Face's PyTorch implementation | ||
(https://github.com/huggingface/pytorch-pretrained-BERT), | ||
so thanks to them! | ||
""" | ||
import logging | ||
|
||
import torch | ||
|
||
from pytorch_pretrained_bert.modeling import BertModel | ||
|
||
from allennlp.modules.scalar_mix import ScalarMix | ||
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder | ||
from allennlp.nn import util | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class BertEmbedder(TokenEmbedder): | ||
""" | ||
A ``TokenEmbedder`` that produces BERT embeddings for your tokens. | ||
Should be paired with a ``BertIndexer``, which produces wordpiece ids. | ||
|
||
Most likely you probably want to use ``PretrainedBertEmbedder`` | ||
for one of the named pretrained models, not this base class. | ||
|
||
Parameters | ||
---------- | ||
bert_model: ``BertModel`` | ||
The BERT model being wrapped. | ||
top_layer_only: ``bool``, optional (default = ``False``) | ||
If ``True``, then only return the top layer instead of apply the scalar mix. | ||
""" | ||
def __init__(self, bert_model: BertModel, top_layer_only: bool = False) -> None: | ||
super().__init__() | ||
self.bert_model = bert_model | ||
self.output_dim = bert_model.config.hidden_size | ||
if not top_layer_only: | ||
self._scalar_mix = ScalarMix(bert_model.config.num_hidden_layers, | ||
do_layer_norm=False) | ||
else: | ||
self._scalar_mix = None | ||
|
||
def get_output_dim(self) -> int: | ||
return self.output_dim | ||
|
||
def forward(self, | ||
input_ids: torch.LongTensor, | ||
offsets: torch.LongTensor = None, | ||
token_type_ids: torch.LongTensor = None) -> torch.Tensor: | ||
""" | ||
Parameters | ||
---------- | ||
input_ids : ``torch.LongTensor`` | ||
The (batch_size, max_sequence_length) tensor of wordpiece ids. | ||
offsets : ``torch.LongTensor``, optional | ||
The BERT embeddings are one per wordpiece. However it's possible/likely | ||
you might want one per original token. In that case, ``offsets`` | ||
represents the indices of the desired wordpiece for each original token. | ||
Depending on how your token indexer is configured, this could be the | ||
position of the last wordpiece for each token, or it could be the position | ||
of the first wordpiece for each token. | ||
|
||
For example, if you had the sentence "Definitely not", and if the corresponding | ||
wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids | ||
would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. | ||
If offsets are provided, the returned tensor will contain only the wordpiece | ||
embeddings at those positions, and (in particular) will contain one embedding | ||
per token. If offsets are not provided, the entire tensor of wordpiece embeddings | ||
will be returned. | ||
token_type_ids : ``torch.LongTensor``, optional | ||
If an input consists of two sentences (as in the BERT paper), | ||
tokens from the first sentence should have type 0 and tokens from | ||
the second sentence should have type 1. If you don't provide this | ||
(the default BertIndexer doesn't) then it's assumed to be all 0s. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way to modify the indexer to provide this? It's fine to do it in another PR, but I'd at least open an issue to track adding that. Seems pretty important if you want to use this for SQuAD. |
||
""" | ||
# pylint: disable=arguments-differ | ||
if token_type_ids is None: | ||
token_type_ids = torch.zeros_like(input_ids) | ||
|
||
input_mask = (input_ids != 0).long() | ||
|
||
all_encoder_layers, _ = self.bert_model(input_ids, input_mask, token_type_ids) | ||
if self._scalar_mix is not None: | ||
mix = self._scalar_mix(all_encoder_layers, input_mask) | ||
else: | ||
mix = all_encoder_layers[-1] | ||
|
||
|
||
if offsets is None: | ||
return mix | ||
else: | ||
batch_size = input_ids.size(0) | ||
range_vector = util.get_range_vector(batch_size, | ||
device=util.get_device_of(mix)).unsqueeze(1) | ||
return mix[range_vector, offsets] | ||
|
||
|
||
@TokenEmbedder.register("bert-pretrained") | ||
class PretrainedBertEmbedder(BertEmbedder): | ||
# pylint: disable=line-too-long | ||
""" | ||
Parameters | ||
---------- | ||
pretrained_model_name: ``str`` | ||
Either the name of the pretrained model to use (e.g. 'bert-base-uncased'), | ||
or the path to the .tar.gz file with the model weights. | ||
|
||
If the name is a key in the list of pretrained models at | ||
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L41 | ||
the corresponding path will be used; otherwise it will be interpreted as a path or URL. | ||
requires_grad : ``bool``, optional (default = False) | ||
If True, compute gradient of BERT parameters for fine tuning. | ||
top_layer_only: ``bool``, optional (default = ``False``) | ||
If ``True``, then only return the top layer instead of apply the scalar mix. | ||
""" | ||
def __init__(self, pretrained_model: str, requires_grad: bool = False, top_layer_only: bool = False) -> None: | ||
model = BertModel.from_pretrained(pretrained_model) | ||
|
||
for param in model.parameters(): | ||
param.requires_grad = requires_grad | ||
|
||
super().__init__(bert_model=model, top_layer_only=top_layer_only) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I didn't realize you could do this. Nice find.