In [None]:
from sklearn.metrics.cluster.supervised import contingency_matrix
from sklearn.metrics import precision_score, recall_score
from nltk.metrics.segmentation import pk, windowdiff
from segeval.window.windowdiff import window_diff
from scipy import sparse as sp
from harvesttext import HarvestText

import scipy
import random
import nltk
import numpy as np
from functools import lru_cache
from nltk.tokenize import  RegexpTokenizer
from nltk.stem import WordNetLemmatizer
from keybert import KeyBERT

import json
from tqdm import tqdm
import re
import jsonlines
import pandas as pd
import itertools as it
import copy
import time

from scipy.sparse import csr_matrix

import ast
from nltk import sent_tokenize, word_tokenize

import string

from nltk.corpus import stopwords
stop_words = stopwords.words('english')

import matplotlib.pyplot as plt
from collections import defaultdict

from sqlitedict import SqliteDict
from sentence_splitter import SentenceSplitter, split_text_into_sentences
from sklearn.feature_extraction.text import CountVectorizer

import spacy
from scipy import stats
import os
import pickle
from numba import jit
import itertools
import concurrent.futures

from rouge_score import rouge_scorer

In [None]:
from sentence_transformers import SentenceTransformer
from transformers import *

#### Define abstract processing regex

In [None]:
regex_background = r'BACKGROUND(.*?)METHODS'
regex_method = r'METHODS(.*?)RESULTS'
regex_result = r'RESULTS(.*?)CONCLUSIONS'
regex_conclusion = r'(?<=CONCLUSIONS)(.*)'
regex_dict = {'background': regex_background, 'method': regex_method, 
              'result': regex_result, 'conclusion': regex_conclusion}

splitter = SentenceSplitter(language='en')
puncs = '"#$%&\'()*+,-/:;<=>@[\\]^_`{|}~'

#### Define Rouge scorer

In [None]:
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer=True)

#### Load CAS-human data

In [None]:
auto_data = []
with open('data/CAS-human/CAS-auto-clean.json', 'r') as f:
    for line in tqdm(f):
        auto_data.append(json.loads(line))

#### Get splits

In [None]:
def get_all_splits_auto(paper_ids, auto_data):
    
    all_splits_batch = dict()
    init_splits_batch = dict()
    
    for paper_id in paper_ids:
    
        processed_abstract = auto_data[paper_id]['abstract']

        num_sent = len(processed_abstract)

        all_sep_points = list(itertools.combinations(range(num_sent + 1), 2))
        all_sep_points = [list(sep_points) for sep_points in all_sep_points]
        all_splits = dict()
        
        for sep_points in all_sep_points:
            
            if sep_points[1] == num_sent and sep_points[0] in [num_sent-3, num_sent-2, num_sent-1]:
                conclusions = processed_abstract[sep_points[0]:]
                premises = [sent for sent in processed_abstract if sent not in conclusions]
                if {'premise': premises, 'conclusion': conclusions} not in all_splits.values():
                    if all(sep <= len(premises) + len(conclusions) for sep in sep_points):
                        all_splits[str(sep_points)] = {'premise': premises, 'conclusion': conclusions}
                    
            elif sep_points[0] == 1 and sep_points[1] in [num_sent-2, num_sent-1]:
                premises = processed_abstract[sep_points[0]:sep_points[1]]
                conclusions = [sent for sent in processed_abstract if sent not in premises]
                if {'premise': premises, 'conclusion': conclusions} not in all_splits.values():
                    if all(sep <= len(premises) + len(conclusions) for sep in sep_points):
                        all_splits[str(sep_points)] = {'premise': premises, 'conclusion': conclusions}
           
            elif sep_points[0] == 2 and sep_points[1] == num_sent-1:
                premises = processed_abstract[sep_points[0]:sep_points[1]]
                conclusions = [sent for sent in processed_abstract if sent not in premises]
                if {'premise': premises, 'conclusion': conclusions} not in all_splits.values():
                    if all(sep <= len(premises) + len(conclusions) for sep in sep_points):
                        all_splits[str(sep_points)] = {'premise': premises, 'conclusion': conclusions}
                    
                
        if all_splits:
        
            selected_sep_points = list(all_splits.keys())
            random_init_sep_points = selected_sep_points[np.random.choice(len(selected_sep_points), 1)[0]]

            init_splits_batch[paper_id] = {'sep_points': random_init_sep_points, 
                                           'premise': all_splits[random_init_sep_points]['premise'], 
                                           'conclusion': all_splits[random_init_sep_points]['conclusion']}

            all_splits_batch[paper_id] = all_splits
            
        else:
            continue
            
    return all_splits_batch, init_splits_batch

In [None]:
def get_word_prob(word, tokenized_corpus):
    num_occur = tokenized_corpus.count(word)
    
    prob = num_occur / len(tokenized_corpus)
    
    if prob > 1e-10:
        return prob
    
    else:
        return 1e-10

In [None]:
class Vocab:
    def __init__(self ):
        self.word_to_index = {}
        self.index_to_word = {}
        self.vocab_size = 0
    
    def sent2seq( self, sent ):
        seq = [] # list of word indices
        words = sent.split()
        for w in words:
            if w not in self.word_to_index and w != '.':
                self.word_to_index[w] = self.vocab_size
                self.index_to_word[self.vocab_size] = w
                self.vocab_size +=1
            if w != '.':
                seq.append( self.word_to_index[w] )
        return seq

