In [9]:
import datasets
from datasets import concatenate_datasets
from tqdm import tqdm
import spacy
import json
from itertools import groupby, combinations
import re
from modelzipper.tutils import *
import itertools
import multiprocessing
import threading

In [10]:
all_data = auto_read_data("/data/zecheng/data/process_wiki_document/one_hop/merged_data/processed_data_2.jsonl")
combine_index = datasets.load_from_disk("/data/zecheng/data/process_wiki_document/one_hop/merged_data/combine_data_hf_2")

all_data_dict = {}
for item in all_data:
    all_data_dict[item.pop("id")] = item

In [None]:
nlp = spacy.load("en_core_web_lg")

def extract_entities(text):  # extract all entities in a text
    doc = nlp(text, disable=["tagger", "parser"])
    return set(e.text for e in doc.ents)

def judge_2hop_pair(prefix_q, prefix_a, suffix_q, suffix_a):
    prefix_q_e, prefix_a_e = extract_entities(prefix_q), extract_entities(prefix_a)
    suffix_q_e, suffix_a_e = extract_entities(suffix_q), extract_entities(suffix_a)
    if prefix_a_e & suffix_q_e and suffix_a_e.isdisjoint(prefix_q_e):
        return True
    return False


def two_hop_judger(group_ids, meta_data, selected_cases=2):
    results = []
    combinations_list = list(itertools.permutations(group_ids, selected_cases))
    for group in combinations_list:
        prefix_id, suffix_id = group[0], group[1]
        prefix_q, suffix_q = meta_data[prefix_id]['question'], meta_data[suffix_id]['question']
        prefix_a, suffix_a = meta_data[prefix_id]['answer'], meta_data[suffix_id]['answer']
        if judge_2hop_pair(prefix_q, prefix_a, suffix_q, suffix_a):
            results.append((prefix_id, suffix_id))
    return tuple(results)

def mp_process_item(item):
    return (two_hop_judger(item["all_ref_ids"], all_data_dict, 2), item["all_ref_ids"])

"""
two_hop_raw_data = []

with multiprocessing.Pool(processes=84) as pool:
    results = list(tqdm(pool.imap(mp_process_item, combine_index), total=len(combine_index)))
    two_hop_raw_data = []
    for result in results:
        two_hop_raw_data.append(result)

print("finish")
print(len(two_hop_raw_data))
"""

all_data_dict[0]

In [None]:
processed_sample = []
tmp = set()
for sample in two_hop_raw_data:
    if sample[0]:
        for selected_pair in sample[0]:
            if selected_pair not in tmp:
                tmp.add(selected_pair)
                processed_sample.append({"prefix_id": selected_pair[0], "suffix_id": selected_pair[1], "all_ref_ids": sample[1]})

print(len(processed_sample))

In [None]:
final = []
for id, sample in enumerate(processed_sample):
    prefix_id, suffix_id = sample["prefix_id"], sample["suffix_id"]
    prefix_q, prefix_a = all_data_dict[prefix_id]["question"], all_data_dict[prefix_id]["answer"]
    suffix_q, suffix_a = all_data_dict[suffix_id]["question"], all_data_dict[suffix_id]["answer"]
    final.append({"id": id, "prefix_q": prefix_q, "prefix_a": prefix_a, "suffix_q": suffix_q, "suffix_a": suffix_a, "prefix_id": prefix_id, "suffix_id": suffix_id, "all_ref_ids": sample["all_ref_ids"]})

print(len(final))
final[0]

In [None]:
from volcenginesdkarkruntime import Ark
import threading

MODELS = {
    "doubao-lite-4k": "ep-20240618124048-xd5vm",
    "doubao-lite-32k": "ep-20240618163533-tzxts",
    "doubao-pro-4k": "ep-20240618125023-lkmzs",
    "doubao-pro-32k": "ep-20240618163715-nmkbp",
}


def create_prompt(prefix_q, prefix_a, suffix_q, suffix_a, prefix_id, suffix_id):
    reference_1, reference_2 = all_data_dict[prefix_id], all_data_dict[suffix_id]

    message1 = f"I want to construct a two-hop question. Given the following two pieces of information, combine them to form a new natural language question according to the provided question, answer and the context. Only give me the new question and do not output any other words.\n\nQ1:\nContext: {reference_1}\nQuestion: {prefix_q}\nAnswer: {prefix_a}\n\nQ2:\nContext: {reference_2}\nQuestion: {suffix_q}\nAnswer: {suffix_a}\n\nConstruct a new question and answer pair that utilizes information from both Q1 and Q2 questions, answers and contexts. The new question should be between 10-40 words long and the answer to the new question should be the related to the answer to Q2. Show me the new question and answer pair in the following format\n\n###Q: new question ###A: new answer\n\nDo not output any other words."

    message2 = f"I want to construct a two-hop question. Given the following two pieces of information, combine them to form a new natural language question according to the provided question and answer. Only give me the new question and do not output any other words.\n\nQ1:\n Question: {prefix_q}\nAnswer: {prefix_a}\n\nQ2:\n Question: {suffix_q}\nAnswer: {suffix_a}\n\nConstruct a new question and answer pair that utilizes information from both Q1 and Q2 questions and answers. The new question should be between 10-40 words long and the answer to the new question should be the related to the answer to Q2. Show me the new question and answer pair in the following format\n\n###Q: new question ###A: new answer\n\nDo not output any other words."

    return message1
    
