# Extract BERT

Notebook to experiment with libraries to extract parts of BERT embeddings

In [1]:
import transformers
import torch

I1103 01:00:27.026603 140463182931712 file_utils.py:39] PyTorch version 1.2.0 available.
I1103 01:00:28.313264 140463182931712 modeling_xlnet.py:194] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .


In [None]:
bert_model = transformers.BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
bert_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')

In [3]:
def convert_to_bert_input(sentences):
  def pad_to_length(tokens, desired_len):
    return tokens + (['[PAD]'] * (desired_len - len(tokens)))
  bert_tokens = [bert_tokenizer.tokenize(sentence) for sentence in sentences]
  max_len = max([len(tokens) for tokens in bert_tokens])
  padded_tokens = [pad_to_length(tokens, max_len) for tokens in bert_tokens]
  padded_ids = [bert_tokenizer.encode(tokens) for tokens in padded_tokens]
  attn_mask = [[1 if token != '[PAD]' else 0 for token in tokens] for tokens in padded_tokens]
  return padded_tokens, padded_ids, attn_mask

In [4]:
STR1 = "My cat is called Xiaonuanhuo and she is warm and fluffy"
STR2 = "this is a much shorter sentence"
padded_tokens, padded_ids, attn_mask = convert_to_bert_input([STR1, STR2])
print(padded_tokens)
print(padded_ids)
print(attn_mask)

[['my', 'cat', 'is', 'called', 'xiao', '##nu', '##an', '##hu', '##o', 'and', 'she', 'is', 'warm', 'and', 'fluffy'], ['this', 'is', 'a', 'much', 'shorter', 'sentence', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']]
[[2026, 4937, 2003, 2170, 19523, 11231, 2319, 6979, 2080, 1998, 2016, 2003, 4010, 1998, 27036], [2023, 2003, 1037, 2172, 7820, 6251, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]


In [5]:
bert_embeddings = bert_model(torch.tensor(padded_ids), attention_mask=torch.tensor(attn_mask))
bert_embeddings[0]

tensor([[[-0.4095,  0.0730,  0.0156,  ...,  0.0193,  0.1457,  0.6264],
         [-0.6220, -0.3098,  0.4594,  ..., -0.2029,  0.3284,  1.0646],
         [-1.0167, -0.3767,  0.4307,  ..., -0.0539,  0.4236,  0.9048],
         ...,
         [-0.7988,  0.2165,  0.7978,  ..., -0.0076,  0.0286,  0.4517],
         [-0.7674, -0.0174,  0.5640,  ...,  0.0461,  0.0452,  0.5110],
         [-0.7671,  0.1120,  0.8036,  ...,  0.1290,  0.1001,  0.5393]],

        [[-0.2225,  0.0793, -0.2679,  ...,  0.0747,  0.3448,  0.2930],
         [-0.3072, -0.0128, -0.4142,  ...,  0.2680,  0.3103,  0.5950],
         [-0.6995,  0.0327, -0.3212,  ...,  0.5684,  0.3032,  0.5805],
         ...,
         [-0.2036,  0.1585, -0.1651,  ...,  0.0816,  0.2754,  0.1891],
         [-0.2093,  0.1336, -0.1721,  ...,  0.0810,  0.2852,  0.1886],
         [-0.1874,  0.1357, -0.1805,  ...,  0.0776,  0.2798,  0.1795]]],
       grad_fn=<NativeLayerNormBackward>)

In [6]:
bert_embeddings[0].shape

torch.Size([2, 15, 768])

In [7]:
# Final hidden layer, one for each token
bert_embeddings[0].shape

torch.Size([2, 15, 768])

In [8]:
# Pooled layer
bert_embeddings[1].shape

torch.Size([2, 768])

In [9]:
len(bert_embeddings[2])

13

In [10]:
# Nth hidden layer of BERT
layer = 7
bert_embeddings[2][layer].shape

torch.Size([2, 15, 768])