In [None]:
def index_splits(splits, vocab):
    indexed_splits = dict()
    
    for paper_id, split in splits.items():
        indexed_splits[paper_id] = {"premise": [ vocab.sent2seq(sent) for sent in split['premise'] ], 
                                    "conclusion": [ vocab.sent2seq(sent) for sent in split['conclusion']]
                                   }
    
    return indexed_splits

In [None]:
def get_doc_list(indexed_splits):
    premise_doc_list = []
    conclusion_doc_list = []
    
    for paper_id in indexed_splits.keys():
        premise_doc_list.append(indexed_splits[paper_id]['premise'])
        conclusion_doc_list.append(indexed_splits[paper_id]['conclusion']) 
        
    return premise_doc_list, conclusion_doc_list

In [None]:
def compute_mi(pre_doc_list, con_doc_list, vocab_size, alpha="-inf", normalized=False): 
    
    num_docs = len(pre_doc_list)
    inv_idx_premise  = np.zeros( (num_docs, vocab_size), dtype = np.int32 )
    inv_idx_conclusion  = np.zeros( (num_docs, vocab_size), dtype = np.int32 )
    
    for doc_id, doc in enumerate(pre_doc_list):
        all_seqs = []
        for seq in doc:
            all_seqs += seq
        tok_ids, counts = np.unique( all_seqs, return_counts=True )
        if len(tok_ids)>0:
            inv_idx_premise[doc_id][ tok_ids ] = counts
            
    for doc_id, doc in enumerate(con_doc_list):
        all_seqs = []
        for seq in doc:
            all_seqs += seq
        tok_ids, counts = np.unique( all_seqs, return_counts=True )
        if len(tok_ids)>0:
            inv_idx_conclusion[doc_id][ tok_ids ] = counts
                
    wp_wc_count_prod_per_doc = inv_idx_premise[:,:, np.newaxis] * inv_idx_conclusion[:,np.newaxis,:]
    wp_wc_count_prod_sum_over_docs = wp_wc_count_prod_per_doc.sum(axis = 0)
        
    num_words_per_premise = inv_idx_premise.sum(axis = 1)
    num_words_per_conclusion = inv_idx_conclusion.sum(axis = 1)
    
    num_words_prod_per_doc = num_words_per_premise * num_words_per_conclusion
    
    wp_count_sum_over_docs = inv_idx_premise.sum(axis = 0)
    wc_count_sum_over_docs = inv_idx_conclusion.sum(axis = 0)
    
    assert wp_count_sum_over_docs.sum() == num_words_per_premise.sum()
    assert wc_count_sum_over_docs.sum() == num_words_per_conclusion.sum()
    assert wp_wc_count_prod_sum_over_docs.sum() == num_words_prod_per_doc.sum()

    unique_counts = np.unique(wp_wc_count_prod_sum_over_docs)
    
    count_histo = dict()
    for unique in unique_counts:
        unique_count = len(np.where(wp_wc_count_prod_sum_over_docs[wp_wc_count_prod_sum_over_docs == unique])[0])
        count_histo[unique] = unique_count
            
    P_wp_wc = wp_wc_count_prod_sum_over_docs /  num_words_prod_per_doc.sum()
    P_wp = wp_count_sum_over_docs / num_words_per_premise.sum()
    P_wc = wc_count_sum_over_docs / num_words_per_conclusion.sum()
    P_wp_x_P_wc = P_wp[:,np.newaxis] * P_wc[ np.newaxis, : ]
   
    if normalized == False:
        
        mi = P_wp_wc * ( np.log( P_wp_wc + 1e-9  ) - np.log( P_wp_x_P_wc + 1e-9  ) )
        mi = mi.sum()

        return mi
    
    else:
        mi = P_wp_wc * ( np.log( P_wp_wc + 1e-9  ) - np.log( P_wp_x_P_wc + 1e-9  ) )        
                
        mi = mi.sum()

        U_wp = - np.sum( P_wp * np.log(P_wp + 1e-9) )
        U_wc = - np.sum( P_wc * np.log(P_wc + 1e-9) )


        if alpha == "-inf":
            U_alpha = np.min([U_wc, U_wp])

        elif alpha == -1:
            U_alpha =  (2 * U_wc * U_wp) / (U_wc + U_wp )

        elif alpha == 0:
            U_alpha = np.sqrt(U_wc * U_wp)

        elif alpha == 1:
            U_alpha = (U_wp + U_wc) / 2

        elif alpha == 2:
            U_alpha = np.sqrt( ((U_wc * U_wc + U_wp * U_wp) / 2) )

        elif alpha == "inf":
            U_alpha = np.max([U_wc, U_wp])

        normalized_mi = mi / (U_alpha + 1e-9)

        return normalized_mi, count_histo

In [None]:
def get_greedy_batches(doc_list, batch_size):
    greedy_batches = []
    
    for i in range(0, len(doc_list), batch_size):
        greedy_batches.append(doc_list[i:i+batch_size])
        
    return greedy_batches

In [None]:
def get_mini_batches(paper_list, batch_size=100):
    np.random.shuffle(paper_list)
    mini_batches = []
    
    for i in range(0, len(paper_list), batch_size):
        mini_batches.append(paper_list[i:i+batch_size])
        
    return mini_batches

In [None]:
def compute_mi_single_split(doc_id, sep_point, split, curr_splits, vocab, alpha, normalized):
    temp_splits = curr_splits.copy()
    temp_splits[doc_id] = {'sep_points': sep_point, 
                           'premise': split['premise'], 'conclusion': split['conclusion']}
    indexed_splits = index_splits(temp_splits, vocab)
    pre_doc_list, con_doc_list = get_doc_list(indexed_splits)
    vocab_size = vocab.vocab_size
    mi, count_histo = compute_mi(pre_doc_list, con_doc_list, vocab_size, alpha, normalized)
    
    return {'mi': mi, 'split': temp_splits, 'histo':count_histo}

