In [None]:
import datasets
import collections
import random

In [3]:
# data_all_list = []

In [4]:
def save(data, save_path):
    import json, os
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'w') as f:
        for item in data:
            id = item['id']
            context = item['context']
            conversations = []
            for qa in item['qa_pairs']:
                question = qa['question']
                answer = qa['answer']
                conversations.append({"role": "user", "content": question})
                conversations.append({"role": "assistant", "content": answer})
            f.write(json.dumps({"id": id, "context": context, "conversations": conversations}) + "\n")

In [None]:
# COQA ('stanfordnlp/coqa')
output_list = []
data = datasets.load_dataset('stanfordnlp/coqa')
data = list(data['train'])
# iterate all items in the dataset

random.seed(0)

item_list = []
for i, item in enumerate(data):
    context = item['story']
    questions = item['questions']
    answers = item['answers']["input_text"]
    item_list.append({
        'context': context,
        'qa_pairs': [{'question': q, 'answer': a} for q, a in zip(questions, answers)]
    })

# randomly merge five of each to one item
num_context = 8
num_merge = 1
for i in range(0, len(item_list), num_merge):
    if i + num_merge > len(item_list):
        break
    context_list = [item_list[j]['context'] for j in range(i, i + num_merge)]
    context_list += [item['context'] for item in random.sample(item_list, random.randint(0, num_context))]
    random.shuffle(context_list)
    context = "\n\n".join(context_list)
    if len(context.split()) > 4096:
        print(f"skip context len: {len(context.split())}")
        continue
    qa_pairs_list = [item_list[j]['qa_pairs'] for j in range(i, i + num_merge)]
    random.shuffle(qa_pairs_list)
    qa_pairs = sum(qa_pairs_list, [])
    while sum([len(qa["question"].split()) for qa in qa_pairs]) + sum([len(qa["answer"].split()) for qa in qa_pairs]) > 512:
        qa_pairs = qa_pairs[:-1]
        # print(f"q eln: {[len(qa['question'].split()) for qa in qa_pairs]}, a len: {[len(qa['answer'].split()) for qa in qa_pairs]}")
        # continue
    output_list.append({
        "id": f"coqa.{len(output_list):04d}",
        'context': context,
        'qa_pairs': qa_pairs
    })

print(f"COQA: {len(output_list)}, context len: {sum([len(x['context'].split()) for x in output_list]) / len(output_list)}")
print(collections.Counter([len(item["qa_pairs"]) for item in output_list]))
save(output_list, 'v4/coqa.jsonl')
# data_all_list.extend(output_list)
# # output_list

In [None]:
# DROP ('ucinlp/drop')
output_list = []
data = datasets.load_dataset('ucinlp/drop')
data = list(data['train'])
random.seed(0)

context_to_item = {}
for item in data:
    context = item['passage']
    if context not in context_to_item:
        context_to_item[context] = []
    context_to_item[context].append(item)


item_list = []
for i, context in enumerate(context_to_item):
    qa_pairs = []
    for item in context_to_item[context]:
        question = item['question']
        answer = item['answers_spans']['spans'][0]
        qa_pairs.append({'question': question, 'answer': answer})
    item_list.append({
        'context': context,
        'qa_pairs': qa_pairs,
    })


num_context = 8
num_merge = 1
for i in range(0, len(item_list), num_merge):
    if i + num_merge > len(item_list):
        break
    context_list = [item_list[j]['context'] for j in range(i, i + num_merge)]
    context_list += [item['context'] for item in random.sample(item_list, random.randint(0, num_context))]
    random.shuffle(context_list)
    context = "\n\n".join(context_list)
    if len(context.split()) > 4096:
        print(f"skip context len: {len(context.split())}")
        continue
    qa_pairs_list = [item_list[j]['qa_pairs'] for j in range(i, i + num_merge)]
    random.shuffle(qa_pairs_list)
    qa_pairs = sum(qa_pairs_list, [])
    while sum([len(qa["question"].split()) for qa in qa_pairs]) + sum([len(qa["answer"].split()) for qa in qa_pairs]) > 512:
        qa_pairs = qa_pairs[:-1]
        # print(f"q eln: {[len(qa['question'].split()) for qa in qa_pairs]}, a len: {[len(qa['answer'].split()) for qa in qa_pairs]}")
        # continue
    output_list.append({
        "id": f"drop.{len(output_list):04d}",
        'context': context,
        'qa_pairs': qa_pairs
    })

print(f"DROP: {len(output_list)}, context len: {sum([len(x['context'].split()) for x in output_list]) / len(output_list)}")
print(collections.Counter([len(item["qa_pairs"]) for item in output_list]))
save(output_list, 'v4/drop.jsonl')
# data_all_list.extend(output_list)
# output_list

