In [44]:
import json
from joblib import Parallel, delayed
from tqdm_joblib import tqdm_joblib
import time
from tqdm.auto import tqdm

In [45]:
from retrying import retry
import openai

@retry(stop_max_attempt_number=5, wait_exponential_multiplier=1000, wait_exponential_max=10000)
def generate(prompt, history=[]):

    client = openai.OpenAI(api_key="sk-xxx", base_url="xxx")
    completion = client.chat.completions.create(
        model="xxx",
        messages = [{"role": "user", "content": prompt}],
        temperature = 0.
    )
    return completion.choices[0].message.content

In [46]:
with open('../../data/antique/ANTIQUE_S5/sample1.json',"r",encoding='utf-8-sig') as f:
    L = json.load(f)

In [47]:
input_data = []
for i in L:
    tmp = {}
    tmp['id'] = i['id']
    tmp['question'] = i['question']
    tmp['answer'] = i['answer']
    tmp['response'] = i['response']
    tmp['rank'] = i['rank']
    input_data.append(tmp)

In [48]:
prompt = '''Please consider the factual accuracy, logic, conciseness, and clarity of the answer based on the input open-ended question and reference answer. Additionally, combine your thoughts to generate five levels of answers: Excellent, Good, Fair, Poor, and Bad.

Factual Accuracy: Analyze whether the information provided in the answer is correct and based on reliable facts and data.
Logic: Analyze whether the answer is logically clear, with reasonable reasoning and consistent coherence.
Conciseness: Analyze whether the answer is brief and to the point, avoiding unnecessary details and verbosity.
Clarity: Analyze whether the answer is expressed clearly and understandably, and whether the language is simple and direct.

Please output the results in the following JSON format:
{
    "Excellent": "xxxx",
    "Good": "xxxx",
    "Fair": "xxxx",
    "Poor": "xxxx",
    "Bad": "xxxx"
}

Question: {question}
Reference Answer: {ref}

Output Generated Answer:
'''

In [49]:
import re

def format_check(input):
    if isinstance(input,dict) and set(list(input)) == set(['Excellent','Good','Fair','Poor','Bad']):
        return True
    else:
        return False

# 定义一个函数，接受一个 markdown 字符串作为参数
def find_dicts(markdown):
    # 定义一个空列表，用于存储找到的 dict
    dicts = []
    
    if markdown == None:
        return dicts

    # 定义一个正则表达式，匹配 dict 的格式
    pattern = r"\{[^{}]*\}"
    # 使用 re.findall 方法，找出 markdown 字符串中所有匹配的子串
    matches = re.findall(pattern, markdown)
    # 遍历每个匹配的子串
    for match in matches:
        # 尝试将子串转换为 dict 类型，如果成功则添加到列表中
        try:
            d = eval(match)
            if isinstance(d, dict):
                dicts.append(d)
        except:
            # 如果转换失败，忽略该子串
            pass
    # 返回找到的 dict 列表
    return dicts

In [50]:
import traceback
import os
import json

import traceback
def deal(item,file_path):
    max_try_retry = 10
    try_num = 0
    content = prompt.replace('{question}',item['question']).replace('{ref}',item['answer'][0])
    for _ in range(max_try_retry):
        try:
            response = generate(content)
            try:
                tmps = json.loads(response)
                # print('yresss')
            except:
                tmps = find_dicts(response)[0]
            
            if format_check(tmps):
                item['example'] = tmps
                with open(file_path, "a+", encoding="utf8") as f:
                    f.write(json.dumps(item, ensure_ascii=False) + "\n")
                    f.flush()
                break

        except Exception as e:
            str_e = str(e)
            if try_num== max_try_retry:
                break

            if 'InvalidRequestError' in str_e:
                if 'maximum context' in str_e:
                    break
                try_num += 1
                continue
            else:
                traceback.print_exc()


In [None]:
file_path = "../../data/antique/non_fact/examples.jsonl"

if not os.path.exists(file_path):
    with open(file_path, 'w') as f:
        pass

with open(file_path,'r',encoding='utf-8') as f:
    L = f.readlines()
finish_id = [json.loads(i)['id'] for i in L]

input_list = [i for i in input_data if i['id'] not in finish_id]

from tqdm.contrib import tzip
from joblib import Parallel, delayed
from tqdm_joblib import tqdm_joblib

num_worker = 32

with tqdm_joblib(desc="My calculation", total=len(input_list)) as progress_bar:
    Parallel(n_jobs=num_worker,prefer="threads")([delayed(deal)(x,file_path=file_path) for x in input_list])