In [1]:
import json
import re


def extract_entities(reference, anonymized):
    """
    根据 comments_text_anonymized 中的占位符，提取 reference 中相应位置的实体名称
    支持提取所有 [PLAYER] 和 [TEAM] 标记对应的实际名称
    """
    # 如果缺少任一输入，返回空列表
    if not reference or not anonymized:
        return [], []

    # 处理特殊token的正则模式
    player_pattern = r'\[PLAYER\]'
    team_pattern = r'\[TEAM\]'
    referee_pattern = r'\[REFEREE\]'
    coach_pattern = r'\[COACH\]'

    # 创建占位符标记的位置映射
    placeholders = []
    for match in re.finditer(player_pattern, anonymized):
        placeholders.append(("PLAYER", match.start(), match.end()))
    for match in re.finditer(team_pattern, anonymized):
        placeholders.append(("TEAM", match.start(), match.end()))
    for match in re.finditer(referee_pattern, anonymized):
        placeholders.append(("REFEREE", match.start(), match.end()))
    for match in re.finditer(coach_pattern, anonymized):
        placeholders.append(("COACH", match.start(), match.end()))

    # 按位置排序占位符
    placeholders.sort(key=lambda x: x[1])

    # 记录参考文本中每个位置对应的原始文本和匿名文本的字符差异
    offset_map = {}
    ref_idx = 0
    anon_idx = 0

    while anon_idx < len(anonymized) and ref_idx < len(reference):
        # 检查是否匹配到占位符
        is_placeholder = False
        placeholder_type = None
        placeholder_length = 0

        if anon_idx + 8 <= len(anonymized) and anonymized[anon_idx:anon_idx + 8] == "[PLAYER]":
            placeholder_type = "PLAYER"
            placeholder_length = 8
            is_placeholder = True
        elif anon_idx + 6 <= len(anonymized) and anonymized[anon_idx:anon_idx + 6] == "[TEAM]":
            placeholder_type = "TEAM"
            placeholder_length = 6
            is_placeholder = True
        elif anon_idx + 9 <= len(anonymized) and anonymized[anon_idx:anon_idx + 9] == "[REFEREE]":
            placeholder_type = "REFEREE"
            placeholder_length = 9
            is_placeholder = True
        elif anon_idx + 7 <= len(anonymized) and anonymized[anon_idx:anon_idx + 7] == "[COACH]":
            placeholder_type = "COACH"
            placeholder_length = 7
            is_placeholder = True

        if is_placeholder:
            # 找到实体名
            start_ref_idx = ref_idx
            # 向前查找，直到遇到非名称字符
            # 名称可能包含字母、空格和连字符
            while ref_idx < len(reference) and (reference[ref_idx].isalpha() or reference[ref_idx] in [' ', '-']):
                ref_idx += 1
            # 记录偏移量和替换文本
            offset_map[anon_idx] = {
                "type": placeholder_type,
                "start": start_ref_idx,
                "end": ref_idx,
                "text": reference[start_ref_idx:ref_idx].strip()
            }
            anon_idx += placeholder_length
        else:
            # 对于相同的字符，同步前进
            if anon_idx < len(anonymized) and ref_idx < len(reference) and anonymized[anon_idx] == reference[ref_idx]:
                anon_idx += 1
                ref_idx += 1
            else:
                # 处理非占位符的不匹配字符
                anon_idx += 1
                ref_idx += 1

    # 提取所有实体名
    player_names = []
    team_names = []

    for placeholder_type, start, end in placeholders:
        if start in offset_map:
            entity_info = offset_map[start]
            entity_text = entity_info["text"]

            if entity_text:
                # 去除可能的标点和空格
                entity_text = entity_text.strip()

                # 处理不同类型的实体
                if entity_info["type"] == "PLAYER" and entity_text:
                    # 只保留看起来像名字的部分
                    name_parts = re.findall(
                        r'[A-Z][a-zA-Z\'.-]+(?:\s+[A-Z][a-zA-Z\'.-]+)*', entity_text)
                    if name_parts:
                        player_names.append(name_parts[0])
                elif entity_info["type"] == "TEAM" and entity_text:
                    # 球队名可能是多个单词组成
                    team_parts = re.findall(
                        r'[A-Z][a-zA-Z\'.-]+(?:\s+[a-zA-Z\'.-]+)*', entity_text)
                    if team_parts:
                        team_names.append(team_parts[0])

    # 兜底方法：查找标准格式的球员名和球队名 "Name (Team)"
    if not player_names or not team_names:
        names_with_team = re.findall(
            r'([A-Z][a-zA-Z\'.-]+(?:\s+[A-Z][a-zA-Z\'.-]+)*)\s*\(([^)]+)\)', reference)
        if names_with_team:
            if not player_names:
                player_names = [name for name, _ in names_with_team]
            if not team_names:
                # 合并所有球队名并去重
                all_teams = [team for _, team in names_with_team]
                team_names = list(set(all_teams))

    return player_names, team_names


