Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

add BERT token embedder #2067

Merged
merged 19 commits into from
Nov 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ expected-line-ending-format=
[BASIC]

# Good variable names which should always be accepted, separated by a comma
good-names=i,j,k,ex,Run,_
good-names=i,j,k,ex,Run,_,logger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I didn't realize you could do this. Nice find.


# Bad variable names which should always be refused, separated by a comma
bad-names=foo,bar,baz,toto,tutu,tata
Expand Down
21 changes: 19 additions & 2 deletions allennlp/common/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,13 +351,30 @@ def _check_is_dict(self, new_history, value):
return value

@staticmethod
def from_file(params_file: str, params_overrides: str = "") -> 'Params':
def from_file(params_file: str, params_overrides: str = "", ext_vars: dict = None) -> 'Params':
"""
Load a `Params` object from a configuration file.

Parameters
----------
params_file : ``str``
The path to the configuration file to load.
params_overrides : ``str``, optional
A dict of overrides that can be applied to final object.
e.g. {"model.embedding_dim": 10}
ext_vars : ``dict``, optional
Our config files are Jsonnet, which allows specifying external variables
for later substitution. Typically we substitute these using environment
variables; however, you can also specify them here, in which case they
take priority over environment variables.
e.g. {"HOME_DIR": "/Users/allennlp/home"}
"""
if ext_vars is None:
ext_vars = {}

# redirect to cache, if necessary
params_file = cached_path(params_file)
ext_vars = dict(os.environ)
ext_vars = {**dict(os.environ), **ext_vars}
DeNeutoy marked this conversation as resolved.
Show resolved Hide resolved

file_dict = json.loads(evaluate_file(params_file, ext_vars=ext_vars))

Expand Down
1 change: 1 addition & 0 deletions allennlp/data/token_indexers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from allennlp.data.token_indexers.token_indexer import TokenIndexer
from allennlp.data.token_indexers.elmo_indexer import ELMoTokenCharactersIndexer
from allennlp.data.token_indexers.openai_transformer_byte_pair_indexer import OpenaiTransformerBytePairIndexer
from allennlp.data.token_indexers.wordpiece_indexer import WordpieceIndexer, PretrainedBertIndexer
174 changes: 174 additions & 0 deletions allennlp/data/token_indexers/wordpiece_indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# pylint: disable=no-self-use
from typing import Dict, List, Callable
import logging

from overrides import overrides

from pytorch_pretrained_bert.tokenization import BertTokenizer

from allennlp.common.util import pad_sequence_to_length
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.tokenizers.token import Token
from allennlp.data.token_indexers.token_indexer import TokenIndexer

logger = logging.getLogger(__name__)

# TODO(joelgrus): Figure out how to generate token_type_ids out of this token indexer.

class WordpieceIndexer(TokenIndexer[int]):
"""
A token indexer that does the wordpiece-tokenization (e.g. for BERT embeddings).
If you are using one of the pretrained BERT models, you'll want to use the ``PretrainedBertIndexer``
subclass rather than this base class.

Parameters
----------
vocab : ``Dict[str, int]``
The mapping {wordpiece -> id}. Note this is not an AllenNLP ``Vocabulary``.
wordpiece_tokenizer : ``Callable[[str], List[str]]``
A function that does the actual tokenization.
namespace : str, optional (default: "wordpiece")
The namespace in the AllenNLP ``Vocabulary`` into which the wordpieces
will be loaded.
use_starting_offsets : bool, optional (default: False)
By default, the "offsets" created by the token indexer correspond to the
last wordpiece in each word. If ``use_starting_offsets`` is specified,
they will instead correspond to the first wordpiece in each word.
max_pieces : int, optional (default: 512)
The BERT embedder uses positional embeddings and so has a corresponding
maximum length for its input ids. Currently any inputs longer than this
will be truncated. If this behavior is undesirable to you, you should
consider filtering them out in your dataset reader.
"""
def __init__(self,
vocab: Dict[str, int],
wordpiece_tokenizer: Callable[[str], List[str]],
namespace: str = "wordpiece",
use_starting_offsets: bool = False,
max_pieces: int = 512) -> None:
self.vocab = vocab

# The BERT code itself does a two-step tokenization:
# sentence -> [words], and then word -> [wordpieces]
# In AllenNLP, the first step is implemented as the ``BertSimpleWordSplitter``,
# and this token indexer handles the second.
self.wordpiece_tokenizer = wordpiece_tokenizer

self._namespace = namespace
self._added_to_vocabulary = False
self.max_pieces = max_pieces
self.use_starting_offsets = use_starting_offsets

@overrides
def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
# If we only use pretrained models, we don't need to do anything here.
pass

def _add_encoding_to_vocabulary(self, vocabulary: Vocabulary) -> None:
# pylint: disable=protected-access
for word, idx in self.vocab.items():
vocabulary._token_to_index[self._namespace][word] = idx
vocabulary._index_to_token[self._namespace][idx] = word