def call_with_messages(client, model_id, user_query, max_attempts=2):
    attempts = 0

    while attempts < max_attempts:
        try:
            completion = client.chat.completions.create(
                model=model_id,
                messages=[
                    {"role": "user", "content": user_query},
                ],
                max_tokens=1024,
            )
            return completion.choices[0].message.content
        except Exception as e:
            print(f"Attempt {attempts + 1} failed: {e}")
            attempts += 1
            if attempts < max_attempts:
                time.sleep(10)  # sleep for 10 seconds before retrying
            else:
                print("Failed after maximum attempts.")
                return None


def process_chunk(chunk, chunk_length, thread_id, model_id, client):
    output_file = f"/data/zecheng/data/process_wiki_document/two_hop/generated_QA_pairs_thread_{thread_id}.jsonl"
    log_file = f"/data/zecheng/data/process_wiki_document/two_hop/logs/processing_log_thread_{thread_id}.log"
    with open(output_file, 'a') as f, open(log_file, 'a') as log_f:
        with tqdm(total=chunk_length, desc=f"Processing thread {thread_id}") as pbar:
            for sample in chunk:
                user_message = create_prompt(sample["prefix_q"], sample["prefix_a"], sample["suffix_q"], sample["suffix_a"], sample["prefix_id"], sample["suffix_id"]) 
                response = call_with_messages(client=client, model_id=model_id, user_query=user_message)
                if response is None:
                    continue
                else:  # has response
                    sample["combined_qa"] = response
                    json_str = json.dumps(sample)
                    # Write to output file
                    f.write(json_str + '\n')
                    f.flush()
                    # Write to log file
                    log_f.write(f"{sample['id']}\n")
                    log_f.flush()
                pbar.update(1)


def split_data(data, n_splits):
    chunk_size = len(data) // n_splits + 1
    chunks = [data[i*chunk_size : (i+1)*chunk_size] for i in range(n_splits)]
    return chunks

testing_cases = final[:len(final)]
# testing_cases = remain_cases
client = Ark(api_key="ea28bf46-979c-49b9-b08a-92303bb99052")
model_name = "doubao-pro-32k"

num_threads = 4
data_splits = split_data(testing_cases, num_threads)
threads = []

for i in range(num_threads):
    chunk_length = len(data_splits[i])
    thread = threading.Thread(target=process_chunk, args=(data_splits[i], chunk_length, i, MODELS[model_name], client))
    threads.append(thread)
    thread.start()

for thread in threads:
    thread.join()


"""
testing_cases = final[:len(final)]
client = Ark(api_key="ea28bf46-979c-49b9-b08a-92303bb99052")
model_name = "doubao-pro-32k"
with open("/data/zecheng/data/process_wiki_document/two_hop/doubao_data.jsonl", "a") as f, tqdm(total=len(testing_cases)) as pbar:
    for item in testing_cases:
        user_message = create_prompt(item["prefix_q"], item["prefix_a"], item["suffix_q"], item["suffix_a"], item["prefix_id"], item["suffix_id"]) 
        back_message = call_with_messages(client, MODELS[model_name], user_message)
        item["combined_qa"] = back_message
        f.write(json.dumps(item) + "\n")
        pbar.update(1)
"""

In [None]:
tmp = auto_read_dir("/data/zecheng/data/process_wiki_document/two_hop/logs")
all_process_ids = []

for file in tmp:
    with open(file) as f:
        for line in f:
            all_process_ids.append(int(line.strip()))

all_indexs = set(list(range(len(final))))
remain_indexs = all_indexs - set(all_process_ids)
print(len(remain_indexs))
print(remain_indexs)

In [None]:
all_files = auto_read_dir("/data/zecheng/data/process_wiki_document/two_hop")
all_data = [auto_read_data(file) for file in all_files] 
all_data = [item for sublist in all_data for item in sublist]
print(len(all_data))

In [None]:
all_data[0]

In [None]:
def process_twohopqa(item):
    combined_qa = item['combined_qa']
    if len(combined_qa.split("###")) != 3:
        return None
    q, a = combined_qa.split("###")[1][2:].strip(), combined_qa.split("###")[2][2:].strip()
    item["combined_question"] = q
    item["final_question"] = a
    return item