def entity_recall(prediction, reference, anonymized):
    """
    计算在prediction中正确识别出的实体名字的比例
    分别返回球员和球队的召回率
    """
    gt_players, gt_teams = extract_entities(reference, anonymized)

    # 在prediction中查找每个实体名（不区分大小写）
    pred_lower = prediction.lower()

    # 球员识别
    player_hit = 0
    matched_players = []
    missed_players = []

    for name in gt_players:
        if name.lower() in pred_lower:
            player_hit += 1
            matched_players.append(name)
        else:
            missed_players.append(name)

    # 球队识别
    team_hit = 0
    matched_teams = []
    missed_teams = []

    for name in gt_teams:
        if name.lower() in pred_lower:
            team_hit += 1
            matched_teams.append(name)
        else:
            missed_teams.append(name)

    return {
        "player_hit": player_hit,  # 正确识别的球员数
        "player_total": len(gt_players),  # 总球员数
        "matched_players": matched_players,
        "missed_players": missed_players,
        "team_hit": team_hit,  # 正确识别的球队数
        "team_total": len(gt_teams),  # 总球队数
        "matched_teams": matched_teams,
        "missed_teams": missed_teams,
        "gt_players": gt_players,
        "gt_teams": gt_teams
    }


def unanonymous_metric(inference_result_path):
    """
    评估模型识别球员名字和球队名的能力
    """
    with open(inference_result_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    pairs = data.get('pairs', [])

    # 统计所有样本的总数
    total_player_hit = 0  # 所有样本中正确识别的球员总数
    total_player_count = 0  # 所有样本中球员总数
    total_team_hit = 0  # 所有样本中正确识别的球队总数
    total_team_count = 0  # 所有样本中球队总数

    sample_count = 0  # 有实体的样本数
    detailed_results = []

    for item in pairs:
        prediction = item.get('prediction', '')
        reference = item.get('reference', '')
        anonymized = item.get('comments_text_anonymized', '')

        # 跳过没有必要字段的条目
        if not (prediction and reference and anonymized):
            continue

        recall_info = entity_recall(prediction, reference, anonymized)

        # 只记录有球员名或球队名的样本
        has_entities = recall_info["player_total"] > 0 or recall_info["team_total"] > 0

        if has_entities:
            sample_count += 1

            # 累加各项指标
            total_player_hit += recall_info["player_hit"]
            total_player_count += recall_info["player_total"]
            total_team_hit += recall_info["team_hit"]
            total_team_count += recall_info["team_total"]

            # 计算单个样本的召回率（仅用于打印）
            player_recall = recall_info["player_hit"] / \
                recall_info["player_total"] if recall_info["player_total"] > 0 else 1.0
            team_recall = recall_info["team_hit"] / \
                recall_info["team_total"] if recall_info["team_total"] > 0 else 1.0

            # print(f"参考: {reference}")
            # print(f"预测: {prediction}")
            # print(f"GT球员: {recall_info['gt_players']}")
            # print(
            #     f"匹配球员: {recall_info['matched_players']} | Player Recall: {player_recall:.2f}")
            # print(f"GT球队: {recall_info['gt_teams']}")
            # print(
            #     f"匹配球队: {recall_info['matched_teams']} | Team Recall: {team_recall:.2f}\n")

            detailed_results.append({
                "reference": reference,
                "prediction": prediction,
                "ground_truth_players": recall_info["gt_players"],
                "matched_players": recall_info["matched_players"],
                "player_recall": player_recall,
                "ground_truth_teams": recall_info["gt_teams"],
                "matched_teams": recall_info["matched_teams"],
                "team_recall": team_recall
            })

    # 计算整体召回率
    overall_player_recall = total_player_hit / \
        total_player_count if total_player_count > 0 else 1.0
    overall_team_recall = total_team_hit / \
        total_team_count if total_team_count > 0 else 1.0
    overall_recall = (overall_player_recall + overall_team_recall) / 2

    if sample_count > 0:
        print(f"球员识别 - 总共: {total_player_count}, 正确识别: {total_player_hit}")
        print(f"球队识别 - 总共: {total_team_count}, 正确识别: {total_team_hit}")
        print(f"整体球员Recall: {overall_player_recall:.4f}")
        print(f"整体球队Recall: {overall_team_recall:.4f}")
        print(f"综合Recall: {overall_recall:.4f}")
        print(f"(共{sample_count}条有实体的样本)")

        # 保存详细结果供进一步分析
        # with open('entity_recognition_results.json', 'w', encoding='utf-8') as f:
        #     json.dump({
        #         "player_recall": overall_player_recall,
        #         "team_recall": overall_team_recall,
        #         "overall_recall": overall_recall,
        #         "total_samples": sample_count,
        #         "total_player_count": total_player_count,
        #         "total_player_hit": total_player_hit,
        #         "total_team_count": total_team_count,
        #         "total_team_hit": total_team_hit,
        #         "detailed_results": detailed_results
        #     }, f, indent=2, ensure_ascii=False)
    else:
        print("没有可评估的样本。")

In [2]:
# 执行评估
inference_result_path = './pretrained_both.json'
unanonymous_metric(inference_result_path)

球员识别 - 总共: 3359, 正确识别: 111
球队识别 - 总共: 3318, 正确识别: 579
整体球员Recall: 0.0330
整体球队Recall: 0.1745
综合Recall: 0.1038
(共3129条有实体的样本)


In [3]:
# 执行评估
inference_result_path = './pretrained_classification.json'
unanonymous_metric(inference_result_path)

球员识别 - 总共: 3359, 正确识别: 90
球队识别 - 总共: 3318, 正确识别: 656
整体球员Recall: 0.0268
整体球队Recall: 0.1977
综合Recall: 0.1123
(共3129条有实体的样本)
