In [74]:
import json
from joblib import Parallel, delayed
from tqdm_joblib import tqdm_joblib
import time
from tqdm.auto import tqdm
import json

In [75]:
from retrying import retry
import openai

@retry(stop_max_attempt_number=5, wait_exponential_multiplier=1000, wait_exponential_max=10000)
def generate(prompt, history=[]):

    client = openai.OpenAI(api_key="sk-xxx", base_url="xxx")
    completion = client.chat.completions.create(
        model="xxx",
        messages = [{"role": "user", "content": prompt}],
        temperature = 0.
    )
    return completion.choices[0].message.content

In [76]:
prompt = '''<<task>
Based on the question, extract the main points from the reference answer. Each bullet point should directly reflect the specifics of the reference answer. </task>

<formatting requirements>
1. Please take the main points from your analysis and add them to a python list modeled after the format below, using as much concise language as possible. 
2. The output python list needs to be able to be loaded by json.loads(), the example is as follows:
{
"key_points": ["point1", "point2", "point3", ... , "point n"]
}
</formatting requirements>

<note>
1. analyze the shortcomings obtained and make sure that they are focused but do not need to be described or explained in detail, keep it concise and clear. 
2.Any output that is not formatted will cause the system to crash!
</note>

<example>
## Input
Question: what does the word remission mean when referring to cancer patients?
Reference answer: Well say the doctor said to me that my mom was on her 6th year of remission it means that she has had no cancer cells in her body. And it also means thats how long the cancer has been gone for! Does that help? Good I hope it did!
Output:
{
    "key points": ["Remission refers to the absence of cancer cells in the body.","It indicates how long the cancer has been gone.","Remission can be measured in years or other timeframes."]
}
</example>

## Then analyze the following answer and output a python list that matches the formatting
Question: {question}
Reference answer: {answer}
Output:
'''

In [77]:
import re

def format_check(inputs, num=None):
    if isinstance(inputs,dict) and set(['key_points']):
        return True
    else:
        return False

# 定义一个函数，接受一个 markdown 字符串作为参数
def find_dicts(markdown):
    # 定义一个空列表，用于存储找到的 dict
    dicts = []
    
    if markdown == None:
        return dicts

    # 定义一个正则表达式，匹配 dict 的格式
    pattern = r"\{[^{}]*\}"
    # 使用 re.findall 方法，找出 markdown 字符串中所有匹配的子串
    matches = re.findall(pattern, markdown)
    # 遍历每个匹配的子串
    for match in matches:
        # 尝试将子串转换为 dict 类型，如果成功则添加到列表中
        try:
            d = eval(match)
            if isinstance(d, dict):
                dicts.append(d)
        except:
            # 如果转换失败，忽略该子串
            pass
    # 返回找到的 dict 列表
    return dicts

In [78]:
import traceback
import os
import json

import traceback
def deal(item,file_path):
    max_try_retry = 5
    try_num = 0
    content = prompt.replace('{question}',item['question']).replace('{answer}',item['answer'])
    for _ in range(max_try_retry):
        try:
            response = generate(content)
            tmps = find_dicts(response)[0]
            
            if format_check(tmps):
                item['points'] = tmps['key_points']
                with open(file_path, "a+", encoding="utf8") as f:
                    f.write(json.dumps(item, ensure_ascii=False) + "\n")
                    f.flush()
                break

        except Exception as e:
            str_e = str(e)
            if try_num== max_try_retry:
                break

            if 'InvalidRequestError' in str_e:
                if 'maximum context' in str_e:
                    break
                try_num += 1
                continue
            else:
                traceback.print_exc()

In [None]:
for r in range(1, 6):
    with open('../../data/antique/ANTIQUE_S5/sample' + str(r) + '.json', "r", encoding='utf-8-sig') as f:
        L = json.load(f)
        input_data = []
        for i in L:
            tmp = {}
            tmp['id'] = i['id']
            tmp['question'] = i['question']
            tmp['label'] = i['label']
            tmp['response'] = i['response']
            tmp['rank'] = i['rank']
            
            # 当 label 为 1 时，将 answer 的不同部分分开处理
            if tmp['label'] == 1:
                tmp1 = tmp.copy()  # 复制当前字典
                tmp1['answer'] = i['answer'][0]  # 赋值第一个 answer
                input_data.append(tmp1)
                
                tmp2 = tmp.copy()  # 复制当前字典
                tmp2['answer'] = i['answer'][1]  # 赋值第二个 answer
                input_data.append(tmp2)
    # 文件路径处理
    file_path = "../../data/antique/fact/sample" + str(r) + "_points.jsonl"
    if not os.path.exists(file_path):
        with open(file_path, 'w') as f:
            pass

    with open(file_path, 'r', encoding='utf-8') as f:
        L = f.readlines()

    if len(L) > 0:
        finish_id = [json.loads(i)['id'] for i in L]
    else:
        finish_id = []

    input_list = [i for i in input_data if i['id'] not in finish_id]

    num_worker = 32

    # 使用并行计算
    with tqdm_joblib(desc="My calculation", total=len(input_list)) as progress_bar:
        Parallel(n_jobs=num_worker, prefer="threads")([delayed(deal)(x, file_path=file_path) for x in input_list])

    with open(file_path, "r", encoding="utf-8") as f:
        L = f.readlines()

    new_LL = [json.loads(i) for i in L]

    # 合并数据
    merged_data = {}

    for entry in new_LL:
        id = entry["id"]
        points = entry["points"]
        answer = entry["answer"]

        if id in merged_data:
            merged_data[id]["points"].append(points)
            merged_data[id]["answer"].append(answer)
        else:
            merged_data[id] = {key: entry[key] for key in entry if key != "points" and key != "answer"}
            merged_data[id]["points"] = [points]
            merged_data[id]["answer"] = [answer]

    with open("../../data/antique/fact/sample" + str(r) + "_points.json", "w", encoding="utf-8") as f:
        json.dump(list(merged_data.values()), f, ensure_ascii=False, indent=4)

