In [None]:
!source /etc/network_turbo

In [None]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
cache_dir = "/root/autodl-tmp/torch_cache"
os.makedirs(cache_dir, exist_ok=True) # 确保目录存在
os.environ["TORCH_INDUCTOR_CACHE_DIR"] = cache_dir

In [None]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
import torch
import json
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer

# --- 1. 设置和加载模型 ---


# 测试数据集的路径
test_data_path = "/root/test_prompts.json"

# 加载gptoss模型和分词器
# 必须从保存的目录加载，这样它会自动应用LoRA适配器
print("Loading fine-tuned model and tokenizer...")



In [None]:
#加载主流模型和分词器
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
model_name = "/root/autodl-tmp"
 
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="cuda"
)
model = PeftModel.from_pretrained(model, "/root/autodl-tmp/qwen-qlora-split2", torch_dtype=torch.bfloat16)
model.merge_and_unload()

In [None]:

print(f"Loading test data from {test_data_path}...")
with open(test_data_path, 'r', encoding='utf-8') as f:
    test_samples = json.load(f)
print(f"Found {len(test_samples)} samples in the test set.")


# --- 3. 辅助函数：用于解析标签 ---


In [None]:
def parse_labels(text_block: str) -> dict:
    """
    从模型的生成文本或真实标签文本中解析出Q1-Q4的标签。
    """
    parsed = {}
    lines = text_block.strip().split('\n')
    for line in lines:
        # 跳过消息ID行，如 'M61'
        if line.lower().startswith('m') and line[1:].isdigit():
            continue
        
        # 分割键和值，例如 "theme: Finance/Crypto"
        parts = line.split(':', 1)
        if len(parts) == 2:
            key = parts[0].strip()
            value = parts[1].strip()
            # 兼容 claim_types, ctas, evidence 等键
            if key in ["theme", "claim_types", "ctas", "evidence"]:
                parsed[key] = value
    return parsed

# --- 4. 执行评估循环 ---

results = []
correct_counts = {"theme": 0, "claim_types": 0, "ctas": 0, "evidence": 0, "all": 0}
total_samples = len(test_samples)

# 使用tqdm来显示进度条
for sample in tqdm(test_samples, desc="Evaluating"):
    # 提取 prompt (system + user message)
    prompt_messages = sample['message'][:-1] # 去掉 assistant 的部分
    
    # 使用 apply_chat_template 来格式化输入，确保与训练时一致
    # add_generation_prompt=True 会在末尾添加 assistant 提示符，引导模型生成
    inputs = tokenizer.apply_chat_template(prompt_messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
    print(inputs.shape)
    # 生成预测
    # 使用 torch.no_grad() 来禁用梯度计算，加速推理
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs,
            max_new_tokens=128,  # 设置一个合理的长度，防止生成过长
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False, # 使用确定性生成，关闭采样
        )
    
    # 解码生成的 token IDs，并跳过输入的 prompt部分
    response_ids = outputs[0][len(inputs[0]):]
    generated_text = tokenizer.decode(response_ids, skip_special_tokens=True)

    # 提取并解析真实标签
    ground_truth_text = sample['message'][-1]['content']
    print(ground_truth_text)
    # 解析预测和真实标签
    parsed_prediction = parse_labels(generated_text)
    parsed_ground_truth = parse_labels(ground_truth_text)

    # --- 5. 比较结果 ---
    is_theme_correct = parsed_prediction.get("theme") == parsed_ground_truth.get("theme")
    is_claim_types_correct = parsed_prediction.get("claim_types") == parsed_ground_truth.get("claim_types")
    is_ctas_correct = parsed_prediction.get("ctas") == parsed_ground_truth.get("ctas")
    is_evidence_correct = parsed_prediction.get("evidence") == parsed_ground_truth.get("evidence")

    # 更新计数
    if is_theme_correct: correct_counts["theme"] += 1
    if is_claim_types_correct: correct_counts["claim_types"] += 1
    if is_ctas_correct: correct_counts["ctas"] += 1
    if is_evidence_correct: correct_counts["evidence"] += 1
        
    is_all_correct = is_theme_correct and is_claim_types_correct and is_ctas_correct and is_evidence_correct
    if is_all_correct:
        correct_counts["all"] += 1

    # 保存详细结果以便后续分析
    results.append({
        "message_id": sample['message_id'],
        "prompt": tokenizer.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True),
        "generated_text": generated_text.strip(),
        "ground_truth_text": ground_truth_text.strip(),
        "parsed_prediction": parsed_prediction,
        "parsed_ground_truth": parsed_ground_truth,
        "is_all_correct": is_all_correct
    })

# --- 6. 计算并打印最终准确率 ---

print("\n--- Evaluation Complete ---")
print(f"Total samples evaluated: {total_samples}")

if total_samples > 0:
    accuracies = {key: (value / total_samples) * 100 for key, value in correct_counts.items()}
    
    print("\nAccuracy Scores:")
    print(f"  Q1 (Theme) Accuracy:         {accuracies['theme']:.2f}% ({correct_counts['theme']}/{total_samples})")
    print(f"  Q2 (Claim Types) Accuracy:   {accuracies['claim_types']:.2f}% ({correct_counts['claim_types']}/{total_samples})")
    print(f"  Q3 (CTAs) Accuracy:          {accuracies['ctas']:.2f}% ({correct_counts['ctas']}/{total_samples})")
    print(f"  Q4 (Evidence) Accuracy:      {accuracies['evidence']:.2f}% ({correct_counts['evidence']}/{total_samples})")
    print("  ----------------------------------")
    print(f"  Overall Exact Match Accuracy: {accuracies['all']:.2f}% ({correct_counts['all']}/{total_samples})")

# (可选) 将详细结果保存到文件中，方便检查错误案例
output_results_file = "evaluation_results.json"
with open(output_results_file, 'w', encoding='utf-8') as f:
    json.dump(results, f, indent=2, ensure_ascii=False)

print(f"\nDetailed results saved to {output_results_file}")

# (可选) 打印一个错误案例进行分析
print("\n--- Example of an Incorrect Prediction ---")
incorrect_examples = [r for r in results if not r['is_all_correct']]
if incorrect_examples:
    example = incorrect_examples[0]
    print(f"Message ID: {example['message_id']}")
    print(f"\n[Prompt Sent to Model]\n{example['prompt']}")
    print(f"\n[Model Prediction]\n{example['generated_text']}")
    print(f"\n[Ground Truth]\n{example['ground_truth_text']}")
else:
    print("Congratulations! No incorrect predictions found.")