final_data = []
for item in all_data:
    processed_item = process_twohopqa(item)
    if processed_item:
        final_data.append(processed_item)

print(f"all_data: {len(all_data)} | final_data: {len(final_data)}")

In [None]:
final_data[0]

In [None]:
from collections import defaultdict
import datasets
transformed_data = defaultdict(list)

for dic in final_data:
    for key, value in dic.items():
        transformed_data[key].append(value)

transformed_data = dict(transformed_data)
hf_dataset = datasets.Dataset.from_dict(transformed_data)
hf_dataset.save_to_disk("/data/zecheng/data/process_wiki_document/two_hop/hf_datasets")

In [None]:
import datasets

data = datasets.load_from_disk("/data/zecheng/data/process_wiki_document/two_hop/hf_datasets")

print(data[1])
print(data[2])

#### Create Rejected / Chosen Pairs

In [19]:
from pprint import pprint
import datasets

twohop_dataset = datasets.load_from_disk("/data/zecheng/data/process_wiki_document/two_hop/hf_datasets")

def reorder_ids(ids_list, prefix_id, suffix_id):
    if prefix_id not in ids_list or suffix_id not in ids_list:
        return ids_list 
    prefix_index = ids_list.index(prefix_id)
    suffix_index = ids_list.index(suffix_id)
    if prefix_index < suffix_index:
        return ids_list
    ids_list.remove(prefix_id)
    suffix_index = ids_list.index(suffix_id)
    ids_list.insert(suffix_index, prefix_id)
    return ids_list

def add_ref_text(item):
    all_ref_ids = item["all_ref_ids"]
    prefix_id, suffix_id = item["prefix_id"], item["suffix_id"]
    all_ref_ids = reorder_ids(all_ref_ids, prefix_id, suffix_id)
    item["all_ref_ids"] = all_ref_ids
    all_ref_text = [all_data_dict[id]["context"] for id in all_ref_ids]
    item["all_ref_text"] = all_ref_text
    item["final_answer"] = item.pop("final_question")
    item.pop("combined_qa")
    return item

new_twohop_dataset = twohop_dataset.map(add_ref_text, num_proc=16)
pprint(new_twohop_dataset[0].keys())
new_twohop_dataset.save_to_disk("/data/zecheng/data/process_wiki_document/two_hop/hf_dataset_step2")

dict_keys(['id', 'prefix_q', 'prefix_a', 'suffix_q', 'suffix_a', 'prefix_id', 'suffix_id', 'all_ref_ids', 'combined_question', 'all_ref_text', 'final_answer'])


Saving the dataset (0/1 shards):   0%|          | 0/878 [00:00<?, ? examples/s]

In [20]:
new_twohop_dataset[0]

{'id': 840,
 'prefix_q': 'What is the potential annual generating capacity from agricultural waste in Scotland for biogas?',
 'prefix_a': 'It is estimated that 0.4 GW of generating capacity might be available from agricultural waste in Scotland.',
 'suffix_q': 'What is the capacity of installed renewable electricity in Scotland at the end of 2015?',
 'suffix_a': '7,723 megawatts (MW)',
 'prefix_id': 7422,
 'suffix_id': 7421,
 'all_ref_ids': [7420, 7422, 7421, 7423],
 'combined_question': 'How does the capacity of installed renewables in 2015 compare to potential tidal power in Scotland?',
 'all_ref_text': ['in 1941, is thought by some to be an early example, with his pink or transparent costume. Writer Roy Thomas penned thought balloons that suggested Firebrand had been involved in a gay relationship with his sidekick and bodyguard Slugger Dunn, although these hints never moved beyond subtext. A more modern example is the violent vigilante superhero Midnighter. The Batman like Midnight

In [23]:
def combine_fn(lst, max_candidates=2):
    trimmed_lists = [random.sample(sublst, min(len(sublst), max_candidates)) if len(sublst) > max_candidates else sublst for sublst in lst]
    all_combinations = itertools.product(*trimmed_lists)
    concatenated_results = [torch.cat(combination) for combination in all_combinations]
    return concatenated_results

lst = [
    [torch.tensor([1]), torch.tensor([2]), torch.tensor([3])],
    [torch.tensor([4]), torch.tensor([5])],
    [torch.tensor([6]), torch.tensor([7]), torch.tensor([8])]
]

results = combine_fn(lst)
for result in results:
    print(result)

tensor([1, 4, 7])
tensor([1, 4, 8])
tensor([1, 5, 7])
tensor([1, 5, 8])
tensor([2, 4, 7])
tensor([2, 4, 8])
tensor([2, 5, 7])
tensor([2, 5, 8])
