In [3]:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
masked_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
embedding_model = BertModel.from_pretrained('bert-base-uncased')
masked_model.eval()
embedding_model.eval()

# Function to calculate both attention-based and masked word weightage
def calculate_word_weightage(sentence):
    # Tokenize the sentence and prepare inputs
    encoding = tokenizer.encode_plus(sentence, return_tensors="pt", add_special_tokens=True)
    input_ids = encoding['input_ids']
    token_type_ids = encoding['token_type_ids']
    attention_mask = encoding['attention_mask']

    # Obtain attention scores
    with torch.no_grad():
        outputs = embedding_model(input_ids, attention_mask=attention_mask, output_attentions=True)
    attentions = outputs.attentions
    
    # Calculate average attention scores (across heads and layers)
    avg_attention = torch.stack(attentions).mean(dim=1).mean(dim=1).mean(dim=1)
    avg_attention = avg_attention.squeeze(0).mean(dim=0).cpu().detach().numpy()
    
    # Initialize dictionary to store word weights
    word_weights = {}

    # Loop over each token in the sentence
    for idx, token_id in enumerate(input_ids[0]):
        if token_id in tokenizer.convert_tokens_to_ids(['[CLS]', '[SEP]', '[PAD]']):
            continue  # Skip special tokens
        
        # Create masked input
        masked_input_ids = input_ids.clone()
        masked_input_ids[0, idx] = tokenizer.convert_tokens_to_ids('[MASK]')
        masked_input_tensor = torch.tensor(masked_input_ids)
        
        # Predict masked token
        with torch.no_grad():
            outputs = masked_model(masked_input_tensor, attention_mask=attention_mask)
        predictions = outputs.logits
        
        # Get the probability of the original token being predicted correctly
        token = tokenizer.convert_ids_to_tokens(token_id.item())
        predicted_prob = torch.softmax(predictions[0, idx], dim=-1)[token_id].item()
        
        # Combine masked probability with average attention score
        combined_weight = (predicted_prob + avg_attention[idx]) / 2.0
        
        # Store the token and its combined weightage
        word_weights[token] = combined_weight

    return word_weights

# Example sentence
sentence = "So I took my highest dose of acid lastnight, went to go lay down and well now its the morning and I dont remember falling asleep. From my research you shouldnt be able to fall asleep on ket until 3 hours after your last line. So wat. Did I get anaesthetized or what?"
word_weights = calculate_word_weightage(sentence)

# Display the word and its weightage (context-based probability + attention score)
for word, weight in word_weights.items():
    print(f"{word}: {weight:.4f}")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  masked_input_tensor = torch.tensor(masked_input_ids)


so: 0.0072
i: 0.1267
took: 0.3721
my: 0.3379
highest: 0.0043
dose: 0.4039
of: 0.5008
acid: 0.0059
last: 0.0251
##night: 0.0047
,: 0.0483
went: 0.0074
to: 0.4927
go: 0.0130
lay: 0.1010
down: 0.3423
and: 0.1111
well: 0.0054
now: 0.0070
its: 0.0050
the: 0.0702
morning: 0.0286
don: 0.2897
##t: 0.0076
remember: 0.3593
falling: 0.4992
asleep: 0.5014
.: 0.0126
from: 0.0548
research: 0.0099
you: 0.2747
shouldn: 0.0179
be: 0.5019
able: 0.4157
fall: 0.4084
on: 0.0279
ke: 0.0042
until: 0.0121
3: 0.0068
hours: 0.1840
after: 0.4687
your: 0.0793
line: 0.0057
wat: 0.0051
did: 0.3914
get: 0.4916
ana: 0.0489
##est: 0.5020
##het: 0.4980
##ized: 0.3835
or: 0.4938
what: 0.0267
?: 0.5013


SyntaxError: invalid syntax (3048194002.py, line 39)