## get run id

In [2]:
import wandb

api = wandb.Api()
runs = api.runs(
    path="1649466799-xi-an-jiaotong-university/qwen-fine-tuning")
for i in runs:
  print("run name = ",i.name," id: ", i.id)

run name =  qwen-fine-tuning  id:  dypubzid
run name =  qwen-fine-tuning  id:  7v3kzuer
run name =  qwen_team_player_prompt_20250519_025340  id:  lar9qs4a
run name =  qwen_team_prompt_20250519_025412  id:  0g3162ji


## 构建unanonymized词表

先查看特殊token的id

In [3]:
import os
import json
import pickle as pkl
from transformers import AutoTokenizer
from tqdm import tqdm
import torch
    # 设置tokenizer
tokenizer_ckpt = "Qwen/Qwen2.5-VL-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_ckpt)
tokenizer.encode("<|im_start|>", "<|im_end|>")

[151644, 151645]

构建词表

In [5]:
import os
import json
import random
import pickle as pkl
from transformers import AutoTokenizer
from tqdm import tqdm
import torch


def collect_all_tokens(root_dir, tokenizer, timestamp_key="gameTime"):
    """
    遍历所有标注文件并收集所有token
    """
    all_tokens = set()

    # 进度条
    all_files = []
    for subdir, _, files in os.walk(root_dir):
        for file in files:
            if os.path.splitext(file)[-1] == '.json':
                all_files.append(os.path.join(subdir, file))

    print(f"Found {len(all_files)} annotation files")

    for file_path in tqdm(all_files, desc="Processing files"):
        try:
            with open(file_path, 'r') as file:
                data = json.load(file)

            # 提取所有的文本
            for clip in data:
                unanonymized = clip.get('comments_text', '')
                if not unanonymized:
                    continue

                # 对每个文本进行tokenize
                tokens = tokenizer(
                    unanonymized,
                    add_special_tokens=True,
                    return_tensors="pt"
                ).input_ids[0]

                # 将token添加到集合中
                all_tokens.update(tokens.tolist())

        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            continue

    return sorted(list(all_tokens))


def main():
    # 设置tokenizer
    tokenizer_ckpt = "Qwen/Qwen2.5-VL-7B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_ckpt)

    # 数据根目录
    ann_root = "../train_data/json/MatchTime"  # 请替换为实际的标注文件根目录

    # 收集所有token
    print("Collecting tokens...")
    all_tokens = collect_all_tokens(ann_root, tokenizer)

    # 添加特殊token ID
    all_tokens.extend([151644, 151645])

    # 输出统计信息
    print(f"Total unique tokens: {len(all_tokens)}")
    print(
        f"Percentage of tokenizer vocabulary: {len(all_tokens)/len(tokenizer)*100:.2f}%")

    # 随机抽样一些token查看
    sample_size = min(20, len(all_tokens))
    sample_tokens = random.sample(all_tokens, sample_size)
    print("\nSample tokens:")
    for token in sample_tokens:
        try:
            print(f"Token ID: {token}, Text: '{tokenizer.decode([token])}'")
        except:
            print(f"Token ID: {token}, Decode failed")

    # 保存为pkl文件
    output_file = "../words_world/qwen/unanonymized_matchtime.pkl"
    with open(output_file, 'wb') as f:
        pkl.dump(all_tokens, f)
    print(f"\nTokens saved to {output_file}")


if __name__ == "__main__":
    main()

Collecting tokens...
Found 3 annotation files


Processing files: 100%|██████████| 3/3 [00:04<00:00,  1.61s/it]

Total unique tokens: 6087
Percentage of tokenizer vocabulary: 4.01%

Sample tokens:
Token ID: 78953, Text: ' sideline'
Token ID: 13391, Text: ' ending'
Token ID: 585, Text: 'ak'
Token ID: 20527, Text: 'ishes'
Token ID: 79630, Text: ' Rooney'
Token ID: 1549, Text: ' again'
Token ID: 3791, Text: 'ising'
Token ID: 2453, Text: ' care'
Token ID: 6050, Text: ' Mo'
Token ID: 7604, Text: 'ady'
Token ID: 11794, Text: ' snow'
Token ID: 78379, Text: ' Toby'
Token ID: 24821, Text: 'icates'
Token ID: 11137, Text: ' causes'
Token ID: 429, Text: ' that'
Token ID: 36112, Text: ' jersey'
Token ID: 6250, Text: ' Ver'
Token ID: 1961, Text: 'oin'
Token ID: 1256, Text: 'ector'
Token ID: 74063, Text: 'olg'

Tokens saved to ../words_world/qwen/unanonymized_matchtime.pkl





## 尝试load adaptor

