#### Context-Aware Embeddings - BERT

In [1]:
from transformers import BertTokenizer, BertModel
import torch

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Two sentences with "bank" in different contexts
sentences = [
    "He sat by the river bank.",
    "She deposited money in the bank."
]

def get_word_embedding(sentence, target_word):

    # Tokenize and get input IDs
    inputs = tokenizer(sentence, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get the last hidden state (batch_size, seq_len, hidden_size)
    last_hidden_state = outputs.last_hidden_state.squeeze(0)
    
    # Decode tokens to align with input words
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
    
    # Find the index of the target word (may need to handle subwords)
    # We'll take the first occurrence for simplicity
    for i, token in enumerate(tokens):
        if target_word in token:
            return last_hidden_state[i].numpy(), tokens
    return None, tokens

# Get embeddings for "bank" in both sentences
vec1, tokens1 = get_word_embedding(sentences[0], "bank")
vec2, tokens2 = get_word_embedding(sentences[1], "bank")

print("Tokens in sentence 1:", tokens1)
print("Tokens in sentence 2:", tokens2)
print("Embedding for 'bank' in sentence 1 (river context):", vec1[:5])  # Show first 5 dims
print("Embedding for 'bank' in sentence 2 (money context):", vec2[:5])

Tokens in sentence 1: ['[CLS]', 'he', 'sat', 'by', 'the', 'river', 'bank', '.', '[SEP]']
Tokens in sentence 2: ['[CLS]', 'she', 'deposited', 'money', 'in', 'the', 'bank', '.', '[SEP]']
Embedding for 'bank' in sentence 1 (river context): [ 0.15994921 -0.33814338 -0.03246783 -0.08658472 -0.39891648]
Embedding for 'bank' in sentence 2 (money context): [ 0.3031039  -0.36687252 -0.35636595  0.1448596   1.0418966 ]


In [None]:
from numpy import dot
from numpy.linalg import norm

def cosine_similarity(a, b):
    return dot(a, b) / (norm(a) * norm(b))

similarity = cosine_similarity(vec1, vec2)
print("Cosine similarity between 'bank' in different contexts:", similarity)
# the vectors are different, and their similarity will be less than 1

Cosine similarity between 'bank' in different contexts: 0.5278751