@overrides
def tokens_to_indices(self,
tokens: List[Token],
vocabulary: Vocabulary,
index_name: str) -> Dict[str, List[int]]:
if not self._added_to_vocabulary:
self._add_encoding_to_vocabulary(vocabulary)
self._added_to_vocabulary = True

text_tokens: List[int] = []
offsets = []
# For initial offsets, start at 0; otherwise, start at -1
offset = 0 if self.use_starting_offsets else -1

for token in tokens:
wordpieces = self.wordpiece_tokenizer(token.text)
wordpiece_ids = [self.vocab[token] for token in wordpieces]

# truncate and pray
if len(text_tokens) + len(wordpiece_ids) > self.max_pieces:
# TODO(joelgrus): figure out a better way to handle this
logger.warning(f"Too many wordpieces, truncating: {[token.text for token in tokens]}")
break

# For initial offsets, the current value of ``offset`` is the start of
# the current wordpiece, so add it to ``offsets`` and then increment it.
if self.use_starting_offsets:
offsets.append(offset)
offset += len(wordpiece_ids)
# For final offsets, the current value of ``offset`` is the end of
# the previous wordpiece, so increment it and then add it to ``offsets``.
else:
offset += len(wordpiece_ids)
offsets.append(offset)
text_tokens.extend(wordpiece_ids)

# add mask according to the original tokens,
# because calling util.get_text_field_mask on the
# "byte pair" tokens will produce the wrong shape
mask = [1 for _ in offsets]

return {
index_name: text_tokens,
f"{index_name}-offsets": offsets,
"mask": mask
}

@overrides
def get_padding_token(self) -> int:
return 0

@overrides
def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument
return {}

