**Install Transformers**

In [None]:
!pip install transformers

**Get a necessary libraries**

In [2]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import string
from torch.nn import functional as F
import torch

**Define necessary function for encoding, decoding, and gets next word predictions**

In [3]:
def load_model(model_name: str):
  assert model_name.lower() in ["bert", "distilbert"]

  if model_name.lower() == "bert":
      model_name = "bert-base-uncased"
  else:
      model_name = "distilbert-base-uncased"

  tokenizer = AutoTokenizer.from_pretrained(model_name)
  model = AutoModelForMaskedLM.from_pretrained(model_name).eval()

  return tokenizer, model


def decode(tokenizer, pred_idx):
    
    ignore_tokens = string.punctuation + '[PAD]'
    tokens = []
    for w in pred_idx:
        token = ''.join(tokenizer.decode(w).split())
        if token not in ignore_tokens:
            tokens.append(token.replace('##', ''))
    return tokens


def encode(tokenizer, text_sentence, add_special_tokens=True):
    text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
    
    # if <mask> is the last token, append a "." so that models dont predict punctuation.
    if tokenizer.mask_token == text_sentence.split()[-1]:
        text_sentence += ' .'

    input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
    mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
    return input_ids, mask_idx


def get_all_predictions(model, tokenizer, text_sentence, top_k_words):
    input_ids, mask_idx = encode(tokenizer, text_sentence)

    with torch.no_grad():
        predict = model(input_ids)[0]

    words = decode(tokenizer,
                   predict[0, mask_idx, :].topk(top_k_words).indices.tolist())
    return words


**Get predictions of input**

In [4]:
model_name = 'BERT'
input_text = "Hello my"
top_n_predictions = 5

# prepare input
input_text += ' <mask>'

**Load models**

In [None]:
tokenizer, model  = load_model(model_name)

**Tokenize data**

In [7]:
input_ids, mask_idx = encode(tokenizer, input_text)

**Get predictions**

In [10]:
with torch.no_grad():
  predictions = model(input_ids)[0]

top_predictions_words = decode(tokenizer, predictions[0, mask_idx, :].topk(top_n_predictions).indices.tolist())

In [11]:
print(top_predictions_words)

['dear', 'friend', 'love', 'friends', 'darling']
