In [30]:
from datasets import load_from_disk
import json
import torch
from nltk.tokenize import sent_tokenize
from torch.nn.functional import softmax
from torch.utils.data import DataLoader
from transformers import BertTokenizer,BertForNextSentencePrediction
import pickle
from torch.utils.data import Dataset
import os

In [11]:
class ContextDataset(Dataset):
    def __init__(self, sentence_pair_list):
        self.sentence_pair_list = sentence_pair_list

    def __len__(self):
        return len(self.sentence_pair_list)
 
    def __getitem__(self,idx):
        return(self.sentence_pair_list[idx])

In [12]:
def get_probabilities_on_text_w_NSP(nsp_model, text, tokenizer, device):
    '''
    Returns a sequence of probabilities which represent confidence that the next sentence is part of the same segment
    
    If text has n sentences, then prob_seq has n-1 probabilities. (If text has 1 sentence, prob_seq is [], the empty list.)
    The ii index of prob seq represents the NSP confidence of the ii and ii+1 sentences in text.
    Probabilities closer to 1 indicate confidence, Probabilities closer to 0 indicate no confidence.
    
    '''
    sentence_list = sent_tokenize(text)
    over_length_indices = []
    sentence_pair_list=[]
    indices_to_be_processed=[]
    #Create Sentence pair list
    if len(sentence_list)==1:
        return [],sentence_list # Return empty list for probs
    
    for ii in range(0,len(sentence_list)-1):
        sentence_1 = sentence_list[ii]
        sentence_2 = sentence_list[ii+1]

        #Encode temporarily, just to count
        encoded = tokenizer.encode_plus(sentence_1, text_pair=sentence_2, return_tensors='pt')
        if encoded['input_ids'].shape[1] > 512: # If two sentences are too long, just split them
            over_length_indices.append(ii)
        else:# add to list to be processed
            indices_to_be_processed.append(ii)
            sentence_pair_list.append([sentence_1,sentence_2])
    # Now, begin calculating probabilities
    with torch.no_grad():
        # Load into a dataset and dataloader. We do this for speed
        context_dataset = ContextDataset(sentence_pair_list)
        context_loader = DataLoader(context_dataset, batch_size=128, shuffle=False, pin_memory=True)
        probs_list = [] #Will be list of tensors
        for batch_idx,batch in enumerate(context_loader):
            if len(batch)==0:
                continue
            # I swear to god this is a legitimate pytorch bug, but we need to reorganize the batch from the dataloader. Whatever
            batch_fixed = [(s1,s2) for s1,s2 in zip(batch[0], batch[1])] 
            # Batch encode
            sentence_pairs = tokenizer.batch_encode_plus(batch_fixed, return_tensors='pt',padding=True)
            # Run through the model
            sentence_pairs.to(device)
            logits = nsp_model(**sentence_pairs)[0]
            # Get Probability of next sentence.
            probs = softmax(logits, dim=1)
            probs = probs[:,0]
            # Add to list of tensors
            probs_list.append(probs.cpu())
            
        #Cat the list of tensors to get a bsize x sequence_length tensors
        if len(probs_list)  == 0:
            all_probs = []
        else:
            all_probs = list(torch.cat(probs_list))
    # Now, we need to sort the probabilities. Some of probabilities are coming from over_length_indices, some of them are coming from indices_to_be_processed
    # We'll zip, then sort, then take the sorted probs.
    indices = over_length_indices + indices_to_be_processed
    one_probs = [1]*len(over_length_indices)
    probs = one_probs + all_probs
    probs = [x for _, x in sorted(zip(indices, probs))]
    # Return probabilities, and also return sentence list for use later as well
    return probs, sentence_list

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda
Quadro RTX 8000
Memory Usage:
Allocated: 0.1 GB
Cached:    0.1 GB


