In [1]:
exec(open("../../header.py").read())

In [2]:
from torch.nn.functional import softmax
from torch.utils.data import Dataset, DataLoader
from datasets import load_from_disk
from nltk.tokenize import sent_tokenize
from transformers import BertTokenizer,BertForNextSentencePrediction

In [3]:
class ContextDataset(Dataset):
    def __init__(self, sentence_pair_list):
        self.sentence_pair_list = sentence_pair_list

    def __len__(self):
        return len(self.sentence_pair_list)
 
    def __getitem__(self,idx):
        return(self.sentence_pair_list[idx])

In [76]:
def get_probabilities_on_text_w_NSP(nsp_model, text, tokenizer, device):
    '''
    Returns a sequence of probabilities which represent confidence that the next sentence is part of the same segment
    
    If text has n sentences, then prob_seq has n-1 probabilities. (If text has 1 sentence, prob_seq is [], the empty list.)
    The ii index of prob seq represents the NSP confidence of the ii and ii+1 sentences in text.
    Probabilities closer to 1 indicate confidence, Probabilities closer to 0 indicate no confidence.
    
    '''
    sentence_list = sent_tokenize(text)
    over_length_indices = []
    sentence_pair_list=[]
    indices_to_be_processed=[]
    #Create Sentence pair list
    if len(sentence_list)==1:
        return [],sentence_list # Return empty list for probs
    
    for ii in range(0,len(sentence_list)-1):
        sentence_1 = sentence_list[ii]
        sentence_2 = sentence_list[ii+1]

        #Encode temporarily, just to count
        encoded = tokenizer.encode_plus(sentence_1, text_pair=sentence_2, return_tensors='pt')
        if encoded['input_ids'].shape[1] > 512: # If two sentences are too long, just split them
            over_length_indices.append(ii)
        else:# add to list to be processed
            indices_to_be_processed.append(ii)
            sentence_pair_list.append([sentence_1,sentence_2])
    # Now, begin calculating probabilities
    with torch.no_grad():
        # Load into a dataset and dataloader. We do this for speed
        context_dataset = ContextDataset(sentence_pair_list)
        context_loader = DataLoader(context_dataset, batch_size=128, shuffle=False, pin_memory=True)
        probs_list = [] #Will be list of tensors
        for batch_idx,batch in enumerate(context_loader):
            if len(batch)==0:
                continue
            # I swear to god this is a legitimate pytorch bug, but we need to reorganize the batch from the dataloader. Whatever
            
            batch_fixed = [(s1,s2) for s1,s2 in zip(batch[0], batch[1])]
            return batch, batch_fixed
            # Batch encode
            sentence_pairs = tokenizer.batch_encode_plus(batch_fixed, return_tensors='pt',padding=True)
            # Run through the model
            sentence_pairs.to(device)
            logits = nsp_model(**sentence_pairs)[0]
            # Get Probability of next sentence.
            probs = softmax(logits, dim=1)
            probs = probs[:,0]
            # Add to list of tensors
            probs_list.append(probs.cpu())
            
        #Cat the list of tensors to get a bsize x sequence_length tensors
        if len(probs_list)  == 0:
            all_probs = []
        else:
            all_probs = list(torch.cat(probs_list))
    # Now, we need to sort the probabilities. Some of probabilities are coming from over_length_indices, some of them are coming from indices_to_be_processed
    # We'll zip, then sort, then take the sorted probs.
    indices = over_length_indices + indices_to_be_processed
    one_probs = [1]*len(over_length_indices)
    probs = one_probs + all_probs
    probs = [x for _, x in sorted(zip(indices, probs))]
    # Return probabilities, and also return sentence list for use later as well
    return probs, sentence_list

In [61]:
# ArgParse
splits = ['train','test']

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print_cuda_info(device)

Using device: cuda
Quadro RTX 8000
Memory Usage:
Allocated: 0.2 GB
Cached:    2.3 GB


In [62]:
nsp_model = BertForNextSentencePrediction.from_pretrained('prajjwal1/bert-small')
nsp_model.eval()
nsp_model.to(device)
tokenizer = BertTokenizer.from_pretrained('prajjwal1/bert-small')

Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertForNextSentencePrediction: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Train on bydate_alt.atheism

In [63]:
split = "train"

dataset_list = []

for config in newsgroup_configs: #loop over labels
        subset_path = RAW_DIR(f'20news/{split}/{config}')
        dataset_list.append((config, load_from_disk(subset_path)))

In [64]:
dataset_list[0]

('bydate_alt.atheism',
 Dataset({
     features: ['text'],
     num_rows: 480
 }))

In [65]:
from time import time

In [77]:
a, b = get_probabilities_on_text_w_NSP(nsp_model, context, tokenizer, device)

In [97]:
len(a[0])

9

In [90]:
b[2]

("> \n\nWow, you're quicker to point out heresy than the Church in the\nMiddle ages.",
 "Seriously though, even the Sheiks at Al-Azhar don't\nclaim that the Shi'ites are heretics.")

In [66]:
start = time()
label, sub_dataset = dataset_list[0]

qid_struct = {}
for ii, entry in enumerate(sub_dataset):# Loop over data entries with the same label
    context = entry['text']
    prob_seq , _ = get_probabilities_on_text_w_NSP(nsp_model, context, tokenizer, device)
    qid_struct[ii] = prob_seq

