## Import

In [1]:
import os
import json
import numpy as np
from tqdm import tqdm

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig

In [None]:
instruction = '''以下是关于医学知识的单项选择题，请根据题目输出唯一的答案选项\n'''
instruction

In [None]:
path_file_data = '../data/MedQA/Mainland/test.jsonl'

## Data

In [None]:
list_dict_test = []
with open(path_file_data, 'r', encoding="utf-8") as f:
    for idx, line in enumerate(f):
        data = json.loads(line)
        data['ID'] = idx
        data['A'], data['B'], data['C'], data['D'], data['E'] = data['options']['A'], data['options']['B'], data['options']['C'], data['options']['D'], data['options']['E']
        del data['options']
        list_dict_test.append(data)

In [None]:
list_dict_test[0]

In [None]:
for idx, dict_test in enumerate(tqdm(list_dict_test)):
    question, answer = dict_test['question'], dict_test['answer']
    a, b, c, d, e = dict_test['A'], dict_test['B'], dict_test['C'], dict_test['D'], dict_test['E']
    question = question.replace("（ ）。", "")
    input = instruction + f"问题：{question}: (A){a}, (B){b}, (C){c}, (D){d}, (E){e}\n" + "答案："
    dict_test['Input'] = input
    # print(input, '\t' ,answer)

## Model

In [None]:
path_dir_model = 'path_of_baichuan2_model'

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(path_dir_model, use_fast=False, trust_remote_code=True)

In [None]:
# model = AutoModelForCausalLM.from_pretrained(path_dir_model, device_map="auto", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(path_dir_model, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)

In [None]:
model.generation_config = GenerationConfig.from_pretrained(path_dir_model)
model.generation_config.max_new_tokens = 512

In [None]:
model.device

In [None]:
model = model.eval()

## Test

In [None]:
def get_response(text, flag_print=True):
    messages = []
    messages.append({"role": "user", "content": text})
    with torch.no_grad():
        response = model.chat(tokenizer, messages)
        if flag_print:
            print(text)
            print('-------------------------------')
            print(response)
    return response

In [None]:
response = get_response(dict_test['Input'])

## Task

In [None]:
name_task = "Prompt-bc2-7B"

In [None]:
list_dict_test[0]

## Run

In [None]:
name_task

In [None]:
# batch 
batch_size = 6
num_batch = int(np.ceil(len(list_dict_test) / batch_size))

with torch.no_grad():
    for idx_batch in tqdm(range(num_batch)):
        # tokenize  
        list_data_batch = list_dict_test[idx_batch*batch_size: (idx_batch+1)*batch_size]
        list_input_batch = [ dict_one['Input'] for dict_one in list_data_batch]

        # try:
        input_id_batch = tokenizer.batch_encode_plus(list_input_batch, padding=True, truncation=True)['input_ids']
        # for chat
        input_id_batch = [
            [model.generation_config.user_token_id]
            + input_id
            + [model.generation_config.assistant_token_id]
            for input_id in input_id_batch
        ]
        input_id_batch = torch.LongTensor(input_id_batch).to(model.device)
        # generate
        output_batch = model.generate(input_id_batch, generation_config=model.generation_config)
        # response
        for dict_one, input_id, output in zip(list_data_batch, input_id_batch, output_batch):
            dict_one["Result"] = tokenizer.decode(
                output[len(input_id):], skip_special_tokens=True
            )
        # except:
        #     print(f"Error: {idx}")
        # save
        # if idx != 0 or idx % num_save == 0:
        #     with open(path_file_dataset, "w", encoding="utf-8") as f:
        #         json.dump(list_dict_test, f, ensure_ascii=False, indent=4)

## Performance

### Read

In [None]:
# save list_dict_test into json
with open(f'../result/{name_task}.json', "w", encoding="utf-8") as f:
    json.dump(list_dict_test, f, ensure_ascii=False, indent=4)

In [None]:
len(list_dict_test), len([dict_test for dict_test in list_dict_test if dict_test.get('Result')])

### Prediction

