In [15]:
import torch
from transformers import BertTokenizer, BertForMaskedLM

In [16]:
model = BertForMaskedLM.from_pretrained('bert-base-multilingual-cased')
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
def prepareInputs(init_text):
    # List of punctuation to determine where segments end
    punc_list = [".", "?", "!"]
    # Prepend the [CLS] tag
    prompt_text = "[CLS] " + init_text
    # Insert the [SEP] tags
    for i in range(0, len(prompt_text)):
        if prompt_text[i] in punc_list:
            prompt_text = prompt_text[:i + 1] + " [SEP]" + prompt_text[i + 1:]

    return prompt_text

In [18]:
def createSegIDs(tokenized_text):
    currentSeg = 0
    seg_ids = []
    for token in tokenized_text:
        seg_ids.append(currentSeg)
        if token == "[SEP]":
            currentSeg += 1

    return seg_ids

In [19]:
def addMask(tokenized_text, mask_word):
    result_text = tokenized_text
    mask_word_tokens = tokenizer.tokenize(mask_word)
    mask_indices = []
    for i in range(0, len(result_text)):
        if result_text[i] in mask_word_tokens:
            result_text[i] = "[MASK]"
            mask_indices.append(i)

    return (result_text, mask_indices)

In [80]:
def getSinglePred(indexed_tokens, indexed_masked_tokens, segment_ids, mask_index):
    tokens_tensor = torch.tensor([indexed_masked_tokens])
    segment_tensor = torch.tensor([segment_ids])

    probs = []
    preds = []

    with torch.no_grad():
        outputs = model(tokens_tensor, token_type_ids=segment_tensor)
        prediction_scores = outputs[0]

    next_token_logits = prediction_scores[0, mask_index, :]
    preds = ([tokenizer.convert_ids_to_tokens(index.item()) for index in next_token_logits.topk(5).indices])
    prob = torch.softmax(next_token_logits, 0)[indexed_tokens[mask_index]].item()

    return (preds, prob)

In [164]:
def getPreds(indexed_tokens,
             indexed_masked_tokens,
             masked_text,
             segment_ids,
             mask_indices,
             totalPreds,
             totalProbs,
             nextSentences,
             index):
    preds, prob = getSinglePred(indexed_tokens, indexed_masked_tokens, segment_ids, mask_indices[index])
    totalPreds.append(preds)
    totalProbs.append(prob)

    for next_word in preds:
        masked_text[mask_indices[index]] = next_word
        indexed_masked_tokens = tokenizer.convert_tokens_to_ids(masked_text)
        if (index == len(mask_indices) - 1):
            nextSentences.append(tokenizer.decode(indexed_masked_tokens))
        else:
            getPreds(indexed_tokens,
                        indexed_masked_tokens,
                        masked_text,
                        segment_ids,
                        mask_indices,
                        totalPreds,
                        totalProbs,
                        nextSentences,
                        index + 1)

In [165]:
input_text = "Es un día hermoso."
word_to_mask = "hermoso"
prepped_text = prepareInputs(input_text)
print(prepped_text)

[CLS] Es un día hermoso. [SEP]


In [166]:
tokenized_text = tokenizer.tokenize(prepped_text)
print(tokenized_text)

['[CLS]', 'Es', 'un', 'día', 'her', '##mos', '##o', '.', '[SEP]']


In [167]:
segment_ids = createSegIDs(tokenized_text)
print(segment_ids)

[0, 0, 0, 0, 0, 0, 0, 0, 0]


In [168]:
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
print(indexed_tokens)

[101, 10912, 10119, 14184, 10485, 13386, 10133, 119, 102]


In [169]:
masked_text, mask_indices = addMask(tokenized_text, word_to_mask)
print(masked_text)
print(mask_indices)

['[CLS]', 'Es', 'un', 'día', '[MASK]', '[MASK]', '[MASK]', '.', '[SEP]']
[4, 5, 6]


In [170]:
indexed_masked_tokens = tokenizer.convert_tokens_to_ids(masked_text)
print(indexed_masked_tokens)

[101, 10912, 10119, 14184, 103, 103, 103, 119, 102]


In [171]:
totalPreds = []
totalProbs = []
nextSentences = []
getPreds(indexed_tokens, indexed_masked_tokens, masked_text, segment_ids, mask_indices, totalPreds, totalProbs, nextSentences, 0)