In [None]:
# narrativeqa ('deepmind/narrativeqa')
output_list = []
data = datasets.load_dataset('deepmind/narrativeqa')
data = list(data['train'])
context_to_item = {}
for item in data:
    context = item['document']['summary']['text']
    if context not in context_to_item:
        context_to_item[context] = []
    context_to_item[context].append(item)
for i, context in enumerate(context_to_item):
    if len(context.split()) > 2048:
        print(f"skip context len: {len(context.split())}")
        continue
    qa_pairs = []
    for item in context_to_item[context]:
        question = item['question']['text']
        answer = item['answers'][0]['text']
        qa_pairs.append({'question': question, 'answer': answer})
    while sum([len(qa["question"].split()) for qa in qa_pairs]) + sum([len(qa["answer"].split()) for qa in qa_pairs]) > 512:
        qa_pairs = qa_pairs[:-1]
        # print(f"q eln: {[len(qa['question'].split()) for qa in qa_pairs]}, a len: {[len(qa['answer'].split()) for qa in qa_pairs]}")
        # continue
    output_list.append({
        "id": f"narrativeqa.{i:04d}",
        'context': context,
        'qa_pairs': qa_pairs
    })
print(f"narrativeqa: {len(output_list)}, context len: {sum([len(x['context'].split()) for x in output_list]) / len(output_list)}")
print(collections.Counter([len(item["qa_pairs"]) for item in output_list]))
save(output_list, 'v4/narrativeqa.jsonl')
# output_list

In [None]:
# PubMed ('qiaojin/PubMedQA')
output_list = []
data = datasets.load_dataset('qiaojin/PubMedQA', 'pqa_labeled')
data = list(data['train'])
context_to_item = {}
for item in data:
    context = "\n".join(item['context']['contexts'])
    if context not in context_to_item:
        context_to_item[context] = []
    context_to_item[context].append(item)
for i, context in enumerate(context_to_item):
    if len(context.split()) > 2048:
        print(f"skip context len: {len(context.split())}")
        continue
    qa_pairs = []
    assert len(context_to_item[context]) == 1
    for item in context_to_item[context]:
        question = item['question']
        answer = f'{item["final_decision"].capitalize()}. {item["long_answer"]}'
        qa_pairs.append({'question': question, 'answer': answer})
    if sum([len(qa["question"].split()) for qa in qa_pairs]) + sum([len(qa["answer"].split()) for qa in qa_pairs]) > 512:
        print(f"q eln: {[len(qa['question'].split()) for qa in qa_pairs]}, a len: {[len(qa['answer'].split()) for qa in qa_pairs]}")
        continue
    output_list.append({
        "id": f"pubmedqa.{i:04d}",
        'context': context,
        'qa_pairs': qa_pairs
    })
print(f"pubmedqa: {len(output_list)}, context len: {sum([len(x['context'].split()) for x in output_list]) / len(output_list)}")
print(collections.Counter([len(item["qa_pairs"]) for item in output_list]))
save(output_list, 'v4/pubmedqa.jsonl')
# data_all_list.extend(output_list)
# output_list

In [None]:
# quail ('textmachinelab/quail')
output_list = []
data = datasets.load_dataset('textmachinelab/quail')
data = list(data['train'])
context_to_item = {}
for item in data:
    context = item['context']
    if context not in context_to_item:
        context_to_item[context] = []
    context_to_item[context].append(item)
for i, context in enumerate(context_to_item):
    if len(context.split()) > 2048:
        print(f"skip context len: {len(context.split())}")
        continue
    qa_pairs = []
    for item in context_to_item[context]:
        question = item['question']
        answer = item['answers'][item['correct_answer_id']]
        if answer == "not enough information":
            continue
        qa_pairs.append({'question': question, 'answer': answer})
    if len(qa_pairs) == 0:
        continue
    if sum([len(qa["question"].split()) for qa in qa_pairs]) + sum([len(qa["answer"].split()) for qa in qa_pairs]) > 512:
        print(f"q eln: {[len(qa['question'].split()) for qa in qa_pairs]}, a len: {[len(qa['answer'].split()) for qa in qa_pairs]}")
        continue
    output_list.append({
        "id": f"quail.{i:04d}",
        'context': context,
        'qa_pairs': qa_pairs
    })
print(f"quail: {len(output_list)}, context len: {sum([len(x['context'].split()) for x in output_list]) / len(output_list)}")
print(collections.Counter([len(item["qa_pairs"]) for item in output_list]))
save(output_list, 'v4/quail.jsonl')
# output_list