In [None]:
def circled_greedy_split_parallel(all_splits, curr_splits, alpha, normalized):
    
    vocab = Vocab()
    
    
    indexed_splits = index_splits(curr_splits, vocab)
    pre_doc_list, con_doc_list = get_doc_list(indexed_splits)
    vocab_size = vocab.vocab_size
     
    best_mi = 0
    
    best_splits = dict()
    best_histo = []
    paper_ids = list(curr_splits.keys()) 
    
    np.random.permutation(paper_ids)
    
    for paper in paper_ids:
        sep_points = list(all_splits[paper].keys())
        splits = list(all_splits[paper].values())
        with concurrent.futures.ThreadPoolExecutor() as executor:
            results = executor.map(compute_mi_single_split, 
                                   [paper]*len(splits), 
                                   sep_points, 
                                   splits, 
                                   [curr_splits]*len(splits), 
                                   [vocab]*len(splits), 
                                   [alpha]*len(splits), 
                                   [normalized]*len(splits))
            
            results = list(results)
            local_mis = [result['mi'] for result in results if result]
            local_splits = [result['split'] for result in results if result]
            
            best_local_mi = np.max(local_mis)
            best_local_idx = np.argmax(local_mis)
            
            if best_local_mi > best_mi:
                best_mi = best_local_mi
                curr_splits = local_splits[best_local_idx]
                best_splits = curr_splits

            else:
                continue
                
            best_histo = [result['histo'] for result in results if result]
            
    return best_splits, best_mi, best_histo

In [None]:
def generate_random_seeds(seed_range, num_seeds):
    random_seeds = []
    for i in range(num_seeds):
        random_seeds.append(np.random.choice(range(seed_range), 1, replace=False)[0])

    return random_seeds

In [None]:
class Cartesian():
    def __init__(self, data_group):
        self.data_group = data_group
        self.counter_idx = len(data_group) - 1
        self.counter = [0 for i in range(0, len(self.data_group))]
        
    def count_length(self):
        i = 0
        length = 1
        while i < len(self.data_group):
            length *= len(self.data_group[i])
            i += 1
        return length
    
    def handle(self):
        self.counter[self.counter_idx] += 1
        if self.counter[self.counter_idx] >= len(self.data_group[self.counter_idx]):
            self.counter[self.counter_idx] = 0
            self.counter_idx -= 1
            if self.counter_idx >= 0:
                self.handle()
            self.counter_idx = len(self.data_group) - 1

#### SentenceBERT to find the most relevant abstracts

In [None]:
class SentenceTransformersNNSearch:
    def __init__( self, model_name="all-MiniLM-L6-v2"):
        self.model = SentenceTransformer(model_name)
        
    def encode( self, sentences):
        return self.model.encode(sentences)
    
    def normalize_embeddings(self, embeddings ):
        assert len( embeddings.shape ) == 2
        normalized_embeddings = embeddings /(np.linalg.norm( embeddings, axis =1, keepdims=True )+1e-12)
        return normalized_embeddings
    
    def rank_sentences( self, source_sentence, target_sentences, top_n=None, reverse=False  ):
        source_embedding = self.normalize_embeddings(self.encode([source_sentence]))[0]
        target_embeddings = self.normalize_embeddings(self.encode(target_sentences))
        sim_scores = np.dot(target_embeddings, source_embedding)
        if top_n is None:
            top_n = len(target_sentences)
        if reverse is False:
            I = np.argsort(-sim_scores)[:top_n]
            D = sim_scores[I]
        elif reverse is True:
            I = np.argsort(sim_scores)[:top_n]
            D = sim_scores[I]
        return D, I 

In [None]:
SentBert_model = SentenceTransformersNNSearch()

#### Exhaustive abstract slicing on the entire dataset

In [None]:
from sklearn.metrics import jaccard_score

In [None]:
def jaccard_distance(A, B):
    #Find intersection of two sets
    A = set(A)
    B = set(B)
    
    nominator = A.intersection(B)
    denominator = A.union(B)
    similarity = len(nominator)/len(denominator)
    
    return 1 - similarity

In [None]:
def get_segment_boundary(input_seq):
    ### input: seq of labels, 1 for conclusion and 0 for premise, e.g. [1, 0, 0, 0, 1] indicates that the first and the last sentence in the abstract are conclusions
    ### output: seq of labels, 1 for boundary and 0 for non-boundary. e.g. [1, 0, 0, 0, 1] ==> [1, 0, 0, 1, 0]
    
    output_seq = np.zeros(len(input_seq))
    
    for i in range(len(input_seq)-1):
        if input_seq[i+1] != input_seq[i]:
            output_seq[i] = 1
        else:
            output_seq[i] = 0
    if input_seq[-1] == input_seq[0]:
        output_seq[-1] = 0
    else:
        output_seq[-1] = 1
    
    return output_seq 

#### GreedyCAS-plus (with NN search)

In [None]:
paper_list = range(0, len(auto_data))

chunk_size = 100
batch_size = 12
num_epochs = 5
seed_size = int(chunk_size / batch_size)

all_chunks = list()
for i in range(0, len(paper_list), chunk_size):
    all_chunks.append(paper_list[i:i+chunk_size])
    
print("chunking finished! \n")