In [172]:
totalPreds

[['de', 'en', 'para', '.', 'a'],
 ['la', 'lu', 'las', 'los', 'lo'],
 ['vida', 'cultura', 'religión', 'educación', 'guerra'],
 ['guerra', 'guerre', '##te', '##ta', '##r'],
 ['##r', '##s', '##rs', '##t', '##ra'],
 ['##ra', '##ras', '##ro', '##ros', '##a'],
 ['##a', '##o', '##as', '##ar', '##an'],
 ['lo', 'Lo', 'li', 'po', 'll'],
 ['##an', '##án', '##ano', '##ar', '##a'],
 ['##a', '##o', '##e', '##as', '##an'],
 ['##an', '##án', '##ano', '##ar', '##a'],
 ['##a', '##o', '##as', '##u', '##ar'],
 ['##ar', '##r', '##al', '##ara', '##are'],
 ['ll', 'pal', 'l', 'lo', 'call'],
 ['##are', '##ar', '##ate', '##ari', '##aar'],
 ['##aar', '##lo', '##oo', '##ar', '##la'],
 ['##la', '##lo', '##le', '##l', '##los'],
 ['##los', '##lo', '##os', '##les', '##las'],
 ['##las', '##los', '##as', '##la', '##les'],
 ['call', 'pal', 'sal', 'bol', 'rosa'],
 ['##les', '##los', '##las', '##le', '##os'],
 ['##os', '##o', '##los', '##s', '##mos'],
 ['##mos', '##os', '##mo', 'sal', '##món'],
 ['##món', '##ón', '##lón',

In [173]:
totalProbs

[7.71815211919602e-06,
 1.567037725180853e-05,
 0.0001654300285736099,
 0.0008031927864067256,
 9.124159987550229e-05,
 1.581715878273826e-05,
 0.0041012815199792385,
 8.052931932400753e-11,
 3.5911118175135925e-05,
 0.008041121065616608,
 8.465244718536269e-06,
 0.005998819135129452,
 2.528444120741824e-08,
 2.6212879089548835e-11,
 8.283332135761157e-06,
 0.000747158657759428,
 2.174424116674345e-05,
 1.202324256155407e-06,
 4.314163959406869e-08,
 1.2995370752832969e-06,
 0.00010645577276591212,
 0.0007507798727601767,
 0.00010410712275188416,
 0.0003608488186728209,
 2.19695721170865e-05,
 4.12734202370757e-09,
 0.0003392874787095934,
 1.2097406397515442e-05,
 0.003711400553584099,
 8.368590897589456e-06,
 9.201301872963086e-05]

In [174]:
nextSentences

['[CLS] Es un día de la vida. [SEP]',
 '[CLS] Es un día de la cultura. [SEP]',
 '[CLS] Es un día de la religión. [SEP]',
 '[CLS] Es un día de la educación. [SEP]',
 '[CLS] Es un día de la guerra. [SEP]',
 '[CLS] Es un día de lu guerra. [SEP]',
 '[CLS] Es un día de lu guerre. [SEP]',
 '[CLS] Es un día de lute. [SEP]',
 '[CLS] Es un día de luta. [SEP]',
 '[CLS] Es un día de lur. [SEP]',
 '[CLS] Es un día de lasr. [SEP]',
 '[CLS] Es un día de lass. [SEP]',
 '[CLS] Es un día de lasrs. [SEP]',
 '[CLS] Es un día de last. [SEP]',
 '[CLS] Es un día de lasra. [SEP]',
 '[CLS] Es un día de losra. [SEP]',
 '[CLS] Es un día de losras. [SEP]',
 '[CLS] Es un día de losro. [SEP]',
 '[CLS] Es un día de losros. [SEP]',
 '[CLS] Es un día de losa. [SEP]',
 '[CLS] Es un día de loa. [SEP]',
 '[CLS] Es un día de loo. [SEP]',
 '[CLS] Es un día de loas. [SEP]',
 '[CLS] Es un día de loar. [SEP]',
 '[CLS] Es un día de loan. [SEP]',
 '[CLS] Es un día en loan. [SEP]',
 '[CLS] Es un día en loán. [SEP]',
 '[CLS] Es 