In [14]:
import pandas as pd
import torch
from transformers import BertTokenizer, BertForMaskedLM
from tqdm import tqdm
from dijkstra.predictions import create_functions

In [15]:
test_dataset_path = '../../data/test.csv'
model_checkpoints = []

datasets = ['carolina', 'psa_small', 'psa_full']
models = ['end_token', 'random_tokens']

for dataset in datasets:
    for model in models:
        model_checkpoints.append(f'../../models/{dataset}/{model}/best_model.pt')

# Load tokenizer and model

In [16]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

In [17]:
model = BertForMaskedLM.from_pretrained('bert-base-multilingual-cased')
model.load_state_dict(torch.load('../../models/carolina/random_tokens/best_model.pt'))

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

# Load test dataset

In [18]:
df = pd.read_csv(test_dataset_path)

In [19]:
df

Unnamed: 0,sentence,last_word
0,"Colossal é um filme de comédia, ação-thriller ...",Vigalondo
1,"O filme é protagonizado por Anne Hathaway, Dan...",Nelson
2,O filme teve estreia no Festival Internacional...,2016
3,Está programado para ser lançado pela NEON em ...,2017
4,Enredo Depois de perder seu emprego e namorado...,Seul
...,...,...
47021,"Nosso compromisso, por meio do Ibross, é contr...",país
47022,"RENILSON REHEM DE SOUZA, médico, é presidente ...",Saúde
47023,br Os artigos publicados com assinatura não tr...,jornal
47024,Sua publicação obedece ao propósito de estimul...,contemporâneo


# Get predictions

In [20]:
get_all_predictions = create_functions(tokenizer, model, 'cuda:0')

In [21]:
import random
sentences = df['sentence']
random.shuffle(sentences)
sentences = sentences[:20]
last_words = ['temporada']
corrects = {}

In [25]:
import datetime

start_time = datetime.datetime.now()

top_k = 5
average_calls = 0

for s in sentences:
    average_calls += get_all_predictions(s,top_k, False,)[-1]

end_time = datetime.datetime.now()

print(end_time-start_time)
print(average_calls/len(sentences))

0:06:45.847408
10.8


In [9]:
total_samples = len(sentences)

In [10]:
def run_for_model(model_name):
    corrects[model_name] = [0] * top_k
    model.load_state_dict(torch.load(model_name))
    loop = tqdm(zip(sentences, last_words), total=total_samples, leave=True)

    for sent, word in loop:
        suggestions = get_all_predictions(sent, tokenizer, model, top_k, False, 'cuda:1')

        try:
            # gotta add this extra space because
            # the prediction of a new word also
            # predicts a preceding whitespace
            ix = suggestions.index(' '+word)
        except:
            continue

        corrects[model_name][ix] += 1

    tot = 0

    for ix, val in enumerate(corrects[model_name]):
        tot += val
        print(f'Top {ix+1}: {tot}/{total_samples} = {tot/total_samples}')

In [12]:
for model_checkpoint in model_checkpoints:
    print(f'Running {model_checkpoint}:')
    run_for_model(model_checkpoint)

Running ../../models/carolina/end_token/best_model.pt:


100%|██████████| 1/1 [00:14<00:00, 14.97s/it]


Top 1: 0/1 = 0.0
Top 2: 0/1 = 0.0
Top 3: 0/1 = 0.0
Top 4: 0/1 = 0.0
Top 5: 1/1 = 1.0
Running ../../models/carolina/random_tokens/best_model.pt:


100%|██████████| 1/1 [00:10<00:00, 10.70s/it]


Top 1: 0/1 = 0.0
Top 2: 0/1 = 0.0
Top 3: 0/1 = 0.0
Top 4: 1/1 = 1.0
Top 5: 1/1 = 1.0
Running ../../models/psa_small/end_token/best_model.pt:


  0%|          | 0/1 [00:01<?, ?it/s]


KeyboardInterrupt: 