jaccard_dist_all = []
pk_score_all = []
rouge_score_all = []
windiff_score_all = []
mi_score_all = []
        
while len(paper_list) != 0:
    
    chunk_counter = 0
    for current_chunk in all_chunks:
        print('===========================================================================')
        print(f'remaining abstracts to slice:{len(paper_list)}')
        print(f"current chunk processing: {chunk_counter}")

        seed_papers = np.random.choice(current_chunk, seed_size, replace=False)
        current_chunk_abstracts = dict()
        for paper in current_chunk:
            current_chunk_abstracts[paper] = " ".join(auto_data[paper]['conclusion'])
        
        chunk_paper_ids = list()
        for paper in seed_papers:
            source_abstract = " ".join(auto_data[paper]['abstract'])
            target_abstracts = list(current_chunk_abstracts.values())
            D, I = SentBert_model.rank_sentences( source_abstract, target_abstracts, top_n=batch_size, reverse=False )
            chunk_idx = np.array(list(current_chunk_abstracts.keys()))[I]
            chunk_idx.tolist().insert(0, paper)
            chunk_paper_ids.append(list(chunk_idx))
            
        paper_list = [rest_paper for rest_paper in paper_list if rest_paper not in current_chunk]
                    
        chunk_pk_scores = []
        chunk_windiff_scores = []
        chunk_jaccard_scores = []
        chunk_rouge_scores = []
        
        for batch_paper_ids in chunk_paper_ids:
            
            ground_truth = dict()
            
            ref_bounds_batch = dict()
            jaccard_ref_batch = []
            
            for paper_id in batch_paper_ids:
                
                try:
                    
                    true_conclusion = auto_data[paper_id]['conclusion']
                    
                    jaccard_ref_batch.append(true_conclusion)
                    
                    true_premise = [sent for sent in auto_data[paper_id]['abstract'] if sent not in auto_data[paper_id]['conclusion']]

                    true_bound = np.zeros(len(auto_data[paper_id]['abstract']))
                    
                    true_bound[-len(true_conclusion):] = 1
                    ref_bounds_batch[paper_id] = get_segment_boundary(true_bound)
                    
                except KeyError:
                    continue
                                
            # randomly start the batch configuration
            all_splits, curr_splits = get_all_splits_auto(batch_paper_ids, auto_data)
            
            best_mi = 0
            best_splits = dict()
            
            for epoch in range(num_epochs):
                splits, mi, histo = circled_greedy_split_parallel(all_splits, curr_splits, alpha="-inf", normalized=True)
                
                if mi > best_mi:
                    best_mi = mi
                    mi_score_all.append(best_mi)
                    best_splits = splits
                    curr_splits = splits
                    
                else:
                    all_splits, curr_splits = get_all_splits_auto(batch_paper_ids, auto_data)
                    
            jaccard_scores = list()
            
            ### construct boundaries for evaluating with pk
            ### hyp_bound: seq for the best configuration, e.g. [0, 0, 0, 1, 0, 0, 1] 0 => not a boundary sentence, 1 => a boundary sentence
            ### true_bound: seq for the ground truth config, e.g. [0, 0, 0, 0, 0, 1, 1] 0 => not a boundary sentence, 1 => a boundary sentence
            
            hyp_bounds_batch = dict()
            jaccard_hyp_batch = []
            rouge_score_batch = []
            
            for key in batch_paper_ids:
                hyp_bound = np.zeros(len(auto_data[key]['abstract']))
                hyp_conclusion = best_splits[key]['conclusion']
                jaccard_hyp_batch.append(hyp_conclusion)
                try:
                    hyp_bound[ast.literal_eval(best_splits[key]['sep_points'])[0]-1] = 1
                    hyp_bound[ast.literal_eval(best_splits[key]['sep_points'])[1]-1] = 1
                    hyp_bounds_batch[key] = hyp_bound
        
                except IndexError:
                    print(best_splits)
                
            batch_pk_scores = []
            for key in best_splits:
                batch_pk_scores.append(pk("".join(str(int(item)) for item in hyp_bounds_batch[key]), "".join(str(int(item)) for item in ref_bounds_batch[key]), 2))
            
            avg_batch_pk = np.mean(batch_pk_scores)
            
            batch_windiff_scores = []
            for key in best_splits:
                batch_windiff_scores.append(windowdiff("".join(str(int(item)) for item in hyp_bounds_batch[key]), "".join(str(int(item)) for item in ref_bounds_batch[key]), 2))
            
            avg_batch_windiff = np.mean(batch_windiff_scores)
            
            batch_jaccard_scores = []
            batch_rouge_scores = []
            for hyp, ref in zip(jaccard_hyp_batch, jaccard_ref_batch):
                
                batch_jaccard_scores.append(jaccard_distance(ref, hyp))
                rouge_score = scorer.score(" ".join(ref), " ".join(hyp))
                
                rouge1_fm = rouge_score['rouge1'][2]
                rouge2_fm = rouge_score['rouge2'][2]
                rougeL_fm = rouge_score['rougeLsum'][2]
                batch_rouge_scores.append((rouge1_fm + rouge2_fm + rougeL_fm)/3)
                
            avg_batch_jaccard = np.mean(batch_jaccard_scores)
            
            avg_batch_rouge = np.mean(batch_rouge_scores)
            
            print(f"averaged Pk on the {chunk_paper_ids.index(batch_paper_ids)}.th batch: {avg_batch_pk}")
            chunk_pk_scores.append(batch_pk_scores)
            print(f"averaged wd on the {chunk_paper_ids.index(batch_paper_ids)}.th batch: {avg_batch_windiff}")
            chunk_windiff_scores.append(batch_windiff_scores)
            print(f"averaged Jaccard on the {chunk_paper_ids.index(batch_paper_ids)}.th batch: {avg_batch_jaccard}")
            chunk_jaccard_scores.append(batch_jaccard_scores)
            print(f"averaged Rouge on the {chunk_paper_ids.index(batch_paper_ids)}.th batch: {avg_batch_rouge}")
            chunk_rouge_scores.append(batch_rouge_scores)
            
        print(f"averaged Pk of the {chunk_counter}.th chunk: {np.mean(chunk_pk_scores)}")
        print(f"averaged wd of the {chunk_counter}.th chunk: {np.mean(chunk_windiff_scores)}")
        print(f"averaged Jaccard of the {chunk_counter}.th chunk: {np.mean(chunk_jaccard_scores)}")
        print(f"averaged Rouge of the {chunk_counter}.th chunk: {np.mean(chunk_rouge_scores)}")
        pk_score_all.append(chunk_pk_scores)
        windiff_score_all.append(chunk_windiff_scores)
        jaccard_dist_all.append(chunk_jaccard_scores)
        rouge_score_all.append(chunk_rouge_scores)
        
        chunk_counter += 1
        
