In [1]:
from transformers import BertTokenizer, BertForMaskedLM
import torch
import random

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [2]:
def mask_tokens(input_ids, mask_prob=0.15):
    labels = input_ids.clone()
    rand = torch.rand(input_ids.shape)
    mask_arr = (rand < mask_prob) & (input_ids != tokenizer.cls_token_id) & (input_ids != tokenizer.sep_token_id) & (input_ids != tokenizer.pad_token_id)

    selection = [i for i in range(input_ids.shape[1]) if mask_arr[0, i]]

    # Ensure at least one token is masked
    if len(selection) == 0:
        valid_indices = [i for i in range(1, input_ids.shape[1]-1)]  # avoid [CLS], [SEP]
        random_idx = random.choice(valid_indices)
        selection = [random_idx]

    input_ids[0, selection] = tokenizer.mask_token_id
    return input_ids, labels, selection

In [3]:
def decode_tokens(token_ids):
    # Merge WordPiece tokens back into full words
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
    full_words = tokenizer.convert_tokens_to_string(tokens)
    return full_words

In [4]:
def predict_masked_tokens(sentence, mask_prob=0.15, top_k=1):
    # Tokenize sentence
    inputs = tokenizer(sentence, return_tensors="pt")
    input_ids = inputs["input_ids"].clone()

    # Mask some tokens
    masked_input_ids, labels, selection = mask_tokens(input_ids, mask_prob)

    # Run model inference
    with torch.no_grad():
        outputs = model(masked_input_ids)
        predictions = outputs.logits

    results = []

    # For each masked position, collect top-k predictions
    for idx in selection:
        topk_ids = torch.topk(predictions[0, idx], k=top_k).indices.tolist()
        predicted_words = [decode_tokens([pred_id]) for pred_id in topk_ids]
        true_word = decode_tokens([labels[0, idx].item()])

        # Result is correct if the true word is in the top-k predictions
        accuracy = 1.0 if true_word in predicted_words else 0.0
    
        results.append({
            "masked_position": idx,
            "predicted": predicted_words,
            "true": true_word,
            "accuracy": accuracy
        })

    # Average accuracy for each sentence
    avg_acc = sum(result["accuracy"] for result in results) / len(results) if results else 0

    return {
        "original": sentence,
        "masked": tokenizer.decode(masked_input_ids[0]),
        "results": results,
        "accuracy": avg_acc
    }

In [5]:
def evaluate_results(sentences, mask_prob=0.15, top_k=1):
    results = []
    total_accuracy = 0

    for sentence in sentences:
        result = predict_masked_tokens(sentence, mask_prob, top_k)
        results.append(result)
        total_accuracy += result["accuracy"]

    avg_accuracy = total_accuracy / len(sentences) if sentences else 0
    return results, avg_accuracy

In [6]:
documents = [
    "I study machine learning",
    "Natural language processing is important",
    "I enjoy working with data",
    "Text mining is an interesting field",
    "Data analysis is crucial for business",
]

# Evaluate with top-5 predictions per masked token
results, avg_acc = evaluate_results(documents, mask_prob=0.15, top_k=5)

for result in results:
    print("Original:", result["original"])
    print("Masked:", result["masked"])
    for res in result["results"]:
        print(f"Predicted: {res['predicted']} | True: {res['true']} | Correct: {res['accuracy']:.2f}")
    print(f"Sentence Accuracy: {result['accuracy']:.2f}\n")

print(f"Average Accuracy across all sentences: {avg_acc:.2f}")

Original: I study machine learning
Masked: [CLS] [MASK] study machine learning [SEP]
Predicted: ['to', 'they', 'students', 'i', 'children'] | True: i | Correct: 1.00
Sentence Accuracy: 1.00

Original: Natural language processing is important
Masked: [CLS] natural language [MASK] is important [SEP]
Predicted: ['acquisition', 'processing', 'learning', 'teaching', 'use'] | True: processing | Correct: 1.00
Sentence Accuracy: 1.00

Original: I enjoy working with data
Masked: [CLS] i enjoy working [MASK] data [SEP]
Predicted: ['with', 'on', 'through', 'out', 'from'] | True: with | Correct: 1.00
Sentence Accuracy: 1.00

Original: Text mining is an interesting field
Masked: [CLS] text mining is [MASK] interesting field [SEP]
Predicted: ['an', 'another', 'very', 'one', 'also'] | True: an | Correct: 1.00
Sentence Accuracy: 1.00

Original: Data analysis is crucial for business
Masked: [CLS] data [MASK] is crucial for [MASK] [SEP]
Predicted: ['mining', 'collection', 'processing', 'analysis', 'qual