In [1]:
import json
import glob

In [2]:
model2maxlen = {
    'gpt-4-turbo-2024-04-09': 128000,
    'gpt-4o': 128000,
    # 'gpt-4o-2024-05-13': 128000,
    'claude-3-haiku-20240307': 200000,
    'claude-3-sonnet-20240229': 200000,
    # 'claude-3-opus-20240229': 200000,
    'chatglm3-6b-128k': 128000,
    'Yarn-Mistral-7b-128k': 128000,
    'internlm2-chat-7b': 200000,
    'internlm2-chat-20b': 200000,
    'moonshot-v1-128k': 128000,
}
TASK_NAME_LIST = ['zh_norm', 'en_norm', 'zh_kg', 'en_kg', 'zh_table', 'zh_medcase', 'zh_counting', 'en_counting']
BASE_DIR = '/ailab/user/sunhongli/workspace/MedLongContextEval/query_result'
result_files = glob.glob(BASE_DIR + '/*')
result_files

['/ailab/user/sunhongli/workspace/MedLongContextEval/query_result/zh_counting_internlm2-chat-20b_v1.jsonl',
 '/ailab/user/sunhongli/workspace/MedLongContextEval/query_result/zh_kg_gpt-4o_v1.jsonl',
 '/ailab/user/sunhongli/workspace/MedLongContextEval/query_result/en_norm_chatglm3-6b-128k_v1.jsonl',
 '/ailab/user/sunhongli/workspace/MedLongContextEval/query_result/en_kg_gpt-4-turbo-2024-04-09_v1.jsonl',
 '/ailab/user/sunhongli/workspace/MedLongContextEval/query_result/zh_norm_internlm2-chat-7b_v1.jsonl',
 '/ailab/user/sunhongli/workspace/MedLongContextEval/query_result/zh_norm_gpt-4-turbo-2024-04-09_v1.jsonl',
 '/ailab/user/sunhongli/workspace/MedLongContextEval/query_result/zh_norm_claude-3-sonnet-20240229_v1.jsonl',
 '/ailab/user/sunhongli/workspace/MedLongContextEval/query_result/en_counting_moonshot-v1-128k_v1.jsonl',
 '/ailab/user/sunhongli/workspace/MedLongContextEval/query_result/zh_norm_gpt-4o_v1.jsonl',
 '/ailab/user/sunhongli/workspace/MedLongContextEval/query_result/en_kg_int

In [3]:
result_file_list = [result_file.split('/')[-1] for result_file in result_files]
result_file_list

['zh_counting_internlm2-chat-20b_v1.jsonl',
 'zh_kg_gpt-4o_v1.jsonl',
 'en_norm_chatglm3-6b-128k_v1.jsonl',
 'en_kg_gpt-4-turbo-2024-04-09_v1.jsonl',
 'zh_norm_internlm2-chat-7b_v1.jsonl',
 'zh_norm_gpt-4-turbo-2024-04-09_v1.jsonl',
 'zh_norm_claude-3-sonnet-20240229_v1.jsonl',
 'en_counting_moonshot-v1-128k_v1.jsonl',
 'zh_norm_gpt-4o_v1.jsonl',
 'en_kg_internlm2-chat-7b_v1.jsonl',
 'en_norm_internlm2-chat-7b_v1.jsonl',
 'zh_counting_gpt-4o_v1.jsonl',
 'zh_counting_claude-3-haiku-20240307_v1.jsonl',
 'en_norm_moonshot-v1-128k_v1.jsonl',
 'en_counting_internlm2-chat-7b_v1.jsonl',
 'zh_kg_internlm2-chat-20b_v1.jsonl',
 'zh_counting_internlm2-chat-7b_v1.jsonl',
 'zh_counting_claude-3-sonnet-20240229_v1.jsonl',
 'en_kg_moonshot-v1-128k_v1.jsonl',
 'en_norm_claude-3-haiku-20240307_v1.jsonl',
 'en_kg_chatglm3-6b-128k_v1.jsonl',
 'zh_kg_gpt-4-turbo-2024-04-09_v1.jsonl',
 'zh_norm_moonshot-v1-128k_v1.jsonl',
 'zh_medcase_gpt-4o_v1.jsonl',
 'zh_norm_chatglm3-6b-128k_v1.jsonl',
 'en_kg_internlm

In [4]:
task_model_score = {task: {model: -1 for model in model2maxlen} for task in TASK_NAME_LIST}
task_model_result = {task: {model: [] for model in model2maxlen} for task in TASK_NAME_LIST}
for task in TASK_NAME_LIST:
    for result_file in result_files:
        if task in result_file.split('/')[-1]:
            with open(result_file, 'r', encoding='utf-8') as f:
                for line in f.readlines():
                    result = json.loads(line)
                    task_model_result[task][result['model']].append({
                        'id': result['id'],
                        'type': result['type'],
                        'true_answer': result['true_answer'],
                        'pred_answer': result['pred_answer'],
                        'sample_size': result['sample_size'],
                    })

# Norm Acc
for task in ['zh_norm', 'en_norm']:
    for model in task_model_result[task]:
        correct = 0
        total = 0
        for result in task_model_result[task][model]:
            total += 1
            true_answer = result['true_answer']
            pred_answer = result['pred_answer']
            if pred_answer == []:
                pred_answer = ''

            if true_answer.lower() == pred_answer.lower():
                correct += 1
        task_model_score[task][model] = (float(correct) / float(total)) * 100 if total > 0 else -1

# KG Precision/Recall/F1
for task in ['zh_kg', 'en_kg']:
    for model in task_model_result[task]:
        TP = 0
        FP = 0
        FN = 0
        for result in task_model_result[task][model]:
            true_answer = [t.lower() for t in result['true_answer']]
            pred_answer = [p.lower() for p in result['pred_answer']]
            true_answer = set(true_answer)
            pred_answer = set(pred_answer)
            for t in true_answer:
                if t in pred_answer:
                    TP += 1
                else:
                    FN += 1
            for p in pred_answer:
                if p not in true_answer:
                    FP += 1
        precision = TP / (TP + FP) if TP + FP > 0 else -1
        recall = TP / (TP + FN) if TP + FN > 0 else -1
        f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else -1

        task_model_score[task][model] = (precision, recall, f1)

# Table
SPLITS = [',','，',';','；','、','+',' ']
for task in ['zh_table']:
    for model in task_model_result[task]:
        TP = 0
        FP = 0
        FN = 0
        for result in task_model_result[task][model]:
            total += 1
            true_answer = result['true_answer']
            pred_answer = list(set(result['pred_answer']))
            if len(true_answer) == 1 and len(true_answer) > 1:
                for split in SPLITS:
                    if split in true_answer[0]:
                        true_answer = [split.join(true_answer)]
                        break
            true_answer = [t.replace(' ', '') for t in true_answer]
            pred_answer = [p.replace(' ', '') for p in pred_answer]
            true_answer = set(true_answer)
            pred_answer = set(pred_answer)
            
            for t in true_answer:
                if t in pred_answer:
                    TP += 1
                else:
                    FN += 1
            for p in pred_answer:
                if p not in true_answer:
                    FP += 1
        precision = TP / (TP + FP) if TP + FP > 0 else -1
        recall = TP / (TP + FN) if TP + FN > 0 else -1
        f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else -1
        task_model_score[task][model] = (precision, recall, f1)            


# Medcase Acc
with open('/ailab/user/sunhongli/workspace/MedLongContextEval/dataset/raw_data/zh_medcase/answer_correct.json', 'r', encoding='utf-8') as f:
    supplement_answers = json.loads(f.read())
supplement_answers = {int(k): v for k, v in supplement_answers.items()}
SPLITS = [',','，',';','；','、','+',' ']
for task in ['zh_medcase']:
    for model in task_model_result[task]:
        correct = 0
        total = 0
        for result in task_model_result[task][model]:
            total += 1
            if result['id'] in supplement_answers:
                true_answers = supplement_answers[result['id']]
            else:
                true_answers = [result['true_answer']]

            for true_answer in true_answers:
                true_answer = true_answer[0]
                pred_answer = result['pred_answer']
                if pred_answer == []:
                    pred_answer = ''
                elif len(pred_answer) == 1:
                    pred_answer = pred_answer[0]
                elif len(pred_answer) > 1:
                    for split in SPLITS:
                        if split in true_answer:
                            pred_answer = split.join(pred_answer)
                            break
                    if type(pred_answer) == list:
                        pred_answer = pred_answer[0]
                true_answer = true_answer.replace(' ', '')
                pred_answer = pred_answer.replace(' ', '')
                if true_answer == pred_answer:
                    correct += 1
                    break
        task_model_score[task][model] = (float(correct) / float(total)) * 100 if total > 0 else -1


# Counting Acc
for task in ['zh_counting', 'en_counting']:
    for model in task_model_result[task]:
        all_correct = 0
        all_total = 0
        sub_tasks = ['acquisition_shuffle', 'acquisition_inc', 'acquisition_1', 'reasoning']
        sub_task_scores = {k: {'total': 0, 'correct': 0, 'acc': 0.0} for k in sub_tasks}
        for result in task_model_result[task][model]:
            all_total += 1
            sub_task_scores[result['type']]['total'] += 1
            true_answer = result['true_answer']
            pred_answer = result['pred_answer']
            if pred_answer == []:
                pred_answer = ''
            if true_answer == pred_answer:
                all_correct += 1
                sub_task_scores[result['type']]['correct'] += 1
        for k, v in sub_task_scores.items():
            sub_task_scores[k]['acc'] = (float(v['correct']) / float(v['total'])) * 100 if v['total'] > 0 else -1
        # task_model_score[task][model] = (float(all_correct) / float(all_total)) * 100 if all_total > 0 else -1
        task_model_score[task][model] = {k: v for k, v in sub_task_scores.items()}
        task_model_score[task][model]['all'] = {'total': all_total, 'correct': all_correct, 'acc': (float(all_correct) / float(all_total)) * 100 if all_total > 0 else -1}

In [5]:
with open('task_result_score.json', 'w', encoding='utf-8') as f:
    f.write(json.dumps(task_model_score, ensure_ascii=False))

In [6]:
token_len_list = [4000, 8000, 16000, 32000, 64000, 128000, 200000]
position_list = [0, 25, 50, 75, 100]


NIAH_BASE_DIR = '/ailab/user/sunhongli/workspace/MedLongContextEval/niah_result'
niah_result_files = glob.glob(NIAH_BASE_DIR + '/*')

with open('/ailab/user/sunhongli/workspace/MedLongContextEval/dataset/task_data/needles/en_pure_needles.json', 'r', encoding='utf-8') as f:
    en_needles = json.loads(f.read())
en_needles_new = {}
for needle in en_needles:
    en_needles_new[needle['id']] = needle
en_needles = en_needles_new

niah_result_model_length_depth_en = {k:{length: {position: [] for position in position_list} for length in token_len_list} for k in model2maxlen}
niah_score_model_length_depth_en = {k:{length: {position: {} for position in position_list} for length in token_len_list} for k in model2maxlen}
for niah_result in niah_result_files:
    if 'en' in niah_result.split('/')[-1]:
        with open(niah_result, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                result = json.loads(line)
                idx = result['idx']
                length = result['length(origin)']
                position = result['position(%)']
                model_name = result['model']
                true_answer = [str(r).lower() for r in en_needles[idx]['ground_truth']]
                pred_answer = str(result['pred_answer']).lower()
                # print(true_answer)
                niah_result_model_length_depth_en[model_name][length][position].append({
                    'id': idx,
                    'true_answer': true_answer,
                    'pred_answer': pred_answer,
                })

for model in niah_result_model_length_depth_en:
    for length in niah_result_model_length_depth_en[model]:
        all_pos_total = 0
        all_pos_correct = 0
        for position in niah_result_model_length_depth_en[model][length]:
            total = 0
            correct = 0
            wrong_id_list = []
            for r in niah_result_model_length_depth_en[model][length][position]:
                total += 1
                right_answer = False
                for ta in r['true_answer']:
                    if ta == r['pred_answer']:
                        correct += 1
                        right_answer = True
                        break
                if not right_answer:
                    wrong_id_list.append(int(r['id']))
            niah_score_model_length_depth_en[model][length][position] = {
                'total': total,
                'correct': correct,
                'acc': (float(correct) / float(total)) * 100 if total > 0 else -1,
                # 'wrong_id_list': wrong_id_list,
            }
            all_pos_total += total
            all_pos_correct += correct
        niah_score_model_length_depth_en[model][length]['all'] = {
            'total': all_pos_total,
            'correct': all_pos_correct,
            'acc': (float(all_pos_correct) / float(all_pos_total)) * 100 if all_pos_total > 0 else -1
        }
niah_score_model_length_depth_en

{'gpt-4-turbo-2024-04-09': {4000: {0: {'total': 20,
    'correct': 19,
    'acc': 95.0},
   25: {'total': 20, 'correct': 18, 'acc': 90.0},
   50: {'total': 20, 'correct': 16, 'acc': 80.0},
   75: {'total': 20, 'correct': 16, 'acc': 80.0},
   100: {'total': 20, 'correct': 18, 'acc': 90.0},
   'all': {'total': 100, 'correct': 87, 'acc': 87.0}},
  8000: {0: {'total': 20, 'correct': 17, 'acc': 85.0},
   25: {'total': 20, 'correct': 19, 'acc': 95.0},
   50: {'total': 20, 'correct': 18, 'acc': 90.0},
   75: {'total': 20, 'correct': 18, 'acc': 90.0},
   100: {'total': 20, 'correct': 17, 'acc': 85.0},
   'all': {'total': 100, 'correct': 89, 'acc': 89.0}},
  16000: {0: {'total': 20, 'correct': 18, 'acc': 90.0},
   25: {'total': 20, 'correct': 18, 'acc': 90.0},
   50: {'total': 20, 'correct': 17, 'acc': 85.0},
   75: {'total': 20, 'correct': 18, 'acc': 90.0},
   100: {'total': 20, 'correct': 16, 'acc': 80.0},
   'all': {'total': 100, 'correct': 87, 'acc': 87.0}},
  32000: {0: {'total': 20, 'corr

In [7]:
with open('niah_en_result_score.json', 'w', encoding='utf-8') as f:
    f.write(json.dumps(niah_score_model_length_depth_en, ensure_ascii=False))

In [8]:
token_len_list = [4000, 8000, 16000, 32000, 64000, 128000, 200000]
position_list = [0, 25, 50, 75, 100]


NIAH_BASE_DIR = '/ailab/user/sunhongli/workspace/MedLongContextEval/niah_result'
niah_result_files = glob.glob(NIAH_BASE_DIR + '/*')

with open('/ailab/user/sunhongli/workspace/MedLongContextEval/dataset/task_data/needles/zh_pure_needles.json', 'r', encoding='utf-8') as f:
    zh_needles = json.loads(f.read())
zh_needles_new = {}
for needle in zh_needles:
    zh_needles_new[needle['id']] = needle
zh_needles = zh_needles_new

niah_result_model_length_depth_zh = {k:{length: {position: [] for position in position_list} for length in token_len_list} for k in model2maxlen}
niah_score_model_length_depth_zh = {k:{length: {position: {} for position in position_list} for length in token_len_list} for k in model2maxlen}
for niah_result in niah_result_files:
    if 'zh' in niah_result.split('/')[-1]:
        with open(niah_result, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                result = json.loads(line)
                idx = result['idx']
                length = result['length(origin)']
                position = result['position(%)']
                model_name = result['model']
                true_answer = [str(r).replace(' ', '') for r in zh_needles[idx]['ground_truth']]
                pred_answer = str(result['pred_answer']).replace(' ', '')
                
                niah_result_model_length_depth_zh[model_name][length][position].append({
                    'id': idx,
                    'true_answer': true_answer,
                    'pred_answer': pred_answer,
                })

for model in niah_result_model_length_depth_zh:
    for length in niah_result_model_length_depth_zh[model]:
        all_pos_total = 0
        all_pos_correct = 0
        for position in niah_result_model_length_depth_zh[model][length]:
            total = 0
            correct = 0
            wrong_id_list = []
            for r in niah_result_model_length_depth_zh[model][length][position]:
                total += 1
                right_answer = False
                for ta in r['true_answer']:
                    if ta == r['pred_answer']:
                        correct += 1
                        right_answer = True
                        break
                if not right_answer:
                    wrong_id_list.append(int(r['id']))
            niah_score_model_length_depth_zh[model][length][position] = {
                'total': total,
                'correct': correct,
                'acc': (float(correct) / float(total)) * 100 if total > 0 else -1,
                # 'wrong_id_list': wrong_id_list,
            }
            all_pos_total += total
            all_pos_correct += correct
        niah_score_model_length_depth_zh[model][length]['all'] = {
            'total': all_pos_total,
            'correct': all_pos_correct,
            'acc': (float(all_pos_correct) / float(all_pos_total)) * 100 if all_pos_total > 0 else -1
        }
niah_score_model_length_depth_zh

{'gpt-4-turbo-2024-04-09': {4000: {0: {'total': 20,
    'correct': 19,
    'acc': 95.0},
   25: {'total': 20, 'correct': 19, 'acc': 95.0},
   50: {'total': 20, 'correct': 18, 'acc': 90.0},
   75: {'total': 20, 'correct': 20, 'acc': 100.0},
   100: {'total': 20, 'correct': 19, 'acc': 95.0},
   'all': {'total': 100, 'correct': 95, 'acc': 95.0}},
  8000: {0: {'total': 20, 'correct': 19, 'acc': 95.0},
   25: {'total': 20, 'correct': 18, 'acc': 90.0},
   50: {'total': 20, 'correct': 18, 'acc': 90.0},
   75: {'total': 20, 'correct': 18, 'acc': 90.0},
   100: {'total': 20, 'correct': 19, 'acc': 95.0},
   'all': {'total': 100, 'correct': 92, 'acc': 92.0}},
  16000: {0: {'total': 20, 'correct': 18, 'acc': 90.0},
   25: {'total': 20, 'correct': 19, 'acc': 95.0},
   50: {'total': 20, 'correct': 19, 'acc': 95.0},
   75: {'total': 20, 'correct': 19, 'acc': 95.0},
   100: {'total': 20, 'correct': 20, 'acc': 100.0},
   'all': {'total': 100, 'correct': 95, 'acc': 95.0}},
  32000: {0: {'total': 20, 'co

In [9]:
with open('niah_zh_result_score.json', 'w', encoding='utf-8') as f:
    f.write(json.dumps(niah_score_model_length_depth_zh, ensure_ascii=False))

In [10]:
# niah_score_model_length_depth_en
# niah_score_model_length_depth_zh

In [11]:
for model in model2maxlen:
    print(model)
    zh_results, en_results = niah_score_model_length_depth_zh[model], niah_score_model_length_depth_en[model]
    for position in position_list:
        total = 0
        correct = 0
        for length in token_len_list:
            print('& ', end='')
            if en_results[length][position]["total"] == 0:
                print('$-$', end='')
            else:
                print(en_results[length][position]["correct"], end='')
            print(' ', end='')
            total += en_results[length][position]["total"]
            correct += en_results[length][position]["correct"]
        print('& ', end='')
        print(str(correct) + '$/$' + str(total), end='')
        print(' ', end='')

        total = 0
        correct = 0
        for length in token_len_list:
            print('& ', end='')
            if zh_results[length][position]["total"] == 0:
                print('$-$', end='')
            else:
                print(zh_results[length][position]["correct"], end='')
            print(' ', end='')
            total += zh_results[length][position]["total"]
            correct += zh_results[length][position]["correct"]
        print('& ', end='')
        print(str(correct) + '$/$' + str(total), end='')
        print(' ', end='')
        
        print('\\\\')


    for length in token_len_list:
        print('& ', end='')
        print(str(en_results[length]["all"]["correct"]) + '$/$100', end='')
        print(' ', end='')
    print('& ', end='')
    for length in token_len_list:
        print('& ', end='')
        print(str(zh_results[length]["all"]["correct"]) + '$/$100', end='')
        print(' ', end='')
    print('& ', end='')
    print('\\\\')

gpt-4-turbo-2024-04-09
& 19 & 17 & 18 & 18 & 18 & 17 & 0 & 107$/$140 & 19 & 19 & 18 & 18 & 18 & 17 & 0 & 109$/$140 \\
& 18 & 19 & 18 & 18 & 15 & 14 & 0 & 102$/$140 & 19 & 18 & 19 & 18 & 19 & 19 & 0 & 112$/$140 \\
& 16 & 18 & 17 & 17 & 16 & 16 & 0 & 100$/$140 & 18 & 18 & 19 & 19 & 18 & 18 & 0 & 110$/$140 \\
& 16 & 18 & 18 & 19 & 18 & 15 & 0 & 104$/$140 & 20 & 18 & 19 & 19 & 18 & 18 & 0 & 112$/$140 \\
& 18 & 17 & 16 & 18 & 16 & 16 & 0 & 101$/$140 & 19 & 19 & 20 & 20 & 20 & 18 & 0 & 116$/$140 \\
& 87$/$100 & 89$/$100 & 87$/$100 & 90$/$100 & 83$/$100 & 78$/$100 & 0$/$100 & & 95$/$100 & 92$/$100 & 95$/$100 & 94$/$100 & 93$/$100 & 90$/$100 & 0$/$100 & \\
gpt-4o
& 16 & 15 & 16 & 17 & 16 & 16 & 0 & 96$/$140 & 19 & 19 & 16 & 19 & 19 & 19 & 0 & 111$/$140 \\
& 16 & 15 & 17 & 18 & 17 & 15 & 0 & 98$/$140 & 19 & 18 & 17 & 18 & 19 & 17 & 0 & 108$/$140 \\
& 16 & 16 & 17 & 17 & 17 & 16 & 0 & 99$/$140 & 19 & 19 & 18 & 19 & 19 & 17 & 0 & 111$/$140 \\
& 16 & 17 & 17 & 16 & 16 & 17 & 0 & 99$/$140 & 16 & 19