print(f"overall Pk score for GreedyCAS-plus: {np.mean(pk_score_all)}")
print(f"overall wd scores for GreedyCAS-plus: {np.mean(windiff_score_all)}")
print(f"overall Jaccard index for GreedyCAS-plus: {1 - np.mean(jaccard_dist_all)}")
print(f"overall Rouge scores for GreedyCAS-plus: {np.mean(rouge_score_all)}")
print(f"overall NMI scores for GreedyCAS-plus: {np.mean(mi_score_all)}")

#### GreedyCAS-base (no NN search)

In [None]:
paper_list = range(0, len(auto_data))

chunk_size = 100
batch_size = 11
num_epochs = 5
seed_size = int(chunk_size / batch_size)

all_chunks = list()c
for i in range(0, len(paper_list), chunk_size):
    all_chunks.append(paper_list[i:i+chunk_size])
    
print("chunking finished! \n")

jaccard_dist_all = []
pk_score_all = []
rouge_score_all = []
windiff_score_all = []
mi_score_all = []
count_histo_all = []
        
while len(paper_list) != 0:
    
    chunk_counter = 0
    for current_chunk in all_chunks:
        print('===========================================================================')
        print(f'remaining abstracts to slice:{len(paper_list)}')
        print(f"current chunk processing: {chunk_counter}")

        chunk_paper_ids = list()
