In [1]:
# Get BERT embeddings using Hugging Face
# !pip install transformers torch

In [2]:
# %pip install --upgrade transformers

from transformers import AutoTokenizer, AutoModel
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Load pre-trained BERT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")

# Two example sentences with different meanings of "bank"
sentences = [
    "He deposited cash in the bank.",
    "She sat by the river bank."
]

In [None]:
bank_vectors = []

for sent in sentences:
    # Tokenize the sentence
    inputs = tokenizer(sent, return_tensors="pt")

    # Get BERT outputs (last hidden states)
    with torch.no_grad():                                   # (if no_grad() isn't used transfrormer will consider it as training part of the transformer)
        outputs = model(**inputs)

    # Convert token IDs to tokens
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    # Find the index of 'bank' token
    try:
        bank_idx = tokens.index("bank") 
    except ValueError:
        raise Exception(f"'bank' not found in tokens : {tokens}")
    
    # Extract the vector for 'bank' token from lst hidden state
    bank_vec = outputs.last_hidden_state[0][bank_idx].numpy()
    bank_vectors.append(bank_vec)


In [None]:
# Print the vectors : 
print("bank (finance) vector : ", bank_vectors[0])


bank (finance) vector :  [ 3.38825375e-01 -4.81747955e-01 -2.08177924e-01  1.66003853e-01
  9.79713976e-01  1.74405500e-01 -5.12155354e-01  7.76752591e-01
 -9.10256207e-02 -1.84646234e-01  4.41605151e-01 -2.76178002e-01
 -3.61095428e-01  1.60315096e-01 -6.13400936e-01 -2.81177849e-01
  3.98038149e-01  1.37111664e-01  1.14339459e+00  4.41541150e-02
 -5.98253250e-01  4.72673737e-02  4.52181697e-01  1.83069438e-01
  1.11062318e-01  4.82494831e-01  1.09142549e-01  3.67951512e-01
 -4.16943461e-01 -3.93803298e-01  6.22381210e-01  8.63100469e-01
  2.09493935e-01  1.76880360e-01  7.98800066e-02 -1.34073883e-01
  1.28841743e-01 -1.37331218e-01 -1.42057109e+00 -7.58760273e-02
 -1.40124813e-01 -7.00906396e-01 -4.63447183e-01  2.35610962e-01
 -1.31730482e-01 -4.82020020e-01  4.42482352e-01  2.38655046e-01
 -6.54097378e-01 -4.76935863e-01 -3.52408409e-01  6.82572246e-01
  3.14004898e-01 -4.13273603e-01  9.06801298e-02  6.28960967e-01
 -8.22423577e-01 -5.72398782e-01 -6.71776891e-01  3.46921310e-02


In [6]:
print("bank (river) vector : ", bank_vectors[1])

bank (river) vector :  [ 1.85169175e-01 -4.78418440e-01 -1.36632040e-01 -2.81865876e-02
 -4.40413326e-01  1.64047018e-01  4.20375347e-01  1.23817003e+00
 -9.74004120e-02 -5.22433639e-01  6.52060390e-01 -1.92709323e-02
  2.88933069e-01  1.47832394e-01 -3.30544412e-01  3.34104806e-01
 -8.19631517e-02  4.29413989e-02  9.54276562e-01 -1.03063412e-01
  5.96620500e-01  3.23470980e-01  1.32806525e-01  1.09550036e-01
  3.61304611e-01 -5.45356423e-03  6.81889892e-01 -4.43433553e-01
 -2.10680813e-01  3.26703638e-01  1.28644824e+00  4.64069992e-01
 -1.46930903e-01 -5.05255163e-03 -3.84421676e-01  2.17052892e-01
  7.03482702e-02 -4.56064522e-01 -2.48021394e-01  7.12162018e-01
 -7.26874590e-01 -9.84282732e-01 -6.76786363e-01  1.05464613e+00
  4.94448543e-01 -5.46030179e-02 -6.33705184e-02 -8.09840485e-02
 -4.75389719e-01  4.67367411e-01 -3.69114280e-01  8.96504700e-01
 -2.51260877e-01 -5.95071971e-01  2.31615722e-01  8.79063427e-01
 -6.47189558e-01 -8.90389085e-01 -1.24822922e-01  1.22467265e-01
  