In [12]:
import json, re, os, sys
from os import path as osp

In [9]:
loc = "../outputs/comparison/20230824_170033/predictions"

models = ['baichuan-13b-chat', 'GPT-3.5-turbo']
for model in models:
    zh = []
    en = []
    file_list = os.listdir(os.path.join(loc, model))
    for file in file_list:
        if '.json' not in file:
            continue
        with open(os.path.join(loc, model, file), 'r') as f:
            data = json.load(f)
        data = data.values()
        if 'sc-zh' in file:
            zh += data
        else:
            en += data
    zh = {idx: value for idx, value in enumerate(zh)}
    en = {idx: value for idx, value in enumerate(en)}
    with open(os.path.join(loc, model, "zh.json"), 'w') as f:
        json.dump(zh, f, indent=4)
    with open(os.path.join(loc, model, "en.json"), 'w') as f:
        json.dump(en, f, indent=4)


In [11]:
from tools.extract_failcases import extract_fail_cases
extract_fail_cases("/mnt/mfs/opsgpt/opencompass/outputs/comparison/20230824_170033")

baichuan-13b-chat
GPT-3.5-turbo
xverse_13b
chinese-alpaca-2-13b
chinese-llama-2-13b
llama-2-13b-chat
internlm_chat_7b
baichuan-7b
qwen_chat_7b
chatglm2-6b


In [25]:
from tools.extract_failcases import extract_answer
def analyze(folder):
    results = {}
    for model in os.listdir(osp.join(folder, 'predictions')):
        model_answer = {}

        for filename in os.listdir(osp.join(folder, 'predictions', model)):
            if not osp.isfile(osp.join(folder, 'predictions', model,
                                       filename)):
                continue
            with open(os.path.join(folder, 'predictions', model, filename),
                      'r',
                      encoding='utf-8') as f:
                data = json.load(f)
            for idx, case in data.items():
                predict = extract_answer(case['prediction'])
                gold = case['reference']['answer']
                qid = case['reference']['id']
                if qid not in model_answer:
                    model_answer[qid] = {
                        'gold': gold, 'en' if 'sc-en' in filename else 'zh': predict}
                else:
                    model_answer[qid]['en' if 'sc-en' in filename else 'zh'] = predict
        results[model] = model_answer
    return results


results = analyze(
    "/mnt/mfs/opsgpt/opencompass/outputs/comparison/20230824_170033")


In [34]:
def analyze_portion(results):
    print('model,same,diff,same_right,diff_en_right,diff_zh_right')
    for model, result in results.items():
        tot = len(result)
        # 答案一致和不一致的题目

        qsame = [q for q in result.values() if q['zh'] == q['en']]
        qdiff = [q for q in result.values() if q['zh'] != q['en']]

        # 计算中英文一致的比例
        same = len(qsame)
        diff = tot - same
        # print(same * 100 / tot, diff * 100 / tot)
        # 计算答案一致时的准确率
        ac_same = sum([1 if q['zh'] == q['gold'] else 0 for q in qsame])
        # 答案不一致时各自的准确率
        ac_en = sum([1 if q['en'] == q['gold'] else 0 for q in qdiff])
        ac_zh = sum([1 if q['zh'] == q['gold'] else 0 for q in qdiff])

        print(
            f'{model},{int(same*100/tot)}%,{int(diff*100/tot)}%,{int(ac_same*100/same)}%,{int(ac_en*100/diff)}%,{int(ac_zh*100/diff)}%')
    return qsame, qdiff


qsame, qdiff = analyze_portion(results)


model,same,diff,same_right,diff_en_right,diff_zh_right
baichuan-13b-chat,63%,36%,66%,35%,28%
GPT-3.5-turbo,76%,23%,81%,61%,20%
xverse_13b,29%,70%,82%,50%,3%
chinese-alpaca-2-13b,65%,34%,64%,46%,24%
chinese-llama-2-13b,58%,41%,60%,43%,25%
llama-2-13b-chat,62%,37%,63%,55%,20%
internlm_chat_7b,67%,32%,65%,42%,29%
baichuan-7b,31%,68%,9%,36%,1%
qwen_chat_7b,68%,31%,67%,48%,25%
chatglm2-6b,62%,37%,54%,36%,30%


In [39]:
# 得到中文对英文错的列表和中文错英文对的列表

common_zhr = set()
common_enr = set()

for model, result in results.items():
    zhr = set([key for key, value in result.items() if value['zh'] == value['gold'] and value['en'] != value['gold']])
    enr = set([key for key, value in result.items() if value['zh'] != value['gold'] and value['en'] == value['gold']])
    if not len(common_zhr):
        common_zhr = zhr
        common_enr = enr
    else:
        common_enr = common_enr & enr
        common_zhr = common_zhr & zhr
    

In [40]:
common_enr

{'9780137370337-A0025',
 '9780137370337-A0031',
 '9780137370337-A0082',
 '9780137370337-A0111',
 '9781119513124-A0158',
 '9781119642220-A0084',
 '9781119787631-A0152',
 '9781119810865-A0125',
 '9781119862918-A0105',
 '9781394182930-A0517',
 '9781394182930-A0836',
 '9781839211898-A0016'}

In [41]:
common_zhr

{'9781394182909-A0187'}

In [71]:
def find_ids(id_list):
    loc = "/mnt/mfs/opsgpt/opencompass/outputs/comparison/20230824_170033/predictions"
    result = {i:dict() for i in id_list}
    for model in os.listdir("/mnt/mfs/opsgpt/opencompass/outputs/comparison/20230824_170033/predictions"):
        for file in os.listdir(osp.join(loc, model)):
            if '.json' not in file:
                continue
            with open(os.path.join(loc, model, file)) as f:
                data = json.load(f)
            for a in data.values():
                if a['reference']['id'] in id_list:
                    result[a['reference']['id']]['zh-prompt' if 'sc-zh' in file else 'en-prompt'] = a['origin_prompt']
                    result[a['reference']['id']][
                                                 ('zh-' if 'sc-zh' in file else 'en-')+model] = a['prediction']
                    result[a['reference']['id']]['answer'] = a['reference']['answer']
    return result

test = find_ids(list(common_zhr))



In [72]:
with open("zh_right.json", 'w') as f:
    json.dump(test, f, sort_keys=True, indent=4)

In [74]:
test = find_ids(list(common_enr))
with open("en_right.json", 'w') as f:
    json.dump(test, f, sort_keys=True, indent=4)


In [None]:
with open("en_right.json", 'r') as f:
    data = json.load(f)
    for key, value in data.items():
        data[key] = dict(answer=value['answer'], en_prompt=value['en-prompt'], zh_prompt=value['zh-prompt'])
    print(json.dumps(data, indent=4))
    with open("en_right_simp.json", 'w') as f:
        json.dump(data, f, indent=4)
