# QA and knowledge triple in ranking

In [5]:
from transformers import AutoTokenizer, AutoModel
import jsonlines
from nltk.tokenize import sent_tokenize
from tqdm import tqdm
import torch
import torch.nn.functional as F
import os
import copy
import json
import pickle

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def file_reader(file_path):
    with jsonlines.open(file_path, 'r') as reader:
        src = []
        tgt = []
        for obj in reader:
            src.append(obj['src'])
            tgt.append(obj['tgt'])
    return src, tgt

def get_query(input):
    query = []
    src = []
    for item in input:
        temp = item.split('</s>')
        query.append(temp[0].replace('<s>','').strip())
        src.append(temp[1].replace('<s>','').strip())
    return query, src

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

#Encode text
def encode(texts):
    # Tokenize sentences
    encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
    encoded_input['input_ids'] = encoded_input['input_ids'].cuda()
    encoded_input['token_type_ids'] = encoded_input['token_type_ids'].cuda()
    encoded_input['attention_mask'] = encoded_input['attention_mask'].cuda()
    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input, return_dict=True)
    # Perform pooling
    embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    # Normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    return embeddings

In [1]:
import torch

In [21]:
# Load model from HuggingFace Hub
# tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
# model = AutoModel.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1").cuda()
from transformers.models.bart.modeling_bart import shift_tokens_right

In [27]:
src, tgt = file_reader('processed_data_no_sent_split/test.jsonl')
querys, src = get_query(src)
max_input_len = 512
number_of_segment = 16
src = [sent_tokenize(i) for i in src]
segmented_input = []
non_segmented_input = []

for i in tqdm(range(len(src))):
    seg_counter = 0
    doc = src[i]
    query = querys[i]
    this_doc = []
    this_seg = query + '</s></s>'
    counter = len(tokenizer.tokenize(this_seg,add_special_tokens=False))
    for sent in doc:
        length = len(tokenizer.tokenize(sent,add_special_tokens=False))
        if counter + length < max_input_len:
            this_seg = this_seg + ' ' + sent
            counter += length
        else:
            this_doc.append(this_seg.split('</s></s>')[-1])
            this_seg = query + '</s></s>'+ sent
            counter = len(tokenizer.tokenize(this_seg, add_special_tokens=False))
    if len(tokenizer.tokenize(this_seg, add_special_tokens=False)) > 256:
        this_doc.append(this_seg.split('</s></s>')[-1])

    #Encode query and docs
    query_emb = encode(query)
    doc_emb = encode(this_doc)
    #Compute dot score between query and all document embeddings
    scores = torch.mm(query_emb, doc_emb.transpose(0, 1))[0].cpu().tolist()
    #Combine docs & scores
    doc_score_pairs = list(zip(this_doc, scores))
    original_doc_score_pairs = copy.deepcopy(doc_score_pairs)
    #Sort by decreasing score
    doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
    if len(doc_score_pairs) <= number_of_segment:
        segmented_input.append([query + '</s></s>' + doc for doc, score in doc_score_pairs] + ['' for i in range(number_of_segment - len(doc_score_pairs))])
    else:
        final_doc = []
        for doc, score in doc_score_pairs:
            if seg_counter >= number_of_segment:
                break
            final_doc.append(query + '</s></s>' + doc)
            seg_counter += 1
        segmented_input.append(final_doc)
    non_segmented_input.append([query + '</s></s>' + doc for doc, score in original_doc_score_pairs])
    torch.cuda.empty_cache()

  0%|          | 0/281 [00:00<?, ?it/s]


KeyError: 'token_type_ids'

In [8]:
# with open( 'processed_data_no_sent_split/test_sort_16.pkl', 'wb') as f:
#     pickle.dump(segmented_input, f)
# with open( 'processed_data_no_sent_split/test_all_seg.pkl', 'wb') as f:
#     pickle.dump(non_segmented_input, f)

