In [2]:
import random
import pandas as pd
import pickle
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM

### Getting top 5 BERT predictions and probabilities

In [7]:
def unmask(tokens, checkpoint, ids=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForMaskedLM.from_pretrained(checkpoint).to(device)

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    if ids:
        input_ids = torch.tensor(tokens).to(device)
    else:
        input_ids = tokenizer(tokens, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device)
    
    masked_idx = torch.where(input_ids == tokenizer.mask_token_id)
    masked_idx = masked_idx[1].item()
    
    model.eval()
    with torch.no_grad():
        logits = model(input_ids).logits # shape: [1, seq_len, vocab_size]
    
    probs = F.softmax(logits[0, masked_idx, :], dim=-1) # shape: [vocab_size]
    top_ids = torch.argsort(probs, axis=-1, descending=True)[:5]
    top_tokens = [tokenizer.convert_ids_to_tokens(id.item()) for id in top_ids]
    top_token_probs = probs[top_ids].cpu()

    return pd.DataFrame({'Prediction': top_tokens, 'Probability': top_token_probs})

In [10]:
simple_agrmt_results = pd.read_csv('../results/syntax_results/simple_agrmt_results.csv', sep='\t')
sentence = random.choice(simple_agrmt_results['masked_sent'])
print('Masked sentence:')
print(sentence + '\n')

print('Training step: 0 (random initialization)')
print('Top 5 BERT predictions:')
unmask(sentence, 'google/multiberts-seed_0-step_0k')

Masked sentence:
the pilots [MASK] tall

Training step: 0 (random initialization)
Top 5 BERT predictions:


Unnamed: 0,Prediction,Probability
0,upgrades,0.000198
1,occult,0.000193
2,minorities,0.000182
3,jakob,0.000176
4,misconduct,0.000172


In [11]:
print('Training step: 2,000,000 (fully trained model)')
print('Top 5 BERT predictions:')
unmask(sentence, 'google/multiberts-seed_0-step_2000k')

Training step: 2,000,000 (fully trained model)
Top 5 BERT predictions:


Unnamed: 0,Prediction,Probability
0,are,0.137907
1,:,0.104613
2,-,0.082343
3,',0.046873
4,very,0.032368


### Checking the negative samples

In [12]:
with open('../data/wikitext/neg_sample_sents.pickle', 'rb') as f:
    sample_sents = pickle.load(f)

sample_sents = pd.DataFrame(sample_sents)
sample_sents.columns = ['Token', 'Token_id', 'PositiveSamples', 'NegativeSamples']
sample_sents

Unnamed: 0,Token,Token_id,PositiveSamples,NegativeSamples
0,a,1037,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23...","[[101, 103, 8945, 11314, 2121, 2003, 2019, 239..."
1,by,2011,"[[101, 2023, 2001, 2628, 2011, 1037, 4626, 253...","[[101, 2023, 2001, 2628, 2011, 1037, 103, 2535..."
2,constraint,27142,"[[101, 10722, 5521, 6563, 3608, 1005, 1055, 37...","[[101, 10722, 5521, 6563, 3608, 1005, 1055, 37..."
3,exploring,11131,"[[101, 2122, 2048, 2020, 4699, 1999, 103, 3306...","[[101, 2122, 2048, 2020, 4699, 1999, 11131, 33..."
4,fined,16981,"[[101, 2002, 2001, 103, 1002, 1015, 1030, 1010...","[[101, 2002, 2001, 16981, 1002, 1015, 1030, 10..."
5,of,1997,"[[101, 1999, 2432, 8945, 11314, 2121, 5565, 10...","[[101, 1999, 2432, 8945, 11314, 2121, 5565, 10..."
6,sob,17540,"[[101, 2002, 2001, 3391, 7622, 2011, 1996, 159...","[[101, 2002, 2001, 3391, 7622, 2011, 1996, 159..."
7,the,1996,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23...","[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23..."
8,variable,8023,"[[101, 2520, 5828, 3766, 2764, 2010, 13805, 22...","[[101, 2520, 5828, 3766, 2764, 2010, 13805, 22..."
9,with,2007,"[[101, 2002, 2001, 2856, 2011, 2198, 14381, 19...","[[101, 2002, 2001, 2856, 2011, 2198, 14381, 19..."


In [17]:
sample_row = sample_sents.sample(1)
sample_row

Unnamed: 0,Token,Token_id,PositiveSamples,NegativeSamples
8,variable,8023,"[[101, 2520, 5828, 3766, 2764, 2010, 13805, 22...","[[101, 2520, 5828, 3766, 2764, 2010, 13805, 22..."


In [18]:
tokenizer = AutoTokenizer.from_pretrained('google/multiberts-seed_0')

pos_sent_ids = sample_row['PositiveSamples'].tolist()[0][0]
neg_sent_ids = sample_row['NegativeSamples'].tolist()[0][0]

pos_sent = tokenizer.convert_ids_to_tokens(pos_sent_ids)
neg_sent = tokenizer.convert_ids_to_tokens(neg_sent_ids)

print(pos_sent)
print(neg_sent)

['[CLS]', 'william', 'carlos', 'williams', 'developed', 'his', 'poetic', 'along', 'distinctly', 'american', 'lines', 'with', 'his', '[MASK]', 'foot', 'and', 'a', 'di', '##ction', 'he', 'claimed', 'was', 'taken', '"', 'from', 'the', 'mouths', 'of', 'polish', 'mothers', '"', '.', '[SEP]', 'both', 'pound', 'and', 'h', '.', 'd', '.', 'turned', 'to', 'writing', 'long', 'poems', ',', 'but', 'retained', 'much', 'of', 'the', 'hard', 'edge', 'to', 'their', 'language', 'as', 'an', 'im', '##agi', '##st', 'legacy', '.', '[SEP]']
['[CLS]', 'william', 'carlos', 'williams', 'developed', 'his', 'poetic', 'along', 'distinctly', 'american', 'lines', 'with', 'his', 'variable', 'foot', 'and', 'a', 'di', '##ction', '[MASK]', 'claimed', 'was', 'taken', '"', 'from', 'the', 'mouths', 'of', 'polish', 'mothers', '"', '.', '[SEP]', 'both', 'pound', 'and', 'h', '.', 'd', '.', 'turned', 'to', 'writing', 'long', 'poems', ',', 'but', 'retained', 'much', 'of', 'the', 'hard', 'edge', 'to', 'their', 'language', 'as', '