In [None]:
# https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/

In [3]:
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
#logging.basicConfig(level=logging.INFO)

import matplotlib.pyplot as plt
% matplotlib inline

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


In [4]:
# text = "Here is the sentence I want embeddings for."
text = "After stealing money from the bank vault, the bank robber was seen fishing on the Mississippi river bank."
marked_text = "[CLS] " + text + " [SEP]"

print (marked_text)

[CLS] After stealing money from the bank vault, the bank robber was seen fishing on the Mississippi river bank. [SEP]


In [5]:
tokenized_text = tokenizer.tokenize(marked_text)
print (tokenized_text)

['[CLS]', 'after', 'stealing', 'money', 'from', 'the', 'bank', 'vault', ',', 'the', 'bank', 'robber', 'was', 'seen', 'fishing', 'on', 'the', 'mississippi', 'river', 'bank', '.', '[SEP]']


In [6]:
list(tokenizer.vocab.keys())[5000:5020]


['knight',
 'lap',
 'survey',
 'ma',
 '##ow',
 'noise',
 'billy',
 '##ium',
 'shooting',
 'guide',
 'bedroom',
 'priest',
 'resistance',
 'motor',
 'homes',
 'sounded',
 'giant',
 '##mer',
 '150',
 'scenes']

In [7]:
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

for tup in zip(tokenized_text, indexed_tokens):
  print (tup)


('[CLS]', 101)
('after', 2044)
('stealing', 11065)
('money', 2769)
('from', 2013)
('the', 1996)
('bank', 2924)
('vault', 11632)
(',', 1010)
('the', 1996)
('bank', 2924)
('robber', 27307)
('was', 2001)
('seen', 2464)
('fishing', 5645)
('on', 2006)
('the', 1996)
('mississippi', 5900)
('river', 2314)
('bank', 2924)
('.', 1012)
('[SEP]', 102)


In [8]:
segments_ids = [1] * len(tokenized_text)
print (segments_ids)

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [9]:
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

# Load pre-trained model (weights)
model = BertModel.from_pretrained('bert-base-uncased')

# Put the model in "evaluation" mode, meaning feed-forward operation.
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Lin

In [10]:
# Predict hidden states features for each layer
with torch.no_grad():
    encoded_layers, _ = model(tokens_tensor, segments_tensors)

In [11]:
print ("Number of layers:", len(encoded_layers))
layer_i = 0

print ("Number of batches:", len(encoded_layers[layer_i]))
batch_i = 0

print ("Number of tokens:", len(encoded_layers[layer_i][batch_i]))
token_i = 0

print ("Number of hidden units:", len(encoded_layers[layer_i][batch_i][token_i]))

Number of layers: 12
Number of batches: 1
Number of tokens: 22
Number of hidden units: 768


In [12]:
# Convert the hidden state embeddings into single token vectors

# Holds the list of 12 layer embeddings for each token
# Will have the shape: [# tokens, # layers, # features]
token_embeddings = [] 

# For each token in the sentence...
for token_i in range(len(tokenized_text)):
  
  # Holds 12 layers of hidden states for each token 
  hidden_layers = [] 
  
  # For each of the 12 layers...
  for layer_i in range(len(encoded_layers)):
    
    # Lookup the vector for `token_i` in `layer_i`
    vec = encoded_layers[layer_i][batch_i][token_i]
    
    hidden_layers.append(vec)
    
  token_embeddings.append(hidden_layers)

# Sanity check the dimensions:
print ("Number of tokens in sequence:", len(token_embeddings))
print ("Number of layers per token:", len(token_embeddings[0]))

Number of tokens in sequence: 22
Number of layers per token: 12


In [13]:
concatenated_last_4_layers = [torch.cat((layer[-1], layer[-2], layer[-3], layer[-4]), 0) for layer in token_embeddings] # [number_of_tokens, 3072]

summed_last_4_layers = [torch.sum(torch.stack(layer)[-4:], 0) for layer in token_embeddings] # [number_of_tokens, 768]

In [14]:
sentence_embedding = torch.mean(encoded_layers[11], 1)


In [15]:
print ("Our final sentence embedding vector of shape:"), sentence_embedding[0].shape[0]


Our final sentence embedding vector of shape:


(None, 768)

In [16]:
sentence_embedding

