In [1]:
import pandas as pd
import torch
from pytorch_pretrained_bert import BertTokenizer, BertForNextSentencePrediction 
import spacy

In [2]:
nlp = spacy.load('en_core_web_sm')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [3]:
def predict_follows(sentence1, sentence2):
    # predicts if sentence2 makes sense in the context of following sentence 1
    text = '[CLS] '+sentence1+' [SEP] '+sentence2+' [SEP]'
    tokenized_text = tokenizer.tokenize(text)
    
    segments_ids = [0 for i in tokenized_text]
    sent1_len = len(tokenizer.tokenize(sentence1))+2
    segments_ids[sent1_len:] = [1 for i in segments_ids[sent1_len:]]
        
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    
    next_sentence_label = torch.tensor([[0]])
    
    with torch.no_grad():
        predictions = model(tokens_tensor, segments_tensors, next_sentence_label=next_sentence_label)

    return(predictions.item())
    
def doc_split(text, thresh):
    # returns list of doc splits
    # splits are generated when a sentence does not seem to follow the previous one. 
    # text, document text
    # thresh, controls when to split by comparing against loss for the model's prediction against the assumption the sentences are concurrent.
    # small values of thresh mean sentences must be more similar to not be split.
    doc = nlp(text)
    docs = []
    sentences = [str(sent) for sent in doc.sents]
    current = sentences[0]
    if len(sentences) > 1:
        for i in range(1,len(sentences)):
            
            loss = predict_follows(sentences[i-1], sentences[i])
            if loss > thresh:
                docs.append(current)
                current = sentences[i]   
            else:
                current += ' '+sentences[i]
        if current not in docs:
            docs.append(current)
    else:
        docs.append(str(list(doc.sents)[0]))

    return(docs)

In [5]:
sentences = """\
One of these sentences is not like the other. \
This sentence could logically follow the first as it pretains to the task at hand. \
A final sentence similar to the first two in that it discuesses being a sentence in a sequnece. \
"""

for sentence_block in doc_split(sentences, 0.00001): # try playing with the value, notice it splits the document into 3 peices due to the offending sentence.
    print(sentence_block) 

One of these sentences is not like the other. This sentence could logically follow the first as it pretains to the task at hand. A final sentence similar to the first two in that it discuesses being a sentence in a sequnece.