# 只用knowledge，并且重新排序

In [None]:
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")

In [3]:

from openie import StanfordOpenIE
from nltk.corpus import stopwords

properties = {
    'openie.affinity_probability_cap': 2 / 3,
}


In [7]:
with open( 'processed_data_no_sent_split/train_all_seg.pkl', 'rb') as f:
    segmented_input = pickle.load(f)

with StanfordOpenIE(properties=properties) as client:
    new_segmented_input = []
    for number, one_segmented_input in tqdm(enumerate(segmented_input)):
        query_without_stopwords = " ".join([item for item in one_segmented_input[0].split('</s></s>')[0].split() if item not in stopwords.words()])
        print(query_without_stopwords)
        print(len(one_segmented_input))
        this_doc_dic = {}
        for i,item in enumerate(segmented_input[number]):
            origin_item = copy.deepcopy(item)
            item = item.split('</s></s>')[-1]
            counter = 0
            triples = []
            for triple in client.annotate(item):
                for word in triple['subject'].split():
                    if word in query_without_stopwords and word not in stopwords.words():
                        counter += 1
                        print(triple, query_without_stopwords)
                        triples.append(triple)
                for word in triple['object'].split():
                    if word in query_without_stopwords and word not in stopwords.words():
                        counter += 1
                        triples.append(triple)
                        print(triple, query_without_stopwords)
            this_doc_dic[origin_item] = [triples, counter]
        new_segmented_input.append(this_doc_dic)

0it [00:00, ?it/s]

