Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Pretrained transformer indexer (#3146)
Browse files Browse the repository at this point in the history
* wip

* Added an indexer for pretrained transformers

* doc

* pylint, mypy

* Missing import...
  • Loading branch information
matt-gardner committed Aug 13, 2019
1 parent fa1ff67 commit f111d8a
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 1 deletion.
1 change: 1 addition & 0 deletions allennlp/data/token_indexers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from allennlp.data.token_indexers.openai_transformer_byte_pair_indexer import OpenaiTransformerBytePairIndexer
from allennlp.data.token_indexers.wordpiece_indexer import WordpieceIndexer, PretrainedBertIndexer
from allennlp.data.token_indexers.spacy_indexer import SpacyTokenIndexer
from allennlp.data.token_indexers.pretrained_transformer_indexer import PretrainedTransformerIndexer
89 changes: 89 additions & 0 deletions allennlp/data/token_indexers/pretrained_transformer_indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Dict, List
import logging

from overrides import overrides
from pytorch_transformers.tokenization_auto import AutoTokenizer
import torch

from allennlp.common.util import pad_sequence_to_length
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.tokenizers.token import Token
from allennlp.data.token_indexers.token_indexer import TokenIndexer


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@TokenIndexer.register("pretrained_transformer")
class PretrainedTransformerIndexer(TokenIndexer[int]):
"""
This :class:`TokenIndexer` uses a tokenizer from the ``pytorch_transformers`` repository to
index tokens. This ``Indexer`` is only really appropriate to use if you've also used a
corresponding :class:`PretrainedTransformerTokenizer` to tokenize your input. Otherwise you'll
have a mismatch between your tokens and your vocabulary, and you'll get a lot of UNK tokens.
Parameters
----------
model_name : ``str``
The name of the ``pytorch_transformers`` model to use.
do_lowercase : ``str``
Whether to lowercase the tokens (this should match the casing of the model name that you
pass)
namespace : ``str``, optional (default=``tags``)
We will add the tokens in the pytorch_transformer vocabulary to this vocabulary namespace.
We use a somewhat confusing default value of ``tags`` so that we do not add padding or UNK
tokens to this namespace, which would break on loading because we wouldn't find our default
OOV token.
"""
# pylint: disable=no-self-use
def __init__(self,
model_name: str,
do_lowercase: bool,
namespace: str = "tags",
token_min_padding_length: int = 0) -> None:
super().__init__(token_min_padding_length)
if model_name.endswith("-cased") and do_lowercase:
logger.warning("Your pretrained model appears to be cased, "
"but your indexer is lowercasing tokens.")
elif model_name.endswith("-uncased") and not do_lowercase:
logger.warning("Your pretrained model appears to be uncased, "
"but your indexer is not lowercasing tokens.")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=do_lowercase)
self._namespace = namespace
self._added_to_vocabulary = False

@overrides
def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
# If we only use pretrained models, we don't need to do anything here.
pass

def _add_encoding_to_vocabulary(self, vocabulary: Vocabulary) -> None:
# pylint: disable=protected-access
for word, idx in self.tokenizer.vocab.items():
vocabulary._token_to_index[self._namespace][word] = idx
vocabulary._index_to_token[self._namespace][idx] = word

@overrides
def tokens_to_indices(self,
tokens: List[Token],
vocabulary: Vocabulary,
index_name: str) -> Dict[str, List[int]]:
if not self._added_to_vocabulary:
self._add_encoding_to_vocabulary(vocabulary)
self._added_to_vocabulary = True
token_text = [token.text for token in tokens]
indices = self.tokenizer.convert_tokens_to_ids(token_text)

return {index_name: indices}

@overrides
def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument
return {}

@overrides
def as_padded_tensor(self,
tokens: Dict[str, List[int]],
desired_num_tokens: Dict[str, int],
padding_lengths: Dict[str, int]) -> Dict[str, torch.Tensor]: # pylint: disable=unused-argument
return {key: torch.LongTensor(pad_sequence_to_length(val, desired_num_tokens[key]))
for key, val in tokens.items()}
13 changes: 12 additions & 1 deletion allennlp/data/token_indexers/wordpiece_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def as_padded_tensor(self,
return {key: torch.LongTensor(pad_sequence_to_length(val, desired_num_tokens[key]))
for key, val in tokens.items()}


@overrides
def get_keys(self, index_name: str) -> List[str]:
"""
Expand Down Expand Up @@ -355,6 +354,18 @@ def __init__(self,
separator_token="[SEP]",
truncate_long_sequences=truncate_long_sequences)

def __eq__(self, other):
if isinstance(other, PretrainedBertIndexer):
for key in self.__dict__:
if key == 'wordpiece_tokenizer':
# This is a reference to a function in the huggingface code, which we can't
# really modify to make this clean. So we special-case it.
continue
if self.__dict__[key] != other.__dict__[key]:
return False
return True
return NotImplemented


def _get_token_type_ids(wordpiece_ids: List[int],
separator_ids: List[int]) -> List[int]:
Expand Down
6 changes: 6 additions & 0 deletions allennlp/tests/data/token_indexers/bert_indexer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def test_starting_ending_offsets(self):
assert indexed_tokens["bert"] == [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]
assert indexed_tokens["bert-offsets"] == [1, 2, 3, 4, 5, 6, 7, 8, 11, 12]

def test_eq(self):
vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
indexer1 = PretrainedBertIndexer(str(vocab_path))
indexer2 = PretrainedBertIndexer(str(vocab_path))
assert indexer1 == indexer2

def test_do_lowercase(self):
# Our default tokenizer doesn't handle lowercasing.
tokenizer = WordTokenizer()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# pylint: disable=no-self-use,invalid-name
from pytorch_transformers.tokenization_auto import AutoTokenizer

from allennlp.common.testing import AllenNlpTestCase
from allennlp.data import Token, Vocabulary
from allennlp.data.token_indexers import PretrainedTransformerIndexer


class TestPretrainedTransformerIndexer(AllenNlpTestCase):
def test_as_array_produces_token_sequence(self):
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', do_lowercase=True)
indexer = PretrainedTransformerIndexer(model_name='bert-base-uncased', do_lowercase=True)
tokens = tokenizer.tokenize('AllenNLP is great')
expected_ids = tokenizer.convert_tokens_to_ids(tokens)
allennlp_tokens = [Token(token) for token in tokens]
vocab = Vocabulary()
indexed = indexer.tokens_to_indices(allennlp_tokens, vocab, 'key')
assert indexed['key'] == expected_ids
7 changes: 7 additions & 0 deletions doc/api/allennlp.data.token_indexers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ allennlp.data.token_indexers
* :ref:`ELMoTokenCharactersIndexer<elmo-indexer>`
* :ref:`OpenaiTransformerBytePairIndexer<openai-transformer-byte-pair-indexer>`
* :ref:`WordpieceIndexer<wordpiece-indexer>`
* :ref:`PretrainedTransformerIndexer<pretrained-transformer-indexer>`
* :ref:`SpacyTokenIndexer<spacy-token-indexer>`


Expand Down Expand Up @@ -72,6 +73,12 @@ allennlp.data.token_indexers
:undoc-members:
:show-inheritance:

.. _pretrained-transformer-indexer:
.. automodule:: allennlp.data.token_indexers.pretrained_transformer_indexer
:members:
:undoc-members:
:show-inheritance:

.. _spacy-token-indexer:
.. automodule:: allennlp.data.token_indexers.spacy_indexer
:members:
Expand Down

0 comments on commit f111d8a

Please sign in to comment.