tensor([[ 3.2873e-02, -2.3461e-01, -7.9924e-02,  3.8905e-01,  8.8793e-01,
          2.1375e-01, -7.8063e-03,  6.2694e-01, -3.2631e-02, -3.4703e-01,
          1.2331e-01, -9.4771e-02, -7.4421e-02,  4.5523e-01, -4.7217e-01,
          1.0344e-01,  3.4667e-01,  1.0410e-01,  5.4372e-01,  6.9113e-02,
         -8.2880e-02,  6.7866e-02,  1.2871e-01,  2.3312e-01,  4.2933e-01,
         -1.2449e-02, -2.1360e-01,  2.2697e-01, -1.2763e-01,  2.8080e-01,
          5.4766e-01, -1.0080e-01,  7.5650e-02, -2.7152e-01, -1.4124e-01,
         -4.0931e-01, -1.9237e-01, -3.9813e-02, -2.3175e-01,  3.3040e-01,
         -3.8484e-01, -3.7469e-01, -2.4978e-01,  3.2707e-01, -9.0565e-04,
         -4.4285e-01,  7.6833e-02, -4.4617e-02,  1.9807e-02,  7.4679e-02,
         -3.2986e-01,  8.2114e-01, -7.6256e-01, -4.1613e-04,  8.9090e-02,
          5.3472e-01, -3.8244e-01, -6.0184e-01, -8.3941e-02, -9.6010e-02,
          3.7689e-01, -2.5287e-01,  4.6429e-01, -5.4761e-01, -2.1230e-02,
          2.0828e-02,  4.8191e-01,  2.

In [17]:
print (text)

After stealing money from the bank vault, the bank robber was seen fishing on the Mississippi river bank.


In [18]:
for i,x in enumerate(tokenized_text):
  print (i,x)

0 [CLS]
1 after
2 stealing
3 money
4 from
5 the
6 bank
7 vault
8 ,
9 the
10 bank
11 robber
12 was
13 seen
14 fishing
15 on
16 the
17 mississippi
18 river
19 bank
20 .
21 [SEP]


In [19]:
print ("First fifteen values of 'bank' as in 'bank robber':")
summed_last_4_layers[10][:15]

First fifteen values of 'bank' as in 'bank robber':


tensor([ 1.1868, -1.5298, -1.3770,  1.0648,  3.1446,  1.4003, -4.2407,  1.3946,
        -0.1170, -1.8777,  0.1091, -0.3862,  0.6744,  2.1924, -4.5306])

In [20]:
print ("First fifteen values of 'bank' as in 'bank vault':")
summed_last_4_layers[6][:15]

First fifteen values of 'bank' as in 'bank vault':


tensor([ 2.1319, -2.1413, -1.6260,  0.8638,  3.3173,  0.1796, -4.4853,  3.1215,
        -0.9740, -3.1780,  0.1045, -1.5481,  0.4758,  1.1703, -4.4859])

In [21]:
print ("First fifteen values of 'bank' as in 'river bank':")
summed_last_4_layers[19][:15]

First fifteen values of 'bank' as in 'river bank':


tensor([ 1.1295, -1.4725, -0.7296, -0.0901,  2.4970,  0.5330,  0.9742,  5.1834,
        -1.0692, -1.5941,  1.9261,  0.7119, -0.9809,  1.2127, -2.9812])

In [22]:
from sklearn.metrics.pairwise import cosine_similarity

# Compare "bank" as in "bank robber" to "bank" as in "river bank"
different_bank = cosine_similarity(summed_last_4_layers[10].reshape(1,-1), summed_last_4_layers[19].reshape(1,-1))[0][0]

# Compare "bank" as in "bank robber" to "bank" as in "bank vault" 
same_bank = cosine_similarity(summed_last_4_layers[10].reshape(1,-1), summed_last_4_layers[6].reshape(1,-1))[0][0]

In [23]:
import spacy
nlp = spacy.load("en_core_web_sm")
doc = nlp(text)

In [24]:
doc.vector

array([-0.20136365,  0.8558488 , -0.90959704, -2.0022335 ,  0.05285798,
       -0.09729461, -0.12291355, -0.08967322, -0.5559663 ,  0.55344105,
       -0.09683708,  1.7974255 , -0.23645914,  0.9595979 , -0.355906  ,
       -0.47161824,  0.81053317,  0.16329156, -1.3595612 ,  0.6663488 ,
        0.54851407, -0.17415449,  0.2818833 ,  2.2935798 ,  1.4775044 ,
       -0.876686  , -0.4828207 ,  1.2497665 , -1.250145  ,  0.8328441 ,
       -1.0026721 , -1.267679  ,  0.5168272 , -0.04681165,  2.0672045 ,
       -1.1562798 ,  0.9741308 ,  1.2332971 , -0.72004116,  0.16166396,
        0.11475471, -0.05980494, -0.5617358 , -0.12064459, -0.06153221,
        0.7364133 ,  1.625066  , -0.06889956, -0.8148624 , -1.1105716 ,
       -0.28389627, -0.6885798 , -1.0539786 , -1.0883272 , -0.5819933 ,
        0.28235465, -1.4406666 ,  1.4301897 ,  2.377942  , -1.6349208 ,
        0.16169724,  0.25153303,  0.7924581 ,  1.1010284 , -0.3240807 ,
       -0.22770417,  0.04354021, -0.37510276, -0.10873371, -0.86

In [46]:
doc[9].vector

array([-1.3527856 ,  1.5364034 , -0.34534848, -2.6574855 ,  0.17459357,
       -0.07838053,  0.16398804,  2.949361  ,  0.7446209 , -2.8426688 ,
        1.01109   , -1.4075065 , -0.80491817,  3.8867738 , -2.4497879 ,
       -0.676257  ,  0.1795485 , -3.1324651 , -2.3045454 , -1.9656184 ,
        1.2784417 ,  1.3653405 ,  0.4712348 ,  3.6835637 ,  1.4607272 ,
       -0.26173374,  2.2768283 ,  3.3023174 , -2.5614347 , -1.5722078 ,
       -3.7269034 , -1.5898702 ,  0.1103521 , -1.7075908 ,  3.2840946 ,
       -4.006767  ,  0.45580792,  1.9496741 ,  0.8073076 , -2.0288486 ,
        0.75987005, -0.54913837, -0.28429893,  0.3500445 ,  3.6737561 ,
        1.6106148 ,  4.5643983 ,  0.42886662, -2.7789893 , -2.2541978 ,
       -0.08394778, -0.7164993 , -0.7259246 ,  0.19097483, -0.4519309 ,
        1.3490716 , -3.535676  ,  4.051957  ,  4.832258  , -3.3579922 ,
        0.2230306 , -1.102554  ,  0.10533589,  3.5841658 ,  0.09987053,
        0.70994174, -1.8589908 , -0.17476669,  2.9334877 , -1.31

In [47]:
doc[18].vector

array([ 0.7249882 , -0.24712256,  1.3216904 , -6.458606  ,  0.5020842 ,
        0.63012844, -0.27721125,  3.7578526 ,  0.6039228 , -1.9971237 ,
        0.39006793, -0.94834006, -0.44606525, -0.23227745,  0.49765533,
       -0.37558216,  0.676269  , -0.879135  , -1.7579178 ,  0.1229308 ,
        2.5017128 ,  0.5192864 ,  3.0233808 ,  5.643265  ,  3.9918735 ,
       -3.912091  ,  0.69894326,  1.448861  , -1.7716358 , -0.5042784 ,
       -2.6361802 , -1.0682964 , -4.7989597 , -0.11987513,  4.414299  ,
       -4.0782657 ,  0.7600807 ,  2.4989939 , -2.6806858 , -0.35546955,
        1.7425576 ,  2.3621912 , -0.7879547 , -0.25962973,  2.8525693 ,
       -0.9905433 ,  2.0743265 ,  2.7730978 , -0.49942034, -0.6372614 ,
        0.32685396, -3.493502  , -2.041121  , -2.8379018 , -1.8471127 ,
       -1.547224  , -2.6406364 ,  1.9618655 ,  4.1340623 , -2.9840367 ,
        1.1072215 ,  2.7083545 ,  0.5886741 ,  3.8353302 ,  1.0379041 ,
        2.2892525 , -0.6877442 , -0.7302381 ,  1.23956   , -0.23

In [48]:
doc[5].vector

array([-3.0978775e-01,  1.5162129e+00,  8.6049581e-01, -1.5532432e+00,
       -6.8119526e-01,  9.8694760e-01,  9.3608916e-01,  1.9538394e+00,
       -1.4366362e+00, -1.9292169e+00,  5.7436848e-01,  2.7897605e-01,
       -2.2740227e-01,  5.2807245e+00, -2.5897789e+00,  5.1210976e-01,
       -4.6914852e-01, -3.4868803e+00, -1.8712564e+00, -4.0094495e-02,
        2.2476783e+00,  1.5026965e+00,  8.0708194e-01,  3.0180011e+00,
        1.8983470e+00, -5.1863074e-01,  5.0027832e-02,  4.0467367e+00,
       -2.6064267e+00, -5.2322793e-01, -2.2480693e+00, -2.4013357e+00,
        1.2991195e+00, -1.8650111e+00,  2.0911574e+00, -3.5860128e+00,
        1.2236983e+00,  8.3517373e-01, -1.2338824e+00, -1.6384890e+00,
        9.1354299e-01,  1.3220488e+00, -2.4393263e-01,  4.7999501e-02,
        2.5791550e+00,  1.5372406e+00,  3.7844346e+00,  1.0160033e+00,
       -1.1462696e+00, -2.4606314e+00,  3.0954963e-01, -8.7353891e-01,
        5.9076715e-01, -1.5679519e+00, -1.4517182e+00,  3.6638421e-01,
      

In [50]:

# Compare "bank" as in "bank robber" to "bank" as in "river bank"
different_bank = cosine_similarity(doc[9].vector.reshape(-1, 1), doc[18].vector.reshape(-1, 1))[0][0]

# Compare "bank" as in "bank robber" to "bank" as in "bank vault" 
same_bank = cosine_similarity(doc[9].vector.reshape(-1, 1), doc[5].vector.reshape(-1, 1))[0][0]

In [58]:
cosine_similarity(doc[9].vector.reshape(-1, 1), doc[18].vector.reshape(-1, 1))

array([[-1.,  1., -1., ...,  1.,  1.,  1.],
       [ 1., -1.,  1., ..., -1., -1., -1.],
       [-1.,  1., -1., ...,  1.,  1.,  1.],
       ...,
       [ 1., -1.,  1., ..., -1., -1., -1.],
       [-1.,  1., -1., ...,  1.,  1.,  1.],
       [-1.,  1., -1., ...,  1.,  1.,  1.]], dtype=float32)

In [59]:
 cosine_similarity(doc[9].vector.reshape(-1, 1), doc[5].vector.reshape(-1, 1))

array([[ 1., -1., -1., ..., -1.,  1., -1.],
       [-1.,  1.,  1., ...,  1., -1.,  1.],
       [ 1., -1., -1., ..., -1.,  1., -1.],
       ...,
       [-1.,  1.,  1., ...,  1., -1.,  1.],
       [ 1., -1., -1., ..., -1.,  1., -1.],
       [ 1., -1., -1., ..., -1.,  1., -1.]], dtype=float32)

In [57]:
doc[18]

bank

In [53]:
doc[9].vector.reshape(-1, 1)

array([[-1.3527856 ],
       [ 1.5364034 ],
       [-0.34534848],
       [-2.6574855 ],
       [ 0.17459357],
       [-0.07838053],
       [ 0.16398804],
       [ 2.949361  ],
       [ 0.7446209 ],
       [-2.8426688 ],
       [ 1.01109   ],
       [-1.4075065 ],
       [-0.80491817],
       [ 3.8867738 ],
       [-2.4497879 ],
       [-0.676257  ],
       [ 0.1795485 ],
       [-3.1324651 ],
       [-2.3045454 ],
       [-1.9656184 ],
       [ 1.2784417 ],
       [ 1.3653405 ],
       [ 0.4712348 ],
       [ 3.6835637 ],
       [ 1.4607272 ],
       [-0.26173374],
       [ 2.2768283 ],
       [ 3.3023174 ],
       [-2.5614347 ],
       [-1.5722078 ],
       [-3.7269034 ],
       [-1.5898702 ],
       [ 0.1103521 ],
       [-1.7075908 ],
       [ 3.2840946 ],
       [-4.006767  ],
       [ 0.45580792],
       [ 1.9496741 ],
       [ 0.8073076 ],
       [-2.0288486 ],
       [ 0.75987005],
       [-0.54913837],
       [-0.28429893],
       [ 0.3500445 ],
       [ 3.6737561 ],
       [ 1

In [54]:
doc[18].vector.reshape(-1, 1)

array([[ 0.7249882 ],
       [-0.24712256],
       [ 1.3216904 ],
       [-6.458606  ],
       [ 0.5020842 ],
       [ 0.63012844],
       [-0.27721125],
       [ 3.7578526 ],
       [ 0.6039228 ],
       [-1.9971237 ],
       [ 0.39006793],
       [-0.94834006],
       [-0.44606525],
       [-0.23227745],
       [ 0.49765533],
       [-0.37558216],
       [ 0.676269  ],
       [-0.879135  ],
       [-1.7579178 ],
       [ 0.1229308 ],
       [ 2.5017128 ],
       [ 0.5192864 ],
       [ 3.0233808 ],
       [ 5.643265  ],
       [ 3.9918735 ],
       [-3.912091  ],
       [ 0.69894326],
       [ 1.448861  ],
       [-1.7716358 ],
       [-0.5042784 ],
       [-2.6361802 ],
       [-1.0682964 ],
       [-4.7989597 ],
       [-0.11987513],
       [ 4.414299  ],
       [-4.0782657 ],
       [ 0.7600807 ],
       [ 2.4989939 ],
       [-2.6806858 ],
       [-0.35546955],
       [ 1.7425576 ],
       [ 2.3621912 ],
       [-0.7879547 ],
       [-0.25962973],
       [ 2.8525693 ],
       [-0