In [None]:
import torch
from transformers import AutoModelForCausalLM, Qwen2_5_VLForConditionalGeneration
import peft

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-VL-7B-Instruct",
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
model.load_adapter('../results/qwen_team_player_prompt_20250519_025340/final_model/')

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

: 

## 算平均token和最多token

In [None]:
import os
import json
from transformers import AutoTokenizer
from tqdm import tqdm
import torch


def calculate_token_stats(root_dir, tokenizer):
    """
    遍历所有标注文件并计算评论的平均token数量和最大token数量
    """
    total_tokens = 0
    max_tokens = 0
    num_comments = 0

    # 进度条
    all_files = []
    for subdir, _, files in os.walk(root_dir):
        for file in files:
            if os.path.splitext(file)[-1] == '.json':
                all_files.append(os.path.join(subdir, file))

    print(f"Found {len(all_files)} annotation files")

    for file_path in tqdm(all_files, desc="Processing files"):
        try:
            with open(file_path, 'r') as file:
                data = json.load(file)

            # 提取所有的文本
            for clip in data:
                unanonymized = clip.get('comments_text', '')
                if not unanonymized:
                    continue

                # 对每个文本进行tokenize
                tokens = tokenizer(
                    unanonymized,
                    add_special_tokens=True,
                    return_tensors="pt"
                ).input_ids[0]

                token_count = len(tokens)
                total_tokens += token_count
                max_tokens = max(max_tokens, token_count)
                num_comments += 1

        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            continue

    avg_tokens = total_tokens / num_comments if num_comments > 0 else 0
    return avg_tokens, max_tokens


def main():
    # 设置tokenizer
    tokenizer_ckpt = "Qwen/Qwen2.5-VL-7B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_ckpt)

    # 数据根目录
    ann_root = "../train_data/json/MatchTime"  # 请替换为实际的标注文件根目录

    # 计算token统计信息
    print("Calculating token statistics...")
    avg_tokens, max_tokens = calculate_token_stats(ann_root, tokenizer)

    # 输出统计信息
    print(f"Average tokens per comment: {avg_tokens:.2f}")
    print(f"Maximum tokens in a single comment: {max_tokens}")


if __name__ == "__main__":
    main()

## Inference unanonymized metric

In [1]:
import json
import re


