In [3]:
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
#logging.basicConfig(level=logging.INFO)

import matplotlib.pyplot as plt
% matplotlib inline

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


100%|██████████| 231508/231508 [00:00<00:00, 258612.08B/s]


In [4]:
text = "After stealing money from the bank vault, the bank robber was seen fishing on the Mississippi river bank."
marked_text = "[CLS] " + text + " [SEP]"

In [5]:
tokenized_text = tokenizer.tokenize(marked_text)
print (tokenized_text)

['[CLS]', 'after', 'stealing', 'money', 'from', 'the', 'bank', 'vault', ',', 'the', 'bank', 'robber', 'was', 'seen', 'fishing', 'on', 'the', 'mississippi', 'river', 'bank', '.', '[SEP]']


In [6]:
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [1] * len(tokenized_text)

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

# Load pre-trained model (weights)
model = BertModel.from_pretrained('bert-base-uncased')

# Put the model in "evaluation" mode, meaning feed-forward operation.
model.eval()

100%|██████████| 407873900/407873900 [05:57<00:00, 1139537.59B/s]


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=

In [7]:
# Predict hidden states features for each layer
with torch.no_grad():
    encoded_layers, _ = model(tokens_tensor, segments_tensors)

In [12]:
encoded_layers[2].shape

torch.Size([1, 22, 768])

In [13]:
indexed_tokens

[101,
 2044,
 11065,
 2769,
 2013,
 1996,
 2924,
 11632,
 1010,
 1996,
 2924,
 27307,
 2001,
 2464,
 5645,
 2006,
 1996,
 5900,
 2314,
 2924,
 1012,
 102]

In [14]:
print ("Number of layers:", len(encoded_layers))
layer_i = 0

print ("Number of batches:", len(encoded_layers[layer_i]))
batch_i = 0

print ("Number of tokens:", len(encoded_layers[layer_i][batch_i]))
token_i = 0

print ("Number of hidden units:", len(encoded_layers[layer_i][batch_i][token_i]))


Number of layers: 12
Number of batches: 1
Number of tokens: 22
Number of hidden units: 768


In [15]:
# Convert the hidden state embeddings into single token vectors

# Holds the list of 12 layer embeddings for each token
# Will have the shape: [# tokens, # layers, # features]
token_embeddings = [] 

# For each token in the sentence...
for token_i in range(len(tokenized_text)):
  
  # Holds 12 layers of hidden states for each token 
  hidden_layers = [] 
  
  # For each of the 12 layers...
  for layer_i in range(len(encoded_layers)):
    
    # Lookup the vector for `token_i` in `layer_i`
    vec = encoded_layers[layer_i][batch_i][token_i]
    
    hidden_layers.append(vec)
    
  token_embeddings.append(hidden_layers)

# Sanity check the dimensions:
print ("Number of tokens in sequence:", len(token_embeddings))
print ("Number of layers per token:", len(token_embeddings[0]))

Number of tokens in sequence: 22
Number of layers per token: 12


In [18]:
token_embeddings[0]

[tensor([ 5.6950e-02,  5.0759e-02, -2.1456e-01, -2.8740e-01,  2.5529e-01,
          1.0143e-01, -3.6744e-01,  4.6512e-02, -5.1119e-02, -4.3769e-01,
         -1.9877e-03, -8.1580e-02,  1.8228e-01, -2.3386e-02, -9.4420e-02,
          4.5619e-01,  5.3530e-01,  3.9253e-01, -1.5481e-02, -7.4215e-01,
         -6.3978e-01, -3.7518e-01, -7.7995e-02, -4.2731e-01, -1.4132e-01,
         -4.5248e-01, -3.8090e-01,  3.4314e-02,  5.4960e-02,  2.1373e-02,
         -1.8991e-01, -1.7218e-01,  3.8683e-01,  1.6420e-01, -4.8821e-01,
          8.2416e-02, -4.9245e-01,  1.6358e-01, -3.2847e-01,  2.8821e-01,
          3.9869e-01, -2.7791e-01,  5.0795e-02,  3.5417e-01,  1.0370e-01,
          1.2390e-01, -1.5609e+00,  1.6210e-02,  6.1083e-01, -1.4186e-01,
          2.6334e-01,  1.7971e-01, -3.1492e-01,  9.8814e-01,  1.1508e-01,
         -1.6590e-01, -2.4925e-01,  1.1797e-01,  4.2549e-01, -1.1502e-01,
          4.2918e-02, -1.6814e-01,  4.0785e-01,  2.1148e-01, -3.2390e-01,
          3.9540e-01,  7.5330e-02, -2.