In [4]:
import random
import torch
import torch.nn.functional as F
import pandas as pd
from transformers import AutoTokenizer, AutoModelForMaskedLM

### Getting top 5 BERT predictions and probabilities

In [2]:
def unmask(tokens, checkpoint, ids=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForMaskedLM.from_pretrained(checkpoint).to(device)

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    if ids:
        input_ids = torch.tensor(tokens).to(device)
    else:
        input_ids = tokenizer(tokens, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device)
    
    masked_idx = torch.where(input_ids == tokenizer.mask_token_id)
    masked_idx = masked_idx[1].item()
    
    model.eval()
    with torch.no_grad():
        logits = model(input_ids).logits # shape: [1, seq_len, vocab_size]
    
    probs = F.softmax(logits[0, masked_idx, :], dim=-1) # shape: [vocab_size]
    top_ids = torch.argsort(probs, axis=-1, descending=True)[:5]
    top_tokens = [tokenizer.convert_ids_to_tokens(id.item()) for id in top_ids]
    top_token_probs = probs[top_ids].cpu()

    return pd.DataFrame({'Prediction': top_tokens, 'Probability': top_token_probs})

In [5]:
simple_agrmt_results = pd.read_csv('../results/syntax_results/simple_agrmt_results.csv', sep='\t')
sentence = random.choice(simple_agrmt_results['masked_sent'])
print('Masked sentence:')
print(sentence + '\n')

print('Training step: 0 (random initialization)')
print('Top 5 BERT predictions:')
unmask(sentence, 'google/multiberts-seed_0-step_0k')

Masked sentence:
the farmers [MASK]

Training step: 0 (random initialization)
Top 5 BERT predictions:


  return torch._C._cuda_getDeviceCount() > 0


Unnamed: 0,Prediction,Probability
0,corrosion,0.000185
1,flights,0.000182
2,gabriel,0.00018
3,##þ,0.000179
4,yellowstone,0.000179


In [6]:
print('Training step: 2,000,000 (fully trained model)')
print('Top 5 BERT predictions:')
unmask(sentence, 'google/multiberts-seed_0-step_2000k')

Training step: 2,000,000 (fully trained model)
Top 5 BERT predictions:


Unnamed: 0,Prediction,Probability
0,who,0.060694
1,in,0.057403
2,-,0.044014
3,teams,0.037971
4,',0.035636