#         
        for i in range(0, len(current_chunk), batch_size):
            chunk_paper_ids.append(list(current_chunk[i:i+batch_size]))
            
        paper_list = [rest_paper for rest_paper in paper_list if rest_paper not in current_chunk]
        
        chunk_pk_scores = []
        chunk_windiff_scores = []
        chunk_jaccard_scores = []
        chunk_rouge_scores = []
        chunk_mi_scores = []
        
        for batch_paper_ids in chunk_paper_ids:
            
            ground_truth = dict()
            
            ref_bounds_batch = dict()
            jaccard_ref_batch = []
            
            for paper_id in batch_paper_ids:
                
                try:

                    true_conclusion = auto_data[paper_id]['conclusion']
                                        
                    jaccard_ref_batch.append(true_conclusion)
                    true_premise = [sent for sent in auto_data[paper_id]['abstract'] if sent not in auto_data[paper_id]['conclusion']]

                    true_bound = np.zeros(len(auto_data[paper_id]['abstract']))
                    
                    true_bound[-len(true_conclusion):] = 1
                    ref_bounds_batch[paper_id] = get_segment_boundary(true_bound)
                    
                except KeyError:
                    continue
                                
            # randomly start the batch configuration
            all_splits, curr_splits = get_all_splits_auto(batch_paper_ids, auto_data)
            
            best_mi = 0
            best_splits = dict()
            best_histo = dict()
            
            for epoch in range(num_epochs):
                splits, mi, histo = circled_greedy_split_parallel(all_splits, curr_splits, alpha="-inf", normalized=True)
                
                if mi > best_mi:
                    best_mi = mi
                    best_histo = histo
                    mi_score_all.append(best_mi)
                    best_splits = splits
                    curr_splits = splits
                    
                else:
                    all_splits, curr_splits = get_all_splits_auto(batch_paper_ids, auto_data)
                    
            
            ### construct boundaries for evaluating with pk
            ### hyp_bound: seq for the best configuration, e.g. [0, 0, 0, 1, 0, 0, 1] 0 => not a boundary sentence, 1 => a boundary sentence
            ### true_bound: seq for the ground truth config, e.g. [0, 0, 0, 0, 0, 1, 1] 0 => not a boundary sentence, 1 => a boundary sentence
            
            hyp_bounds_batch = dict()
            jaccard_hyp_batch = []
            rouge_score_batch = []
            mi_batch = best_mi
            
            for key in best_splits:
                hyp_bound = np.zeros(len(auto_data[key]['abstract']))
                
                hyp_conclusion = best_splits[key]['conclusion']
                
                jaccard_hyp_batch.append(hyp_conclusion)
                
                try:
                    hyp_bound[ast.literal_eval(best_splits[key]['sep_points'])[0]-1] = 1
                    hyp_bound[ast.literal_eval(best_splits[key]['sep_points'])[1]-1] = 1
                    hyp_bounds_batch[key] = hyp_bound
                except IndexError:
                    print(best_splits[key])
                
            batch_pk_scores = []
            for key in best_splits:
                batch_pk_scores.append(pk("".join(str(int(item)) for item in hyp_bounds_batch[key]), "".join(str(int(item)) for item in ref_bounds_batch[key]), 2))
            
            avg_batch_pk = np.mean(batch_pk_scores)
            
            batch_windiff_scores = []
            for key in best_splits:
                batch_windiff_scores.append(windowdiff("".join(str(int(item)) for item in hyp_bounds_batch[key]), "".join(str(int(item)) for item in ref_bounds_batch[key]), 2))
            
            avg_batch_windiff = np.mean(batch_windiff_scores)
            
            batch_jaccard_scores = []
            batch_rouge_scores = []
            
            for hyp, ref in zip(jaccard_hyp_batch, jaccard_ref_batch):
                
                batch_jaccard_scores.append(jaccard_distance(ref, hyp))
                rouge_score = scorer.score(" ".join(ref), " ".join(hyp))
                rouge1_fm = rouge_score['rouge1'][2]
                rouge2_fm = rouge_score['rouge2'][2]
                rougeL_fm = rouge_score['rougeLsum'][2]
                batch_rouge_scores.append((rouge1_fm + rouge2_fm + rougeL_fm)/3)
                
            avg_batch_jaccard = np.mean(batch_jaccard_scores)
            
            avg_batch_rouge = np.mean(batch_rouge_scores)
            
            print(f"averaged Pk on the {chunk_paper_ids.index(batch_paper_ids)}.th batch: {avg_batch_pk}")
            chunk_pk_scores.append(avg_batch_pk)
            print(f"averaged wd on the {chunk_paper_ids.index(batch_paper_ids)}.th batch: {avg_batch_windiff}")
            chunk_windiff_scores.append(avg_batch_windiff)
            print(f"averaged Jaccard on the {chunk_paper_ids.index(batch_paper_ids)}.th batch: {avg_batch_jaccard}")
            chunk_jaccard_scores.append(avg_batch_jaccard)
            print(f"averaged Rouge on the {chunk_paper_ids.index(batch_paper_ids)}.th batch: {avg_batch_rouge}")
            chunk_rouge_scores.append(avg_batch_rouge)
            print(f"averaged MI on the {chunk_paper_ids.index(batch_paper_ids)}.th batch: {mi_batch}")
            chunk_mi_scores.append(mi_batch)
            
            count_histo_all.extend(best_histo)
            
        print(f"averaged Pk of the {chunk_counter}.th chunk: {np.mean(chunk_pk_scores)}")
        print(f"averaged wd of the {chunk_counter}.th chunk: {np.mean(chunk_windiff_scores)}")
        print(f"averaged Jaccard of the {chunk_counter}.th chunk: {np.mean(chunk_jaccard_scores)}")
        print(f"averaged Rouge of the {chunk_counter}.th chunk: {np.mean(chunk_rouge_scores)}")
        print(f"averaged MI of the {chunk_counter}.th chunk: {np.mean(chunk_mi_scores)}")
        
        pk_score_all.append(np.mean(chunk_pk_scores))
        windiff_score_all.append(np.mean(chunk_windiff_scores))
        jaccard_dist_all.append(np.mean(chunk_jaccard_scores))
        rouge_score_all.append(np.mean(chunk_rouge_scores))
        mi_score_all.append(np.mean(chunk_mi_scores))
                
        chunk_counter += 1
        
print(f"overall pk scores for GreedyCAS-base: {np.mean(pk_score_all)}")
print(f"overall wd scores for GreedyCAS-base: {np.mean(windiff_score_all)}")
print(f"overall Jaccard index for GreedyCAS-base: {1 - np.mean(jaccard_dist_all)}")
print(f"overall Rouge scores for GreedyCAS-base: {np.mean(rouge_score_all)}")
print(f"overall NMI scores for GreedyCAS-base: {np.mean(mi_score_all)}")

#### Random-base

In [None]:
paper_list = range(0, len(auto_data))

all_splits, _ = get_all_splits_auto(paper_list, auto_data)

ref_bounds_all = dict()       
hyp_bounds_all = dict()

jaccard_dist_all = []
rouge_score_all = []
pk_score_all = []
windiff_score_all = []


for paper_id in tqdm(paper_list):
    try:
        true_conclusion = auto_data[paper_id]['conclusion']
        true_premise = [sent for sent in auto_data[paper_id]['abstract'] if sent not in auto_data[paper_id]['conclusion']]
        
        true_bound = np.zeros(len(auto_data[paper_id]['abstract']))
        true_bound[-len(true_conclusion):] = 1
        ref_bounds_all[paper_id] = get_segment_boundary(true_bound)
        
        true_abstract = true_conclusion + true_premise
        
        np.random.seed(1)
        random_split_boundary = np.random.choice([idx for idx in range(1, len(true_abstract)+1, 1)], size=2, replace=False)
        random_split_boundary.sort()
        random_split_boundary = [idx for idx in list(random_split_boundary)]
        
        random_split_segment = dict()
        temp = true_abstract[random_split_boundary[0]:random_split_boundary[1]]
        
        if any(sent in true_abstract[-3:] for sent in temp) :
            random_split_segment['conclusion'] = temp
            random_split_segment['premise'] = [sent for sent in true_abstract if sent not in temp]
        else:
            random_split_segment['premise'] = temp
            random_split_segment['conclusion'] = [sent for sent in true_abstract if sent not in temp]
                            
        hyp_bound = np.zeros(len(auto_data[paper_id]['abstract']))
        
        hyp_conclusion = random_split_segment['conclusion']
        
        jaccard_dist_all.append(jaccard_distance(hyp_conclusion, true_conclusion))
        
        rouge_score = scorer.score(" ".join(true_conclusion), " ".join(hyp_conclusion))
        
        rouge1_fm = rouge_score['rouge1'][2]
        rouge2_fm = rouge_score['rouge2'][2]
        rougeL_fm = rouge_score['rougeLsum'][2]
        
        rouge_score_all.append((rouge1_fm + rouge2_fm + rougeL_fm)/3)
        
        try:
            hyp_bound[random_split_boundary[0]-1] = 1
            hyp_bound[random_split_boundary[1]-1] = 1
            hyp_bounds_all[paper_id] = hyp_bound
        except IndexError:
            print(all_splits[paper_id])   

    except KeyError:
        continue
    
    