def extract_entities(reference, anonymized):
    """
    根据 comments_text_anonymized 中的占位符，提取 reference 中相应位置的实体名称
    支持提取所有 [PLAYER] 和 [TEAM] 标记对应的实际名称
    """
    # 如果缺少任一输入，返回空列表
    if not reference or not anonymized:
        return [], []

    # 处理特殊token的正则模式
    player_pattern = r'\[PLAYER\]'
    team_pattern = r'\[TEAM\]'
    referee_pattern = r'\[REFEREE\]'
    coach_pattern = r'\[COACH\]'

    # 创建占位符标记的位置映射
    placeholders = []
    for match in re.finditer(player_pattern, anonymized):
        placeholders.append(("PLAYER", match.start(), match.end()))
    for match in re.finditer(team_pattern, anonymized):
        placeholders.append(("TEAM", match.start(), match.end()))
    for match in re.finditer(referee_pattern, anonymized):
        placeholders.append(("REFEREE", match.start(), match.end()))
    for match in re.finditer(coach_pattern, anonymized):
        placeholders.append(("COACH", match.start(), match.end()))

    # 按位置排序占位符
    placeholders.sort(key=lambda x: x[1])

    # 记录参考文本中每个位置对应的原始文本和匿名文本的字符差异
    offset_map = {}
    ref_idx = 0
    anon_idx = 0

    while anon_idx < len(anonymized) and ref_idx < len(reference):
        # 检查是否匹配到占位符
        is_placeholder = False
        placeholder_type = None
        placeholder_length = 0

        if anon_idx + 8 <= len(anonymized) and anonymized[anon_idx:anon_idx + 8] == "[PLAYER]":
            placeholder_type = "PLAYER"
            placeholder_length = 8
            is_placeholder = True
        elif anon_idx + 6 <= len(anonymized) and anonymized[anon_idx:anon_idx + 6] == "[TEAM]":
            placeholder_type = "TEAM"
            placeholder_length = 6
            is_placeholder = True
        elif anon_idx + 9 <= len(anonymized) and anonymized[anon_idx:anon_idx + 9] == "[REFEREE]":
            placeholder_type = "REFEREE"
            placeholder_length = 9
            is_placeholder = True
        elif anon_idx + 7 <= len(anonymized) and anonymized[anon_idx:anon_idx + 7] == "[COACH]":
            placeholder_type = "COACH"
            placeholder_length = 7
            is_placeholder = True

        if is_placeholder:
            # 找到实体名
            start_ref_idx = ref_idx
            # 向前查找，直到遇到非名称字符
            # 名称可能包含字母、空格和连字符
            while ref_idx < len(reference) and (reference[ref_idx].isalpha() or reference[ref_idx] in [' ', '-']):
                ref_idx += 1
            # 记录偏移量和替换文本
            offset_map[anon_idx] = {
                "type": placeholder_type,
                "start": start_ref_idx,
                "end": ref_idx,
                "text": reference[start_ref_idx:ref_idx].strip()
            }
            anon_idx += placeholder_length
        else:
            # 对于相同的字符，同步前进
            if anon_idx < len(anonymized) and ref_idx < len(reference) and anonymized[anon_idx] == reference[ref_idx]:
                anon_idx += 1
                ref_idx += 1
            else:
                # 处理非占位符的不匹配字符
                anon_idx += 1
                ref_idx += 1

    # 提取所有实体名
    player_names = []
    team_names = []

    for placeholder_type, start, end in placeholders:
        if start in offset_map:
            entity_info = offset_map[start]
            entity_text = entity_info["text"]

            if entity_text:
                # 去除可能的标点和空格
                entity_text = entity_text.strip()

                # 处理不同类型的实体
                if entity_info["type"] == "PLAYER" and entity_text:
                    # 只保留看起来像名字的部分
                    name_parts = re.findall(
                        r'[A-Z][a-zA-Z\'.-]+(?:\s+[A-Z][a-zA-Z\'.-]+)*', entity_text)
                    if name_parts:
                        player_names.append(name_parts[0])
                elif entity_info["type"] == "TEAM" and entity_text:
                    # 球队名可能是多个单词组成
                    team_parts = re.findall(
                        r'[A-Z][a-zA-Z\'.-]+(?:\s+[a-zA-Z\'.-]+)*', entity_text)
                    if team_parts:
                        team_names.append(team_parts[0])

    # 兜底方法：查找标准格式的球员名和球队名 "Name (Team)"
    if not player_names or not team_names:
        names_with_team = re.findall(
            r'([A-Z][a-zA-Z\'.-]+(?:\s+[A-Z][a-zA-Z\'.-]+)*)\s*\(([^)]+)\)', reference)
        if names_with_team:
            if not player_names:
                player_names = [name for name, _ in names_with_team]
            if not team_names:
                # 合并所有球队名并去重
                all_teams = [team for _, team in names_with_team]
                team_names = list(set(all_teams))

    return player_names, team_names


def entity_recall(prediction, reference, anonymized):
    """
    计算在prediction中正确识别出的实体名字的比例
    分别返回球员和球队的召回率
    """
    gt_players, gt_teams = extract_entities(reference, anonymized)

    # 在prediction中查找每个实体名（不区分大小写）
    pred_lower = prediction.lower()

    # 球员识别
    player_hit = 0
    matched_players = []
    missed_players = []

    for name in gt_players:
        if name.lower() in pred_lower:
            player_hit += 1
            matched_players.append(name)
        else:
            missed_players.append(name)

    # 球队识别
    team_hit = 0
    matched_teams = []
    missed_teams = []

    for name in gt_teams:
        if name.lower() in pred_lower:
            team_hit += 1
            matched_teams.append(name)
        else:
            missed_teams.append(name)

    return {
        "player_hit": player_hit,  # 正确识别的球员数
        "player_total": len(gt_players),  # 总球员数
        "matched_players": matched_players,
        "missed_players": missed_players,
        "team_hit": team_hit,  # 正确识别的球队数
        "team_total": len(gt_teams),  # 总球队数
        "matched_teams": matched_teams,
        "missed_teams": missed_teams,
        "gt_players": gt_players,
        "gt_teams": gt_teams
    }


