In [None]:
import os

if os.path.basename(os.getcwd()) != 'HUST-NLP-Medical-MultiDocument-Summarization-':
    %cd ../../

In [10]:
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import LEDForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, AutoTokenizer
import torch
from torch.nn import CrossEntropyLoss
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score
from tqdm.notebook import tqdm

In [None]:
RANDOM_SEED = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
PATH = 'allenai/led-base-16384'
tokenizer = AutoTokenizer.from_pretrained(PATH)
special_tokens_dict = {'additional_special_tokens': ['<doc-sep>']}
tokenizer.add_special_tokens(special_tokens_dict)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(PATH)
model.resize_token_embeddings(len(tokenizer))

In [None]:
DOC_SEP_ = "<doc-sep>"
docsep_token_id = tokenizer.convert_tokens_to_ids(DOC_SEP_)

In [None]:
%pip install evaluate
%pip install rouge-score
import evaluate
rouge = evaluate.load('rouge')

In [None]:
cochrane_train_input = pd.read_csv("./datasets/mslr_data/ms2/train-inputs.csv")
cochrane_train_input["Abstract"].fillna("",inplace = True)

abstract_count_agg = cochrane_train_input.groupby('ReviewID')['Abstract'].count().reset_index(name='Abstract_Count')

cochrane_train_input['Combined_Abstract_Title'] = cochrane_train_input.apply(lambda row: [row['Title'],row['Abstract']], axis=1)
cochrane_train_input = cochrane_train_input.groupby('ReviewID')['Combined_Abstract_Title'].agg(list).reset_index()
cochrane_train_input['Abstracts'] = cochrane_train_input['Combined_Abstract_Title'].apply(lambda list_of_lists: [item for sublist in list_of_lists for item in sublist]) # Flatten
cochrane_train_input = cochrane_train_input.drop(columns=['Combined_Abstract_Title'])

cochrane_train_input = pd.merge(cochrane_train_input, abstract_count_agg, on='ReviewID', how='inner') 
cochrane_train_input.sort_values(by='Abstract_Count', inplace=True, ignore_index=True)
hehe = cochrane_train_input.shape[0]

In [16]:
class PT_Medical_Dataset(Dataset):
    def __init__(self,tokenizer:AutoTokenizer,train_data):
        self.data = train_data.loc[train_data['Abstracts'].apply(len) >= 3].copy()
        self.data.reset_index(drop=True,inplace=True)
        self.tokenizer = tokenizer
        self.data["Target"] = ''
    
        for i in tqdm(range(len(self))):
            sentences = self.data.loc[i,"Abstracts"]
            target_ids = self.calc_target_cluster(sentences)
            target = sentences[target_ids.argmax()]
            self.data.loc[i,'Target'] = target
            self.data.at[i,'Abstracts'] = [k for k in sentences if k!=target]
            self.data.at[i,'Abstracts'] = DOC_SEP_.join(self.data.at[i,'Abstracts']) + DOC_SEP_
        
    def __len__(self):
        return self.data.shape[0]
        
    def calc_target_cluster(self,sentence):
        res = np.zeros(len(sentence))
        pairwise_rouge = np.zeros((len(sentence),len(sentence)))
        for k in range(len(sentence)):
            for l in range(k+1,len(sentence)):
                score = rouge.compute(predictions = [sentence[k]],references = [sentence[l]],rouge_types=['rouge1','rouge2','rougeL'])
                pairwise_rouge[k][l] = sum(score.values())/3
                pairwise_rouge[l][k] = sum(score.values())/3
            
        for k in range(len(sentence)):
            res[k] = sum(pairwise_rouge[k])/len(sentence)
        return res

In [None]:
for num in range(0,hehe,500):
    cochrane_train_input_2 = cochrane_train_input.loc[num:num+499,:]
    print(cochrane_train_input_2.shape)
    train_dataset = PT_Medical_Dataset(tokenizer,cochrane_train_input_2)
    cochrane_train_input_2 = train_dataset.data
    cochrane_train_input_2.to_csv(f"./datasets/mslr_data/ms2/hehe{num}.csv",index=False)
    print(num)

In [18]:
final = pd.read_csv("./datasets/mslr_data/ms2/hehe0.csv")

In [19]:
for num in range(500,hehe,500):
    final = pd.concat([final,pd.read_csv(f"./datasets/mslr_data/ms2/hehe{num}.csv")],ignore_index=True)

In [20]:
final.to_csv("./datasets/mslr_data/ms2/train-inputs-pretrain.csv")

In [None]:
final