In [4]:
import os

from datasets import load_dataset, load_metric, DatasetDict
rouge = load_metric("rouge", trust_remote_code=True)
import textdistance

import json
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, LEDForConditionalGeneration
from collections import defaultdict, Counter
import nltk
from nltk.tokenize import word_tokenize, RegexpTokenizer
from nltk.corpus import stopwords

import sys
sys.path.append('./DMRST_Parser/')

from model_depth import ParsingNet
import numpy as np
import torch
import re
import copy
import string

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

device

'cuda'

## load original data

In [2]:
def preprocess_docs(docs):
    all_docs = docs.split("|||||")
    new_docs = []
    for i, doc in enumerate(all_docs):
        doc = re.sub(r'\s*\n\s*\n', ' ', str(doc))
        doc = doc.replace('  ', ' ')
        
        if doc.strip() == '' or len(doc) < 50: # if empty
            continue
        new_docs.append(doc)

    return new_docs

In [3]:
original_dataset = load_dataset('multi_news', trust_remote_code=True)

## PRIMER model

In [None]:
PRIMER_path = 'allenai/PRIMERA'
tokenizer = AutoTokenizer.from_pretrained(PRIMER_path)

## EDU model

In [129]:
def inference(model, tokenizer, input_sentences, batch_size):
    LoopNeeded = int(np.ceil(len(input_sentences) / batch_size))

    input_sentences = [tokenizer.tokenize(i, add_special_tokens=False) for i in input_sentences]
    all_segmentation_pred = []
    all_tree_parsing_pred = []

    with torch.no_grad():
        for loop in range(LoopNeeded):
            StartPosition = loop * batch_size
            EndPosition = (loop + 1) * batch_size
            if EndPosition > len(input_sentences):
                EndPosition = len(input_sentences)

            input_sen_batch = input_sentences[StartPosition:EndPosition]
            _, _, SPAN_batch, _, predict_EDU_breaks = model.TestingLoss(input_sen_batch, input_EDU_breaks=None, LabelIndex=None,
                                                                        ParsingIndex=None, GenerateTree=True, use_pred_segmentation=True)
            all_segmentation_pred.extend(predict_EDU_breaks)
            all_tree_parsing_pred.extend(SPAN_batch)
    return input_sentences, all_segmentation_pred, all_tree_parsing_pred


def inference_only_EDU_break(model, tokenizer, input_sentences, batch_size):
    LoopNeeded = int(np.ceil(len(input_sentences) / batch_size))

    input_sentences = [tokenizer.tokenize(i, add_special_tokens=False) for i in input_sentences]
    all_segmentation_pred = []
    all_tree_parsing_pred = []

    with torch.no_grad():
        for loop in range(LoopNeeded):
            StartPosition = loop * batch_size
            EndPosition = (loop + 1) * batch_size
            if EndPosition > len(input_sentences):
                EndPosition = len(input_sentences)

            input_sen_batch = input_sentences[StartPosition:EndPosition]
            EncoderOutputs, Last_Hiddenstates, _, predict_EDU_breaks = model.encoder(input_sen_batch, None, is_test=True)

            all_segmentation_pred.extend(predict_EDU_breaks)
    return input_sentences, all_segmentation_pred



def tokens_to_string(tokens):
    text = ''.join([' '+token.lstrip('▁') if token.startswith('▁') else token for token in tokens])
    return text


def split_list_by_positions(lst, positions):
    result = []
    prev_pos = 0
    
    for pos in positions:
        token_list = lst[prev_pos:pos+1]
        result.append(tokens_to_string(token_list)) 
        prev_pos = pos+1
    result = '||'.join(result)
    return result

In [130]:
EDU_device = "cpu"

In [None]:

model_path = './DMRST_Parser/depth_mode/Savings/multi_all_checkpoint.torchsave'

edu_parsing_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", use_fast=True)
edu_parsing_model = AutoModel.from_pretrained("xlm-roberta-base").to(device)

model = ParsingNet(edu_parsing_model, bert_tokenizer=edu_parsing_tokenizer, device=EDU_device)

model = model.to(EDU_device)
state_dict = torch.load(model_path, map_location=device)

model.load_state_dict(state_dict) 
model = model.eval()

In [None]:
model = model.to(device)
model.encoder = model.encoder.to(device)

## Make Dataset

In [8]:
import concurrent.futures # 用于并行计算

def process_data(sample, doc_i):
    row_data_dict = {}
    document = sample['document']
    
    # process docs to get doc list
    docs_list = preprocess_docs(document)
    if len(docs_list)==0: # empty docs, return empty dict, and fail Flag
        return row_data_dict, True
    # generate EDUs for each doc
    input_sentences, all_segmentation_pred, all_tree_parsing_pred = inference(model, edu_parsing_tokenizer, docs_list, 1)
    # process EDU results, and back to docs
    all_docs_edus = []
    for doc_sent, doc_seg in zip(input_sentences, all_segmentation_pred):
        doc_edus_splits = split_list_by_positions(doc_sent, doc_seg)

        all_docs_edus.append(doc_edus_splits)

    docs_edus = ' ||||| '.join(all_docs_edus)
    row_data_dict['document'] = docs_edus
    row_data_dict['summary'] = sample['summary']
    
    row_data_dict['parsing'] = all_tree_parsing_pred
    row_data_dict['id'] = doc_i
    return row_data_dict, False

