In [None]:
import json
import os
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
import torch
import sys
sys.path.append('/mnt/pvc-data.common/ChenZikang/codes/LLM-Evaluation/code')

from model.eval_model import EvalQwen

In [None]:
class EvalQwen(EvalQwen):
    def _gen_input(self, data):
        query = f'''
请作为一名乳腺癌专科医生，回答以下问题。
请判断下面的内容属于病情诊断、干预建议和心理支持三类分类中的哪一类。
如果属于病情诊断请输出1，属于干预建议请输出2，属于心理支持请输出3，不属于以上任何一类请输出0。
注意：只需要输出1，2，3，0四个数字中的一个，不需要输出任何额外的内容！
input: 【患者】{data['input']}【患者】【医生】{data['output']}【医生】
output:
'''
        messages  = [{'role': 'user', 'content': query}]
        return messages

In [None]:
device = "cuda:1"
local_model_path = "/mnt/pvc-data.common/ChenZikang/huggingface"

In [None]:
def eval_models(eval_model, dataset, batch_size, max_new_tokens, output_path):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    eval_model.clear_pre_ref()
    eval_model.inference_batch(dataset, batch_size=batch_size, max_new_tokens=max_new_tokens)
    eval_model.save_pre_ref(output_path)

In [None]:
model = AutoModelForCausalLM.from_pretrained(f'{local_model_path}/Qwen/Qwen2.5-7B-Instruct', trust_remote_code=True, torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained(f"{local_model_path}/Qwen/Qwen2.5-7B-Instruct", padding_side='left', trust_remote_code=True)

In [None]:
eval_qwen = EvalQwen(model=model, tokenizer=tokenizer, device=device)

In [None]:
dataset = json.load(open("../../dataset/public_qa.json", 'r'))
for data in dataset:
    data['answer'] = ''
dataset[0]

In [None]:
eval_models(eval_qwen, dataset, batch_size=12, max_new_tokens=8, output_path="../../outputs/qwen_class")

In [None]:
dataset_classified = {
    "病情诊断": [],
    "干预建议": [],
    "心理支持": []
}
for i in range(len(eval_qwen.predictions)):
    if '1' in eval_qwen.predictions[i]:
        dataset_classified['病情诊断'].append(dataset[i])
    elif '2' in eval_qwen.predictions[i]:
        dataset_classified['干预建议'].append(dataset[i])
    elif '3' in eval_qwen.predictions[i]:
        dataset_classified['心理支持'].append(dataset[i])

json.dump(dataset_classified, open("../../dataset/public_qa_classified.json", 'w'), ensure_ascii=False)

In [None]:
len(dataset_classified['心理支持'])

In [None]:
dataset_classified['心理支持'][1]