print(f"Time for {label}: {time() - start}")

[['From: mathew <mathew@mantis.co.uk>\nSubject: Alt.Atheism FAQ: Atheist Resources\nSummary: Books, addresses, music -- anything related to atheism\nKeywords: FAQ, atheism, books, music, fiction, addresses, contacts\nExpires: Thu, 29 Apr 1993 11:57:19 GMT\nDistribution: world\nOrganization: Mantis Consultants, Cambridge.', 'UK.', 'Supersedes: <19930301143317@mantis.co.uk>\nLines: 290\n\nArchive-name: atheism/resources\nAlt-atheism-archive-name: resources\nLast-modified: 11 December 1992\nVersion: 1.0\n\n                              Atheist Resources\n\n                      Addresses of Atheist Organizations\n\n                                     USA\n\nFREEDOM FROM RELIGION FOUNDATION\n\nDarwin fish bumper stickers and assorted other atheist paraphernalia are\navailable from the Freedom From Religion Foundation in the US.', 'Write to:  FFRF, P.O.', 'Box 750, Madison, WI 53701.', 'Telephone: (608) 256-8900\n\nEVOLUTION DESIGNS\n\nEvolution Designs sell the "Darwin fish".', 'It\'s a f

KeyboardInterrupt: 

# Try running the get `get_probabilities_on_text_w_NSP` functin

In [12]:
nsp_model = nsp_model
text = context
tokenizer = tokenizer
device = device

In [13]:
sentence_list = sent_tokenize(text)
over_length_indices = []
sentence_pair_list=[]
indices_to_be_processed=[]
#Create Sentence pair list
if len(sentence_list)==1:
    output =  [], sentence_list # Return empty list for probs

In [14]:
for ii in range(0,len(sentence_list)-1):
    sentence_1 = sentence_list[ii]
    sentence_2 = sentence_list[ii+1]

    #Encode temporarily, just to count
    encoded = tokenizer.encode_plus(sentence_1, text_pair=sentence_2, return_tensors='pt')
    if encoded['input_ids'].shape[1] > 512: # If two sentences are too long, just split them
        over_length_indices.append(ii)
    else:# add to list to be processed
        indices_to_be_processed.append(ii)
        sentence_pair_list.append([sentence_1,sentence_2])

In [56]:
for a, b in context_loader:
    print(f'a:{a}')
    print(f'b:{b}')

a:['From: I3150101@dbstu1.rz.tu-bs.de (Benedikt Rosenau)\nSubject: Re: Wholly Babble (Was Re: free moral agency)\nOrganization: Technical University Braunschweig, Germany\nLines: 10\n\nIn article <2944159064.5.p00261@psilink.com>\n"Robert Knowles" <p00261@psilink.com> writes:\n \n(Deletion)\n>Of course, there is also the\n>Book of the SubGenius and that whole collection of writings as well.', 'Does someone know a FTP site with it?']
b:['Does someone know a FTP site with it?', 'Benedikt']


In [59]:
a[1]

'Does someone know a FTP site with it?'

In [36]:
# Now, begin calculating probabilities
with torch.no_grad():
    # Load into a dataset and dataloader. We do this for speed
    context_dataset = ContextDataset(sentence_pair_list)
    context_loader = DataLoader(context_dataset, batch_size=128, shuffle=False, pin_memory=True)
    probs_list = [] #Will be list of tensors
    for batch_idx,batch in enumerate(context_loader):
        if len(batch)==0:
            continue
        # Batch encode
        sentence_pairs = tokenizer.batch_encode_plus(batch, return_tensors='pt',padding=True)
        # Run through the model
        sentence_pairs.to(device)
        logits = nsp_model(**sentence_pairs)[0]
        # Get Probability of next sentence.
        probs = softmax(logits, dim=1)
        probs = probs[:,0]
        # Add to list of tensors
        probs_list.append(probs.cpu())

    #Cat the list of tensors to get a bsize x sequence_length tensors
    if len(probs_list)  == 0:
        all_probs = []
    else:
        all_probs = list(torch.cat(probs_list))
# Now, we need to sort the probabilities. Some of probabilities are coming from over_length_indices, some of them are coming from indices_to_be_processed
# We'll zip, then sort, then take the sorted probs.
indices = over_length_indices + indices_to_be_processed
one_probs = [1]*len(over_length_indices)
probs = one_probs + all_probs
probs = [x for _, x in sorted(zip(indices, probs))]

In [34]:
sentence_pairs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [35]:
sentence_pairs['input_ids'].shape

torch.Size([2, 136])

In [28]:
print(f"Type: {type(batch_fixed)}")
print(f"Length: {len(batch_fixed)}")
print(f"First element: {batch_fixed[0]}")

Type: <class 'list'>
Length: 2
First element: ('From: I3150101@dbstu1.rz.tu-bs.de (Benedikt Rosenau)\nSubject: Re: Wholly Babble (Was Re: free moral agency)\nOrganization: Technical University Braunschweig, Germany\nLines: 10\n\nIn article <2944159064.5.p00261@psilink.com>\n"Robert Knowles" <p00261@psilink.com> writes:\n \n(Deletion)\n>Of course, there is also the\n>Book of the SubGenius and that whole collection of writings as well.', 'Does someone know a FTP site with it?')


In [32]:
type(batch_fixed[0])

tuple

In [31]:
type(batch[0])

list