In [2]:
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',clean_up_tokenization_spaces=True)
model = BertModel.from_pretrained('bert-base-uncased')

In [3]:
def collect_word_embeddings(word, sentences, tokenizer, model):
    embeddings = []
    for sentence in sentences:
        inputs = tokenizer(sentence, return_tensors='pt')
        outputs = model(**inputs)
        last_hidden_state = outputs.last_hidden_state[0]
        
        tokens = tokenizer.tokenize(sentence)
        target_indices = [i for i, token in enumerate(tokens) if word in token]
        adjusted_indices = [i + 1 for i in target_indices]
        
        if not adjusted_indices:
            continue
        
        word_embeddings = last_hidden_state[adjusted_indices]
        word_embedding = word_embeddings.mean(dim=0)
        embeddings.append(word_embedding.detach().numpy())
    return embeddings

In [4]:
# Set up sentences that all use the same word with the same meaning.
sentences = [
    "He went to the bank to deposit money.",
    "He went to the bank to withdraw money.",
    "He went to the bank to open an account.",
    "He went to the bank to steal money.",
    "He works as a teller at the bank.",
]

embeddings = collect_word_embeddings("bank", sentences, tokenizer, model)

Now let's see if the same word ("bank") gets the same embedding when used in slightly different sentences.

In [5]:
# Convert list of embeddings to numpy array
embeddings_array = np.vstack(embeddings)

# Compute cosine similarity matrix
similarity_matrix = cosine_similarity(embeddings_array)

print("Cosine Similarity Matrix:")
print(similarity_matrix)

Cosine Similarity Matrix:
[[0.9999999  0.9737757  0.9500245  0.9442824  0.80067784]
 [0.9737757  0.9999997  0.95471483 0.9555222  0.82595307]
 [0.9500245  0.95471483 1.         0.93418175 0.8358381 ]
 [0.9442824  0.9555222  0.93418175 0.9999999  0.8418055 ]
 [0.80067784 0.82595307 0.8358381  0.8418055  0.99999964]]
