In [None]:
import pandas as pd
import torch
from datasets import load_dataset
from transformers import BertTokenizer, BertForMaskedLM
from tqdm import tqdm
from dijkstra.predictions import create_functions
import re

In [None]:
# brazilian alphabet
lower_case = r'abcdefghijklmnopqrstuvwxyzáàâãéêíóôõúç'
upper_case = r'ABCDEFGHIJKLMNOPQRSTUVWXYZÁÀÂÃÉÊÍÓÔÕÚÇ'

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertForMaskedLM.from_pretrained('bert-base-multilingual-cased')

In [None]:
state_dict=torch.load('../../models/carolina/end_token/best_model.pt', map_location='cuda:1')
model.load_state_dict(state_dict)

In [None]:
# get sentences
df = load_dataset('carolmou/random-sentences')['test']
sentences = df['correct_text']

# Get predictions

In [None]:
# create function with specified tokenizer, model and device
get_all_predictions = create_functions(tokenizer, model, 'cuda:1')

In [None]:
# matches all lower case words or word with the first upper character and hiphenized words
reg = rf'\b(?:[{upper_case}][{lower_case}]*|[{lower_case}]+(?:-[{lower_case}]+)*|[{lower_case}]*[{upper_case}](?=[{lower_case}]))\b'

In [None]:
# of (sentence, last_word)
data_pairs = []

for sent in sentences:
    # all words
    words = list(re.finditer(reg, sent))

    if not words:
        continue

    beg_index = words[-1].start()
    sent_without_last = sent[:beg_index]
    data_pairs.append((sent_without_last, words[-1].group()))

In [None]:
total_samples = len(data_pairs)

In [None]:
corrects = {}

In [None]:
top_k = 5

In [None]:
def run_for_model():
    # 'corrects' is a prefix sum array
    corrects = [0] * top_k
    loop = tqdm(data_pairs, total=total_samples, leave=True)

    for sent, word in loop:
        suggestions = get_all_predictions(sent,top_k, False)[:-1]
        try:
            # gotta add this extra space because
            # the prediction of a new word also
            # predicts a preceding whitespace
            ix = suggestions.index(' '+word)
        except:
            continue

        corrects[ix] += 1

    tot = 0

    # retrieve the actual value from the PSA
    for ix, val in enumerate(corrects):
        tot += val
        print(f'Top {ix+1}: {tot}/{total_samples} = {tot/total_samples}')

In [None]:
run_for_model()