In [None]:
import torch

from pytorch_pretrained_bert.modeling import BertConfig, BertModel

from allennlp.common.testing import ModelTestCase
from allennlp.data.dataset import Batch
from allennlp.data.fields import TextField, ListField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers.wordpiece_indexer import PretrainedBertIndexer
from allennlp.data.tokenizers import WordTokenizer
from allennlp.data.tokenizers.word_splitter import BertBasicWordSplitter
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.token_embedders.bert_token_embedder import BertEmbedder
from allennlp.common.testing.test_case import AllenNlpTestCase

In [None]:
vocab_path = AllenNlpTestCase.FIXTURES_ROOT / 'bert' / 'vocab.txt'
print(vocab_path)
token_indexer = PretrainedBertIndexer(str(vocab_path))

config_path = AllenNlpTestCase.FIXTURES_ROOT / 'bert' / 'config.json'
config = BertConfig(str(config_path))
bert_model = BertModel(config)
token_embedder = BertEmbedder(bert_model)

In [None]:
tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())

#            2   3    4   3     5     6   8      9    2   14   12
sentence1 = "the quickest quick brown fox jumped over the lazy dog"
tokens1 = tokenizer.tokenize(sentence1)

#            2   3     5     6   8      9    2  15 10 11 14   1
sentence2 = "the quick brown fox jumped over the laziest lazy elmo"
tokens2 = tokenizer.tokenize(sentence2)

vocab = Vocabulary()



instance1 = Instance({"tokens": TextField(tokens1, {"bert": token_indexer})})
instance2 = Instance({"tokens": TextField(tokens2, {"bert": token_indexer})})

batch = Batch([instance1, instance2])
batch.index_instances(vocab)

padding_lengths = batch.get_padding_lengths()
tensor_dict = batch.as_tensor_dict(padding_lengths)
tokens = tensor_dict["tokens"]

# 16 = [CLS], 17 = [SEP]
assert tokens["bert"].tolist() == [
        [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 14, 12, 17, 0],
        [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]
]

assert tokens["bert-offsets"].tolist() == [
        [1, 3, 4, 5, 6, 7, 8, 9, 10, 11],
        [1, 2, 3, 4, 5, 6, 7, 10, 11, 12]
]

# No offsets, should get 14 vectors back ([CLS] + 12 token wordpieces + [SEP])
bert_vectors = token_embedder(tokens["bert"])
assert list(bert_vectors.shape) == [2, 14, 12]

# Offsets, should get 10 vectors back.
bert_vectors = token_embedder(tokens["bert"], offsets=tokens["bert-offsets"])
assert list(bert_vectors.shape) == [2, 10, 12]

## Now try top_layer_only = True
tlo_embedder = BertEmbedder(bert_model, top_layer_only=True)
bert_vectors = tlo_embedder(tokens["bert"])
assert list(bert_vectors.shape) == [2, 14, 12]

bert_vectors = tlo_embedder(tokens["bert"], offsets=tokens["bert-offsets"])
assert list(bert_vectors.shape) == [2, 10, 12]