for key in ref_bounds_all.keys():
    pk_score_all.append(pk("".join(str(int(item)) for item in hyp_bounds_all[key]), "".join(str(int(item)) for item in ref_bounds_all[key]), 2))
    windiff_score_all.append(windowdiff("".join(str(int(item)) for item in hyp_bounds_all[key]), "".join(str(int(item)) for item in ref_bounds_all[key]), 2))
    

print(f"avg. pk scores for random fetch: {np.mean(pk_score_all)}")
print(f"avg. wd scores for random fetch: {np.mean(windiff_score_all)}")
print(f"avg. Jaccard index for random fetch: {1 - np.mean(jaccard_dist_all)}")
print(f"avg. Rouge scores for random fetch: {np.mean(rouge_score_all)}")

#### Random-plus

In [None]:
paper_list = range(0, len(auto_data))

all_splits, _ = get_all_splits_auto(paper_list, auto_data)

ref_bounds_all = dict()       
hyp_bounds_all = dict()

jaccard_dist_all = []
rouge_score_all = []
pk_score_all = []
windiff_score_all = []

indexed_splits = dict()

for paper_id in tqdm(paper_list):
    try:
        true_conclusion = auto_data[paper_id]['conclusion']
        true_premise = [sent for sent in auto_data[paper_id]['abstract'] if sent not in auto_data[paper_id]['conclusion']]

        true_bound = np.zeros(len(auto_data[paper_id]['abstract']))
        true_bound[-len(true_conclusion):] = 1
        ref_bounds_all[paper_id] = get_segment_boundary(true_bound)
        
        random.seed(1)
        random_split_boundary, random_split_segment = random.choice(list(all_splits[paper_id].items()))
    
        hyp_bound = np.zeros(len(auto_data[paper_id]['abstract']))
        
        hyp_conclusion = random_split_segment['conclusion']
        jaccard_dist_all.append(jaccard_distance(hyp_conclusion, true_conclusion))
        
        rouge_score = scorer.score(" ".join(true_conclusion), " ".join(hyp_conclusion))
        
        rouge1_fm = rouge_score['rouge1'][2]
        rouge2_fm = rouge_score['rouge2'][2]
        rougeL_fm = rouge_score['rougeLsum'][2]
        
        rouge_score_all.append((rouge1_fm + rouge2_fm + rougeL_fm)/3)
        
        indexed_splits[paper_id] = {'premise': [sent for sent in random_split_segment['premise']], 
                                    'conclusion': [sent for sent in random_split_segment['conclusion']]}
        
        try:
            hyp_bound[ast.literal_eval(random_split_boundary)[0]-1] = 1
            hyp_bound[ast.literal_eval(random_split_boundary)[1]-1] = 1
            hyp_bounds_all[paper_id] = hyp_bound
            
        except IndexError:
            print(all_splits[paper_id])   

    except KeyError:
        continue
    
for key in ref_bounds_all.keys():
    pk_score_all.append(pk("".join(str(int(item)) for item in hyp_bounds_all[key]), "".join(str(int(item)) for item in ref_bounds_all[key]), 2))
    windiff_score_all.append(windowdiff("".join(str(int(item)) for item in hyp_bounds_all[key]), "".join(str(int(item)) for item in ref_bounds_all[key]), 2))
        
print(f"avg. pk scores for random fetch: {np.mean(pk_score_all)}")
print(f"avg. wd scores for random fetch: {np.mean(windiff_score_all)}")
print(f"avg. Jaccard index for random fetch: {1 - np.mean(jaccard_dist_all)}")
print(f"avg. Rouge scores for random fetch: {np.mean(rouge_score_all)}")

#### SBERT-sim

In [None]:
### SBERT-sim baseline: choose the configuration where the semantic similarity between premises and conclusions is the greatest

paper_list = range(0, len(auto_data))

all_splits, _ = get_all_splits_auto(paper_list, auto_data)

ref_bounds_all = dict()   
hyp_bounds_all = dict()

jaccard_dist_all = []
rouge_score_all = []
pk_score_all = []
windiff_score_all = []