@overrides
def pad_token_sequence(self,
tokens: Dict[str, List[int]],
desired_num_tokens: Dict[str, int],
padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: # pylint: disable=unused-argument
return {key: pad_sequence_to_length(val, desired_num_tokens[key])
for key, val in tokens.items()}


@TokenIndexer.register("bert-pretrained")
class PretrainedBertIndexer(WordpieceIndexer):
# pylint: disable=line-too-long
"""
A ``TokenIndexer`` corresponding to a pretrained BERT model.

Parameters
----------
pretrained_model: ``str``, optional (default = None)
Either the name of the pretrained model to use (e.g. 'bert-base-uncased'),
or the path to the .txt file with its vocabulary.

If the name is a key in the list of pretrained models at
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py#L33
the corresponding path will be used; otherwise it will be interpreted as a path or URL.
use_starting_offsets: bool, optional (default: False)
By default, the "offsets" created by the token indexer correspond to the
last wordpiece in each word. If ``use_starting_offsets`` is specified,
they will instead correspond to the first wordpiece in each word.
do_lowercase: ``bool``, optional (default = True)
Whether to lowercase the tokens before converting to wordpiece ids.
max_pieces: int, optional (default: 512)
The BERT embedder uses positional embeddings and so has a corresponding
maximum length for its input ids. Currently any inputs longer than this
will be truncated. If this behavior is undesirable to you, you should
consider filtering them out in your dataset reader.
"""
def __init__(self,
pretrained_model: str,
use_starting_offsets: bool = False,
do_lowercase: bool = True,
max_pieces: int = 512) -> None:
bert_tokenizer = BertTokenizer.from_pretrained(pretrained_model, do_lowercase)
super().__init__(vocab=bert_tokenizer.vocab,
wordpiece_tokenizer=bert_tokenizer.wordpiece_tokenizer.tokenize,
namespace="bert",
use_starting_offsets=use_starting_offsets,
max_pieces=max_pieces)
17 changes: 17 additions & 0 deletions allennlp/data/tokenizers/word_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import spacy
import ftfy

from pytorch_pretrained_bert.tokenization import BasicTokenizer as BertTokenizer

from allennlp.common import Registrable
from allennlp.common.util import get_spacy_model
from allennlp.data.tokenizers.token import Token
Expand Down Expand Up @@ -178,3 +180,18 @@ def batch_split_words(self, sentences: List[str]) -> List[List[Token]]:
def split_words(self, sentence: str) -> List[Token]:
# This works because our Token class matches spacy's.
return _remove_spaces(self.spacy(self._standardize(sentence)))


@WordSplitter.register("bert-basic")
class BertBasicWordSplitter(WordSplitter):
"""
The ``BasicWordSplitter`` from the BERT implementation.
This is used to split a sentence into words.
Then the ``BertTokenIndexer`` converts each word into wordpieces.
"""
def __init__(self, do_lower_case: bool = True) -> None:
self.basic_tokenizer = BertTokenizer(do_lower_case)

@overrides
def split_words(self, sentence: str) -> List[Token]:
return [Token(text) for text in self.basic_tokenizer.tokenize(sentence)]
1 change: 1 addition & 0 deletions allennlp/modules/token_embedders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from allennlp.modules.token_embedders.token_characters_encoder import TokenCharactersEncoder
from allennlp.modules.token_embedders.elmo_token_embedder import ElmoTokenEmbedder
from allennlp.modules.token_embedders.openai_transformer_embedder import OpenaiTransformerEmbedder
from allennlp.modules.token_embedders.bert_token_embedder import BertEmbedder, PretrainedBertEmbedder
127 changes: 127 additions & 0 deletions allennlp/modules/token_embedders/bert_token_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
A ``TokenEmbedder`` which uses one of the BERT models
(https://github.com/google-research/bert)
to produce embeddings.

At its core it uses Hugging Face's PyTorch implementation
(https://github.com/huggingface/pytorch-pretrained-BERT),
so thanks to them!
"""
import logging

import torch

from pytorch_pretrained_bert.modeling import BertModel

from allennlp.modules.scalar_mix import ScalarMix
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
from allennlp.nn import util

logger = logging.getLogger(__name__)


class BertEmbedder(TokenEmbedder):
"""
A ``TokenEmbedder`` that produces BERT embeddings for your tokens.
Should be paired with a ``BertIndexer``, which produces wordpiece ids.

Most likely you probably want to use ``PretrainedBertEmbedder``
for one of the named pretrained models, not this base class.

Parameters
----------
bert_model: ``BertModel``
The BERT model being wrapped.
top_layer_only: ``bool``, optional (default = ``False``)
If ``True``, then only return the top layer instead of apply the scalar mix.
"""
def __init__(self, bert_model: BertModel, top_layer_only: bool = False) -> None:
super().__init__()
self.bert_model = bert_model
self.output_dim = bert_model.config.hidden_size
if not top_layer_only:
self._scalar_mix = ScalarMix(bert_model.config.num_hidden_layers,
do_layer_norm=False)
else:
self._scalar_mix = None

def get_output_dim(self) -> int:
return self.output_dim

def forward(self,
input_ids: torch.LongTensor,
offsets: torch.LongTensor = None,
token_type_ids: torch.LongTensor = None) -> torch.Tensor:
"""
Parameters
----------
input_ids : ``torch.LongTensor``
The (batch_size, max_sequence_length) tensor of wordpiece ids.
offsets : ``torch.LongTensor``, optional
The BERT embeddings are one per wordpiece. However it's possible/likely
you might want one per original token. In that case, ``offsets``
represents the indices of the desired wordpiece for each original token.
Depending on how your token indexer is configured, this could be the
position of the last wordpiece for each token, or it could be the position
of the first wordpiece for each token.

For example, if you had the sentence "Definitely not", and if the corresponding
wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
If offsets are provided, the returned tensor will contain only the wordpiece
embeddings at those positions, and (in particular) will contain one embedding
per token. If offsets are not provided, the entire tensor of wordpiece embeddings
will be returned.
token_type_ids : ``torch.LongTensor``, optional
If an input consists of two sentences (as in the BERT paper),
tokens from the first sentence should have type 0 and tokens from
the second sentence should have type 1. If you don't provide this
(the default BertIndexer doesn't) then it's assumed to be all 0s.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to modify the indexer to provide this? It's fine to do it in another PR, but I'd at least open an issue to track adding that. Seems pretty important if you want to use this for SQuAD.

"""
# pylint: disable=arguments-differ
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)

input_mask = (input_ids != 0).long()

all_encoder_layers, _ = self.bert_model(input_ids, input_mask, token_type_ids)
if self._scalar_mix is not None:
mix = self._scalar_mix(all_encoder_layers, input_mask)
else:
mix = all_encoder_layers[-1]


if offsets is None:
return mix
else:
batch_size = input_ids.size(0)
range_vector = util.get_range_vector(batch_size,
device=util.get_device_of(mix)).unsqueeze(1)
return mix[range_vector, offsets]


@TokenEmbedder.register("bert-pretrained")
class PretrainedBertEmbedder(BertEmbedder):
# pylint: disable=line-too-long
"""
Parameters
----------
pretrained_model_name: ``str``
Either the name of the pretrained model to use (e.g. 'bert-base-uncased'),
or the path to the .tar.gz file with the model weights.

If the name is a key in the list of pretrained models at
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L41
the corresponding path will be used; otherwise it will be interpreted as a path or URL.
requires_grad : ``bool``, optional (default = False)
If True, compute gradient of BERT parameters for fine tuning.
top_layer_only: ``bool``, optional (default = ``False``)
If ``True``, then only return the top layer instead of apply the scalar mix.
"""
def __init__(self, pretrained_model: str, requires_grad: bool = False, top_layer_only: bool = False) -> None:
model = BertModel.from_pretrained(pretrained_model)

for param in model.parameters():
param.requires_grad = requires_grad

super().__init__(bert_model=model, top_layer_only=top_layer_only)
Loading