def process_data_only_EDU_breaks(sample, doc_i):
    row_data_dict = {}
    document = sample['document']
    
    # process docs to get doc list
    docs_list = preprocess_docs(document)
    if len(docs_list)==0: # empty docs, return empty dict, and Flag to skip this sample
        return row_data_dict, True
    # generate EDUs for each doc
#     input_sentences, all_segmentation_pred, all_tree_parsing_pred = inference(model, edu_parsing_tokenizer, docs_list, 1)
    input_sentences, all_segmentation_pred = inference_only_EDU_break(model, edu_parsing_tokenizer, docs_list, 1)
    # process EDU results, and back to docs
    all_docs_edus = []
    for doc_sent, doc_seg in zip(input_sentences, all_segmentation_pred):
        doc_edus_splits = split_list_by_positions(doc_sent, doc_seg)

        all_docs_edus.append(doc_edus_splits)

    docs_edus = ' ||||| '.join(all_docs_edus)
    row_data_dict['document'] = docs_edus
    row_data_dict['summary'] = sample['summary']
    
#     row_data_dict['parsing'] = all_tree_parsing_pred
    row_data_dict['id'] = doc_i
    return row_data_dict, False

In [51]:
# 1. train data process
trian_data_list = []

for i in tqdm(range(original_dataset['train'].num_rows)):
    sample = original_dataset['train'][i]
    
    row_data_dict, flag_to_skip = process_data_only_EDU_breaks(sample, i)
    if flag_to_skip:
        print(f"doc id: {i}, empty docs")
        continue
    trian_data_list.append(row_data_dict)



  1%|          | 456/44972 [01:02<1:15:35,  9.82it/s]

doc id: 453, empty docs


  8%|▊         | 3731/44972 [07:57<1:15:21,  9.12it/s]

doc id: 3728, empty docs


 36%|███▌      | 16292/44972 [34:54<31:43, 15.07it/s]  

doc id: 16290, empty docs


 37%|███▋      | 16491/44972 [35:21<1:00:10,  7.89it/s]

doc id: 16489, empty docs


 42%|████▏     | 18811/44972 [40:10<1:20:14,  5.43it/s]

doc id: 18812, empty docs


 43%|████▎     | 19282/44972 [41:06<50:13,  8.52it/s]  

doc id: 19279, empty docs


 48%|████▊     | 21622/44972 [45:53<30:50, 12.62it/s]  

doc id: 21620, empty docs


 68%|██████▊   | 30739/44972 [1:05:06<15:33, 15.24it/s]  

doc id: 30735, empty docs


 93%|█████████▎| 41997/44972 [1:28:55<03:27, 14.32it/s]  

doc id: 41993, empty docs


100%|██████████| 44972/44972 [1:35:16<00:00,  7.87it/s]


In [52]:
# 2. validation data process
validation_data_list = []

for i in tqdm(range(original_dataset['validation'].num_rows)):
    sample = original_dataset['validation'][i]

    row_data_dict, flag_to_skip = process_data_only_EDU_breaks(sample, i)
    if flag_to_skip:
        print(f"doc id: {i}, empty docs")
        continue
    validation_data_list.append(row_data_dict)
    

 86%|████████▋ | 4855/5622 [10:03<01:08, 11.15it/s]  

doc id: 4850, empty docs


100%|██████████| 5622/5622 [11:42<00:00,  8.01it/s]


In [53]:
# 3. test data process
test_data_list = []

for i in tqdm(range(original_dataset['test'].num_rows)):
    sample = original_dataset['test'][i]

    row_data_dict, flag_to_skip = process_data_only_EDU_breaks(sample, i)
    if flag_to_skip:
        print(f"doc id: {i}, empty docs")
        continue
    test_data_list.append(row_data_dict)



 84%|████████▍ | 4738/5622 [10:18<01:35,  9.29it/s]  

doc id: 4736, empty docs


100%|██████████| 5622/5622 [12:14<00:00,  7.66it/s]


In [67]:
## Save processed dict dataset to json form, and save

with open('dataset/my_processed_dataset/trian_data_list.json', 'w') as json_file:
    json.dump(trian_data_list, json_file)

with open('dataset/my_processed_dataset/validation_data_list.json', 'w') as json_file:
    json.dump(validation_data_list, json_file)

with open('dataset/my_processed_dataset/test_data_list.json', 'w') as json_file:
    json.dump(test_data_list, json_file)

## Load dataset and push to hub

