In [6]:
# import pandas as pd
import csv
import numpy as np
import os
import copy
from os.path import join as pjoin
from glob import iglob

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

In [17]:
#!/usr/bin/env python
from __future__ import division

import argparse
import glob
import os
import random
import signal
import time

import torch

model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers', 'encoder', 'ff_actv', 'use_interval', 'rnn_size']


In [18]:
root_path = '/data/ksb/'
bert_root_path = pjoin(root_path, 'BertSum/PreSumm')
bert_model_dir = pjoin(bert_root_path, 'models')

data_dir = pjoin(root_path, 'cnn-dailymail/finished_files')

three_data_dir = pjoin(root_path, 'three-mat')
three_data_test = pjoin(three_data_dir, 'test')

#### Loss function 비교  

*Trained Model parameter 필요*


In [19]:
def get_cos_similarity(inputs, summaries):
    tfidf_vectorizer = TfidfVectorizer()

    cos_similarity_list = []
    for input_, summary_ in zip(inputs, summaries):
        try:
            tfidf_matrix = tfidf_vectorizer.fit_transform([input_, summary_])

            similarity = cosine_similarity(tfidf_matrix[0] , tfidf_matrix[1])[0][0]
        except ValueError:
            similarity = 0.0
            
        cos_similarity_list.append(similarity)

    return cos_similarity_list

In [20]:
import jsonlines
import json

data_list = []
for data_p in iglob(pjoin(three_data_test, '**.json'), recursive=False):
    
    with open(data_p,'r',encoding='utf-8') as f:
        data = json.load(f)
        data_list.append(data)

### Origin candidate set

In [21]:
from rouge import Rouge
rouge = Rouge()

In [31]:
from transformers import BertTokenizer

tok = BertTokenizer.from_pretrained('bert-base-uncased', verbose=False)

def bert_encode(x, max_len=-1):
    cls_token_id = tok.cls_token_id
    sep_token_id = tok.sep_token_id

    _ids = tok.encode(x, add_special_tokens=False)
    ids = [cls_token_id] # [CLS]
    if max_len > 0:
        ids.extend(_ids[:max_len - 2])
    else:
        ids.extend(_ids[:512 - 2])
    ids.append(sep_token_id) # [SEP], meaning end of sentence
    return ids

def bert_decode(x):
    result = tok.decode(x, skip_special_tokens=True)
    return result

In [23]:
def detect_trigram(src, tgt):
    assert len(tgt) > 2 and len(src) > 2
        
    tgt_trigrams = [(tgt[i-1],tgt[i],tgt[i+1]) for i in range(1,len(tgt)-1)]
    src_trigrams = [(src[i-1],src[i],src[i+1]) for i in range(1,len(src)-1)]
    
    for src_tri in src_trigrams:
        if src_tri in tgt_trigrams:
            return True ## Detect trigram overlapped with target
        
    return False
    

In [24]:
def detect_4_gram(src, tgt):
    assert len(tgt) > 3 and len(src) > 3
    tgt_4_grams = [(tgt[i-2], tgt[i-1],tgt[i],tgt[i+1]) for i in range(2,len(tgt)-1)]
    src_4_grams = [(src[i-2], src[i-1],src[i],src[i+1]) for i in range(2,len(src)-1)]
    
    for src_gram in src_4_grams:
        if src_gram in tgt_4_grams:
            return True ## Detect 4-gram overlapped with target
        
    return False
    

In [25]:
def detect_5_gram(src, tgt):
    assert len(tgt) > 4 and len(src) > 4
    tgt_5_grams = [(tgt[i-2], tgt[i-1],tgt[i],tgt[i+1], tgt[i+2]) for i in range(2,len(tgt)-2)]
    src_5_grams = [(src[i-2], src[i-1],src[i],src[i+1], src[i+2]) for i in range(2,len(src)-2)]
    
    for src_gram in src_5_grams:
        if src_gram in tgt_5_grams:
            return True ## Detect 5-gram overlapped with target
        
    return False