In [9]:
# # quail ('rajpurkar/squad_v2')
# output_list = []
# data = datasets.load_dataset('rajpurkar/squad_v2')
# data = list(data['train'])
# context_to_item = {}
# for item in data:
#     context = item['context']
#     if context not in context_to_item:
#         context_to_item[context] = []
#     context_to_item[context].append(item)
# for i, context in enumerate(context_to_item):
#     if len(context.split()) > 1024:
#         print(f"skip context len: {len(context.split())}")
#         continue
#     qa_pairs = []
#     for item in context_to_item[context]:
#         if len(item['answers']["text"]) == 0:
#             continue
#         question = item['question']
#         answer = item['answers']["text"][0]
#         qa_pairs.append({'question': question, 'answer': answer})
#     if len(qa_pairs) == 0:
#         continue
#     output_list.append({
#         "id": f"squad_v2.{i:04d}",
#         'context': context,
#         'qa_pairs': qa_pairs
#     })
# print(f"squad_v2: {len(output_list)}, context len: {sum([len(x['context'].split()) for x in output_list]) / len(output_list)}")
# print(collections.Counter([len(item["qa_pairs"]) for item in output_list]))
# data_all_list.extend(output_list)
# # output_list

In [None]:
# ms marco ('microsoft/ms_marco')
output_list = []
data = datasets.load_dataset('microsoft/ms_marco', 'v1.1')
data = list(data['train'])
random.seed(0)

context_to_item = {}
for item in data:
    passages = item['passages']
    if sum(passages['is_selected']) == 0:
        continue
    is_selected_idx = passages['is_selected'].index(1)
    assert isinstance(is_selected_idx, int)
    context = passages['passage_text'][is_selected_idx]
    if context not in context_to_item:
        context_to_item[context] = []
    context_to_item[context].append(item)
item_list = []
for i, context in enumerate(context_to_item):
    qa_pairs = []
    for item in context_to_item[context]:
        if len(item['answers']) == 0:
            continue
        question = item['query']
        answer = item['answers'][0]
        qa_pairs.append({'question': question, 'answer': answer})
    if len(qa_pairs) == 0:
        continue
    item_list.append({
        'context': context,
        'qa_pairs': qa_pairs,
    })
# randomly merge five of each to one item
num_context = 32
num_merge = 4
for i in range(0, len(item_list), num_merge):
    if i + num_merge > len(item_list):
        break
    context_list = [item_list[j]['context'] for j in range(i, i + num_merge)]
    context_list += [item['context'] for item in random.sample(item_list, random.randint(0, num_context))]
    random.shuffle(context_list)
    context = "\n\n".join(context_list)
    if len(context.split()) > 4096:
        print(f"skip context len: {len(context.split())}")
        continue
    qa_pairs_list = [item_list[j]['qa_pairs'] for j in range(i, i + num_merge)]
    random.shuffle(qa_pairs_list)
    qa_pairs = sum(qa_pairs_list, [])
    while sum([len(qa["question"].split()) for qa in qa_pairs]) + sum([len(qa["answer"].split()) for qa in qa_pairs]) > 512:
        qa_pairs = qa_pairs[:-1]
        # print(f"q eln: {[len(qa['question'].split()) for qa in qa_pairs]}, a len: {[len(qa['answer'].split()) for qa in qa_pairs]}")
        # continue
    output_list.append({
        "id": f"msmarco.{len(output_list):04d}",
        'context': context,
        'qa_pairs': qa_pairs
    })

print(f"msmarco: {len(output_list)}, context len: {sum([len(x['context'].split()) for x in output_list]) / len(output_list)}")
print(collections.Counter([len(item["qa_pairs"]) for item in output_list]))
save(output_list, 'v4/msmarco.jsonl')

# data_all_list.extend(output_list)
# output_list

In [None]:
# PwC
output_list = []
data = datasets.load_dataset('sggetao/PwC')
data = list(data['train'])
context_to_item = {}
for item in data:
    context = item['input']
    if context not in context_to_item:
        context_to_item[context] = []
    context_to_item[context].append(item)