summarize meeting .
13
Starting server with command: java -Xmx8G -cp /home/tiezheng/.stanfordnlp_resources/stanford-corenlp-4.1.0/* edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9000 -timeout 60000 -threads 5 -maxCharLength 100000 -quiet True -serverProperties corenlp_server-4ab125a500644666.props -preload openie
{'subject': 'this', 'relation': 'is', 'object': 'our last meeting'} summarize meeting .
{'subject': 'i', 'relation': 'go from', 'object': 'meeting'} summarize meeting .
{'subject': 'i', 'relation': 'go from', 'object': 'previous meeting'} summarize meeting .
{'subject': "'s", 'relation': 'see', 'object': 'minutes from last meeting'} summarize meeting .
{'subject': "'s", 'relation': 'see', 'object': 'minutes from meeting'} summarize meeting .


1it [00:06,  6.91s/it]

{'subject': 'i', 'relation': 'have do for', 'object': 'last meeting'} summarize meeting .
{'subject': 'i', 'relation': 'still have do for', 'object': 'last meeting'} summarize meeting .
{'subject': 'i', 'relation': 'still have do for', 'object': 'meeting'} summarize meeting .
{'subject': 'i', 'relation': 'have do for', 'object': 'meeting'} summarize meeting .
project manager user interface introduce prototype remote control ?
13
{'subject': 'we', 'relation': "'ll have", 'object': 'prototype presentation'} project manager user interface introduce prototype remote control ?
{'subject': 'we', 'relation': 'were talking', 'object': 'about trying incorporate into our prototype'} project manager user interface introduce prototype remote control ?
{'subject': 'trying', 'relation': 'incorporate into', 'object': 'our prototype'} project manager user interface introduce prototype remote control ?
{'subject': 'we', 'relation': 'so were talking', 'object': 'about trying incorporate into our prototy

2it [00:14,  7.07s/it]

{'subject': 'i', 'relation': "'m", 'object': 'project manager'} project manager user interface introduce prototype remote control ?
{'subject': 'i', 'relation': "'m", 'object': 'project manager'} project manager user interface introduce prototype remote control ?
{'subject': 'you', 'relation': "'ve know", 'object': 'user interface'} project manager user interface introduce prototype remote control ?
{'subject': 'you', 'relation': "'ve know", 'object': 'user interface'} project manager user interface introduce prototype remote control ?
{'subject': 'they', 'relation': "'ve just know", 'object': 'user interface'} project manager user interface introduce prototype remote control ?
{'subject': 'they', 'relation': "'ve just know", 'object': 'user interface'} project manager user interface introduce prototype remote control ?
{'subject': 'they', 'relation': "'ve know", 'object': 'user interface'} project manager user interface introduce prototype remote control ?
{'subject': 'they', 'relatio

2it [00:20, 10.21s/it]


# 处理重新排序的结果 (rerank)

In [28]:
def triple2text(triples):
    output = ""
    for triple in triples:
        this_triple_text = 'subject: ' + triple['subject'] + ' relation: ' + triple['relation'] +  ' object: ' + triple['object'] + '</s>'
        output += this_triple_text
    return output
    
with open( 'processed_data_no_sent_split/val_all_seg_with_triple.pkl', 'rb') as f:
    data = pickle.load(f)
new_seg_input = []
triples = []
number_of_segment = 16
for doc in data:
    new = {k: v for k, v in sorted(doc.items(), key=lambda item: item[1][1], reverse=True)}
    new_seg_list = list(new.keys())
    # print(new[new_seg_list[0]][0])
    # print(triple2text(new[new_seg_list[0]][0]))
    # break
    if len(new_seg_list) < number_of_segment:
        new_seg_input.append(new_seg_list + ['' for i in range(number_of_segment - len(new_seg_list))])
        triples.append([triple2text(new[new_seg_list[i]][0]) for i in range(len(new_seg_list))] + ['' for i in range(number_of_segment - len(new_seg_list))])
    else:
        new_seg_input.append(new_seg_list[:number_of_segment])
        triples.append([triple2text(new[new_seg_list[i]][0]) for i in range(number_of_segment)])


with open('ka_rerank/val_triple.pkl', 'wb') as f:
    pickle.dump(triples, f)

In [23]:
with open( 'ka_rerank/val.pkl', 'rb') as f:
    data = pickle.load(f)

In [25]:
for doc in data:
    for seg in doc:
        length = len(tokenizer.tokenize(seg,add_special_tokens=True))
        if length > 512:
            print(length)


515
518
513
516
519
515
518
516
516
525
517
513
514
513
518
513
516
517
513
514
516
517
515
516
514
515
513
514
513
515
514
513
513
513
522
522
516
513
516
518
516
514
514
514
513
516
515
516
516
513
515
519
514
517
515
517
516
513
526
514
514
513
515
514
518
516
516
513
519
516
513
513
515
523
516
521
513
516
520
514
516
521
519
513
515
515
519
515
515
513
524
517
520
518
514
520
516
520
514
519
514
521
514
522
514
519
516
516
515
514
522
517
513
518
525
514
517
516
516
529
515
516
515
514
515
514
518
516
517
515
514
526
513
522
513
521
517
524
520
516
514
518
525
522
514
513
526
522
516
519
522
516
514
522
515
519
515
527
524
515
521
518
519
525
516
522
519
514
516
531
518
515
517
514
515
518
514
516
528
515
525
524
518
516
530
519
533
519
518
518
532
526
515
523
522
529
520
516
526
533
516
517
519
518
533
520
521
519
520
516
525
533
522
519
522
514
529
518
519
519
531
517
539
518
523
516
517
524
515
538
518
514
518
516
519
523
520
519
519
515
540
522
524
518
518
523
515
516
528
513


In [2]:
model = AutoModel.from_pretrained("facebook/bart-large")
model.knowledge_encoder = AutoModel.from_pretrained("facebook/bart-large").encoder

In [3]:
a = model.encoder

In [2]:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

In [4]:
tokenizer('He </s></s> I am not perfrct', add_special_tokens=True, padding='max_length', truncation=True, max_length=512, return_tensors='pt')

{'input_ids': tensor([[    0,   894,  1437,     2,     2,    38,   524,    45,   228, 12997,
          3894,     2,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,  