In [None]:
%%capture
# Import torch regexp and pretrained bert model and bert tokenizer
import torch
import torchvision
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM;
import re

In [None]:
# Load model
tokenizer = BertTokenizer.from_pretrained('bert-large-cased')
model = BertForMaskedLM.from_pretrained('bert-large-cased')
model.eval()
print("Model loading finished")

In [None]:
def bertPredictMaskedWord(text):
    """
        Predict mask using bert from text input. 
        Use '[MASK]' token to set mask in sentence. 
        You can use two sentences if you want to precondition the model.
        More than two 
    """
    # Use regular expression to format input sentences
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
    sentences[0] = '[CLS] ' + sentences[0]
    segments_ids = []
    # Tokenize sentences
    for i in range(len(sentences)):
        sentences[i] = (sentences[i] + ' [SEP]').replace('.', ' .')
        for s in sentences[i].split(' '):
            segments_ids.append(i)  
    sentencesText = ' '.join(sentences)
    tokenized_text = tokenizer.tokenize(sentencesText)
    # Convert token to vocabulary indices
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    masked_index = tokenized_text.index('[MASK]')
    # Define which word belongs to what sentence
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    # Run prediction
    with torch.no_grad():
        predictions = model(tokens_tensor, segments_tensors)
    # Take argmax and convert to token
    predicted_index = torch.argmax(predictions[0, masked_index]).item()
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
    return predicted_token

In [None]:
# Let's now test the model. 
# BERT is trained to take either one (sentence A) or two (A, B) sentences as input. 
# See paper for more details: https://arxiv.org/abs/1810.04805
text = "This is a machine learning demo. This is how machines [MASK]."
maskedWord = bertPredictMaskedWord(text)
print(maskedWord)

In [None]:
text = "I like cooking. My favourite dish is [MASK]."
maskedWord = bertPredictMaskedWord(text)
print(maskedWord)