**Understanding BERT embeddings based on the code shared by Chris McCormick**

https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/


In [1]:
!pip install -q transformers
!pip install -q sentence_transformers

In [2]:
import torch
from transformers import BertTokenizerFast, BertModel # Use BertTokenizerFast to ensure that words_ids() can be used
from sentence_transformers import util

In [3]:
import numpy as np
import nltk

In [4]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [5]:
# Initialize the tokenizer and model
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")

model = BertModel.from_pretrained("bert-base-cased", output_hidden_states=True)
model.eval()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

#### Step 1: Map a word of interest to the set of tokens corersponding to it

In [6]:
# Scenario 1: A word with a punctuation is split into multiple tokens and 
# these tokens are treated as separate words
trial_sent_1 = "I like cookie's"

current_token_list = tokenizer.tokenize(trial_sent_1)
print(current_token_list)
current_encoding = tokenizer.encode_plus(trial_sent_1)
print(current_encoding)
current_encoding.word_ids()

['I', 'like', 'cookie', "'", 's']
{'input_ids': [101, 146, 1176, 25413, 112, 188, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}


[None, 0, 1, 2, 3, 4, None]

In [7]:
# Scenario 2: A word without a punctuation is split into multiple tokens and 
# these tokens are treated as a single word
trial_sent_2 = "I like cookys"

current_token_list = tokenizer.tokenize(trial_sent_2)
print(current_token_list)
current_encoding = tokenizer.encode_plus(trial_sent_2)
current_encoding.word_ids()

['I', 'like', 'cook', '##ys']


[None, 0, 1, 2, 2, None]

In [8]:
# Scenario 3: Index of 'word of interest' needs to be identified using the location of the word or 
# subwords post tokenization
trial_sent_3 = "After stealing money from the bank vault, the bank robber was seen fishing on the Mississippi river bank"

current_token_list = tokenizer.tokenize(trial_sent_3)
current_token_list

['After',
 'stealing',
 'money',
 'from',
 'the',
 'bank',
 'vault',
 ',',
 'the',
 'bank',
 'r',
 '##ob',
 '##ber',
 'was',
 'seen',
 'fishing',
 'on',
 'the',
 'Mississippi',
 'river',
 'bank']

In [9]:
# Python's default string split is different from that obtained using BERT tokenizer
# Evaluate the use NLTK's tokenizer
trial_sent_3.split(' ')

['After',
 'stealing',
 'money',
 'from',
 'the',
 'bank',
 'vault,',
 'the',
 'bank',
 'robber',
 'was',
 'seen',
 'fishing',
 'on',
 'the',
 'Mississippi',
 'river',
 'bank']

In [10]:
nltk.tokenize.word_tokenize(trial_sent_3)

['After',
 'stealing',
 'money',
 'from',
 'the',
 'bank',
 'vault',
 ',',
 'the',
 'bank',
 'robber',
 'was',
 'seen',
 'fishing',
 'on',
 'the',
 'Mississippi',
 'river',
 'bank']

#### Step 2: Obtain token embeddings

In [11]:
actual_sent = "I love cookies"
current_token_list = tokenizer.tokenize(actual_sent)
print(current_token_list)

current_encoding = tokenizer.encode_plus(actual_sent)
print(current_encoding)
current_encoding.word_ids()

['I', 'love', 'cookies']
{'input_ids': [101, 146, 1567, 18621, 102], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}


[None, 0, 1, 2, None]

In [12]:
indexed_tokens = current_encoding['input_ids']
segments_ids = [1] * len(indexed_tokens)

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

In [13]:
print(tokens_tensor.size())
segments_tensors.size()

torch.Size([1, 5])


torch.Size([1, 5])

In [14]:
with torch.no_grad():
  outputs = model(tokens_tensor, segments_tensors)
  hidden_states = outputs[2]

In [15]:
# Tuple length = number of hidden layers (default = 12) + token embeddings (i.e. embeddings corresponding to the input layer)
len(hidden_states)

13

In [16]:
# For a given layer, a tensor of size Batch x Number_Of_Tokens (including [CLS] and [SEP]) x Dimension
hidden_states[0].size()

torch.Size([1, 5, 768])

In [17]:
token_embeddings = torch.stack(hidden_states, dim=0)
token_embeddings.size()

torch.Size([13, 1, 5, 768])

In [18]:
# Remove dimension 1 i.e. "batches"
token_embeddings = torch.squeeze(token_embeddings, dim=1)
token_embeddings.size()

torch.Size([13, 5, 768])

In [19]:
# Swap dimensions 0 and 1.
token_embeddings = token_embeddings.permute(1,0,2)
token_embeddings.size()

torch.Size([5, 13, 768])

In [20]:
# Pooling strategy -- sum together the last four layers
token_vecs_sum = []

# For each token in the sentence...
for token in token_embeddings:
    sum_vec = torch.sum(token[-4:], dim=0) # returns column sum
    token_vecs_sum.append(sum_vec)

print ('Shape is: %d x %d' % (len(token_vecs_sum), len(token_vecs_sum[0])))

Shape is: 5 x 768


In [21]:
# Direct implementation of the pooling strategy
rev_token_embeddings = token_embeddings.detach().clone()
rev_token_embeddings = rev_token_embeddings[:, -4:, :]
print(rev_token_embeddings.size())

rev_token_sum = torch.sum(rev_token_embeddings, dim=1)
rev_token_sum.size()

torch.Size([5, 4, 768])


torch.Size([5, 768])

In [22]:
torch.equal(rev_token_sum[0], token_vecs_sum[0])

True

#### Step 3: Obtain the word embedding by pooling over the relevant token embeddings


In [23]:
# If multiple tokens map to a single word, then obtain the mean of the relevant token embeddings
# sentence_embedding = torch.mean(token_vecs, dim=0)

#### Step 4: Understand the contextual difference in the use of the word "bank"

In [24]:
multiple_context_sent = "After stealing money from the bank vault, the bank robber was seen fishing on the Mississippi river bank"

current_encoding = tokenizer.encode_plus(multiple_context_sent)
indexed_tokens = current_encoding['input_ids']
segments_ids = [1] * len(indexed_tokens)

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

with torch.no_grad():
  outputs = model(tokens_tensor, segments_tensors)
  hidden_states = outputs[2]

token_embeddings = torch.stack(hidden_states, dim=0)
token_embeddings = torch.squeeze(token_embeddings, dim=1)
token_embeddings = token_embeddings.permute(1,0,2)

# Pooling strategy -- sum together the last four layers
token_vecs_sum = []

# For each token in the sentence...
for token in token_embeddings:
    sum_vec = torch.sum(token[-4:], dim=0) # returns column sum
    token_vecs_sum.append(sum_vec)

In [25]:
# Add 1 to the values below as [CLS] token is not considered
np.where(np.array(tokenizer.tokenize(multiple_context_sent)) == 'bank')[0]

array([ 5,  9, 20])

In [26]:
# Each occurence of the word "bank" has the same index
for idx in [6, 10, 21]:
  print(indexed_tokens[idx])

3085
3085
3085


In [27]:
# [CLS] + 21 + [SEP]
len(token_vecs_sum)

23

In [28]:
bank_tensors = torch.cat((torch.unsqueeze(token_vecs_sum[6], 0)
                          , torch.unsqueeze(token_vecs_sum[10], 0)
                          , torch.unsqueeze(token_vecs_sum[21], 0))
                          , dim=0)
bank_tensors.size()

torch.Size([3, 768])

In [29]:
# <bank vault, bank robber, river bank>
# Context changes the similarity between the same word
util.cos_sim(bank_tensors, bank_tensors)

tensor([[1.0000, 0.9078, 0.5361],
        [0.9078, 1.0000, 0.5107],
        [0.5361, 0.5107, 1.0000]])

#### Section 4: Non-context specific word embedding

In [30]:
single_word_sent = "bank"

current_encoding = tokenizer.encode_plus(single_word_sent)
indexed_tokens = current_encoding['input_ids']
segments_ids = [1] * len(indexed_tokens)

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

with torch.no_grad():
  outputs = model(tokens_tensor, segments_tensors)
  hidden_states = outputs[2]

In [31]:
# Three indexes as [CLS] + word + [SEP]
indexed_tokens

[101, 3085, 102]

In [32]:
token_embeddings = torch.stack(hidden_states, dim=0)
token_embeddings = torch.squeeze(token_embeddings, dim=1)
token_embeddings = token_embeddings.permute(1,0,2)
token_embeddings.size()

torch.Size([3, 13, 768])