Sentence embedding

In [None]:
import pickle
import json
from sentence_transformers import SentenceTransformer

sentence_model = SentenceTransformer('sentence-transformers/multi-qa-mpnet-base-dot-v1').to('cuda')

### zsre
data_path="data/KnowEdit/ZsRE/zsre_mend_train_10000.json"
with open(data_path, 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)
sentences = []
subjects = []
for i, train_data in enumerate(input_data):
    new_fact = train_data['src'] + ' ' + train_data['alt']
    sentences.append(new_fact)
    subjects.append(train_data['subject'])
embeddings = sentence_model.encode(sentences)

with open(data_path.split('.')[0] + '_embeddings.pkl', "wb") as fOut:
    pickle.dump({'sentences': sentences, 'subjects': subjects, 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)

    
### wiki_counterfact
data_path="data/KnowEdit/wiki_counterfact/train_cf.json"
with open(data_path, 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)
sentences = []
subjects = []
for i, train_data in enumerate(input_data):
    new_fact = train_data['prompt'] + ' ' + train_data['target_new']
    sentences.append(new_fact)
    subjects.append(train_data['subject'])
embeddings = sentence_model.encode(sentences)

with open(data_path.split('.')[0] + '_embeddings.pkl', "wb") as fOut:
    pickle.dump({'sentences': sentences, 'subjects': subjects, 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)


### wiki_recent
data_path="data/KnowEdit/wiki_recent/recent_train.json"
with open(data_path, 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)
sentences = []
subjects = []
for i, train_data in enumerate(input_data):
    new_fact = train_data['prompt'] + ' ' + train_data['target_new']
    sentences.append(new_fact)
    subjects.append(train_data['subject'])
embeddings = sentence_model.encode(sentences)

with open(data_path.split('.')[0] + '_embeddings.pkl', "wb") as fOut:
    pickle.dump({'sentences': sentences, 'subjects': subjects, 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)


### wikibio
data_path="data/KnowEdit/WikiBio/wikibio-train-all.json"
with open(data_path, 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)
sentences = []
subjects = []
for i, train_data in enumerate(input_data):
    new_fact = train_data['labels'] + ' ' + train_data['text']
    sentences.append(new_fact)
    subjects.append(train_data['concept'])
embeddings = sentence_model.encode(sentences)

with open(data_path.split('.')[0] + '_embeddings.pkl', "wb") as fOut:
    pickle.dump({'sentences': sentences, 'subjects': subjects, 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)

Sorting

In [None]:
import json

### wiki_counterfact
with open("data/KnowEdit/wiki_counterfact/train_cf.json", 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)

for i in range(len(input_data)):
    input_data[i]['attribute_len'] = len(input_data[i]['portability'].keys()) + len(input_data[i]['locality'].keys())
sorted_input_data = sorted(input_data, key=lambda x: x['attribute_len'], reverse=True)

json_str = json.dumps(sorted_input_data, indent=4)
with open('data/KnowEdit/wiki_counterfact/train_cf_sorted.json', mode='w', encoding='utf-8') as json_file:
    json_file.write(json_str)


### wiki_recent
with open("data/KnowEdit/wiki_recent/recent_train.json", 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)

for i in range(len(input_data)):
    input_data[i]['attribute_len'] = len(input_data[i]['portability'].keys()) + len(input_data[i]['locality'].keys())
sorted_input_data = sorted(input_data, key=lambda x: x['attribute_len'], reverse=True)

json_str = json.dumps(sorted_input_data, indent=4)
with open('data/KnowEdit/wiki_recent/recent_train_sorted.json', mode='w', encoding='utf-8') as json_file:
    json_file.write(json_str)

Data construction

In [None]:
import random
import json
import torch
import pickle
from sentence_transformers import util

def knowledge_edit_template(new_facts, question):
    return "Please acknowledge the updated information provided below and respond to the subsequent question.\n\n[Updated Information]:\n" \
        + new_facts + "\n\n[Question]:\n" + question

def sentence_completion_prompt(question):
    return f"Please complete the sentence below. You should ONLY output the completed part.\n\n{question}"

def text_completion_prompt(question):
    return f"Please complete the text below. You should ONLY output the completed part.\n\n{question}"

def question_answering_prompt(question):
    return f"Please answer the question below. You should ONLY output the answer.\n\n{question}"

def data_append(data, source, idx, new_facts, question, answer):
    if new_facts:
        data.append({
            "id": "identity_{0}_{1}".format(str(idx), source),
            "conversations": [
            {
                "from": 'human',
                "value": knowledge_edit_template(new_facts, question)
            },
            {
                "from": 'gpt',
                "value": answer
            },
            ]
        })
    else:
        data.append({
            "id": "identity_{0}_{1}".format(str(idx), source),
            "conversations": [
            {
                "from": 'human',
                "value": question
            },
            {
                "from": 'gpt',
                "value": answer
            },
            ]
        })
    idx += 1
    return idx, data


def retrieve_new_facts(embedding_path, sentence_model, query_sentence, query_subject, num):
    with open(embedding_path, "rb") as fIn:
        stored_data = pickle.load(fIn)
        stored_sentences = stored_data['sentences']
        stored_subjects = stored_data['subjects']
        stored_embeddings = stored_data['embeddings']

    stored_embeddings = torch.tensor(stored_embeddings).to('cuda')
    stored_embeddings = util.normalize_embeddings(stored_embeddings)

    query_embedding = util.normalize_embeddings(torch.tensor(sentence_model.encode(
        query_sentence, show_progress_bar=False)).unsqueeze(0).to('cuda'))

    hits = util.semantic_search(query_embedding, stored_embeddings, score_function=util.dot_score, top_k=5)
    assert len(hits) == 1
    hit = hits[0]
    retrieved_sentences = [stored_sentences[hit[k]["corpus_id"]] for k in range(len(hit))]
    retrieved_subjects = [stored_subjects[hit[k]["corpus_id"]] for k in range(len(hit))]

    retrieved_sent = []
    for i in range(len(retrieved_sentences)):
        if retrieved_subjects[i] != query_subject and retrieved_sentences[i] != query_sentence:
            retrieved_sent.append(retrieved_sentences[i])

    try:        
        retrieved_sent = random.sample(retrieved_sent, num)
        idx = 1
        new_facts = f"{idx}. " + query_sentence
        for s in retrieved_sent:
            idx += 1
            new_facts += f"\n{idx}. " + s
    except:
        new_facts = query_sentence
        
    return new_facts




out_of_scope_questions = []
data = []
need_gpt4 = []


### ZsRE_train
idx = 0
with open("data/KnowEdit/ZsRE/zsre_mend_train_10000.json", 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)

for i in range(len(input_data)):
    if idx >= 4000:
        break
    new_facts = input_data[i]['src'] + " " + input_data[i]['alt']
    subject = input_data[i]['subject']

    rand_num = random.random()
    if rand_num < 0.5:
        new_facts = new_facts
    elif rand_num < 0.75:
        new_facts = retrieve_new_facts('data/KnowEdit/ZsRE/zsre_mend_train_10000_embeddings.pkl', sentence_model, new_facts, subject, num=1)
    else:
        new_facts = retrieve_new_facts('data/KnowEdit/ZsRE/zsre_mend_train_10000_embeddings.pkl', sentence_model, new_facts, subject, num=2)

    question = input_data[i]['src']
    answer = input_data[i]['alt']
    ground_truth = input_data[i]['answers'][0] if input_data[i]['answers'] else ""

    rand_num = random.random()
    if rand_num < 0.3:
        # rewrite
        if question != "" and answer != "" and ground_truth != "":
            idx, data = data_append(data, 'zsre_train', idx, new_facts, question, answer)
            idx, data = data_append(data, 'zsre_train', idx, None, question, ground_truth)

    else:
        # rephrase
        question = input_data[i]['rephrase']
        if question != "" and answer != "" and ground_truth != "":
            idx, data = data_append(data, 'zsre_train', idx, new_facts, question, answer)
            idx, data = data_append(data, 'zsre_train', idx, None, question, ground_truth)

    question = input_data[i]['loc'].replace("nq question: ", "")
    answer = input_data[i]['loc_ans']
    if question != "" and answer != "":
        idx, data = data_append(data, 'zsre_train', idx, new_facts, question, answer)
        idx, data = data_append(data, 'zsre_train', idx, None, question, answer)
        out_of_scope_questions.append({'question': question, 'answer': answer})


### wiki_counterfact_train
idx = 0
with_in = 0
with_out = 0
with open("data/KnowEdit/wiki_counterfact/train_cf_sorted.json", 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)

for i in range(len(input_data)):
    if with_in == with_out and with_in > 1200:
        break
    new_facts = input_data[i]['prompt'] + " " + input_data[i]['target_new']
    subject = input_data[i]['subject']
    question = input_data[i]['prompt']
    answer = input_data[i]['target_new']
    ground_truth = input_data[i]['ground_truth']

    rand_num = random.random()
    if rand_num < 0.5:
        new_facts = new_facts
    elif rand_num < 0.75:
        new_facts = retrieve_new_facts('data/KnowEdit/wiki_counterfact/train_cf_embeddings.pkl', sentence_model, new_facts, subject, num=1)
    else:
        new_facts = retrieve_new_facts('data/KnowEdit/wiki_counterfact/train_cf_embeddings.pkl', sentence_model, new_facts, subject, num=2)

    rand_num = random.random()
    if rand_num < 0.3:
        if question != "" and answer != "" and ground_truth != "":
            idx, data = data_append(data, 'wiki_counterfact_train', idx, new_facts, question, answer)
            idx, data = data_append(data, 'wiki_counterfact_train', idx, None, question, ground_truth)
            with_in += 1
    else:
        for attribution in ['portability']:
            if attribution in input_data[i]:
                for k in input_data[i][attribution].keys():
                    for j in range(len(input_data[i][attribution][k])):
                        question = input_data[i][attribution][k][j]['prompt']
                        answer = input_data[i][attribution][k][j]['ground_truth']
                        if question != "" and answer != "":
                            idx, data = data_append(data, 'wiki_counterfact_train', idx, new_facts, question, answer)
                            idx, data = data_append(data, 'wiki_counterfact_train', idx, None, question, None)
                            need_gpt4.append({'question': question, 'prompt_new': sentence_completion_prompt(question)})
                            with_in += 1

    for attribution in ['locality']:
        if attribution in input_data[i]:
            for k in input_data[i][attribution].keys():
                for j in range(len(input_data[i][attribution][k])):
                    question = input_data[i][attribution][k][j]['prompt']
                    answer = input_data[i][attribution][k][j]['ground_truth']
                    if question != "" and answer != "":
                        idx, data = data_append(data, 'wiki_counterfact_train', idx, new_facts, question, answer)
                        idx, data = data_append(data, 'wiki_counterfact_train', idx, None, question, answer)
                        with_out += 1
                        out_of_scope_questions.append({'question': question, 'answer': answer})

### wiki_recent_train
idx = 0
with_in = 0
with_out = 0
with open("data/KnowEdit/wiki_recent/recent_train_sorted.json", 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)

for i in range(len(input_data)):
    if with_in == with_out and with_in > 1200:
        break
    new_facts = input_data[i]['prompt'] + " " + input_data[i]['target_new']
    subject = input_data[i]['subject']
    question = input_data[i]['prompt']
    answer = input_data[i]['target_new']

    rand_num = random.random()
    if rand_num < 0.5:
        new_facts = new_facts
    elif rand_num < 0.75:
        new_facts = retrieve_new_facts('data/KnowEdit/wiki_recent/recent_train_embeddings.pkl', sentence_model, new_facts, subject, num=1)
    else:
        new_facts = retrieve_new_facts('data/KnowEdit/wiki_recent/recent_train_embeddings.pkl', sentence_model, new_facts, subject, num=2)

    if question != "" and answer != "":
        idx, data = data_append(data, 'wiki_recent_train', idx, new_facts, question, answer)
        idx, data = data_append(data, 'wiki_recent_train', idx, None, question, None)
        need_gpt4.append({'question': question, 'prompt_new': sentence_completion_prompt(question)})
        with_in += 1

    for attribution in ['portability']:
        if attribution in input_data[i]:
            for k in input_data[i][attribution].keys():
                for j in range(len(input_data[i][attribution][k])):
                    question = input_data[i][attribution][k][j]['prompt']
                    if question != "" and [item for sublist in input_data[i][attribution][k][j]['ground_truth'] for item in sublist] != []:
                        answer = [item for sublist in input_data[i][attribution][k][j]['ground_truth'] for item in sublist][0]
                        idx, data = data_append(data, 'wiki_recent_train', idx, new_facts, question, answer)
                        idx, data = data_append(data, 'wiki_recent_train', idx, None, question, None)
                        need_gpt4.append({'question': question, 'prompt_new': sentence_completion_prompt(question)})
                        with_in += 1
    
    for attribution in ['locality']:
        if attribution in input_data[i]:
            for k in input_data[i][attribution].keys():
                for j in range(min(len(input_data[i][attribution][k]), 3)):
                    question = input_data[i][attribution][k][j]['prompt']
                    if question != "" and [item for sublist in input_data[i][attribution][k][j]['ground_truth'] for item in sublist] != []:
                        answer = [item for sublist in input_data[i][attribution][k][j]['ground_truth'] for item in sublist][0]
                        idx, data = data_append(data, 'wiki_recent_train', idx, new_facts, question, answer)
                        idx, data = data_append(data, 'wiki_recent_train', idx, None, question, answer)
                        with_out += 1
                        out_of_scope_questions.append({'question': question, 'answer': answer})


### wikibio_train
idx = 0
with open("data/KnowEdit/WikiBio/wikibio-train-all.json", 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)

for i in range(250):
    new_facts = input_data[i]['text'] + " " + input_data[i]['labels']
    subject = input_data[i]['concept']
    question = input_data[i]['text']
    answer = input_data[i]['labels']

    rand_num = random.random()
    if rand_num < 0.5:
        new_facts = new_facts
    elif rand_num < 0.75:
        new_facts = retrieve_new_facts('data/KnowEdit/WikiBio/wikibio-train-all_embeddings.pkl', sentence_model, new_facts, subject, num=1)
    else:
        new_facts = retrieve_new_facts('data/KnowEdit/WikiBio/wikibio-train-all_embeddings.pkl', sentence_model, new_facts, subject, num=2)

    idx, data = data_append(data, 'wikibio_train', idx, new_facts, question, answer)
    idx, data = data_append(data, 'wikibio_train', idx, None, question, None)
    need_gpt4.append({'question': question, 'prompt_new': text_completion_prompt(question)})
    
    for attribution in ['locality']:
        if attribution in input_data[i]:
            for k in input_data[i][attribution].keys():
                for j in range(min(len(input_data[i][attribution][k]), 1)):
                    question = input_data[i][attribution][k][j]['prompt']
                    if question != "" and input_data[i][attribution][k][j]['ground_truth'] != []:
                        answer = input_data[i][attribution][k][j]['ground_truth'][0]
                        idx, data = data_append(data, 'wikibio_train', idx, new_facts, question, answer)
                        idx, data = data_append(data, 'wikibio_train', idx, None, question, answer)
                        out_of_scope_questions.append({'question': question, 'answer': answer})


### MQUAKE-CF
idx = 0
with open("data/WizardLM/WizardLM_evol_instruct_V2_143k.json", 'r', encoding='utf-8') as input_file:
    wizardlm_data = json.load(input_file)

with open("data/MQuAKE/MQuAKE-CF.json", 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)

for i in range(2500):
    # with edit
    requested_rewrite = input_data[i]['requested_rewrite']
    new_facts = ""
    if len(requested_rewrite) > 1:
        for j in range(len(requested_rewrite)):
            new_facts += str(j+1) + ". " + requested_rewrite[j]['prompt'].replace("{}", requested_rewrite[j]['subject']) + " " + requested_rewrite[j]['target_new']['str'] + "\n"
    else:
        new_facts += requested_rewrite[0]['prompt'].replace("{}", requested_rewrite[0]['subject']) + " " + requested_rewrite[0]['target_new']['str']
    new_facts = new_facts.strip("\n")
    question = input_data[i]['questions'][0]

    answer = input_data[i]['new_answer'] + "\n"
    for k in range(len(input_data[i]['new_single_hops'])):
        answer += input_data[i]['new_single_hops'][k]['cloze'] + " " + input_data[i]['new_single_hops'][k]['answer'] + ". "
    answer = answer.strip()

    idx, data = data_append(data, 'mquake_cf', idx, new_facts, question, answer)

    # without edit
    question = input_data[i]['questions'][0]
    answer = input_data[i]['answer'] + "\n"
    for k in range(len(input_data[i]['single_hops'])):
        answer += input_data[i]['single_hops'][k]['cloze'] + " " + input_data[i]['single_hops'][k]['answer'] + ". "
    answer = answer.strip()

    idx, data = data_append(data, 'mquake_cf', idx, None, question, answer)

    # new_fact_unrelated
    rand_num = random.random()
    if rand_num < 0.5:
        out_of_scope_question = random.sample(out_of_scope_questions, 1)[0]
        idx, data = data_append(data, 'mquake_cf', idx, new_facts, out_of_scope_question['question'], out_of_scope_question['answer'])
        idx, data = data_append(data, 'mquake_cf', idx, None, out_of_scope_question['question'], out_of_scope_question['answer'])
    else:
        out_of_scope_question = random.sample(wizardlm_data, 1)[0]
        assert out_of_scope_question['conversations'][0]['from'] == "human"
        idx, data = data_append(data, 'mquake_cf', idx, new_facts, out_of_scope_question['conversations'][0]['value'], None)
        idx, data = data_append(data, 'mquake_cf', idx, None, out_of_scope_question['conversations'][0]['value'], None)
        need_gpt4.append({'question': out_of_scope_question['conversations'][0]['value'], 'prompt_new': out_of_scope_question['conversations'][0]['value']})


### MQUAKE-T
idx = 0
with open("data/MQuAKE/MQuAKE-T.json", 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)

for i in range(1500):
    # with edit
    requested_rewrite = input_data[i]['requested_rewrite']
    new_facts = ""
    if len(requested_rewrite) > 1:
        for j in range(len(requested_rewrite)):
            new_facts += str(j+1) + ". " + requested_rewrite[j]['prompt'].replace("{}", requested_rewrite[j]['subject']) + " " + requested_rewrite[j]['target_new']['str'] + "\n"
    else:
        new_facts += requested_rewrite[0]['prompt'].replace("{}", requested_rewrite[0]['subject']) + " " + requested_rewrite[0]['target_new']['str']
    new_facts = new_facts.strip("\n")
    question = input_data[i]['questions'][0]

    answer = input_data[i]['new_answer'] + "\n"
    for k in range(len(input_data[i]['new_single_hops'])):
        answer += input_data[i]['new_single_hops'][k]['cloze'] + " " + input_data[i]['new_single_hops'][k]['answer'] + ". "
    answer = answer.strip()

    idx, data = data_append(data, 'mquake_t', idx, new_facts, question, answer)
    
    # without edit
    question = input_data[i]['questions'][0]
    answer = input_data[i]['answer'] + "\n"
    for k in range(len(input_data[i]['single_hops'])):
        answer += input_data[i]['single_hops'][k]['cloze'] + " " + input_data[i]['single_hops'][k]['answer'] + ". "
    answer = answer.strip()

    idx, data = data_append(data, 'mquake_t', idx, None, question, answer)

    # new_fact_unrelated
    rand_num = random.random()
    if rand_num < 0.5:
        out_of_scope_question = random.sample(out_of_scope_questions, 1)[0]
        idx, data = data_append(data, 'mquake_t', idx, new_facts, out_of_scope_question['question'], out_of_scope_question['answer'])
        idx, data = data_append(data, 'mquake_t', idx, None, out_of_scope_question['question'], out_of_scope_question['answer'])
    else:
        out_of_scope_question = random.sample(wizardlm_data, 1)[0]
        assert out_of_scope_question['conversations'][0]['from'] == "human"
        idx, data = data_append(data, 'mquake_t', idx, new_facts, out_of_scope_question['conversations'][0]['value'], None)
        idx, data = data_append(data, 'mquake_t', idx, None, out_of_scope_question['conversations'][0]['value'], None)
        need_gpt4.append({'question': out_of_scope_question['conversations'][0]['value'], 'prompt_new': out_of_scope_question['conversations'][0]['value']})


### COUNTERFACT related question
idx = 0
with open("data/counterfact/counterfact_related_QA.json", 'r', encoding='utf-8') as input_file:
    input_data = json.load(input_file)

with open("data/WizardLM/WizardLM_evol_instruct_V2_143k.json", 'r', encoding='utf-8') as input_file:
    wizardlm_data = json.load(input_file)
wizardlm_data = random.sample(wizardlm_data, len(wizardlm_data))

for i in range(7500):
    new_facts = input_data[i]['new_fact']
    question = input_data[i]['question']
    answer = input_data[i]['answer']
    idx, data = data_append(data, 'counterfact_qa', idx, new_facts, question, answer)
    idx, data = data_append(data, 'counterfact_qa', idx, None, question, None)
    need_gpt4.append({'question': question, 'prompt_new': question_answering_prompt(question)})

    question = wizardlm_data[i]['conversations'][0]['value']
    assert wizardlm_data[i]['conversations'][0]['from'] == 'human'
    idx, data = data_append(data, 'counterfact_qa', idx, new_facts, question, None)
    idx, data = data_append(data, 'counterfact_qa', idx, None, question, None)
    need_gpt4.append({'question': question, 'prompt_new': question})


### save data
json_str = json.dumps(data, indent=4)
with open("LTE_train_data.json", mode='w', encoding='utf-8') as output_file:
    output_file.write(json_str)
print(len(data))

with open("LTE_train_data_need_gpt4.jsonl", 'w', encoding='utf-8') as output_file:
    for i in range(len(need_gpt4)):
        output_file.write(json.dumps(need_gpt4[i])+ "\n")


The file "LTE_train_data_need_gpt4.jsonl" contains all data that requires GPT-4's completion. After acquiring GPT-4's answer, you should fill these answers into "LTE_train_data.json" to obtain the final training data. 