In [1]:
from transformers import BertModel, BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [66]:
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np

In [31]:
# Initialize the model and tokenizer
model = BertModel.from_pretrained("bert-base-cased").eval()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")



In [171]:
example_sentences = ["I went to the store to buy some strawberry jam", "There was a huge traffic jam this morning", "Last night, the band had a great jam", "She spread the jam evenly on the bread", "I love eating jam"]

# Tokenize the sentences
tokenized = tokenizer.batch_encode_plus(example_sentences, padding=True, return_tensors='pt', add_special_tokens=True)

In [135]:
print(tokenized)

{'input_ids': tensor([[  101,  1109,  6363,  1127,  1304,  2816,   102,     0,     0],
        [  101,  1109, 11771,  1127,  1304,  6782,   102,     0,     0],
        [  101,   146,  1355,  1106,   170,  3240,  1314,  1480,   102],
        [  101,  1109,  2298,  3885,  1103,  3240,   102,     0,     0],
        [  101,  1109,  1873,  5866,  1103,  3240,   102,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0]])}


In [136]:
# These are the vocabulary ids of each token
tokenized['input_ids']

tensor([[  101,  1109,  6363,  1127,  1304,  2816,   102,     0,     0],
        [  101,  1109, 11771,  1127,  1304,  6782,   102,     0,     0],
        [  101,   146,  1355,  1106,   170,  3240,  1314,  1480,   102],
        [  101,  1109,  2298,  3885,  1103,  3240,   102,     0,     0],
        [  101,  1109,  1873,  5866,  1103,  3240,   102,     0,     0]])

In [172]:
# Get word embeddings for all sentences
with torch.no_grad():
    embeddings = model(**tokenized).last_hidden_state

In [153]:
embeddings.shape

torch.Size([5, 13, 768])

In [173]:
# Compare the embeddings of the word "jam" in different sentences
interesting_embeddings = []
interesting_token = "jam"

for i in range(embeddings.shape[0]):
    tokens = [tokenizer.ids_to_tokens[int(id)] for id in tokenized['input_ids'][i]]
    print(tokens)
    index = tokens.index(interesting_token)
    print(index)
    interesting_embeddings.append(embeddings[i][index].reshape(1,-1))
    

['[CLS]', 'I', 'went', 'to', 'the', 'store', 'to', 'buy', 'some', 'straw', '##berry', 'jam', '[SEP]']
11
['[CLS]', 'There', 'was', 'a', 'huge', 'traffic', 'jam', 'this', 'morning', '[SEP]', '[PAD]', '[PAD]', '[PAD]']
6
['[CLS]', 'Last', 'night', ',', 'the', 'band', 'had', 'a', 'great', 'jam', '[SEP]', '[PAD]', '[PAD]']
9
['[CLS]', 'She', 'spread', 'the', 'jam', 'evenly', 'on', 'the', 'bread', '[SEP]', '[PAD]', '[PAD]', '[PAD]']
4
['[CLS]', 'I', 'love', 'eating', 'jam', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
4


In [174]:
# Look at the pairwise cosine similarities between all instances of "jam"
for emb in interesting_embeddings:
    print(" ".join(f"{cosine_similarity(emb, emb2)[0][0]:.3f}" for emb2 in interesting_embeddings))

1.000 0.536 0.840 0.656 0.893
0.536 1.000 0.527 0.684 0.513
0.840 0.527 1.000 0.589 0.820
0.656 0.684 0.589 1.000 0.595
0.893 0.513 0.820 0.595 1.000