In [14]:
nsp_model = BertForNextSentencePrediction.from_pretrained('prajjwal1/bert-small')
nsp_model.eval()
nsp_model.to(device)
tokenizer = BertTokenizer.from_pretrained('prajjwal1/bert-small')

Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertForNextSentencePrediction: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [25]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/ay1626/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [15]:
newsgroup_configs = ['bydate_alt.atheism',
                 'bydate_comp.graphics',
                 'bydate_comp.os.ms-windows.misc',
                 'bydate_comp.sys.ibm.pc.hardware',
                 'bydate_comp.sys.mac.hardware',
                 'bydate_comp.windows.x',
                 'bydate_misc.forsale',
                 'bydate_rec.autos',
                 'bydate_rec.motorcycles',
                 'bydate_rec.sport.baseball',
                 'bydate_rec.sport.hockey',
                 'bydate_sci.crypt',
                 'bydate_sci.electronics',
                 'bydate_sci.med',
                 'bydate_sci.space',
                 'bydate_soc.religion.christian',
                 'bydate_talk.politics.guns',
                 'bydate_talk.politics.mideast',
                 'bydate_talk.politics.misc',
                 'bydate_talk.religion.misc']

In [32]:
mode = '20news'
splits = ['train', 'test']
for split in splits: # Loop over train test
    dataset_list = []
    for config in newsgroup_configs: #loop over labels
        subset_path = os.path.expanduser(f'~/NLU_data/raw/{mode}/{split}/{config}')
        print(f"Loading {subset_path}")
        dataset_list.append((config,load_from_disk(subset_path)))

    for label, sub_dataset in dataset_list: #Loop over labels
        qid_struct = {}
        for ii, entry in enumerate(sub_dataset):# Loop over data entries with the same label
            context = entry['text']
            prob_seq , _ = get_probabilities_on_text_w_NSP(nsp_model, context, tokenizer, device)
            qid_struct[ii] = prob_seq
        file_name = f'{label}_qid_struct.pkl'
        processed_folder = os.path.expanduser(f"~/NLU_data/processed/{mode}/{split}/")
        
        # Make folder
        try:
            os.mkdir(processed_folder)
        except FileExistsError as error:
            pass
        
        # Save segment breakpoints
        with open(processed_folder + file_name, 'wb') as handle:
            print(f"Saving {handle}")
            pickle.dump(qid_struct, handle, protocol=pickle.HIGHEST_PROTOCOL)

Loading /home/ay1626/NLU_data/raw/20news/train/bydate_alt.atheism
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_comp.graphics
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_comp.os.ms-windows.misc
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_comp.sys.ibm.pc.hardware
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_comp.sys.mac.hardware
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_comp.windows.x
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_misc.forsale
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_rec.autos
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_rec.motorcycles
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_rec.sport.baseball
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_rec.sport.hockey
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_sci.crypt
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_sci.electronics
Loading /home/ay1626/NLU_data/raw/20news/train/bydate_sci.med
Loading /home/ay1626/NL

In [34]:
type(qid_struct)

dict

In [36]:
qid_struct.keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,

In [38]:
qid_struct[27]

[tensor(1.0000),
 tensor(0.9102),
 tensor(0.8659),
 tensor(0.8923),
 tensor(0.5648),
 tensor(1.0000),
 tensor(1.0000),
 tensor(1.0000),
 tensor(1.0000),
 tensor(0.6457),
 tensor(0.9032),
 tensor(0.9475),
 tensor(0.9843),
 tensor(1.0000),
 tensor(1.0000),
 tensor(0.8006),
 tensor(0.9445),
 tensor(0.8945),
 tensor(0.0432),
 tensor(0.9989),
 tensor(0.5431),
 tensor(0.9335),
 tensor(0.9983),
 tensor(0.0308),
 tensor(1.0000),
 tensor(0.7119),
 tensor(0.9659),
 tensor(0.8694),
 tensor(0.0757),
 tensor(0.7168),
 tensor(0.9637),
 tensor(1.0000),
 tensor(1.0000),
 tensor(0.1371),
 tensor(0.2259),
 tensor(0.9864),
 tensor(1.0000),
 tensor(0.9693),
 tensor(0.9950),
 tensor(0.9994),
 tensor(1.0000),
 tensor(0.4812),
 tensor(0.9168),
 tensor(0.9427),
 tensor(0.9712),
 tensor(0.9930),
 tensor(0.8222),
 tensor(0.8959),
 tensor(0.8693),
 tensor(0.2803)]