In [1]:
from transformers import BertTokenizerFast, AutoModel
import torch
import numpy as np
import pandas as pd

In [2]:
def get_word_embeddings(sentence,model,tokenizer,layers=[-4, -3, -2, -1]):
        
    encoded_sentence= tokenizer.encode_plus(sentence,return_tensors="pt")
    
    with torch.no_grad():
        output = model(**encoded_sentence)
        
    hidden_states = output.hidden_states
    
    output = torch.stack([hidden_states[i] for i in layers]).sum(0).squeeze()
    
    sentence=sentence.strip().split(" ")
    
    word_embeddings={}
    
    for word_index in range(len(sentence)):
        tokens_for_word_id = np.where(np.array(encoded_sentence.word_ids()) == word_index)
        word_tokens = output[tokens_for_word_id]
        word_embeddings[sentence[word_index]]=word_tokens.mean(dim=0)
    
    return word_embeddings

In [3]:
sentence ="this is a simple example on how to extract contextualized word embeddings from BERT"
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
model = AutoModel.from_pretrained("bert-base-cased", output_hidden_states=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
word_embeddings=get_word_embeddings(sentence,model,tokenizer)
pd.DataFrame(word_embeddings)

Unnamed: 0,this,is,a,simple,example,on,how,to,extract,contextualized,word,embeddings,from,BERT
0,-1.630641,1.787610,0.753495,2.547040,-1.056210,-2.126183,0.320213,3.557543,6.466884,4.723215,2.969266,4.320745,2.094373,-0.355209
1,-0.258968,-0.165800,0.170220,0.700909,-1.225829,-1.486979,-2.497697,0.120161,0.303110,-0.011509,3.591769,3.829406,2.694431,-0.765758
2,0.651023,3.086640,-0.404416,-0.260927,-3.786962,-1.136896,-0.490549,-2.265811,-2.068238,-0.995770,2.217768,-0.862766,3.068004,-0.993397
3,-0.337982,0.871563,0.323382,4.165645,0.171016,-1.365587,1.145535,2.383952,-1.065548,-0.060255,0.949844,-0.793853,-1.799135,1.006229
4,-0.971201,0.120285,0.240250,-2.331689,-7.664180,0.233749,-1.454072,-1.417325,0.052033,-1.007352,0.203638,-0.725703,1.268796,-2.467154
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
763,-5.255383,-0.034469,2.124026,-3.197527,-1.442207,-0.233140,0.564277,1.046748,-2.018145,-1.885844,-0.099606,-1.626473,-0.844005,-3.023158
764,-2.037163,-0.218114,0.629302,4.901371,4.873151,-2.875613,-3.444485,-1.458117,1.384867,-2.920961,-3.455287,0.114403,-4.447982,-1.153096
765,4.384311,3.310249,4.281366,4.679168,4.127016,-0.161553,-0.413020,4.250200,5.344140,1.910372,2.727714,-0.515223,2.061347,3.227530
766,-1.041151,0.651437,2.766595,-0.317495,-2.777891,-0.887264,-2.311985,-0.770246,-4.797533,-0.092786,0.209306,0.315204,0.626048,0.653203