for i, context in enumerate(context_to_item):
    if len(context.split()) > 2048:
        print(f"skip context len: {len(context.split())}")
        continue
    qa_pairs = []
    for item in context_to_item[context]:
        if len(item['answer'].split()) > 128 or len(item['prompt'].split()) > 128:
            continue
        question = item['prompt']
        answer = item['answer']
        qa_pairs.append({'question': question, 'answer': answer})
    # randomly sample 5 qa pairs for each context
    # qa_pairs = qa_pairs[:8]
    while sum([len(qa["question"].split()) for qa in qa_pairs]) + sum([len(qa["answer"].split()) for qa in qa_pairs]) > 512:
        qa_pairs = qa_pairs[:-1]
        # print(f"q eln: {[len(qa['question'].split()) for qa in qa_pairs]}, a len: {[len(qa['answer'].split()) for qa in qa_pairs]}")
        continue
    output_list.append({
        "id": f"pwc.{i:04d}",
        'context': context,
        'qa_pairs': qa_pairs
    })
print(f"pwc: {len(output_list)}, context len: {sum([len(x['context'].split()) for x in output_list]) / len(output_list)}")
print(collections.Counter([len(item["qa_pairs"]) for item in output_list]))
save(output_list, 'v4/pwc.jsonl')
# output_list

In [None]:
import json, torch, transformers, random
from tqdm import tqdm
tokenizer = transformers.AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')

output_list = []
random.seed(0)

with open("../../data/metaicl/hr_to_lr.json", "r") as f:
    config = json.load(f)

