In [41]:
import torch
from transformers import BertTokenizer, BertModel,BertForMaskedLM

def multiple_mask_tokens(input_text, n = 5):
    """
    :param input_text: string with MASK tokens
    :param n: the top number of tokens to return
    :return: list of n tokens for every mask token. Returns a blank list if no mask token is found
    """
    model = BertForMaskedLM.from_pretrained('bert-base-uncased')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    inputs = tokenizer(input_text, return_tensors='pt')
    outputs = model(**inputs)

    # predicitons is the probability distribution over the vocabulary for each token
    predictions = outputs[0]

    # get index of masked tokens
    masked_indices = (inputs['input_ids'] == tokenizer.mask_token_id).nonzero(as_tuple=True)

    if masked_indices[0].shape[0] == 0:
        print("No masked tokens found")
        return []

    # get the probability distribution over the vocabulary for each masked token
    masked_predictions = predictions[masked_indices]

    # get the top 5 predictions for each masked token
    top_n_values = torch.topk(masked_predictions, n, dim=1,sorted=True)
    top_n_probability = top_n_values.values
    top_n_token_numbers = top_n_values.indices

    # get the token words for the top n predictions
    answers = []
    for i in range(len(masked_indices[0]-1)):
        proabilities = []
        top_n_tokens = tokenizer.convert_ids_to_tokens(top_n_token_numbers[i])
        for j in range(len(top_n_tokens)):
            proabilities.append((top_n_tokens[j],top_n_probability[i][j].item()))
        answers.append(proabilities)
    return answers

print("Masked tokens in the sentence are:", multiple_mask_tokens("[MASK] [MASK] [MASK] of the US is public service"))

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Masked tokens in the sentence are: [[('the', 10.744735717773438), ('a', 6.8373260498046875), ('all', 6.063361167907715), ('one', 5.773168087005615), ('another', 5.709738731384277)], [('federal', 6.544925212860107), ('national', 6.397103786468506), ('the', 6.17555570602417), ('first', 6.027252197265625), ('other', 5.97396183013916)], [('state', 7.726027965545654), ('policy', 7.337986946105957), ('government', 7.20404577255249), ('department', 6.940457344055176), ('part', 6.6692962646484375)]]
