In [448]:
import torch
from transformers import BertTokenizer, BertForMaskedLM

In [451]:
model = BertForMaskedLM.from_pretrained("TurkuNLP/wikibert-base-es-cased")
tokenizer = BertTokenizer.from_pretrained("TurkuNLP/wikibert-base-es-cased")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [454]:
def prepareInputs(init_text):
    # List of punctuation to determine where segments end
    punc_list = [".", "?", "!"]
    # Prepend the [CLS] tag
    prompt_text = "[CLS] " + init_text
    # Insert the [SEP] tags
    for i in range(0, len(prompt_text)):
        if prompt_text[i] in punc_list:
            prompt_text = prompt_text[:i + 1] + " [SEP]" + prompt_text[i + 1:]

    return prompt_text

In [457]:
def createSegIDs(tokenized_text):
    currentSeg = 0
    seg_ids = []
    for token in tokenized_text:
        seg_ids.append(currentSeg)
        if token == "[SEP]":
            currentSeg += 1

    return seg_ids

In [507]:
def computeLogProb(masked_text, original_text, segment_ids, index):  
    indexed_tokens = tokenizer.convert_tokens_to_ids(masked_text)
    print(masked_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    segment_tensor = torch.tensor([segment_ids])
        
    with torch.no_grad():
        outputs = model(tokens_tensor, token_type_ids=segment_tensor)
        next_token_logits = outputs[0][0, index, :]

    next_token_logprobs = next_token_logits - next_token_logits.logsumexp(0)
    print(next_token_logprobs)

    return next_token_logprobs[tokenizer.convert_tokens_to_ids(original_text[index])].item()

In [508]:
def bigContext(tokenized_text, segment_ids, index):
    text = tokenized_text.copy()
    text[index] = "[MASK]"
    return computeLogProb(text, tokenized_text, segment_ids, index)

In [509]:
def smallContext(tokenized_text, segment_ids, index):
    text = tokenized_text.copy()
    for i in range(1, len(text) - 1):
        if i != index - 1 and i != index + 1:
            text[i] = "[MASK]"
    return computeLogProb(text, tokenized_text, segment_ids, index)

In [510]:
input_text = "Es un día hermoso."
prepped_text = prepareInputs(input_text)
print(prepped_text)

[CLS] Es un día hermoso. [SEP]


In [511]:
tokenized_text = tokenizer.tokenize(prepped_text)
print(tokenized_text)

['[CLS]', 'es', 'un', 'dia', 'hermos', '##o', '.', '[SEP]']


In [512]:
segment_ids = createSegIDs(tokenized_text)
print(segment_ids)

[0, 0, 0, 0, 0, 0, 0, 0]


In [513]:
print(smallContext(tokenized_text, segment_ids, 3))
print(bigContext(tokenized_text, segment_ids, 3))

['[CLS]', '[MASK]', 'un', '[MASK]', 'hermos', '[MASK]', '[MASK]', '[SEP]']
tensor([-14.8545, -14.9305, -16.4350,  ..., -17.9626, -15.8582, -18.0102])
-8.252510070800781
['[CLS]', 'es', 'un', '[MASK]', 'hermos', '##o', '.', '[SEP]']
tensor([-16.3195, -17.2106, -19.6498,  ..., -19.7954, -19.6179, -18.5400])
-9.8970365524292


In [514]:
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

In [515]:
tokens_tensor = torch.tensor([indexed_tokens])
segment_tensor = torch.tensor([segment_ids])
    
with torch.no_grad():
    outputs = model(tokens_tensor, token_type_ids=segment_tensor)
    next_token_logits = outputs[0][0, 3, :]

In [516]:
next_token_logprobs = next_token_logits - next_token_logits.logsumexp(0)