In [12]:
import pandas as pd
import torch
from transformers import BertTokenizer, BertForMaskedLM
from tqdm import tqdm
from dijkstra.predictions import get_all_predictions

In [13]:
test_dataset_path = '../../data/test.csv'
model_checkpoints = []

datasets = ['carolina', 'psa_small', 'psa_full']
models = ['end_token', 'random_tokens']

for dataset in datasets:
    for model in models:
        model_checkpoints.append(f'../../models/{dataset}/{model}/best_model.pt')

# Load tokenizer and model

In [14]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

In [15]:
model = BertForMaskedLM.from_pretrained('bert-base-multilingual-cased')

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Load test dataset

In [16]:
df = pd.read_csv(test_dataset_path)

In [17]:
df

Unnamed: 0,sentence,last_word
0,"Colossal é um filme de comédia, ação-thriller ...",Vigalondo
1,"O filme é protagonizado por Anne Hathaway, Dan...",Nelson
2,O filme teve estreia no Festival Internacional...,2016
3,Está programado para ser lançado pela NEON em ...,2017
4,Enredo Depois de perder seu emprego e namorado...,Seul
...,...,...
47021,"Nosso compromisso, por meio do Ibross, é contr...",país
47022,"RENILSON REHEM DE SOUZA, médico, é presidente ...",Saúde
47023,br Os artigos publicados com assinatura não tr...,jornal
47024,Sua publicação obedece ao propósito de estimul...,contemporâneo


# Get predictions

In [18]:
top_k = 5

In [19]:
sentences = df['sentence'][:5]
last_words = df['last_word'][:5]
corrects = {}

In [20]:
total_samples = len(sentences)

In [21]:
def run_for_model(model_name):
    corrects[model_name] = [0] * top_k
    model.load_state_dict(torch.load(model_name))
    loop = tqdm(zip(sentences, last_words), total=total_samples, leave=True)

    for sent, word in loop:
        suggestions = get_all_predictions(sent, tokenizer, model, top_k, False, 'cuda:0')

        try:
            # gotta add this extra space because
            # the prediction of a new word also
            # predicts a preceding whitespace
            ix = suggestions.index(' '+word)
        except:
            continue

        corrects[model_name][ix] += 1

    tot = 0

    for ix, val in enumerate(corrects[model_name]):
        tot += val
        print(f'Top {ix+1}: {tot}/{total_samples} = {tot/total_samples}')

In [22]:
for model_checkpoint in model_checkpoints:
    print(f'Running {model_checkpoint}:')
    run_for_model(model_checkpoint)

Running ../../models/carolina/end_token/best_model.pt:


  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:39<00:00,  7.88s/it]


Top 1: 0/5 = 0.0
Top 2: 0/5 = 0.0
Top 3: 0/5 = 0.0
Top 4: 0/5 = 0.0
Top 5: 0/5 = 0.0
Running ../../models/carolina/random_tokens/best_model.pt:


100%|██████████| 5/5 [00:29<00:00,  5.82s/it]


Top 1: 0/5 = 0.0
Top 2: 0/5 = 0.0
Top 3: 0/5 = 0.0
Top 4: 0/5 = 0.0
Top 5: 0/5 = 0.0
Running ../../models/psa_small/end_token/best_model.pt:


100%|██████████| 5/5 [00:35<00:00,  7.00s/it]


Top 1: 0/5 = 0.0
Top 2: 0/5 = 0.0
Top 3: 1/5 = 0.2
Top 4: 1/5 = 0.2
Top 5: 2/5 = 0.4
Running ../../models/psa_small/random_tokens/best_model.pt:


100%|██████████| 5/5 [09:37<00:00, 115.60s/it]


Top 1: 0/5 = 0.0
Top 2: 0/5 = 0.0
Top 3: 0/5 = 0.0
Top 4: 0/5 = 0.0
Top 5: 0/5 = 0.0
Running ../../models/psa_full/end_token/best_model.pt:


100%|██████████| 5/5 [00:55<00:00, 11.13s/it]


Top 1: 0/5 = 0.0
Top 2: 1/5 = 0.2
Top 3: 2/5 = 0.4
Top 4: 2/5 = 0.4
Top 5: 2/5 = 0.4
Running ../../models/psa_full/random_tokens/best_model.pt:


100%|██████████| 5/5 [00:41<00:00,  8.26s/it]

Top 1: 0/5 = 0.0
Top 2: 0/5 = 0.0
Top 3: 1/5 = 0.2
Top 4: 1/5 = 0.2
Top 5: 1/5 = 0.2



