In [1]:
from datasets import load_from_disk
import json
import torch
from nltk.tokenize import sent_tokenize
from torch.nn.functional import softmax
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformers import BertTokenizer,BertForNextSentencePrediction

from torch.utils.data import Dataset
import pickle

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda
GeForce RTX 3060 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [4]:
#dataset = load_from_disk('/home/adong/School/NLUProject/data/trivia_qa_rc_tiny')
dataset = load_from_disk(r'\\wsl$\Ubuntu-20.04\home\jolteon\NLUProject\data\trivia_qa_rc_tiny')

In [None]:
dataset

In [None]:
#dataloader = DataLoader(dataset,batch_size=256,shuffle=False)

# Need to make a dict that is:

    entry_id -> [search_context_idx -> splits,  
             entity_pages_idx -> splits]  
             
             

In [7]:
nsp_model = BertForNextSentencePrediction.from_pretrained('prajjwal1/bert-small')
nsp_model.eval()
nsp_model.to(device)
tokenizer = BertTokenizer.from_pretrained('prajjwal1/bert-small')


Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertForNextSentencePrediction: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [14]:
class ContextDataset(Dataset):
    def __init__(self, sentence_pair_list):
        self.sentence_pair_list = sentence_pair_list

    def __len__(self):
        return len(self.sentence_pair_list)
 
    def __getitem__(self,idx):
        return(self.sentence_pair_list[idx])

def get_probabilities_on_text_w_NSP(nsp_model, text, tokenizer, device):
    '''
    Returns a sequence of probabilities which represent confidence that the next sentence is part of the same segment
    
    If text has n sentences, then prob_seq has n-1 probabilities. (If text has 1 sentence, prob_seq is [], the empty list.)
    The ii index of prob seq represents the NSP confidence of the ii and ii+1 sentences in text.
    Probabilities closer to 1 indicate confidence, Probabilities closer to 0 indicate no confidence.
    
    '''
    #print(text)
    text = text.lower()
    #print(text)
    sentence_list = sent_tokenize(text)
    over_length_indices = []
    sentence_pair_list=[]
    indices_to_be_processed=[]
    #Create Sentence pair list
    if len(sentence_list)==1:
        return [],sentence_list # Return empty list for probs
    
    for ii in range(0,len(sentence_list)-1):
        sentence_1 = sentence_list[ii]
        sentence_2 = sentence_list[ii+1]

        #Encode temporarily, just to count
        encoded = tokenizer.encode_plus(sentence_1, text_pair=sentence_2, return_tensors='pt')
        if encoded['input_ids'].shape[1] > 512: # If two sentences are too long, just split them
            over_length_indices.append(ii)
        else:# add to list to be processed
            indices_to_be_processed.append(ii)
            sentence_pair_list.append([sentence_1,sentence_2])
    # Now, begin calculating probabilities
    with torch.no_grad():
        # Load into a dataset and dataloader. We do this for speed
        context_dataset = ContextDataset(sentence_pair_list)
        context_loader = DataLoader(context_dataset, batch_size=64, shuffle=False, pin_memory=True)
        probs_list = [] #Will be list of tensors
        for batch_idx,batch in enumerate(context_loader):
            if len(batch)==0:
                continue
            # I swear to god this is a legitimate pytorch bug, but we need to reorganize the batch from the dataloader. Whatever
            batch_fixed = [(s1,s2) for s1,s2 in zip(batch[0], batch[1])] 
            # Batch encode
            sentence_pairs = tokenizer.batch_encode_plus(batch_fixed, return_tensors='pt',padding=True)
            # Run through the model
            sentence_pairs.to(device)
            logits = nsp_model(**sentence_pairs)[0]
            # Get Probability of next sentence.
            probs = softmax(logits, dim=1)
            probs = probs[:,0]
            # Add to list of tensors
            probs_list.append(probs)
            
        #Cat the list of tensors to get a bsize x sequence_length tensors
        if len(probs_list)  == 0:
            all_probs = []
        else:
            all_probs = list(torch.cat(probs_list))
    # Now, we need to sort the probabilities. Some of probabilities are coming from over_length_indices, some of them are coming from indices_to_be_processed
    # We'll zip, then sort, then take the sorted probs.
    indices = over_length_indices + indices_to_be_processed
    one_probs = [1]*len(over_length_indices)
    probs = one_probs + all_probs
    probs = [x for _, x in sorted(zip(indices, probs))]
    # Return probabilities, and also return sentence list for use later as well
    return probs, sentence_list


def get_tokens_per_sentence_list(tokenizer,sentence_list):
    tokens_per_sentence_list = [len(tokenizer.encode(sentence)) for sentence in sentence_list]
    return tokens_per_sentence_list

