In [41]:
from transformers import BertTokenizerFast, BertModel
import torch

In [None]:
# Initialize tokenizer and model
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

def get_context(sentence, target, tokenizer=tokenizer, model=model):
    """Get the contextual embedding for a target word, given the context

    Args:
        sentence (str): The sentence
        target (str): The targe word
        tokenize: The tokenizer
        model: The model

    Returns:
        target_embedding (Tensor): The output embedding
    """
    # Tokenize sentence (with offsets)
    tokens = tokenizer(sentence, return_tensors='pt', return_offsets_mapping=True)
    input_ids = tokens['input_ids']
    offsets = tokens['offset_mapping'][0]
    sentence_token_ids = input_ids[0].tolist()

    # Tokenize target (without special tokens)
    target_token_ids = tokenizer(target, add_special_tokens=False)["input_ids"]

    # Find the target's token indices in the sentence
    def find_sublist_index(big_list, sub_list):
        for i in range(len(big_list) - len(sub_list) + 1):
            if big_list[i:i+len(sub_list)] == sub_list:
                return list(range(i, i+len(sub_list)))
        return []

    target_indices = find_sublist_index(sentence_token_ids, target_token_ids)

    if not target_indices:
        print(f"Target word '{target}' not found in sentence.")
        target_embedding = None
    else:
        # Remove 'offset_mapping' before passing to model
        tokens_for_model = {k: v for k, v in tokens.items() if k != "offset_mapping"}
        with torch.no_grad():
            outputs = model(**tokens_for_model)
            last_hidden = outputs.last_hidden_state  # (batch_size, seq_len, hidden_dim)
        target_embedding = last_hidden[0, target_indices, :].mean(dim=0)   # Mean pool embeddings across the matched subword token indices

    return target_embedding

In [None]:
target = "router"
sentence = "internet router"


target_embedding = get_context(sentence, target, tokenizer=tokenizer, model=model)

Target Embedding Shape: torch.Size([768])
tensor([ 6.5329e-01, -5.9943e-01, -1.0861e-01,  1.7143e-01,  9.5945e-01,
        -4.6524e-01, -5.9754e-02,  6.4838e-01,  1.7152e-01,  2.3350e-02,
         7.9649e-02, -7.0422e-01,  1.2538e-01,  1.8586e-01, -8.1287e-01,
        -7.0066e-02, -1.3266e-01,  4.7322e-01,  4.3658e-01,  7.4859e-02,
        -5.6744e-02,  2.2535e-01,  1.2266e-01,  4.9263e-01,  2.9372e-01,
        -2.9797e-02, -1.5052e-01,  5.7854e-01, -3.1055e-02, -1.2978e-01,
         4.1654e-01, -3.6298e-02,  2.4265e-01, -1.1813e-01, -1.6506e-01,
        -5.7869e-01,  5.3950e-01,  1.4741e-01, -3.6675e-01,  3.2974e-02,
        -5.0385e-01, -4.0850e-01,  9.7585e-01,  1.7044e-01,  2.2634e-01,
        -2.2040e-01,  4.3005e-01,  1.9141e-01, -1.7639e-01, -3.8107e-01,
        -3.7672e-01,  3.9518e-01, -4.9043e-01,  6.8427e-01,  7.4023e-02,
         1.0097e+00,  3.3399e-01, -6.6654e-01,  3.1090e-01, -3.2876e-01,
         3.4661e-01, -8.6467e-02, -2.0116e-01, -1.7665e-01,  5.7967e-01,
        -

{'input_ids': tensor([[ 101, 4274, 2799, 2099,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]]), 'offset_mapping': tensor([[[ 0,  0],
         [ 0,  8],
         [ 9, 14],
         [14, 15],
         [ 0,  0]]])}
[2, 3]
Target Embedding Shape: torch.Size([768])
tensor([ 6.5329e-01, -5.9943e-01, -1.0861e-01,  1.7143e-01,  9.5945e-01,
        -4.6524e-01, -5.9754e-02,  6.4838e-01,  1.7152e-01,  2.3350e-02,
         7.9649e-02, -7.0422e-01,  1.2538e-01,  1.8586e-01, -8.1287e-01,
        -7.0066e-02, -1.3266e-01,  4.7322e-01,  4.3658e-01,  7.4859e-02,
        -5.6744e-02,  2.2535e-01,  1.2266e-01,  4.9263e-01,  2.9372e-01,
        -2.9797e-02, -1.5052e-01,  5.7854e-01, -3.1055e-02, -1.2978e-01,
         4.1654e-01, -3.6298e-02,  2.4265e-01, -1.1813e-01, -1.6506e-01,
        -5.7869e-01,  5.3950e-01,  1.4741e-01, -3.6675e-01,  3.2974e-02,
        -5.0385e-01, -4.0850e-01,  9.7585e-01,  1.7044e-01,  2.2634e-01,
        -2.2040e-01,  4.3005e-01