In [3]:
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score
from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载模型和 tokenizer
model_path = "/root/autodl-tmp/outputs/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)

def get_bio_tags(text):
    """
    调用大模型进行 BIO 标签预测
    """
    sys_prompt = """
    你是一个医疗命名实体专家！请根据当前对话文本内容，识别出每一句话中的BIO实体标签
    """
    messages = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": text}
    ]

    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    generated_ids = model.generate(**model_inputs, max_new_tokens=512)
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

def calculate_f1(y_true, y_pred):
    """
    计算 F1 分数
    """
    true_labels = [label.split() for label in y_true]
    pred_labels = [label.split() for label in y_pred]

    true_labels_flat = [item for sublist in true_labels for item in sublist]
    pred_labels_flat = [item for sublist in pred_labels for item in sublist]

    precision = precision_score(true_labels_flat, pred_labels_flat, average='micro')
    recall = recall_score(true_labels_flat, pred_labels_flat, average='micro')
    f1 = f1_score(true_labels_flat, pred_labels_flat, average='micro')
    return precision, recall, f1

# 读取 TSV 文件
data = pd.read_csv("dev.tsv", sep='\t', encoding='utf8')

# 随机抽取100条数据
sampled_data = data.sample(n=100, random_state=42)

results = []
predicted_labels = []
true_labels = list(sampled_data["output"])

# 创建列表来存储预测内容
predictions = []

for idx, row in tqdm(sampled_data.iterrows(), total=len(sampled_data)):
    instruction = row["instruction"]
    actual_output = row["output"]

    # 获取预测输出
    predicted_output = get_bio_tags(instruction).strip()  # 去掉首尾空白字符

    # 如果实际输出和预测输出的长度不同，视为预测错误
    if len(predicted_output.split()) != len(actual_output.split()):
        print(f"长度不一致: index={idx}, instruction={instruction}, predicted_output={predicted_output}")
        predicted_output = "O " * len(actual_output.split())  # 使用默认值避免程序中断
        predicted_output = predicted_output.strip()  # 去掉最后的多余空格

    # 将 predicted_output 添加到 predicted_labels 列表中
    predicted_labels.append(predicted_output)
    predictions.append([instruction, actual_output, predicted_output])
    results.append(predicted_output == actual_output)

# 计算 F1 分数
precision, recall, f1 = calculate_f1(true_labels, predicted_labels)

print(f"精确率(Precision): {precision}")
print(f"召回率(Recall): {recall}")
print(f"F1 分数(F1 Score): {f1}")

# 打印预测内容
print(predictions)

# 将预测内容保存为新的 TSV 文件
output_df = pd.DataFrame(predictions, columns=['instruction', 'output', 'pred_output'])
output_df.to_csv("predictions_sampled.tsv", sep='\t', index=False)

print("随机抽样预测结果已保存到 predictions_sampled.tsv 文件中")


 18%|█▊        | 18/100 [01:57<07:15,  5.31s/it]

长度不一致: index=14816, instruction=另外孩子如果吃母乳，孩子妈妈忌辛辣刺激，生冷食品, predicted_output=O O O O O O O O O O O O O O O O O O O O O O O O O


100%|██████████| 100/100 [10:57<00:00,  6.57s/it]

精确率(Precision): 0.9589702333065165
召回率(Recall): 0.9589702333065165
F1 分数(F1 Score): 0.9589702333065165
[['家里有康崔乐', 'O O O B-Drug I-Drug I-Drug', 'O O O B-Drug I-Drug I-Drug'], ['但是反反复复这样。所以很担心', 'O O O O O O O O O O O O O O', 'O O O O O O O O O O O O O O'], ['最好化验一下，查看是哪种腹泻', 'O O O O O O O O O O O O B-Symptom I-Symptom', 'O O O O O O O O O O O O B-Symptom I-Symptom'], ['注意孩子拉的很多的话，注意适当补液，如果拉的水不多，可以继续光母乳，预防脱水', 'O O O O O O O O O O O O O O O B-Operation I-Operation O O O O O O O O O O O O O O O O O O O B-Symptom I-Symptom', 'O O O O B-Symptom I-Symptom I-Symptom I-Symptom O O O O O O O B-Operation I-Operation O O O B-Symptom I-Symptom I-Symptom O O O O O O O O O O O O O B-Symptom I-Symptom'], ['要不要放冰箱喝剩下的奶', 'O O O O O O O O O O O', 'O O O O O O O O O O O'], ['头疼', 'B-Symptom I-Symptom', 'B-Symptom I-Symptom'], ['可以口服助消化的药物。', 'O O O O B-Drug_Category I-Drug_Category I-Drug_Category I-Drug_Category I-Drug_Category I-Drug_Category O', 'O O O O B-Drug_Category I-Drug_Category I-Drug_Cate