def apply_threshold(prob_seq,tokens_per_sentence_list,threshold):
    '''
    If prob_seq is empty, we will return and empty list.
    '''
    # Initialize
    cutoff_indices = []
    running_length = tokens_per_sentence_list[0]
    # 
    for ii,prob in enumerate(prob_seq):
        if prob <= threshold:
            cutoff_indices.append(ii)
            running_length = tokens_per_sentence_list[ii+1]
            
        elif running_length + tokens_per_sentence_list[ii+1] > 512:
            cutoff_indices.append(ii)
            running_length = tokens_per_sentence_list[ii+1]
            
        else:
            running_length += tokens_per_sentence_list[ii+1]
        
    return cutoff_indices

def get_cutoff_indices(text, threshold, nsp_model,tokenizer, device):
    
    prob_seq, sentence_list = get_probabilities_on_text_w_NSP(nsp_model, text, tokenizer,device)
    tokens_per_sentence_list = get_tokens_per_sentence_list(tokenizer, sentence_list)
    cutoff_indices = apply_threshold(prob_seq, tokens_per_sentence_list, threshold=.5)
    
    return cutoff_indices


In [16]:

threshold=.5
qid_struct = {}
key='test'
for ii, entry  in tqdm(enumerate(dataset)):
    if ii ==100:
        break
    print('started: ',str(ii))

    if len(entry['entity_pages']['wiki_context'])==0:
        wiki_context_probs = None
    else:
        wiki_context_probs = []
        for context in entry['entity_pages']['wiki_context']:
            prob_seq , _ = get_probabilities_on_text_w_NSP(nsp_model, context, tokenizer, device)
            wiki_context_probs.append(prob_seq)
            
    if len(entry['search_results']['search_context']) == 0:
         search_context_probs = None
    else:
        search_context_probs = []
        for context in entry['search_results']['search_context']:

            prob_seq , _ = get_probabilities_on_text_w_NSP(nsp_model, context, tokenizer, device)
            search_context_probs.append(prob_seq)
    
    qid_struct[entry['question_id']] = (wiki_context_probs,search_context_probs)
file_name = key + '_qid_struct.pkl'

with open(r"\\wsl$\Ubuntu-20.04\home\jolteon\NLUProject\\" + file_name, 'wb') as handle:
    pickle.dump(qid_struct, handle, protocol=pickle.HIGHEST_PROTOCOL)


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

started:  0
started:  1
started:  2
started:  3
started:  4
started:  5
started:  6
started:  7
started:  8
started:  9
started:  10
started:  11
started:  12
started:  13
started:  14
started:  15
started:  16
started:  17
started:  18
started:  19
started:  20
started:  21
started:  22
started:  23
started:  24
started:  25
started:  26
started:  27
started:  28
started:  29
started:  30
started:  31
started:  32
started:  33
started:  34
started:  35
started:  36
started:  37
started:  38
started:  39
started:  40
started:  41
started:  42
started:  43
started:  44
started:  45
started:  46
started:  47
started:  48
started:  49
started:  50
started:  51
started:  52
started:  53
started:  54
started:  55
started:  56
started:  57
started:  58
started:  59
started:  60
started:  61
started:  62
started:  63
started:  64
started:  65
started:  66
started:  67
started:  68
started:  69
started:  70



KeyboardInterrupt: 

# Cutoff index Sequence version of for loop

In [59]:

threshold=.5
qid_struct = {}
key='test'
for ii, entry  in tqdm(enumerate(dataset)):
    print('started: ',str(ii))

    if len(entry['entity_pages']['wiki_context'])==0:
        wiki_context_splits = None
    else:
        wiki_context_splits = []
        for context in entry['entity_pages']['wiki_context']:
            cutoff_indices = get_cutoff_indices(context, threshold, nsp_model, tokenizer,device)
            wiki_context_splits.append(cutoff_indices)
            
    if len(entry['search_results']['search_context']) == 0:
         search_context_splits = None
    else:
        search_context_splits = []
        for context in entry['search_results']['search_context']:

            cutoff_indices = get_cutoff_indices(context, threshold, nsp_model, tokenizer,device)
            search_context_splits.append(cutoff_indices)
    
    qid_struct[entry['question_id']] = (wiki_context_splits,search_context_splits)
file_name = key + '_qid_struct.pkl'

with open(r"\\wsl$\Ubuntu-20.04\home\jolteon\NLUProject\\" + file_name, 'wb') as handle:
    pickle.dump(qid_struct, handle, protocol=pickle.HIGHEST_PROTOCOL)   
     

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

started:  0
started:  1
started:  2
started:  3



KeyboardInterrupt: 

In [None]:
   with open(r"\\wsl$\Ubuntu-20.04\home\jolteon\NLUProject\\" + file_name, 'wb') as handle:
            pickle.dump(qid_struct, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
dataset[44]

In [None]:
class LSTM_Over_BERT(nn.Module):
    '''
    Input is a entry in trivia_qa dataset
    '''
    def __init__(self):
        super(LSTM_Over_BERT, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Input is a data entry in trivia_qa dataset.
entry contains the question, possible answers, correct answer, and possibly, multiple spans of text?