def unanonymous_metric(inference_result_path):
    """
    评估模型识别球员名字和球队名的能力
    """
    with open(inference_result_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    pairs = data.get('pairs', [])

    # 统计所有样本的总数
    total_player_hit = 0  # 所有样本中正确识别的球员总数
    total_player_count = 0  # 所有样本中球员总数
    total_team_hit = 0  # 所有样本中正确识别的球队总数
    total_team_count = 0  # 所有样本中球队总数

    sample_count = 0  # 有实体的样本数
    detailed_results = []

    for item in pairs:
        prediction = item.get('prediction', '')
        reference = item.get('reference', '')
        anonymized = item.get('comments_text_anonymized', '')

        # 跳过没有必要字段的条目
        if not (prediction and reference and anonymized):
            continue

        recall_info = entity_recall(prediction, reference, anonymized)

        # 只记录有球员名或球队名的样本
        has_entities = recall_info["player_total"] > 0 or recall_info["team_total"] > 0

        if has_entities:
            sample_count += 1

            # 累加各项指标
            total_player_hit += recall_info["player_hit"]
            total_player_count += recall_info["player_total"]
            total_team_hit += recall_info["team_hit"]
            total_team_count += recall_info["team_total"]

            # 计算单个样本的召回率（仅用于打印）
            player_recall = recall_info["player_hit"] / \
                recall_info["player_total"] if recall_info["player_total"] > 0 else 1.0
            team_recall = recall_info["team_hit"] / \
                recall_info["team_total"] if recall_info["team_total"] > 0 else 1.0

            # print(f"参考: {reference}")
            # print(f"预测: {prediction}")
            # print(f"GT球员: {recall_info['gt_players']}")
            # print(
            #     f"匹配球员: {recall_info['matched_players']} | Player Recall: {player_recall:.2f}")
            # print(f"GT球队: {recall_info['gt_teams']}")
            # print(
            #     f"匹配球队: {recall_info['matched_teams']} | Team Recall: {team_recall:.2f}\n")

            detailed_results.append({
                "reference": reference,
                "prediction": prediction,
                "ground_truth_players": recall_info["gt_players"],
                "matched_players": recall_info["matched_players"],
                "player_recall": player_recall,
                "ground_truth_teams": recall_info["gt_teams"],
                "matched_teams": recall_info["matched_teams"],
                "team_recall": team_recall
            })

    # 计算整体召回率
    overall_player_recall = total_player_hit / \
        total_player_count if total_player_count > 0 else 1.0
    overall_team_recall = total_team_hit / \
        total_team_count if total_team_count > 0 else 1.0
    overall_recall = (overall_player_recall + overall_team_recall) / 2

    if sample_count > 0:
        print(f"球员识别 - 总共: {total_player_count}, 正确识别: {total_player_hit}")
        print(f"球队识别 - 总共: {total_team_count}, 正确识别: {total_team_hit}")
        print(f"整体球员Recall: {overall_player_recall:.4f}")
        print(f"整体球队Recall: {overall_team_recall:.4f}")
        print(f"综合Recall: {overall_recall:.4f}")
        print(f"(共{sample_count}条有实体的样本)")

        # 保存详细结果供进一步分析
        # with open('entity_recognition_results.json', 'w', encoding='utf-8') as f:
        #     json.dump({
        #         "player_recall": overall_player_recall,
        #         "team_recall": overall_team_recall,
        #         "overall_recall": overall_recall,
        #         "total_samples": sample_count,
        #         "total_player_count": total_player_count,
        #         "total_player_hit": total_player_hit,
        #         "total_team_count": total_team_count,
        #         "total_team_hit": total_team_hit,
        #         "detailed_results": detailed_results
        #     }, f, indent=2, ensure_ascii=False)
    else:
        print("没有可评估的样本。")

In [2]:
# 执行评估
inference_result_path = './inference_results/raw_team_prompt.json'
unanonymous_metric(inference_result_path)

球员识别 - 总共: 2998, 正确识别: 51
球队识别 - 总共: 2953, 正确识别: 2336
整体球员Recall: 0.0170
整体球队Recall: 0.7911
综合Recall: 0.4040
(共2790条有实体的样本)


In [4]:
# 执行评估
inference_result_path = './inference_results/raw_team_player_prompt.json'
unanonymous_metric(inference_result_path)

球员识别 - 总共: 2998, 正确识别: 51
球队识别 - 总共: 2953, 正确识别: 2348
整体球员Recall: 0.0170
整体球队Recall: 0.7951
综合Recall: 0.4061
(共2790条有实体的样本)


In [3]:
# 执行评估
inference_result_path = './inference_results/team_player_prompt.json'
unanonymous_metric(inference_result_path)

球员识别 - 总共: 2993, 正确识别: 246
球队识别 - 总共: 2948, 正确识别: 1566
整体球员Recall: 0.0822
整体球队Recall: 0.5312
综合Recall: 0.3067
(共2785条有实体的样本)


In [29]:
# 执行评估
inference_result_path = './inference_results/team_prompt.json'
unanonymous_metric(inference_result_path)

球员识别 - 总共: 2947, 正确识别: 263
球队识别 - 总共: 2902, 正确识别: 1604
整体球员Recall: 0.0892
整体球队Recall: 0.5527
综合Recall: 0.3210
(共2742条有实体的样本)


In [28]:
import json
with open('./inference_results/team_prompt.json', 'r') as f:
    data = json.load(f)
    pairs = data.get('pairs', [])
len(pairs)  # 输出样本数量

2804