In [1]:
from transformers import AutoTokenizer, AutoModel
import jsonlines
from nltk.tokenize import sent_tokenize
from tqdm import tqdm
import torch
import torch.nn.functional as F
import os
import copy
import json
import pickle
import numpy as np
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
from transformers.models.bart.modeling_bart import shift_tokens_right

os.environ['CUDA_VISIBLE_DEVICES'] = '7'
def file_reader(file_path):
    with jsonlines.open(file_path, 'r') as reader:
        src = []
        tgt = []
        for obj in reader:
            src.append(obj['src'])
            tgt.append(obj['tgt'])
    return src, tgt

def get_query(input):
    query = []
    src = []
    for item in input:
        temp = item.split('</s>')
        query.append(temp[0].replace('<s>','').strip())
        src.append(temp[1].replace('<s>','').strip())
    return query, src

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
#Encode text
def encode(texts):
    # Tokenize sentences
    encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
    encoded_input['input_ids'] = encoded_input['input_ids'].cuda()
    encoded_input['token_type_ids'] = encoded_input['token_type_ids'].cuda()
    encoded_input['attention_mask'] = encoded_input['attention_mask'].cuda()
    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input, return_dict=True)
    # Perform pooling
    embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    # Normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    return embeddings

# segment 文章到paragraph中

In [None]:
src, tgt = file_reader('processed_data_turn_split/train.jsonl')
querys, src = get_query(src)
max_input_len = 512
number_of_segment = 16
src = [i.split('<turn_seperator>') for i in src]

segmented_input = []
non_segmented_input = []

for i in tqdm(range(len(src))):
    seg_counter = 0
    doc = src[i]
    query = querys[i]
    this_doc = []
    this_seg = query + '</s></s>'
    counter = len(tokenizer.tokenize(this_seg,add_special_tokens=False))
    for sent in doc:
        length = len(tokenizer.tokenize(sent,add_special_tokens=False))
        if counter + length < max_input_len:
            this_seg = this_seg + ' ' + sent
            counter += length
        else:
            speaker = sent.split(':')[0]
            for subsent in sent_tokenize(sent):
                length = len(tokenizer.tokenize(subsent,add_special_tokens=False))
                if counter + length < max_input_len:
                    this_seg = this_seg + ' ' + subsent
                    counter += length
                else:
                    # print(subsent)
                    if speaker not in subsent:
                        subsent = speaker + ': ' + subsent
                    this_doc.append(this_seg.split('</s></s>')[-1])
                    this_seg = query + '</s></s>'+ subsent
                    counter = len(tokenizer.tokenize(this_seg, add_special_tokens=False))
    this_doc.append(this_seg.split('</s></s>')[-1])
    non_segmented_input.append([query + '</s></s>' + doc for doc in this_doc])
    torch.cuda.empty_cache()

: 

In [None]:
# with open( 'processed_data_no_sent_split/test_sort_16.pkl', 'wb') as f:
#     pickle.dump(segmented_input, f)
with open('processed_data_turn_split/train_all_seg.pkl', 'wb') as f:
    pickle.dump(non_segmented_input, f)

# 只取前16段

In [4]:
with open('processed_data_turn_split/val_all_seg.pkl', 'rb') as f:
    segmented_input = pickle.load(f)

number_of_segment = 4
new_segmented_input = []

for doc in segmented_input:
    if len(doc) < number_of_segment:
        new_segmented_input.append(doc + ['' for i in range(number_of_segment - len(doc))])
    else:
        new_segmented_input.append(doc[:number_of_segment])

with open('processed_data_turn_split/val_4_no_sort.pkl', 'wb') as f:
    pickle.dump(new_segmented_input, f)


# 只用knowledge，并且重新排序

In [16]:

from openie import StanfordOpenIE
from nltk.corpus import stopwords

properties = {
    'openie.affinity_probability_cap': 2 / 3,
}


In [None]:
with open( 'processed_data_turn_split/train_all_seg.pkl', 'rb') as f:
    segmented_input = pickle.load(f)

with StanfordOpenIE(properties=properties) as client:
    new_segmented_input = []
    for number, one_segmented_input in tqdm(enumerate(segmented_input)):
        query_without_stopwords = " ".join([item for item in one_segmented_input[0].split('</s></s>')[0].split() if item not in stopwords.words()])
        print(query_without_stopwords)
        print(len(one_segmented_input))
        this_doc_dic = {}
        for i,item in enumerate(segmented_input[number]):
            origin_item = copy.deepcopy(item)
            item = item.split('</s></s>')[-1]
            counter = 0
            triples = []
            for triple in client.annotate(item):
                for word in triple['subject'].split():
                    if word in query_without_stopwords and word not in stopwords.words():
                        counter += 1
                        print(triple, query_without_stopwords)
                        triples.append(triple)
                for word in triple['object'].split():
                    if word in query_without_stopwords and word not in stopwords.words():
                        counter += 1
                        triples.append(triple)
                        print(triple, query_without_stopwords)
            this_doc_dic[origin_item] = [triples, counter]
        new_segmented_input.append(this_doc_dic)

