In [1]:
from datasets import load_from_disk
import json
import torch
from nltk.tokenize import sent_tokenize
from torch.nn.functional import softmax


In [2]:
from transformers import BertTokenizer,BertForNextSentencePrediction

In [7]:
def get_probabilities_on_text_w_NSP(nsp_model, text, tokenizer):
    '''
    Returns a sequence of probabilities which represent confidence that the next sentence is part of the same segment
    
    If text has n sentences, then prob_seq has n-1 probabilities. 
    The ii index of prob seq represents the NSP confidence of the ii and ii+1 sentences in text.
    Probabilities closer to 1 indicate confidence, Probabilities closer to 0 indicate no confidence.
     
    '''
    #Create sentence list
    sentence_list = sent_tokenize(text)
    prob_seq = []
    #Iterate over all sequential pairs
    for ii in range(0,len(sentence_list)-1):
        sentence_1 = sentence_list[ii]
        sentence_2 = sentence_list[ii+1]
        
        #Encode
        encoded = tokenizer.encode_plus(sentence_1, text_pair=sentence_2, return_tensors='pt')
        
        #print(encoded['input_ids'].shape[1])
        if encoded['input_ids'].shape[1] > 512: # If two sentences are too long, just split them
            prob_seq.append(0)
        else:
            #Not too long, pass through the model and get a probability
            with torch.no_grad():
                logits = nsp_model(**encoded)[0]

            probs = softmax(logits, dim=1)
            prob_seq.append(probs[0][0])
    #End for loop
    return prob_seq,sentence_list

def get_tokens_per_sentence_list(tokenizer,sentence_list):
    tokens_per_sentence_list = [len(tokenizer.encode(sentence)) for sentence in sentence_list]
    return tokens_per_sentence_list

def apply_threshold(prob_seq,tokens_per_sentence_list,threshold):
    # Initialize
    cutoff_indices = []
    running_length = tokens_per_sentence_list[0]
    # 
    for ii,prob in enumerate(prob_seq):
        if prob <= threshold:
            cutoff_indices.append(ii)
            running_length = tokens_per_sentence_list[ii+1]
            
        elif running_length + tokens_per_sentence_list[ii+1] > 512:
            cutoff_indices.append(ii)
            running_length = tokens_per_sentence_list[ii+1]
            
        else:
            running_length += tokens_per_sentence_list[ii+1]
        
    return cutoff_indices

def get_cutoff_indices(text, threshold, nsp_model,tokenizer):
    
    prob_seq,sentence_list = get_probabilities_on_text_w_NSP(nsp_model, text, tokenizer)
    tokens_per_sentence_list = get_tokens_per_sentence_list(tokenizer, sentence_list)
    cutoff_indices = apply_threshold(prob_seq, tokens_per_sentence_list, threshold=.5)
    
    return cutoff_indices

In [116]:
encoded['input_ids'].shape

torch.Size([1, 51])

In [4]:
dataset = load_from_disk('/home/adong/School/NLUProject/data/trivia_qa_rc_tiny')

# Need to make a dict that is:

    entry_id -> [search_context_idx -> splits,  
             entity_pages_idx -> splits]  
             
             

In [5]:
nsp_model = BertForNextSentencePrediction.from_pretrained('bert-base-cased')
nsp_model.eval()
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')



Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForNextSentencePrediction: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:

threshold=.5
qid_struct = {}
for ii, entry  in enumerate(dataset):
    print(entry['question_id'])
    print('num entity pages, num search context', len(entry['entity_pages']['wiki_context']),len(entry['search_results']['search_context']))
    
    wiki_context_splits = []
    for context in entry['entity_pages']['wiki_context']:
        cutoff_indices = get_cutoff_indices(context, threshold, nsp_model, tokenizer)
        wiki_context_splits.append(cutoff_indices)
        
    search_context_splits = []
    for context in entry['search_results']['search_context']:
        cutoff_indices = get_cutoff_indices(context, threshold, nsp_model, tokenizer)
        search_context_splits.append(cutoff_indices)
    
    qid_struct[entry['question_id']] = (wiki_context_splits,search_context_splits)
    if ii == 5:
        break

tc_69
num entity pages, num search context 0 1
tc_261
num entity pages, num search context 0 1
tc_280
num entity pages, num search context 0 1


Token indices sequence length is longer than the specified maximum sequence length for this model (605 > 512). Running this sequence through the model will result in indexing errors


tc_586
num entity pages, num search context 1 0
tc_1007
num entity pages, num search context 0 1
tc_1020
num entity pages, num search context 0 1


In [9]:
qid_struct

{'tc_69': ([], [[]]),
 'tc_261': ([], [[16]]),
 'tc_280': ([],
  [[11,
    18,
    30,
    46,
    52,
    68,
    75,
    79,
    81,
    87,
    92,
    108,
    122,
    142,
    143,
    145,
    151,
    156,
    173]]),
 'tc_586': ([[13,
    15,
    27,
    33,
    49,
    65,
    76,
    89,
    105,
    120,
    135,
    153,
    166,
    173,
    189,
    208,
    226,
    240,
    254,
    257,
    259,
    269,
    284,
    301,
    320,
    322,
    323,
    338,
    355,
    374,
    391,
    406,
    418,
    432,
    444,
    462,
    474,
    485,
    501,
    517,
    535,
    554,
    563,
    575,
    591]],
  []),
 'tc_1007': ([], [[5, 22, 28, 41, 47]]),
 'tc_1020': ([], [[10, 15]])}

In [10]:
dataset[1]

{'answer': {'aliases': ['My Fair Lady (2010 film)',
   'Enry Iggins',
   "Why Can't the English%3F",
   'My Fair Lady',
   'My Fair Lady (upcoming film)',
   'My Fair Lady (musical)',
   'My fair lady',
   "I'm an Ordinary Man",
   'My Fair Lady (2014 film)',
   'My Fair Lady (2012 film)',
   'My Fair Lady (2015 film)'],
  'matched_wiki_entity_name': '',
  'normalized_aliases': ['my fair lady musical',
   'my fair lady',
   'my fair lady 2010 film',
   'why can t english 3f',
   'my fair lady upcoming film',
   'my fair lady 2012 film',
   'my fair lady 2014 film',
   'my fair lady 2015 film',
   'i m ordinary man',
   'enry iggins'],
  'normalized_matched_wiki_entity_name': '',
  'normalized_value': 'my fair lady',
  'type': 'WikipediaEntity',
  'value': 'My Fair Lady'},
 'entity_pages': {'doc_source': [],
  'filename': [],
  'title': [],
  'wiki_context': []},
 'question': 'Which musical featured the song The Street Where You Live?',
 'question_id': 'tc_261',
 'question_source': 'htt

In [None]:
class LSTM_Over_BERT(nn.Module):
    '''
    Input is a entry in trivia_qa dataset
    '''
    def __init__(self):
        super(LSTM_Over_BERT, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Input is a data entry in trivia_qa dataset.
entry contains the question, possible answers, correct answer, and possibly, multiple spans of text?
