In [None]:
import os
import json
import subprocess


def unzip_data(data_zip_path, data_path):
    if not os.path.exists(data_path):
        unzip_cmd = f"unzip {data_zip_path}"
        subprocess.run(unzip_cmd, shell=True)


def load_json_file(json_path):
    with open(json_path) as json_file:
        data_set = json.load(json_file)
    return data_set


def generate_sys_prompt(datapoint):
    sys_prompt = f"You are a helpful assistant. You are going to determine the inference relation (entailment or contradiction) between pairs of Clinical Trial Reports (CTRs) and the statements, making claims about one of the summarized sections of the CTRs: {datapoint['Section_id']}."
    return sys_prompt


def generate_input_prompt(datapoint, primary_section, secondary_section=None):
    input_prompt = [f"This task type is \"{datapoint['Type']}\"."]
    if datapoint["Type"] == "Comparison":
        input_prompt += ["There are multiple CTRs."]

    statement = datapoint["Statement"]
    input_prompt += [f"The statement is \"{statement}\"."]
    input_prompt += [f"The primary CTR section includes,\n"]
    input_prompt += ['\n'.join(primary_section)]

    if secondary_section:
        input_prompt += [f"The secondary CTR section includes,\n"]
        input_prompt += ['\n'.join(secondary_section)]

    return ' '.join(input_prompt)


def generate_output_prompt(datapoint):
    output_prompt = ["Based on the provided evidence think step by step . \n"]
    output_prompt += [f"I think the relationship is \"{datapoint['Label']}\"."]
    return ' '.join(output_prompt)


def generate_cot_dataset(json_path="../training_data/train.json", output_path="llama-2-train.json"):
    data_zip_path = "../training_data.zip"
    data_path = "../training_data"
    unzip_data(data_zip_path, data_path)

    CT_json_dir = os.path.join(data_path, "CT json")
    subdata_set = load_json_file(json_path)
    subdata_uuid_list = list(subdata_set.keys())
    subdata_statements = [subdata_set[subdata_uuid_list[i]]["Statement"] for i in range(len(subdata_uuid_list))]
    subdata_json_list = []

    for i in range(len(subdata_uuid_list)):
        datapoint = subdata_set[subdata_uuid_list[i]]
        primary_ctr_path = os.path.join(CT_json_dir, datapoint["Primary_id"]+".json")
        primary_ctr = load_json_file(primary_ctr_path)
        primary_section = primary_ctr[datapoint["Section_id"]]

        sys_prompt = generate_sys_prompt(datapoint)
        input_prompt = generate_input_prompt(datapoint, primary_section)

        if datapoint["Type"] == "Comparison":
            secondary_ctr_path = os.path.join(CT_json_dir, datapoint["Secondary_id"]+".json")
            secondary_ctr = load_json_file(secondary_ctr_path)
            secondary_section = secondary_ctr[datapoint["Section_id"]]
            input_prompt = generate_input_prompt(datapoint, primary_section, secondary_section)

        output_prompt = generate_output_prompt(datapoint)

        subdata_json_list.append({
            "text": f"<s>[INST] <<SYS>>\n{sys_prompt}<</SYS>>. {input_prompt}[/INST] {output_prompt}</s>"
        })

    with open(output_path, 'w') as jsonFile:
        jsonFile.write(json.dumps(subdata_json_list, indent=4))


if __name__ == '__main__':
    generate_cot_dataset(json_path="../training_data/train.json", output_path="llama-2-train.json")
    generate_cot_dataset(json_path="../training_data/dev.json", output_path="llama-2-dev.json")