# Single rerank 处理重新排序的结果 (rerank) 

In [11]:
def triple2text(triples):
    output = ""
    for triple in triples:
        this_triple_text = 'subject: ' + triple['subject'] + ' relation: ' + triple['relation'] +  ' object: ' + triple['object'] + '</s>'
        output += this_triple_text
    return output
    
with open( 'processed_data_turn_split/test_all_seg_with_triple.pkl', 'rb') as f:
    data = pickle.load(f)

new_seg_input = []
triples = []
number_of_segment = 16
for doc in data:
    new = {k: v for k, v in sorted(doc.items(), key=lambda item: item[1][1], reverse=True)}
    new_seg_list = list(new.keys())
    # print(new[new_seg_list[0]][0])
    # print(triple2text(new[new_seg_list[0]][0]))
    # break
    if len(new_seg_list) < number_of_segment:
        new_seg_input.append(new_seg_list + ['' for i in range(number_of_segment - len(new_seg_list))])
        triples.append([triple2text(new[new_seg_list[i]][0]) for i in range(len(new_seg_list))] + ['' for i in range(number_of_segment - len(new_seg_list))])
    else:
        new_seg_input.append(new_seg_list[:number_of_segment])
        triples.append([triple2text(new[new_seg_list[i]][0]) for i in range(number_of_segment)])

# with open('processed_data_turn_split/test_16_rerank.pkl', 'wb') as f:
#     pickle.dump(new_seg_input, f)

# with open('processed_data_turn_split/test_triple.pkl', 'wb') as f:
#     pickle.dump(triples, f)

# Duals rerank 处理重新排序的结果 (QA + Knowledge) 

In [3]:
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
model = AutoModel.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1").cuda()

In [None]:
from nltk.corpus import stopwords
out_words = {}
for word in stopwords.words():
    out_words[word] = 1


def triple2text(triples):
    words = []
    for triple in triples:
        for item in triple['subject'].split():
            if item not in words:
                if item.isalpha():
                    if item not in stopwords.words():
                        words.append(item)
    for triple in triples:
        for item in triple['relation'].split():
            if item not in words:
                if item.isalpha():
                    if item not in stopwords.words():
                        words.append(item)
    for triple in triples:
        for item in triple['object'].split():
            if item not in words:
                if item.isalpha():
                    if item not in stopwords.words():
                        words.append(item)
    output = "</s>".join(words)
        # this_triple_text = triple['subject'] + '</s>' + triple['relation'] +  '</s>' + triple['object'] + '</s>'
        # output += this_triple_text
    return output
    
with open('processed_data_turn_split/train_all_seg_with_triple.pkl', 'rb') as f:
    data = pickle.load(f)

new_seg_input = []
triples = []
number_of_segment = 4

for doc in tqdm(data):
    keys = list(doc.keys())
    query = keys[0].split('</s></s>')[0]
    this_doc = [item.split('</s></s>')[-1] for item in keys]
    
    #Encode query and docs
    query_emb = encode(query)
    doc_emb = encode(this_doc)
    #Compute dot score between query and all document embeddings
    scores = torch.mm(query_emb, doc_emb.transpose(0, 1))[0].cpu().tolist()
    knowledge_score = []
    for key in keys:
        knowledge_score.append(doc[key][1])
    knowledge_score = knowledge_score / np.linalg.norm(knowledge_score)
    merge_score = []
    for i in range(len(knowledge_score)):
        merge_score.append(knowledge_score[i] + scores[i])
    
    for i in range(len(keys)):
        doc[keys[i]].append(merge_score[i])

    new = {k: v for k, v in sorted(doc.items(), key=lambda item: item[1][-1], reverse=True)}
    new_seg_list = list(new.keys())
    
    if len(new_seg_list) < number_of_segment:
        new_seg_input.append(new_seg_list + ['' for i in range(number_of_segment - len(new_seg_list))])
        triples.append([triple2text(new[new_seg_list[i]][0]) for i in range(len(new_seg_list))] + ['' for i in range(number_of_segment - len(new_seg_list))])
    else:
        new_seg_input.append(new_seg_list[:number_of_segment])
        triples.append([triple2text(new[new_seg_list[i]][0]) for i in range(number_of_segment)])

with open('processed_data_turn_split/train_4_rerank.pkl', 'wb') as f:
    pickle.dump(new_seg_input, f)

with open('processed_data_turn_split/train_4_triple.pkl', 'wb') as f:
    pickle.dump(triples, f)


 34%|███▍      | 430/1257 [17:20<33:20,  2.42s/it]  


KeyboardInterrupt: 

# 从两个分数（QA和knowledge）中得到最终的排序结果

In [50]:
import pickle
from nltk.corpus import stopwords
out_words = {}
for word in stopwords.words():
    out_words[word] = 1