In [None]:
my_dataset = load_dataset("json", data_files={'train':"dataset/my_processed_dataset/trian_data_list.json",
                                                  'validation':"dataset/my_processed_dataset/validation_data_list.json",
                                                  'test':"dataset/my_processed_dataset/test_data_list.json"})

In [68]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
my_dataset.push_to_hub("HF-Data-for-Retriever/multi_news")

## Make ground-truth of docs ranking and similarity scores

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer, util

# sent_model = SentenceTransformer('all-mpnet-base-v2', device=device)
sent_model = SentenceTransformer('multi-qa-mpnet-base-cos-v1', device=device)


### Get dataset 

In [146]:
my_dataset = load_dataset('HF-Data-for-Retriever/multi_news')

### Calculate doc similarity

In [64]:
def get_meaningful_words(text):
    punctuations = string.punctuation + "“”‘’—–``'''s..." # punctuations with additional punctuations

    stop_words = set(stopwords.words('english'))

    tokens = word_tokenize(text)

    meaningful_words = [word for word in tokens if word.lower() not in stop_words and word.lower() not in punctuations]

    return meaningful_words



def chunk_text_with_slide(text, chunk_size=512, window_size=200):
    assume_ratio_to_token_num = 1.5
    words = text.split(' ')
    chunk_word_num = int(chunk_size//assume_ratio_to_token_num)
    window_word_size = int(window_size//assume_ratio_to_token_num)
    text_chunks = []
    chunk_assume_num = int((len(words)-chunk_word_num)//window_word_size + 1)
    for i in range(chunk_assume_num): # window chunk 构建
        chunk_start_id = i*window_word_size
        text_chunks.append(" ".join(words[chunk_start_id:chunk_start_id+chunk_word_num]))
    text_chunks.append(" ".join(words[-chunk_size:])) # 最后的文本作为chunk
    return text_chunks

In [7]:


def sent_transform_sim(seq_list, summary, model=None):
    results = []
    for seq in seq_list:
        embedding1 = model.encode(seq, convert_to_tensor=True)
        embedding2 = model.encode(summary, convert_to_tensor=True)

        cosine_sim = util.pytorch_cos_sim(embedding1, embedding2)
        results.append(cosine_sim.item())
    return results
    

### Update dataset

In [72]:

def add_new_column(sample):
    sample['doc_len'] = len(sample['document'])
    return sample



def add_new_column_doc_sim(sample): # rough_word_num
    doc_list = sample['document'].split('|||||')
    summary = sample['summary']
    
    doc_clean_list = []
    doc_edus_list = []
    doc_tokens_list = []
    for i, doc in enumerate(doc_list):
        doc = doc.lstrip()
        edus = doc.split('||')
        doc_edus_list.append(edus)
        clean_doc = "".join(edus)
        doc_clean_list.append(clean_doc)
        doc_tokens_list.append(tokenizer.encode(clean_doc, add_special_tokens=False))
    
    sample['doc_token_num'] = [] # calculate the token number of each doc, need tokenize first, time consuming
    sample['doc_rough_word_num'] = [] # calculate the rough word num of each doc, split by 'space'
    for i in range(len(doc_clean_list)):
        clean_doc = doc_clean_list[i]
        doc_edus = doc_edus_list[i]
        doc_tokens = doc_tokens_list[i]
        rough_word_num = len(clean_doc.split(' '))
        sample['doc_rough_word_num'].append(rough_word_num)
        sample['doc_token_num'].append(len(doc_tokens))
    return sample



def update_doc_sim(sample):
    doc_list = sample['document'].split('|||||')
    summary = sample['summary']
    
    doc_clean_list = []
    doc_edus_list = []
    for i, doc in enumerate(doc_list):
        doc = doc.lstrip()
        edus = doc.split('||')
        doc_edus_list.append(edus)
        clean_doc = "".join(edus)
        doc_clean_list.append(clean_doc)
    
    
    sample['sent_trans_doc_score'] = sent_transform_sim(doc_clean_list, summary, sent_model)
    sample['sent_trans_doc_edu_score'] = [sent_transform_sim(doc_edus_list[i], summary, sent_model) for i in range(len(doc_edus_list))]
    
    return sample

In [None]:
# my_dataset_new_train = my_dataset['train'].map(add_new_column_doc_sim)
# my_dataset_new_validation = my_dataset['validation'].map(add_new_column_doc_sim)
# my_dataset_new_test = my_dataset['test'].map(add_new_column_doc_sim)

my_dataset_new_train = my_dataset['train'].map(update_doc_sim)
my_dataset_new_validation = my_dataset['validation'].map(update_doc_sim)
my_dataset_new_test = my_dataset['test'].map(update_doc_sim)

In [88]:
my_dataset_new = DatasetDict({'train':my_dataset_new_train,
                              'validation': my_dataset_new_validation, 
                              'test': my_dataset_new_test})

## login huggingface and push data to hub

In [None]:
from huggingface_hub import login ## -> Need your own Huggingface token!
login()

In [None]:
my_dataset_new.push_to_hub("HF-Data-for-Retriever/multi_news")