In [15]:
import torch
from transformers import BertTokenizer, BertForMaskedLM

In [196]:
model = BertForMaskedLM.from_pretrained('bert-base-multilingual-cased')
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [197]:
def prepareInputs(init_text):
    # List of punctuation to determine where segments end
    punc_list = [".", "?", "!"]
    # Prepend the [CLS] tag
    prompt_text = "[CLS] " + init_text
    # Insert the [SEP] tags
    for i in range(0, len(prompt_text)):
        if prompt_text[i] in punc_list:
            prompt_text = prompt_text[:i + 1] + " [SEP]" + prompt_text[i + 1:]

    return prompt_text

In [198]:
def createSegIDs(tokenized_text):
    currentSeg = 0
    seg_ids = []
    for token in tokenized_text:
        seg_ids.append(currentSeg)
        if token == "[SEP]":
            currentSeg += 1

    return seg_ids

In [199]:
def addMask(tokenized_text, mask_word):
    result_text = tokenized_text
    mask_word_tokens = tokenizer.tokenize(mask_word)
    mask_indices = []
    for i in range(0, len(result_text)):
        if result_text[i] in mask_word_tokens:
            result_text[i] = "[MASK]"
            mask_indices.append(i)

    return (result_text, mask_indices)

In [200]:
def getSinglePred(indexed_tokens, indexed_masked_tokens, segment_ids, mask_index):
    tokens_tensor = torch.tensor([indexed_masked_tokens])
    segment_tensor = torch.tensor([segment_ids])

    with torch.no_grad():
        outputs = model(tokens_tensor, token_type_ids=segment_tensor)
        prediction_scores = outputs[0]

    next_token_logits = prediction_scores[0, mask_index, :]
    preds = ([tokenizer.convert_ids_to_tokens(index.item()) for index in next_token_logits.topk(5).indices])
    prob = torch.softmax(next_token_logits, 0)[indexed_tokens[mask_index]].item()

    return (preds, prob)

In [201]:
def getPreds(indexed_tokens,
             indexed_masked_tokens,
             masked_text,
             segment_ids,
             mask_indices,
             totalPreds,
             totalProbs,
             nextSentences,
             index):
    preds, prob = getSinglePred(indexed_tokens, indexed_masked_tokens, segment_ids, mask_indices[index])
    totalPreds.append(preds)
    totalProbs.append(prob)

    for next_word in preds:
        masked_text[mask_indices[index]] = next_word
        indexed_masked_tokens = tokenizer.convert_tokens_to_ids(masked_text)
        if (index == len(mask_indices) - 1):
            result = [indexed_masked_tokens[i] for i in mask_indices]
            nextSentences.append(tokenizer.decode(result))
        else:
            getPreds(indexed_tokens,
                        indexed_masked_tokens,
                        masked_text,
                        segment_ids,
                        mask_indices,
                        totalPreds,
                        totalProbs,
                        nextSentences,
                        index + 1)

In [282]:
input_text = "[MASK] [MASK] día [MASK]."
word_to_mask = "día"
prepped_text = prepareInputs(input_text)
print(prepped_text)

[CLS] [MASK] [MASK] día [MASK]. [SEP]


In [283]:
tokenized_text = tokenizer.tokenize(prepped_text)
print(tokenized_text)

['[CLS]', '[MASK]', '[MASK]', 'día', '[MASK]', '.', '[SEP]']


In [284]:
segment_ids = createSegIDs(tokenized_text)
print(segment_ids)

[0, 0, 0, 0, 0, 0, 0]


In [285]:
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
print(indexed_tokens)

[101, 103, 103, 14184, 103, 119, 102]


In [286]:
masked_text, mask_indices = addMask(tokenized_text, word_to_mask)
print(masked_text)
print(mask_indices)

['[CLS]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '.', '[SEP]']
[3]


In [287]:
indexed_masked_tokens = tokenizer.convert_tokens_to_ids(masked_text)
print(indexed_masked_tokens)

[101, 103, 103, 103, 103, 119, 102]


In [288]:
totalPreds = []
totalProbs = []
nextSentences = []
getPreds(indexed_tokens, indexed_masked_tokens, masked_text, segment_ids, mask_indices, totalPreds, totalProbs, nextSentences, 0)

In [289]:
totalPreds

[['.', '।', ',', ':', '。']]

In [290]:
totalProbs

[4.460775926418137e-06]

In [291]:
nextSentences

['.', '।', ',', ':', '。']

In [None]:
"Es un dia hermoso."
"Es un dia [MASK] [MASK] [MASK]."
"[MASK] [MASK] dia [MASK]."