for task in tqdm(config["train"]):
    dp = torch.load(f'../../data/metaicl/llama/{task}/{task}_16384_100_train.jsonl')
    k_shot = 64
    k_test = 4
    max_num = 200
    n = len(dp['input'])


    input_text = [tokenizer.decode(dp['input'][i]) for i in range(n)]
    output_text = [tokenizer.decode(dp['output'][i]) for i in range(n)]
    avg_input_len = sum([len(x.split()) for x in input_text]) / n
    k_shot = min(k_shot, 2048 // int(avg_input_len))
    k_test = min(k_test, 512 // int(avg_input_len))
    
    if k_shot == 0 or k_test == 0:
        print(f"skip {task} due to avg_input_len: {avg_input_len}")
        continue

    # reduce the number of examples
    n = min(n, max_num * (k_shot + k_test))
    index = random.sample(range(len(input_text)), n)
    input_text = [input_text[i] for i in index]
    output_text = [output_text[i] for i in index]

    # first k are demonstrations, the rest are test case
    for i in range(0, n, k_shot + k_test):
        if i + k_shot + k_test > n:
            break
        context = "\n\n".join([f"Input: {input_text[j]}\nOutput: {output_text[j]}" for j in range(i, i + k_shot)])
        # test_input_text = f"Input: {input_text[i + k]}\nOutput: "
        # test_output_text = f"{output_text[i + k]}"
        output_list.append({
            "id": f"metaicl.{len(output_list):04d}.{task}.{i // (k_shot + k_test)}",
            'context': context,
            'qa_pairs': [
                {'question': f"Input: {input_text[j]}\nOutput:", 'answer': output_text[j]} for j in range(i + k_shot, i + k_shot + k_test)
            ],
        })
    print(task, n, avg_input_len, k_shot, k_test, len(output_list))
print(f"metaicl: {len(output_list)}, context len: {sum([len(x['context'].split()) for x in output_list]) / len(output_list)}")
print(collections.Counter([len(item["qa_pairs"]) for item in output_list]))
save(output_list, 'v4/metaicl.jsonl')

In [None]:
import re

# Long Alpaca
# data_path = "../../data/long-llm/longalpaca/train.json"
# data = datasets.load_dataset('json', data_files=data_path)
data = datasets.load_dataset('Yukang/LongAlpaca-12k')
data = list(data["train"])
output_list = []
for i, item in enumerate(data):
    first_message = item["instruction"]
    second_message = item["output"]
    match_success = False
    try:
        result = re.search(r"^(.*)Now the (.*?) ends\.(.*?)$", first_message, re.DOTALL)
        assert result is not None, f"id: {i}"
        instruction = result.group(3).strip()
        context = first_message[:first_message.index(result.group(3))].strip()
        response = second_message
        match_success = True
    except:
        pass

    if not match_success:
        continue

    # if the lenth of the context exeed 8192, skip
    if len(context.split()) > 6144:
        continue    

    output_list.append({
        "id": f"longalpaca.{i:05d}",
        "context": context,
        "conversations": [
            {"role": "user", "content": instruction},
            {"role": "assistant", "content": response}
        ]
    })
# data_all_list.extend(output_list)
print(f"LONGALPACA: {len(output_list)}, context len: {sum([len(x['context'].split()) for x in output_list]) / len(output_list)}")
print(collections.Counter([len(item["conversations"]) for item in output_list]))
# save output_list to v2/longalpaca.jsonl
with open("v4/longalpaca.jsonl", "w") as f:
    for item in output_list:
        f.write(json.dumps(item) + "\n")

In [None]:
# booksum
output_list = []
data = datasets.load_dataset('kmfoda/booksum')
data = list(data['train'])
# iterate all items in the dataset
for i, item in enumerate(data):
    assert item is not None
    book_id = item["book_id"].split(".")[0]
    summary_id = item["summary_id"]
    context = f"{book_id}, {summary_id}:\n\n{item['chapter']}"
    question = f"Summarize {book_id}, {summary_id}"
    answer = item['summary_text']
    if len(context.split()) > 4096:
        # print(f"skip context len: {len(context.split())}")
        continue
    if len(question.split()) + len(answer.split()) > 512:
        # print(f"q len: {len(question.split())}, a len: {len(answer.split())}")
        continue
    output_list.append({
        "id": f"booksum.{i:04d}",
        'context': context,
        'qa_pairs': [{'question': question, 'answer': answer}]
    })
print(f"booksum: {len(output_list)}, context len: {sum([len(x['context'].split()) for x in output_list]) / len(output_list)}")
print(collections.Counter([len(item["qa_pairs"]) for item in output_list]))
save(output_list, 'v4/booksum.jsonl')
# data_all_list.extend(output_list)
# # output_list

In [None]:
# concatenate all "v4/*.jsonl" to "sft-v4.jsonl"
import os
import json
import glob

config = {
    "sft-v4-qa.jsonl": ["coqa", "drop", "narrativeqa", "pubmedqa", "quail", "msmarco"],
    "sft-v4-ift.jsonl": ["coqa", "drop", "narrativeqa", "pubmedqa", "quail", "msmarco", "booksum", "longalpaca"],
    "sft-v4.jsonl": ["coqa", "drop", "narrativeqa", "pubmedqa", "quail", "msmarco", "booksum", "longalpaca", "metaicl"],
}

for output_path, dataset_names in config.items():
    output_list = []
    for dataset_name in dataset_names:
        with open(f"v4/{dataset_name}.jsonl", "r") as f:
            data = [json.loads(line) for line in f]
        output_list.extend(data)
    with open(output_path, "w") as f:
        for item in output_list:
            f.write(json.dumps(item) + "\n")
    print(f"{output_path}: {len(output_list)}")


In [None]:
import pandas as pd
import json
with open("sft-v4.jsonl", "r") as f:
    output_list = [json.loads(line) for line in f]

# measure "#docs	#instructions	doc len	instruction len	response len"
stat_list = []
for item in output_list:
    dataset = item['id'].split(".")[0]
    n_docs = 1
    n_instructions = sum([x['role'] == 'user' for x in item['conversations']])
    n_responses = sum([x['role'] == 'assistant' for x in item['conversations']])
    doc_len = len(item['context'].split())
    instruction_len_mean = sum([len(x['content'].split()) for x in item['conversations'] if x['role'] == 'user']) / n_instructions
    response_len_mean = sum([len(x['content'].split()) for x in item['conversations'] if x['role'] == 'assistant']) / n_responses
    instruction_len_max = max([len(x['content'].split()) for x in item['conversations'] if x['role'] == 'user'])
    response_len_max = max([len(x['content'].split()) for x in item['conversations'] if x['role'] == 'assistant'])
    instruction_len_min = min([len(x['content'].split()) for x in item['conversations'] if x['role'] == 'user'])
    response_len_min = min([len(x['content'].split()) for x in item['conversations'] if x['role'] == 'assistant'])
    instruction_len_sum = sum([len(x['content'].split()) for x in item['conversations'] if x['role'] == 'user'])
    response_len_sum = sum([len(x['content'].split()) for x in item['conversations'] if x['role'] == 'assistant'])
    stat_list.append({
        "dataset": dataset,
        "n_docs": n_docs,
        "n_instructions": n_instructions,
        "n_responses": n_responses,
        "doc_len": doc_len,
        "instruction_len_mean": instruction_len_mean,
        "response_len_mean": response_len_mean,
        "instruction_len_max": instruction_len_max,
        "response_len_max": response_len_max,
        "instruction_len_min": instruction_len_min,
        "response_len_min": response_len_min,
        "instruction_len_sum": instruction_len_sum,
        "response_len_sum": response_len_sum,
    })

df = pd.DataFrame(stat_list)
df.groupby("dataset").mean()

In [None]:
# df.groupby("dataset").sum()

In [None]:
# df.groupby("dataset").max()

In [None]:
# df.groupby("dataset").min()

In [None]:
# # plot distribution of df["doc_len"] 
# import matplotlib.pyplot as plt
# import seaborn as sns
# sns.histplot(df["doc_len"], bins=100)
# plt.show()