In [None]:
def get_prediction(result_pred, list_option_name):
    list_str_split = ['最终答案是','最终答案为','最可能的选项是','最可能的诊断是','正确的选项是', '正确选项是', '正确选项为','最可能的答案是','最有可能的答案是']
    for str_split in list_str_split:
        if len(result_pred.split(str_split))>1:
            result_pred = result_pred.split(str_split)[-1]
            result_pred = result_pred.split('。')[0]
            break
    list_option_idx = [ chr(ord('A') + idx) for idx in range(len(list_option_name))]
    list_option_idx_name = [ [option_idx, option_name] for option_idx, option_name in zip(list_option_idx, list_option_name)]
    list_option_idx_name = sorted(list_option_idx_name, key=lambda x: len(x[1]), reverse=True)
    list_option_match = []
    # search option_idx and option_name
    for option_idx, option_name in list_option_idx_name:
        if option_idx in result_pred or option_name in result_pred:
            list_option_match.append(option_idx)
            if option_name in result_pred:
                result_pred = result_pred.replace(option_name, '')
    if len(list_option_match)==1:
        return list_option_match[0]
    elif len(list_option_match)>1:
        return 'Null'
    else:
        return 'Null'

In [None]:
list_dict_result_correct, list_dict_result_wrong, list_dict_result_none = [], [], []
for idx, dict_test in enumerate(tqdm(list_dict_test)):
    list_option_name = [dict_test['A'], dict_test['B'], dict_test['C'], dict_test['D'], dict_test['E']]
    dict_test['Prediction'] = get_prediction(dict_test['Result'], list_option_name)
    if dict_test['Prediction']=='Null':
        list_dict_result_none.append(dict_test)
    elif dict_test['Answer'] == dict_test['Prediction']:
        list_dict_result_correct.append(dict_test)
    else:
        list_dict_result_wrong.append(dict_test)
    # if len(dict_test['Result']) > 6:
    #     print(dict_test['Result']) 

In [None]:
count_correct = len(list_dict_result_correct)
count_wrong = len(list_dict_result_wrong)
count_none = len(list_dict_result_none)
acc = round(len(list_dict_result_correct) / len(list_dict_test)*100, 2)
print(f"Accuracy is {acc}% with {count_correct} correct, {count_wrong} wrong, and {count_none} none")

In [None]:
# Too long 111

In [None]:
with open(f'../result/{name_task}.json', 'w', encoding="utf-8") as f:
    f.write(json.dumps(list_dict_test, indent=4, ensure_ascii=False))

### Final

In [None]:
with open(f'../result/{name_task}.json', 'r', encoding="utf-8") as f:
    list_dict_test_sample = json.load(f)
with open('../data/list_test_knowledge_idx.txt', 'r') as f:
    list_test_knowledge_idx = f.readlines()
with open('../data/list_test_example_idx.txt', 'r') as f:
    list_test_example_idx = f.readlines()

In [None]:
list_test_knowledge_idx = [ int(idx.strip()) for idx in list_test_knowledge_idx if idx.strip() ]
list_test_example_idx = [ int(idx.strip()) for idx in list_test_example_idx if idx.strip() ]
list_dict_test_knowledge = [ list_dict_test_sample[idx] for idx in list_test_knowledge_idx ]
list_dict_test_example = [ list_dict_test_sample[idx] for idx in list_test_example_idx ]
print(len(list_dict_test_sample), len(list_dict_test_knowledge), len(list_dict_test_example))

In [None]:
def cal_performance(list_dict_test_sample):
    list_dict_result_correct, list_dict_result_wrong, list_dict_result_none = [], [], []
    for idx, dict_test in enumerate(list_dict_test_sample):
        if dict_test['Prediction']=='Null':
            list_dict_result_none.append(dict_test)
        elif dict_test['Answer'] == dict_test['Prediction']:
            list_dict_result_correct.append(dict_test)
        else:
            list_dict_result_wrong.append(dict_test)
    count_correct = len(list_dict_result_correct)
    count_wrong = len(list_dict_result_wrong)
    count_none = len(list_dict_result_none)
    acc = len(list_dict_result_correct) / len(list_dict_test_sample)
    print(f"Accuracy is {round(acc*100, 2)}% with {count_none} none, {count_correct} correct, and {count_wrong} wrong， total {len(list_dict_test_sample)}")
    return round(acc, 4), count_none, count_correct, count_wrong

In [None]:
acc_MK, count_none_MK, _, _ = cal_performance(list_dict_test_knowledge)
acc_CA, count_none_CA, _, _ = cal_performance(list_dict_test_example)
acc_all, count_none_all, _, _ = cal_performance(list_dict_test_sample)

In [None]:
print(count_none_MK, acc_MK, count_none_CA, acc_CA, count_none_all, acc_all, sep='\t')

## End

In [None]:
print('Done')