def triple2text(triples):
    words = []
    for triple in triples:
        for item in triple['subject'].split():
            if item not in words:
                if item.isalpha():
                    if item not in stopwords.words():
                        words.append(item)
    for triple in triples:
        for item in triple['relation'].split():
            if item not in words:
                if item.isalpha():
                    if item not in stopwords.words():
                        words.append(item)
    for triple in triples:
        for item in triple['object'].split():
            if item not in words:
                if item.isalpha():
                    if item not in stopwords.words():
                        words.append(item)
    output = "</s>".join(words)
        # this_triple_text = triple['subject'] + '</s>' + triple['relation'] +  '</s>' + triple['object'] + '</s>'
        # output += this_triple_text
    return output

with open('processed_data_turn_split/test_all_seg_with_triple.pkl', 'rb') as f:
    data = pickle.load(f)
with open('processed_data_turn_split/test_rankscores.pkl', 'rb') as f:
    qa_scores = pickle.load(f)

In [52]:
for i in tqdm(range(len(data))):
    keys = list(data[i].keys())
    a = qa_scores[i][0]
    this_qascores = [(float(i)-min(a))/(max(a)-min(a)) for i in a]
    ka_scores = []
    for key in keys:
        ka_scores.append(data[i][key][-1])
    a = ka_scores
    if max(a)-min(a) == 0:
        pass
    else:    
        ka_scores = [(float(i)-min(a))/(max(a)-min(a)) for i in a]
    all_scores = []
    for j in range(len(ka_scores)):
        all_scores.append(ka_scores[j]+this_qascores[j])
    print(all_scores)
    print(ka_scores)
    break
    for j in range(len(keys)):
        this_seg_info = {}
        this_seg_info['triples'] = data[i][keys[j]]
        this_seg_info['qa_score'] = this_qascores[j]
        this_seg_info['ka_score'] = ka_scores[j]
        this_seg_info['all_scores'] = all_scores[j]
        # key_phrase = triple2text(data[i][keys[j]][0])
        # this_seg_info['key_phrase'] = key_phrase
        data[i][keys[j]] = this_seg_info

  0%|          | 0/281 [00:00<?, ?it/s]

[0.39629098796272816, 0.0, 0.5004784694158335, 0.4339330207563356, 0.8503727076038221, 0.5425439163487842, 0.7586785438731851, 0.7870905196367776, 0.5776469966170464, 0.4896485936732833, 0.41190044296728734, 0.2466511403305259, 0.34279931585563905, 0.6003288209300695, 0.6032938205793645, 0.817669441731727, 0.1701137659773067, 0.4440076183898867, 0.6683264685093961, 0.3689582337421294, 0.44705734834710076, 0.5836230515644137, 0.2612245968242321, 0.20285591962922397, 0.7000073162441122, 0.4353438830931291, 0.36996960197689666, 0.3161959846984747, 2.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.18181818181818182, 0.0, 0.0, 1.0]





[0.39629098796272816, 0.0, 0.5004784694158335, 0.4339330207563356, 0.8503727076038221, 0.5425439163487842, 0.7586785438731851, 0.7870905196367776, 0.5776469966170464, 0.4896485936732833, 0.41190044296728734, 0.2466511403305259, 0.34279931585563905, 0.6003288209300695, 0.6032938205793645, 0.817669441731727, 0.1701137659773067, 0.4440076183898867, 0.6683264685093961, 0.3689582337421294, 0.44705734834710076, 0.5836230515644137, 0.2612245968242321, 0.20285591962922397, 0.7000073162441122, 0.2535257012749473, 0.36996960197689666, 0.3161959846984747, 1.0]


In [16]:
all_keywords = []
for doc in data:
    keywords = []
    for seg_results in doc:
        keywords += seg_results.split('</s>')
    keywords = list(set(keywords))
    keywords = [item for item in keywords if len(item) > 2]
    all_keywords.append(keywords)

In [2]:
src, tgt = file_reader('processed_data_turn_split/train.jsonl')
querys, src = get_query(src)

FileNotFoundError: [Errno 2] No such file or directory: 'processed_data_turn_split/test_rankscores.jsonl'

In [17]:
counter = 0
overlap = 0
for i in range(len(all_keywords)):
    counter += len(all_keywords[i])
    for word in all_keywords[i]:
        if word in tgt[i]:
            overlap += 1
print(overlap/counter)

0.14826708439396907


In [60]:
with open('processed_data_turn_split/test_8_dual_rank.pkl', 'rb') as f:
    data = pickle.load(f)

In [63]:
print(data[0][:2])

["summarize the whole meeting .</s></s>barry hughes: it 's a potential barrier , but i do n't think it is a barrier . there was a shortage of registered intermediaries in wales , and i know that the ministry of justice have taken action to deal with that , and we have had a number of people who are now in a position to act as intermediaries . now , of course , if they were to decide not to do that anymore , we may have a problem , but , in turn , we would be looking to recruit more people into those positions . so , yes , it has the potential to serve as a barrier , but in practice , i do n't think it would be a barrier . i think , particularly given the very low numbers we 're talking about , we would be able to manage it . i 've got no significant concerns , i have to say . lynne neagle am: thank you . well , we 've come to the end of our time . can i thank you for attending , the three of you , and for your answers , which have been fascinating and very clear and most helpful to the