In [13]:
import ast
from collections import defaultdict
import random
import numpy as np
import pandas as pd
import pickle
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM

### Getting top 5 BERT predictions and probabilities

In [7]:
def unmask(tokens, checkpoint, ids=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForMaskedLM.from_pretrained(checkpoint).to(device)

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    if ids:
        input_ids = torch.tensor(tokens).to(device)
    else:
        input_ids = tokenizer(tokens, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device)
    
    masked_idx = torch.where(input_ids == tokenizer.mask_token_id)
    masked_idx = masked_idx[1].item()
    
    model.eval()
    with torch.no_grad():
        logits = model(input_ids).logits # shape: [1, seq_len, vocab_size]
    
    probs = F.softmax(logits[0, masked_idx, :], dim=-1) # shape: [vocab_size]
    top_ids = torch.argsort(probs, axis=-1, descending=True)[:5]
    top_tokens = [tokenizer.convert_ids_to_tokens(id.item()) for id in top_ids]
    top_token_probs = probs[top_ids].cpu()

    return pd.DataFrame({'Prediction': top_tokens, 'Probability': top_token_probs})

In [10]:
simple_agrmt_results = pd.read_csv('../results/syntax_results/simple_agrmt_results.csv', sep='\t')
sentence = random.choice(simple_agrmt_results['masked_sent'])
print('Masked sentence:')
print(sentence + '\n')

print('Training step: 0 (random initialization)')
print('Top 5 BERT predictions:')
unmask(sentence, 'google/multiberts-seed_0-step_0k')

Masked sentence:
the pilots [MASK] tall

Training step: 0 (random initialization)
Top 5 BERT predictions:


Unnamed: 0,Prediction,Probability
0,upgrades,0.000198
1,occult,0.000193
2,minorities,0.000182
3,jakob,0.000176
4,misconduct,0.000172


In [11]:
print('Training step: 2,000,000 (fully trained model)')
print('Top 5 BERT predictions:')
unmask(sentence, 'google/multiberts-seed_0-step_2000k')

Training step: 2,000,000 (fully trained model)
Top 5 BERT predictions:


Unnamed: 0,Prediction,Probability
0,are,0.137907
1,:,0.104613
2,-,0.082343
3,',0.046873
4,very,0.032368


### Checking the random negative samples created automatically in the pipeline

(negative samples: same sentence as positive samples, but with another random word masked.)

In [5]:
with open('../data/wikitext/sample_sents_m1.pickle', 'rb') as f:
    sample_sents = pickle.load(f)

sample_sents = pd.DataFrame(sample_sents)
sample_sents.columns = ['Token', 'Token_id', 'PositiveSamples', 'NegativeSamples']
sample_sents

Unnamed: 0,Token,Token_id,PositiveSamples,NegativeSamples
0,a,1037,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23...","[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23..."
1,aa,9779,"[[101, 3424, 1030, 1011, 1030, 2948, 4721, 200...","[[101, 3424, 1030, 1011, 1030, 2948, 4721, 200..."
2,abandon,10824,"[[101, 2004, 1996, 2154, 2979, 1010, 2007, 929...","[[101, 2004, 1996, 2154, 2979, 1010, 2007, 929..."
3,abandoned,4704,"[[101, 2014, 3535, 2001, 7736, 1998, 9610, 767...","[[101, 2014, 3535, 2001, 7736, 1998, 9610, 767..."
4,abandoning,19816,"[[101, 2174, 1010, 1999, 1996, 6234, 3134, 380...","[[101, 2174, 1010, 103, 1996, 6234, 3134, 3805..."
...,...,...,...,...
9075,zone,4224,"[[101, 2006, 1996, 20198, 3483, 1005, 1055, 21...","[[101, 2006, 1996, 20198, 3483, 1005, 1055, 21..."
9076,zones,10019,"[[101, 6059, 3387, 2435, 1996, 3036, 6987, 315...","[[101, 6059, 3387, 2435, 1996, 3036, 6987, 315..."
9077,zoo,9201,"[[101, 2350, 14345, 2421, 1996, 17692, 2103, 1...","[[101, 2350, 14345, 2421, 1996, 17692, 2103, 1..."
9078,zoom,24095,"[[101, 27916, 1005, 7444, 2806, 2001, 15063, 1...","[[101, 27916, 1005, 7444, 2806, 2001, 15063, 1..."


In [6]:
sample_row = sample_sents.sample(1)
sample_row

Unnamed: 0,Token,Token_id,PositiveSamples,NegativeSamples
4120,inclusion,10502,"[[101, 1996, 103, 1997, 1037, 5961, 2011, 1183...","[[101, 1996, 10502, 1997, 1037, 5961, 2011, 11..."


In [7]:
tokenizer = AutoTokenizer.from_pretrained('google/multiberts-seed_0')

pos_sent_ids = sample_row['PositiveSamples'].tolist()[0][0]
neg_sent_ids = sample_row['NegativeSamples'].tolist()[0][0]

pos_sent = tokenizer.convert_ids_to_tokens(pos_sent_ids)
neg_sent = tokenizer.convert_ids_to_tokens(neg_sent_ids)

print(pos_sent)
print(neg_sent)

['[CLS]', 'the', '[MASK]', 'of', 'a', 'poem', 'by', 'joyce', ',', 'i', 'hear', 'an', 'army', ',', 'which', 'was', 'sent', 'to', 'pound', 'by', 'w', '.', 'b', '.', 'ye', '##ats', ',', 'took', 'on', 'a', 'wider', 'importance', 'in', 'the', 'history', 'of', 'literary', 'modernism', ',', 'as', 'the', 'subsequent', 'correspondence', 'between', 'the', 'two', 'led', 'to', 'the', 'serial', 'publication', ',', 'at', 'pound', "'", 's', 'be', '##hes', '##t', ',', 'of', 'a', 'portrait', 'of', 'the', 'artist', 'as', 'a', 'young', 'man', 'in', 'the', 'ego', '##ist', '.', '[SEP]', 'joyce', "'", 's', 'poem', 'is', 'not', 'written', 'in', 'free', 'verse', ',', 'but', 'in', 'r', '##hy', '##ming', 'qu', '##at', '##rain', '##s', '.', '[SEP]']
['[CLS]', 'the', 'inclusion', 'of', 'a', 'poem', 'by', 'joyce', ',', 'i', 'hear', 'an', 'army', ',', 'which', 'was', 'sent', 'to', 'pound', 'by', 'w', '.', 'b', '.', 'ye', '##ats', ',', 'took', 'on', 'a', 'wider', 'importance', 'in', '[MASK]', 'history', 'of', 'liter

### Creating negative samples within the same POS category

In [19]:
def disarrange_within_pos(df, token_column, pos_column):
    # dictionary holding lists of indices for each POS tag
    pos_to_indices = defaultdict(list)

    for idx, pos_list in df[pos_column].items():
        for pos in pos_list:
            pos_to_indices[pos].append(idx)
    
    # Shuffle tokens within the same POS category
    dissarranged_df = df.copy()
    for pos, indices in pos_to_indices.items():
        if len(indices) > 1:
            values = dissarranged_df.loc[indices, token_column].values
            shuffled = np.random.permutation(len(values))
            # np.random.shuffle(shuffled)

            not_deranged = values == values[shuffled]
            while np.any(not_deranged):
                to_shuffle = shuffled[not_deranged]
                if len(to_shuffle) > 1:
                    np.random.shuffle(to_shuffle)
                    shuffled[not_deranged] = to_shuffle
                else:
                    idx = to_shuffle[0]
                    possible_indices = np.delete(shuffled, idx)
                    new_idx = np.random.choice(possible_indices)
                    shuffled[idx], shuffled[new_idx] = shuffled[new_idx], shuffled[idx]
                not_deranged = values == values[shuffled]

            dissarranged_df.loc[indices, token_column] = values[shuffled]

    return dissarranged_df.rename(columns={'positive_samples': 'negative_samples'})

In [15]:
with open('../data/wikitext/sample_sents_m1.pickle', 'rb') as f:
    samples = pd.DataFrame(pickle.load(f), columns=['token', 'token_id', 'positive_samples', 'negative_samples']).sort_values('token').reset_index(drop=True)

wordbank = pd.read_csv('../data/wikitext/wikitext_wordbank.tsv', sep='\t')

samples = pd.concat([samples.drop(columns=['negative_samples']), wordbank['POS'].apply(ast.literal_eval)], axis=1)
samples

Unnamed: 0,token,token_id,positive_samples,POS
0,a,1037,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23...",[DET]
1,aa,9779,"[[101, 3424, 1030, 1011, 1030, 2948, 4721, 200...","[PROPN, NOUN]"
2,abandon,10824,"[[101, 2004, 1996, 2154, 2979, 1010, 2007, 929...",[VERB]
3,abandoned,4704,"[[101, 2014, 3535, 2001, 7736, 1998, 9610, 767...",[VERB]
4,abandoning,19816,"[[101, 2174, 1010, 1999, 1996, 6234, 3134, 380...",[VERB]
...,...,...,...,...
9075,zone,4224,"[[101, 2006, 1996, 20198, 3483, 1005, 1055, 21...",[NOUN]
9076,zones,10019,"[[101, 6059, 3387, 2435, 1996, 3036, 6987, 315...",[NOUN]
9077,zoo,9201,"[[101, 2350, 14345, 2421, 1996, 17692, 2103, 1...","[PROPN, NOUN]"
9078,zoom,24095,"[[101, 27916, 1005, 7444, 2806, 2001, 15063, 1...",[NOUN]


In [20]:
shuffled_df = disarrange_within_pos(samples, 'token', 'POS').sort_values('token').reset_index(drop=True)
shuffled_df

Unnamed: 0,token,token_id,negative_samples,POS
0,a,2019,"[[101, 2728, 8945, 11314, 2121, 2003, 103, 239...",[DET]
1,aa,9087,"[[101, 14379, 17748, 2039, 2058, 2010, 17945, ...",[NOUN]
2,abandon,10897,"[[101, 2027, 2018, 1037, 8467, 4433, 1010, 435...",[VERB]
3,abandoned,4663,"[[101, 2379, 1996, 2203, 1997, 1996, 3204, 204...",[VERB]
4,abandoning,10951,"[[101, 2076, 1996, 2458, 1997, 2087, 8915, 649...",[VERB]
...,...,...,...,...
9075,zone,3311,"[[101, 8945, 11314, 2121, 5652, 1999, 2048, 31...",[NOUN]
9076,zones,25323,"[[101, 10254, 1005, 1055, 4955, 2000, 1996, 16...",[NOUN]
9077,zoo,22892,"[[101, 4649, 11802, 5147, 6025, 2010, 2516, 20...",[ADJ]
9078,zoom,7984,"[[101, 2137, 5053, 4076, 1037, 3151, 103, 3252...",[NOUN]


In [21]:
# checking the results
a_pos_sample = samples[samples.token == 'a'].positive_samples[0][1]
a_neg_sample = shuffled_df[shuffled_df.token == 'a'].negative_samples[0][1]
print(tokenizer.decode(a_pos_sample))
print(tokenizer.decode(a_neg_sample))

[CLS] this was followed by [MASK] starring role in the play herons written by simon stephens, which was performed in 2001 at the royal court theatre. [SEP] he had a guest role in the television series judge john deed in 2002. [SEP]
[CLS] boulter starred in two films in 2008, daylight robbery by filmmaker paris leonti, and donkey punch directed by olly blackburn. [SEP] in may 2008, boulter made a guest appearance on a two @ - @ part episode arc of the television series waking the dead, followed by [MASK] appearance on the television series survivors in november 2008. [SEP]


In [22]:
# concatenating the negative samples to the samples df
pos_neg_samples = pd.concat([samples, shuffled_df.negative_samples], axis=1).drop(columns=['POS'])
# pos_neg_samples.columns = ['token', 'token_id', 'postive_samples', 'negative_samples']
pos_neg_samples

Unnamed: 0,token,token_id,positive_samples,negative_samples
0,a,1037,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23...","[[101, 2728, 8945, 11314, 2121, 2003, 103, 239..."
1,aa,9779,"[[101, 3424, 1030, 1011, 1030, 2948, 4721, 200...","[[101, 14379, 17748, 2039, 2058, 2010, 17945, ..."
2,abandon,10824,"[[101, 2004, 1996, 2154, 2979, 1010, 2007, 929...","[[101, 2027, 2018, 1037, 8467, 4433, 1010, 435..."
3,abandoned,4704,"[[101, 2014, 3535, 2001, 7736, 1998, 9610, 767...","[[101, 2379, 1996, 2203, 1997, 1996, 3204, 204..."
4,abandoning,19816,"[[101, 2174, 1010, 1999, 1996, 6234, 3134, 380...","[[101, 2076, 1996, 2458, 1997, 2087, 8915, 649..."
...,...,...,...,...
9075,zone,4224,"[[101, 2006, 1996, 20198, 3483, 1005, 1055, 21...","[[101, 8945, 11314, 2121, 5652, 1999, 2048, 31..."
9076,zones,10019,"[[101, 6059, 3387, 2435, 1996, 3036, 6987, 315...","[[101, 10254, 1005, 1055, 4955, 2000, 1996, 16..."
9077,zoo,9201,"[[101, 2350, 14345, 2421, 1996, 17692, 2103, 1...","[[101, 4649, 11802, 5147, 6025, 2010, 2516, 20..."
9078,zoom,24095,"[[101, 27916, 1005, 7444, 2806, 2001, 15063, 1...","[[101, 2137, 5053, 4076, 1037, 3151, 103, 3252..."


In [None]:
list_of_tuples = [tuple(row) for row in pos_neg_samples.itertuples(index=False, name=None)]
# with open('../data/wikitext/shuffled_sample_sents.pickle', 'wb') as f: 
#     pickle.dump(list_of_tuples, f)