In [26]:
def detect_ngram_list(src, tgt_list, n_gram='trigram'):
    
    if n_gram =='trigram':
        return sum([detect_trigram(src, tgt) for tgt in tgt_list]) > 0
    elif n_gram =='4-gram':
        return sum([detect_4_gram(src, tgt) for tgt in tgt_list]) > 0
    else :
        return sum([detect_4_gram(src, tgt) for tgt in tgt_list])>0

In [27]:
def get_candidate_set(sent_set, reference=None, n_gram='trigram'):
    
    assert n_gram in ['trigram', '4-gram','5-gram']
        
    if n_gram == 'trigram':
        detect_ngram = detect_trigram
    elif n_gram == '4-gram':
        detect_ngram = detect_4_gram
    else:
        detect_ngram = detect_5_gram
        
    
    possible_set_ids = []
    
    for sent_id, sent in sent_set:
        possible_2_sent_idx = []
        
        # number of summary sentences = 2
        for tgt_sent_id, tgt_sent in sent_set[sent_id+1:]:
            
            # Detect n-gram (default= trigram) 
            if not detect_ngram(src=sent, tgt=tgt_sent):
                possible_2_sent_idx.append(set([sent_id, tgt_sent_id]))
                
        possible_3_sent_idx = copy.deepcopy(possible_2_sent_idx)
        
        # number of summary sentences = 3
        for tgt_sent_id, tgt_sent in sent_set[sent_id+1:]:
            for poss_sent_ids in possible_2_sent_idx:
                
                poss_sent = [sent_set[ids][1] for ids in poss_sent_ids]
                if not detect_ngram_list(src=tgt_sent, tgt_list=poss_sent, n_gram=n_gram):
                    poss_3_ids = copy.deepcopy(poss_sent_ids)
                    poss_3_ids.add(tgt_sent_id)
                    
                    possible_3_sent_idx.append(poss_3_ids)
                    

        possible_sent_idx = possible_2_sent_idx + possible_3_sent_idx
        
        for ids in possible_sent_idx:
            if not ids in possible_set_ids:
                possible_set_ids.append(ids)

    return possible_set_ids
        

In [28]:
import pylcs

def compute_txt_redundancy_score(candidate_id):

    cand_num = len(candidate_id)
    
    score = torch.zeros([cand_num], dtype=torch.float64)
        
    def _compute_redundancy(cand):
        redundancy = 0.0
        
        for i, src_sen in enumerate(cand):
            for j, tgt_sen in enumerate(cand[i+1:]):
                if i != j:
                    lcs_val = pylcs.lcs(src_sen, tgt_sen)
                    redundancy += lcs_val 
        
        sents_len = sum([len(s) for sents in cand for s in sents])
        return redundancy / sents_len

    for i in range(cand_num):
        score[i] = np.mean(_compute_redundancy(candidate_id[i]))

    return score

In [29]:
new_data_path = pjoin(three_data_dir,'reconstructed_test.jsonl')

In [40]:

