In [None]:
import json
import jsonlines
import pandas as pd
from tqdm.auto import tqdm
import re
import random

In [None]:
# Prepare cwq train data for BM25 search

train_data_path = 'yout_cwq_train_data_path'
data = json.load(open(train_data_path))
print('Data length:', len(data))
data_dic = {i['question']: i for i in data}

gpt_data = pd.read_csv('your_cwq_gpt_results_path', sep='\t')
total = 0
for i in range(len(gpt_data)):
    assert gpt_data['question'][i] in data_dic, gpt_data['question'][i]
    try:
        key = gpt_data['question'][i]
        temp_s = re.search(r'Action [0-9] Finish', gpt_data['gpt_out'][i]).group(0).replace('Action ', '').replace(' Finish', '')
        temp_num = int(temp_s) - 1
        final_q = re.search('Action ' + str(temp_num) + r' (Question|Multi_Answer_Question)\[.*?\]', gpt_data['gpt_out'][i]).group(0)
        final_q = final_q.replace('Action ' + str(temp_num), '')
        final_q = final_q.replace('Multi_Answer_Question[', '').replace('Question[', '')
        final_q = final_q.replace(']', '').strip().rstrip()
        total += 1
        data_dic[key]['final_q'] = final_q
    except:
        continue

print(total)
data = [v for k, v in data_dic.items()]
for item in data:
    if 'final_q' not in item:
        item['final_q'] = ''
    item['query'] = item['question'] + ' ' + item['final_q'] + ' ' + ','.join(item['answers'])

json.dump(data, open('cwq_train_data_with_final_q.json', 'w'), indent=2)

In [None]:
# Prepare webqsp train data for BM25 search

train_data_path = 'yout_webqsp_train_data_path'
data = json.load(open(train_data_path))
print('Data length:', len(data))
data_dic = {i['question']: i for i in data}

gpt_data = json.load(open('your_webqsp_gpt_results_path'))
total = 0
for i in range(len(gpt_data)):
    key = gpt_data[i]['question'][0].upper() + gpt_data[i]['question'][1:]
    assert key in data_dic, gpt_data[i]['question']
    try:
        temp_s = re.search(r'Action [0-9] Finish', gpt_data[i]['gpt_out']).group(0).replace('Action ', '').replace(' Finish', '')
        temp_num = int(temp_s) - 1
        final_q = re.search('Action ' + str(temp_num) + r' (Question|Multi_Answer_Question)\[.*?\]', gpt_data[i]['gpt_out']).group(0)
        final_q = final_q.replace('Action ' + str(temp_num), '')
        final_q = final_q.replace('Multi_Answer_Question[', '').replace('Question[', '')
        final_q = final_q.replace(']', '').strip().rstrip()
        total += 1
        data_dic[key]['final_q'] = final_q
    except:
        continue

print(total)
data = [v for k, v in data_dic.items()]
for item in data:
    if 'final_q' not in item:
        item['final_q'] = ''
    item['query'] = item['question'] + ' ' + item['final_q'] + ' ' + ','.join(item['answers'])
    if len(item['query']) > 1020:
        item['query'] = item['query'][:1020]

json.dump(data, open('webqsp_train_data_with_final_q.json', 'w'), indent=2)

In [None]:
# Prepare cwq dev data for BM25 search

train_data_path = 'yout_cwq_dev_data_path'
data = json.load(open(train_data_path))
print('Data length:', len(data))

for item in data:
    if 'final_q' not in item:
        item['final_q'] = ''
    item['query'] = item['question'] + ' ' + item['final_q'] + ' ' + ','.join(item['answers'])
    if len(item['query']) > 1020:
        item['query'] = item['query'][:1020]

json.dump(data, open('cwq_dev_data_with_final_q.json', 'w'), indent=2)

In [None]:
import spacy

nlp = spacy.load("en_core_web_sm")

In [None]:
bm25_results = json.load(open('webqsp_train_bm25_results.json'))

In [None]:
# Select positive and hard negative passages

train_data = []

max_positive = 100
max_hard = 100
use_has_query_number = 70
use_normal_number = 30

progress_bar = tqdm(range(len(bm25_results)))
for result in bm25_results:
    doc = nlp(result['question'] + ' ' + result['final_q'])
    query_entities = []
    for token in doc:
        if token.pos_ == "PROPN":
            query_entities.append(token.text)
    answers = result['answers']
    ctxs = result['ctxs']
    ctxs.sort(key = lambda x: x['score'])
    answer_ctxs = {key: [] for key in answers}

    positive_ctxs = []
    hard_negative_ctxs = []
    has_query_ctxs = []
    normal_ctxs = []
    for ctx in ctxs:
        ctx_text = ctx['text']
        q_ok = False
        a_ok = False
        for entity in query_entities:
            if entity in ctx_text:
                q_ok = True
                break
        for answer in answers:
            if answer in ctx_text:
                a_ok = True
                answer_ctxs[answer].append(ctx)
        if q_ok and a_ok:
            pass
        elif q_ok:
            has_query_ctxs.append(ctx)
        elif a_ok:
            pass
        else:
            normal_ctxs.append(ctx)

    hard_negative_ctxs = has_query_ctxs[:use_has_query_number]

    for key in answer_ctxs.keys():
        answer_ok = 0
        for ctx in answer_ctxs[key]:
            ctx_text = ctx['text']
            q_ok = False
            for entity in query_entities:
                if entity in ctx_text:
                    q_ok = True
                    break
            if q_ok:
                answer_ok += 1
                positive_ctxs.append(ctx)
                answer_ctxs[key].remove(ctx)
        if answer_ok < 5:
            positive_ctxs += answer_ctxs[key][:5 - answer_ok]
            if len(answer_ctxs[key]) > 5 - answer_ok:
                length = len(answer_ctxs[key])
                hard_negative_ctxs += answer_ctxs[key][max(length - 5, 5 - answer_ok):length]
    
    hard_negative_ctxs.sort(key=lambda x: x['score'], reverse=True)

    positive_ctxs = list(set([p['title'] + 'SPLIT!!!' + p['text'] for p in positive_ctxs]))
    positive_ctxs = [{'title': p.split('SPLIT!!!')[0], 'text': p.split('SPLIT!!!')[1]} for p in positive_ctxs]
    hard_negative_ctxs = list(set([hn['title'] + 'SPLIT!!!' + hn['text'] for hn in hard_negative_ctxs]))
    hard_negative_ctxs = [{'title': hn.split('SPLIT!!!')[0], 'text': hn.split('SPLIT!!!')[1]} for hn in hard_negative_ctxs]

    hard_negative_ctxs = hard_negative_ctxs[:max_hard]
    if len(hard_negative_ctxs) == 0:
        hard_negative_ctxs = normal_ctxs[:use_normal_number]
    train_data.append({
        'id': result['id'],
        'question': result['question'],
        'answers': answers,
        "positive_ctxs": positive_ctxs,
        "negative_ctxs": [],
        "hard_negative_ctxs": hard_negative_ctxs
    })
    progress_bar.update(1)

num = 0
for data in train_data:
    if len(data['positive_ctxs']) > 0 and len(data['hard_negative_ctxs']) > 0:
        num += 1
print(num)

mean_positive = 0
mean_hard = 0
for i in range(len(train_data)):
    mean_positive += len(train_data[i]['positive_ctxs'])
    mean_hard += len(train_data[i]['hard_negative_ctxs'])
print(mean_positive / len(train_data), mean_hard / len(train_data))

json.dump(train_data, open('webqsp_dpr_train_data.json', 'w'), indent=2)