for paper_id in tqdm(paper_list):
    try:
        true_conclusion = auto_data[paper_id]['conclusion']
        
        true_premise = [sent for sent in auto_data[paper_id]['abstract'] if sent not in auto_data[paper_id]['conclusion']]

        true_bound = np.zeros(len(auto_data[paper_id]['abstract']))
        true_bound[-len(true_conclusion):] = 1
        ref_bounds_all[paper_id] = get_segment_boundary(true_bound)
                
        all_sim_scores = []
        all_boundaries = []
        
        for boundary, split in all_splits[paper_id].items():
            premises = " ".join(split['premise'])
            conclusions = " ".join(split['conclusion'])
            premise_embedding = SentBert_model.normalize_embeddings(SentBert_model.encode([premises]))[0]
            conclusion_embedding = SentBert_model.normalize_embeddings(SentBert_model.encode([conclusions]))[0]
            sim_score = np.dot(premise_embedding, conclusion_embedding)
            all_sim_scores.append(sim_score)
            all_boundaries.append(boundary)
        
        max_sim_idx = np.argmax(all_sim_scores)
        
        hyp_bound = np.zeros(len(auto_data[paper_id]['abstract']))
        
        hyp_conclusion = all_splits[paper_id][all_boundaries[max_sim_idx]]['conclusion']
        jaccard_dist_all.append(1 - jaccard_distance(hyp_conclusion, true_conclusion))
        
        rouge_score = scorer.score(" ".join(true_conclusion), " ".join(hyp_conclusion))
        
        rouge1_fm = rouge_score['rouge1'][2]
        rouge2_fm = rouge_score['rouge2'][2]
        rougeL_fm = rouge_score['rougeLsum'][2]
        
        rouge_score_all.append((rouge1_fm + rouge2_fm + rougeL_fm)/3)
        
        try:
            hyp_bound[ast.literal_eval(all_boundaries[max_sim_idx])[0]-1] = 1
            hyp_bound[ast.literal_eval(all_boundaries[max_sim_idx])[1]-1] = 1
            hyp_bounds_all[paper_id] = hyp_bound
            
        except IndexError:
            print(all_splits[paper_id])
        

    except KeyError:
        continue
        
all_random_pk = []


for key in ref_bounds_all.keys():
    pk_score_all.append(pk("".join(str(int(item)) for item in hyp_bounds_all[key]), "".join(str(int(item)) for item in ref_bounds_all[key]), 2))
    windiff_score_all.append(windowdiff("".join(str(int(item)) for item in hyp_bounds_all[key]), "".join(str(int(item)) for item in ref_bounds_all[key]), 2))
    
print(f"avg. pk scores for SBERT: {np.mean(pk_score_all)}")
print(f"avg. wd scores for SBERT: {np.mean(windiff_score_all)}")
print(f"avg. Jaccard index for SBERT: {np.mean(jaccard_dist_all)}")
print(f"avg. Rouge scores for SBERT: {np.mean(rouge_score_all)}")

#### TextTiling 

In [None]:
### TextTiling baseline: split abstract based on lexical cohesion between segments

paper_list = range(0, len(auto_data))

all_splits, _ = get_all_splits_auto(paper_list, auto_data)

ref_bounds_all = dict()         
hyp_bounds_all = dict()

ht = HarvestText(language="en")

jaccard_dist_all = []
rouge_score_all = []
pk_score_all = []
windiff_score_all = []

for paper_id in tqdm(paper_list):
    try:
        true_conclusion = auto_data[paper_id]['conclusion']

        true_premise = [sent for sent in auto_data[paper_id]['abstract'] if sent not in auto_data[paper_id]['conclusion']]
        
        true_abstract = true_premise + true_conclusion
        
        true_bound = np.zeros(len(auto_data[paper_id]['abstract']))
        true_bound[-len(true_conclusion):] = 1
        
        ref_bounds_all[paper_id] = get_segment_boundary(true_bound)
        
        predicted_paras = ht.cut_paragraphs("\n\n".join(true_abstract), num_paras=2)
        
        hyp_bound = []
        for sent in true_abstract:
            if sent in predicted_paras[0]:
                hyp_bound.append(1)
            else:
                hyp_bound.append(0)
        
        hyp_conclusion = auto_data[paper_id]['abstract'][hyp_bound.index(0):]
        
        jaccard_dist_all.append(jaccard_distance(hyp_conclusion, true_conclusion))
        
        rouge_score = scorer.score(" ".join(true_conclusion), " ".join(hyp_conclusion))
        
        rouge1_fm = rouge_score['rouge1'][2]
        rouge2_fm = rouge_score['rouge2'][2]
        rougeL_fm = rouge_score['rougeLsum'][2]
        
        rouge_score_all.append((rouge1_fm + rouge2_fm + rougeL_fm)/3)
        
        hyp_bound_final = np.zeros(len(auto_data[paper_id]['abstract']))
                
        for i in range(len(hyp_bound) - 1):
            if hyp_bound[i] != hyp_bound[i+1]:
                hyp_bound_final[i] = 1
            else:
                continue
        
        if hyp_bound[-1] != hyp_bound[0]:
            hyp_bound_final[-1] = 1
                                    
        hyp_bounds_all[paper_id] = hyp_bound_final    
            
    except KeyError:
        continue

for key in ref_bounds_all.keys():
    pk_score_all.append(pk("".join(str(int(item)) for item in hyp_bounds_all[key]), "".join(str(int(item)) for item in ref_bounds_all[key]), 2))
    windiff_score_all.append(windowdiff("".join(str(int(item)) for item in hyp_bounds_all[key]), "".join(str(int(item)) for item in ref_bounds_all[key]), 2))

print(f"avg. pk scores for TextTiling: {np.mean(pk_score_all)}")
print(f"avg. wd scores for TextTiling: {np.mean(windiff_score_all)}")
print(f"avg. Jaccard index  for TextTiling: {1 - np.mean(jaccard_dist_all)}")
print(f"avg. Rouge scores for TextTiling: {np.mean(rouge_score_all)}")