with open(new_data_path, 'w', encoding='utf-8') as f:
    writer = jsonlines.Writer(f)
    
    for data in data_list:
        candidates = data['candidates']
        article = data['article']
        abstract = data["abstract"]


        summaries = [cand[0] for cand in candidates]    
        encoded_cand_set = [[bert_encode(s, 180) for s in cs] for cs in summaries]
        threshold = min([cand[1] for cand in candidates]) 

        sent_set = []

        for i, encoded_cand in enumerate(encoded_cand_set):
            for j, encoded_sent in enumerate(encoded_cand):

                sent_id = sum([len(prev) for prev in encoded_cand_set[:i]])+j
                sent_set.append((sent_id, encoded_sent))

        reduced_cand_ids = get_candidate_set(sent_set)
        reduced_cand_sents = [[sent_set[i][1] for i in ids] for ids in reduced_cand_ids]
        reduced_cand_set_dec = [[bert_decode(x) for x in cand] for cand in reduced_cand_sents]


        # Drop candidate which has lower score than threshold
        rouge_cands_set = []
        for c in reduced_cand_set_dec:
            scores = rouge.get_scores(' '.join(abstract), ' '.join(c))[0]
            score = scores['rouge-l']['f']

            rouge_cands_set.append((score, c))

        rouge_cands_set = sorted(rouge_cands_set, key=lambda x: x[0], reverse=True)
        reconstructed_candidates = [s[-1] for s in rouge_cands_set[:20]]
        
        origin_cand_rouge = [round(rouge.get_scores(' '.join(abstract), ' '.join(cand[0]))[0]['rouge-l']['f'],4) for cand in candidates]
        new_cand_rouge = [round(rouge.get_scores(' '.join(abstract), ' '.join(cand))[0]['rouge-l']['f'],4) for cand in reconstructed_candidates]
        
        print("Origin ROUGE : {}".format(origin_cand_rouge))
        print("New Candidate ROUGE : {}\n".format(new_cand_rouge))
        
        if max(new_cand_rouge) <= max(origin_cand_rouge):
            origin_cand_id = np.argmax(origin_cand_rouge)
            origin_cand_redun = compute_txt_redundancy_score([cand[0] for cand in candidates]).item()
            reconst_cand_redun = compute_txt_redundancy_score([cand for cand in reconstructed_candidates]).item()
            
            print(origin_cand_redun)
            print(reconst_cand_redun)
            assert origin_cand_redun[origin_cand_id] > max(reconst_cand_redun)
            
        
        new_data = {'article':article, 'candidates':candidates, 'abstract':abstract, 'new_candidates': reconstructed_candidates}
        writer.write(new_data)

Origin ROUGE : [0.4474, 0.5116, 0.4478]
New Candidate ROUGE : [0.5333, 0.5333, 0.5333, 0.5128, 0.5128, 0.5128, 0.48, 0.48, 0.4762, 0.4762, 0.4762, 0.4687, 0.4561, 0.4561, 0.4561, 0.4533, 0.4533, 0.4533, 0.4533, 0.4478]

Origin ROUGE : [0.4138, 0.3373, 0.375]
New Candidate ROUGE : [0.4211, 0.4186, 0.4051, 0.4045, 0.3918, 0.3908, 0.38, 0.3778, 0.3678, 0.3556, 0.3556, 0.3441, 0.34, 0.3301, 0.3297, 0.3288, 0.321, 0.3191, 0.3158, 0.3143]

Origin ROUGE : [0.3208, 0.3368, 0.2936]
New Candidate ROUGE : [0.3956, 0.3956, 0.3908, 0.3908, 0.3846, 0.3846, 0.3488, 0.3368, 0.3368, 0.3368, 0.3368, 0.3299, 0.3299, 0.3297, 0.3297, 0.3243, 0.3226, 0.3226, 0.321, 0.321]

Origin ROUGE : [0.4286, 0.4737, 0.4717]
New Candidate ROUGE : [0.5306, 0.52, 0.4884, 0.48, 0.4792, 0.4762, 0.4752, 0.4752, 0.4694, 0.4694, 0.4646, 0.4646, 0.4598, 0.4583, 0.4565, 0.4471, 0.4444, 0.4444, 0.44, 0.4356]

Origin ROUGE : [0.1791, 0.3023, 0.1724]
New Candidate ROUGE : [0.5306, 0.4444, 0.4444, 0.4262, 0.4211, 0.3939, 0.3939, 0.3

ValueError: only one element tensors can be converted to Python scalars

In [None]:
print("Origin Redundancy score : {}".format(round(np.mean(origin_redun), 4)))
print("Origin cosine similarity between document and summaries : {}".format(round(np.mean(origin_doc_sims), 4)))
print("Origin ROUGE score between reference and summaries : {}".format(round(np.mean(origin_ref_rouges), 4)))

In [None]:
print("Origin Redundancy score : {}".format(round(np.mean(refine_redun), 4)))
print("Origin cosine similarity between document and summaries : {}".format(round(np.mean(refine_doc_sims), 4)))
print("Origin ROUGE score between reference and summaries : {}".format(round(np.mean(refine_ref_rouges), 4)))