In [None]:
import json
import glob

model2maxlen = {
    'gpt-4-turbo-2024-04-09': 128000,
    'gpt-4o': 128000,
    'claude-3-haiku-20240307': 200000,
    'claude-3-sonnet-20240229': 200000,
    'moonshot-v1-128k': 128000,
    'chatglm3-6b-128k': 128000,
    'internlm2-chat-7b': 200000,
    'internlm2-chat-20b': 200000,
    'Yarn-Mistral-7b-128k': 128000,
    'Yi-6B-200K': 200000,
}
SAMPLE_SIZE = [4000, 8000, 16000, 32000, 64000, 128000, 200000]
TASK_NAME_LIST = ['zh_norm', 'en_norm', 'zh_kg', 'en_kg', 'zh_table', 'zh_medcase']
BASE_DIR = '../../evaluation_result/query_result'
result_files = glob.glob(BASE_DIR + '/*')

result_file_list = [result_file.split('/')[-1] for result_file in result_files]

task_model_size_score = {task: {model: {int(size): -1 for size in SAMPLE_SIZE} for model in model2maxlen} for task in TASK_NAME_LIST}
task_model_size_ssm_socre = {task: {model: {int(size): [] for size in SAMPLE_SIZE} for model in model2maxlen} for task in TASK_NAME_LIST}

task_model_size_result = {task: {model: {int(size): [] for size in SAMPLE_SIZE} for model in model2maxlen} for task in TASK_NAME_LIST}

In [None]:
for task in TASK_NAME_LIST:
    for result_file in result_files:
        if task in result_file.split('/')[-1]:
            with open(result_file, 'r', encoding='utf-8') as f:
                for line in f.readlines():
                    result = json.loads(line)
                    task_model_size_result[task][result['model']][int(result['sample_size'])].append({
                        'id': result['id'],
                        'type': result['type'],
                        'true_answer': result['true_answer'],
                        'pred_answer': result['pred_answer'],
                        'pred_origin_answer': result['pred_origin'],
                    })

In [None]:
# Norm Acc
for task in ['zh_norm', 'en_norm']:
    for model in task_model_size_result[task]:
        for size in task_model_size_result[task][model]:
            correct = 0
            total = 0
            ssm_correct = 0
            for result in task_model_size_result[task][model][size]:
                total += 1
                true_answer = result['true_answer']
                pred_answer = result['pred_answer']
                pred_origin_answer = result['pred_origin_answer']
                if pred_answer == []:
                    pred_answer = ''

                if true_answer.lower() == pred_answer.lower():
                    correct += 1
                if true_answer.lower() in pred_answer.lower():
                    ssm_correct += 1
            task_model_size_score[task][model][size] = (float(correct) / float(total)) if total > 0 else -1
            task_model_size_ssm_socre[task][model][size] = (float(ssm_correct) / float(total)) if total > 0 else -1
        
# KG Precision/Recall/F1
for task in ['zh_kg', 'en_kg']:
    for model in task_model_size_result[task]:
        for size in task_model_size_result[task][model]:
            TP = 0
            FP = 0
            FN = 0
            for result in task_model_size_result[task][model][size]:
                true_answer = [t.lower() for t in result['true_answer']]
                pred_answer = [p.lower() for p in result['pred_answer']]
                true_answer = set(true_answer)
                pred_answer = set(pred_answer)
                for t in true_answer:
                    if t in pred_answer:
                        TP += 1
                    else:
                        FN += 1
                for p in pred_answer:
                    if p not in true_answer:
                        FP += 1
            precision = TP / (TP + FP) if TP + FP > 0 else -1
            recall = TP / (TP + FN) if TP + FN > 0 else -1
            f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else -1
            task_model_size_score[task][model][size] = (precision, recall, f1)

            total = 0
            ssm_correct = 0
            for result in task_model_size_result[task][model][size]:
                total += 1
                true_answer = [t.lower() for t in result['true_answer']]
                true_answer = set(true_answer)
                pred_origin_answer = result['pred_origin_answer'].lower()
                right_answer = True
                for t in true_answer:
                    if t not in pred_origin_answer:
                        right_answer = False
                        break
                if right_answer:
                    ssm_correct += 1

            task_model_size_ssm_socre[task][model][size] = (float(ssm_correct) / float(total)) if total > 0 else -1

