# BERT Word Embeddings

Three ways to obtain word embeddings from BERT:
- context-free
- context-averaged
- 

In [3]:
#!pip install torch
#!pip install transformers

## import

In [1]:
from transformers import BertTokenizer, BertModel
import pandas as pd
import numpy as np
import nltk
import torch

In [2]:
# Loading the pre-trained BERT model
###################################
# Embeddings will be derived from
# the outputs of this model
model = BertModel.from_pretrained('bert-base-uncased',
                                  output_hidden_states = True,
                                  )

# Setting up the tokenizer
###################################
# This is the same tokenizer that
# was used in the model to generate 
# embeddings to ensure consistency
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
# Text corpus
##############
# These sentences show the different
# forms of the word 'bank' to show the
# value of contextualized embeddings

texts = ["bank",
         "The river bank was flooded.",
         "The bank vault was robust.",
         "He had to bank on her for support.",
         "The bank was out of money.",
         "The bank teller was a man."]

## data preprocessing

In [4]:
def bert_text_preparation(text, tokenizer):
    """Preparing the input for BERT
    
    Takes a string argument and performs
    pre-processing like adding special tokens,
    tokenization, tokens to ids, and tokens to
    segment ids. All tokens are mapped to seg-
    ment id = 1.
    
    Args:
        text (str): Text to be converted
        tokenizer (obj): Tokenizer object
            to convert text into BERT-re-
            adable tokens and ids
        
    Returns:
        list: List of BERT-readable tokens
        obj: Torch tensor with token ids
        obj: Torch tensor segment ids
    
    
    """
    marked_text = "[CLS] " + text + " [SEP]"
    tokenized_text = tokenizer.tokenize(marked_text)
#     print( tokenized_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
#     print(indexed_tokens)
    segments_ids = [1]*len(indexed_tokens)

    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
#     print(segments_tensors)

    return tokenized_text, tokens_tensor, segments_tensors
    

## Run Bert 
 **(Just extract Last layer of Bert for embedding)**

In [5]:
def get_bert_embeddings(tokens_tensor, segments_tensors, model):
    """Get embeddings from an embedding model
    
    Args:
        tokens_tensor (obj): Torch tensor size [n_tokens]
            with token ids for each token in text
        segments_tensors (obj): Torch tensor size [n_tokens]
            with segment ids for each token in text
        model (obj): Embedding model to generate embeddings
            from token and segment ids
    
    Returns:
        list: List of list of floats of size
            [n_tokens, n_embedding_dimensions]
            containing embeddings for each token
    
    """
    
    # Gradient calculation id disabled
    # Model is in inference mode
    with torch.no_grad():
        outputs = model(tokens_tensor, segments_tensors)
        # Removing the first hidden state
        # The first state is the input state
        hidden_states = outputs[2][1:]
        #print(outputs)

    # Getting embeddings from the final BERT layer
    token_embeddings = hidden_states[-1]
#     print(token_embeddings.shape)
#     print(token_embeddings)
    # Collapsing the tensor into 1-dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=0)
    # Converting torchtensors to lists
#     print(token_embeddings.shape)
#     print(token_embeddings)
    list_token_embeddings = [token_embed.tolist() for token_embed in token_embeddings]
#     print(len(list_token_embeddings))

    return list_token_embeddings

## Embedding comparison

In [16]:


# Getting embeddings for the target
# word in all given contexts
target_word_embeddings = []
count = 0
for text in texts:
#     if count==1:
#         break
    tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(text, tokenizer)
    list_token_embeddings = get_bert_embeddings(tokens_tensor, segments_tensors, model)
    
    # Find the position 'bank' in list of tokens
    word_index = tokenized_text.index('bank')
#     print(word_index)
    # Get the embedding for bank
    word_embedding = list_token_embeddings[word_index]
#     print(len(list_token_embeddings))
    target_word_embeddings.append(word_embedding)
#     count+=1
# target_word_embeddings

In [7]:
from scipy.spatial.distance import cosine

# Calculating the distance between the
# embeddings of 'bank' in all the
# given contexts of the word

list_of_distances = []
for text1, embed1 in zip(texts, target_word_embeddings):
    print(text1)
    for text2, embed2 in zip(texts, target_word_embeddings):
        print('\t',text2)
        cos_dist = 1 - cosine(embed1, embed2)
        list_of_distances.append([text1, text2, cos_dist])

distances_df = pd.DataFrame(list_of_distances, columns=['text1', 'text2', 'distance'])
distances_df

bank
	 bank
	 The river bank was flooded.
	 The bank vault was robust.
	 He had to bank on her for support.
	 The bank was out of money.
	 The bank teller was a man.
The river bank was flooded.
	 bank
	 The river bank was flooded.
	 The bank vault was robust.
	 He had to bank on her for support.
	 The bank was out of money.
	 The bank teller was a man.
The bank vault was robust.
	 bank
	 The river bank was flooded.
	 The bank vault was robust.
	 He had to bank on her for support.
	 The bank was out of money.
	 The bank teller was a man.
He had to bank on her for support.
	 bank
	 The river bank was flooded.
	 The bank vault was robust.
	 He had to bank on her for support.
	 The bank was out of money.
	 The bank teller was a man.
The bank was out of money.
	 bank
	 The river bank was flooded.
	 The bank vault was robust.
	 He had to bank on her for support.
	 The bank was out of money.
	 The bank teller was a man.
The bank teller was a man.
	 bank
	 The river bank was flooded.
	 The ban

Unnamed: 0,text1,text2,distance
0,bank,bank,1.0
1,bank,The river bank was flooded.,0.338063
2,bank,The bank vault was robust.,0.494099
3,bank,He had to bank on her for support.,0.25614
4,bank,The bank was out of money.,0.469942
5,bank,The bank teller was a man.,0.466021
6,The river bank was flooded.,bank,0.338063
7,The river bank was flooded.,The river bank was flooded.,1.0
8,The river bank was flooded.,The bank vault was robust.,0.523325
9,The river bank was flooded.,He had to bank on her for support.,0.331584


In [8]:
distances_df[distances_df.text1 == 'bank']

Unnamed: 0,text1,text2,distance
0,bank,bank,1.0
1,bank,The river bank was flooded.,0.338063
2,bank,The bank vault was robust.,0.494099
3,bank,He had to bank on her for support.,0.25614
4,bank,The bank was out of money.,0.469942
5,bank,The bank teller was a man.,0.466021


In [9]:
distances_df[distances_df.text1 == 'The bank vault was robust.']

Unnamed: 0,text1,text2,distance
12,The bank vault was robust.,bank,0.494099
13,The bank vault was robust.,The river bank was flooded.,0.523325
14,The bank vault was robust.,The bank vault was robust.,1.0
15,The bank vault was robust.,He had to bank on her for support.,0.416074
16,The bank vault was robust.,The bank was out of money.,0.759213
17,The bank vault was robust.,The bank teller was a man.,0.867661


In [10]:
cos_dist = 1 - cosine(target_word_embeddings[0], np.sum(target_word_embeddings, axis=0))
print(f'Distance between context-free and context-averaged = {cos_dist}')

Distance between context-free and context-averaged = 0.6590346345539333


In [11]:
texts

['bank',
 'The river bank was flooded.',
 'The bank vault was robust.',
 'He had to bank on her for support.',
 'The bank was out of money.',
 'The bank teller was a man.']

In [12]:
len(target_word_embeddings)

6

In [13]:
len(target_word_embeddings[0])

768