# Table size version
SPLITS = [',','，',';','；','、','+',' ']
for task in ['zh_table']:
    for model in task_model_size_result[task]:
        for size in task_model_size_result[task][model]:
            TP = 0
            FP = 0
            FN = 0
            for result in task_model_size_result[task][model][size]:
                true_answer = result['true_answer']
                pred_answer = list(set(result['pred_answer']))
                if len(pred_answer) == 1 and len(true_answer) > 1:
                    for split in SPLITS:
                        if split in true_answer[0]:
                            true_answer = [split.join(true_answer)]
                            break
                true_answer = [t.replace(' ', '') for t in true_answer]
                pred_answer = [p.replace(' ', '') for p in pred_answer]
                true_answer = set(true_answer)
                pred_answer = set(pred_answer)
                
                for t in true_answer:
                    if t in pred_answer:
                        TP += 1
                    else:
                        FN += 1
                for p in pred_answer:
                    if p not in true_answer:
                        FP += 1
            precision = TP / (TP + FP) if TP + FP > 0 else -1
            recall = TP / (TP + FN) if TP + FN > 0 else -1
            f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else -1
            task_model_size_score[task][model][size] = (precision, recall, f1)

            total = 0
            ssm_correct = 0
            for result in task_model_size_result[task][model][size]:
                total += 1
                true_answer = result['true_answer']
                pred_origin_answer = result['pred_origin_answer']
                right_answer = True
                for t in true_answer:
                    if t not in pred_origin_answer:
                        right_answer = False
                        break
                if right_answer:
                    ssm_correct += 1
            task_model_size_ssm_socre[task][model][size] = (float(ssm_correct) / float(total)) if total > 0 else -1

# Medcase Acc size version
with open('../../dataset/raw_data/zh_medcase/answer_correct.json', 'r', encoding='utf-8') as f:
    supplement_answers = json.loads(f.read())
supplement_answers = {int(k): v for k, v in supplement_answers.items()}

SPLITS = [',','，',';','；','、','+',' ']
for task in ['zh_medcase']:
    for model in task_model_size_result[task]:
        for size in task_model_size_result[task][model]:
            correct = 0
            total = 0
            ssm_correct = 0
            for result in task_model_size_result[task][model][size]:
                total += 1
                if result['id'] in supplement_answers:
                    true_answers = supplement_answers[result['id']]
                else:
                    true_answers = [result['true_answer']]

                for true_answer in true_answers:
                    true_answer = true_answer[0]
                    pred_answer = result['pred_answer']
                    if pred_answer == []:
                        pred_answer = ''
                    elif len(pred_answer) == 1:
                        pred_answer = pred_answer[0]
                    elif len(pred_answer) > 1:
                        for split in SPLITS:
                            if split in true_answer:
                                pred_answer = split.join(pred_answer)
                                break
                        if type(pred_answer) == list:
                            pred_answer = pred_answer[0]
                    true_answer = true_answer.replace(' ', '')
                    pred_answer = pred_answer.replace(' ', '')
                    if true_answer == pred_answer:
                        correct += 1
                        break
                right_ssm_answer = False
                for true_answer in true_answers:
                    true_answer = true_answer[0]
                    pred_origin_answer = result['pred_origin_answer']
                    if true_answer in pred_origin_answer:
                        right_ssm_answer = True
                        break
                if right_ssm_answer:
                    ssm_correct += 1
            task_model_size_score[task][model][size] = (float(correct) / float(total)) if total > 0 else -1
            task_model_size_ssm_socre[task][model][size] = (float(ssm_correct) / float(total)) if total > 0 else -1



In [None]:

print(json.dumps(task_model_size_score, ensure_ascii=False, indent=4))
with open('task_result_score_samplesize.json', 'w', encoding='utf-8') as f:
    f.write(json.dumps(task_model_size_score, ensure_ascii=False, indent=4))

print('\n')
print(json.dumps(task_model_size_ssm_socre, ensure_ascii=False, indent=4))
with open('task_result_ssm_score_samplesize.json', 'w', encoding='utf-8') as f:
    f.write(json.dumps(task_model_size_ssm_socre, ensure_ascii=False, indent=4))