# 相关代码

## 有关词表

### 查看原词表信息

In [4]:
import pickle as pkl
import random
import torch
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import numpy as np
import os

# 设置环境
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")

# 加载与模型相同的tokenizer
tokenizer_ckpt = "/home/jiayuanrao/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B-Instruct"
try:
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_ckpt)
    # 添加与模型中相同的特殊token
    tokenizer.add_tokens(["[PLAYER]", "[TEAM]", "[COACH]",
                         "[REFEREE]", "([TEAM])"], special_tokens=True)
except Exception as e:
    print(f"加载tokenizer时出错: {e}")
    print("如果是权限问题，请确保有访问权限或使用本地可用的tokenizer")

# 加载足球相关词汇ID列表
file_path = './soccer_words_llama3.pkl'
with open(file_path, 'rb') as file:
    token_ids_list = pkl.load(file)

# 添加特殊token ID（与模型中相同）
token_ids_list.append(128000)
token_ids_list.append(128001)

# 计算基本统计信息
print(f"允许的token总数: {len(token_ids_list)}")
print(f"占tokenizer词表大小的百分比: {len(token_ids_list)/len(tokenizer)*100:.2f}%")

使用设备: cuda
允许的token总数: 2909
占tokenizer词表大小的百分比: 2.27%


### 构建新词表

In [6]:
import os
import json
import pickle as pkl
from transformers import AutoTokenizer
from tqdm import tqdm
import torch


def collect_all_tokens(root_dir, tokenizer, timestamp_key="gameTime"):
    """
    遍历所有标注文件并收集所有token
    """
    all_tokens = set()

    # 进度条
    all_files = []
    for subdir, _, files in os.walk(root_dir):
        for file in files:
            if file in ['Labels-caption.json', 'Labels-caption_with_gt.json']:
                all_files.append(os.path.join(subdir, file))

    print(f"Found {len(all_files)} annotation files")

    for file_path in tqdm(all_files, desc="Processing files"):
        try:
            with open(file_path, 'r') as file:
                data = json.load(file)

            # 提取所有的文本
            for annotation in data.get('annotations', []):
                unanonymized = annotation.get('description', '')
                if not unanonymized:
                    continue

                # 对每个文本进行tokenize
                tokens = tokenizer(
                    unanonymized,
                    add_special_tokens=True,
                    return_tensors="pt"
                ).input_ids[0]

                # 将token添加到集合中
                all_tokens.update(tokens.tolist())

        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            continue

    return sorted(list(all_tokens))


def main():
    # 设置tokenizer
    tokenizer_ckpt = "/home/jiayuanrao/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_ckpt)
    tokenizer.add_tokens(
        ["[PLAYER]", "[TEAM]", "[COACH]", "[REFEREE]", "([TEAM])"],
        special_tokens=True
    )

    # 数据根目录
    ann_root = "./dataset"  # 请替换为实际的标注文件根目录

    # 收集所有token
    print("Collecting tokens...")
    all_tokens = collect_all_tokens(ann_root, tokenizer)

    # 添加特殊token ID
    all_tokens.extend([128000, 128001])

    # 输出统计信息
    print(f"Total unique tokens: {len(all_tokens)}")
    print(
        f"Percentage of tokenizer vocabulary: {len(all_tokens)/len(tokenizer)*100:.2f}%")

    # 随机抽样一些token查看
    sample_size = min(20, len(all_tokens))
    sample_tokens = random.sample(all_tokens, sample_size)
    print("\nSample tokens:")
    for token in sample_tokens:
        try:
            print(f"Token ID: {token}, Text: '{tokenizer.decode([token])}'")
        except:
            print(f"Token ID: {token}, Decode failed")

    # 保存为pkl文件
    output_file = "unanonymized_soccer_words_llama3.pkl"
    with open(output_file, 'wb') as f:
        pkl.dump(all_tokens, f)
    print(f"\nTokens saved to {output_file}")


if __name__ == "__main__":
    main()

Collecting tokens...
Found 942 annotation files


Processing files: 100%|██████████| 942/942 [00:27<00:00, 34.88it/s]

Total unique tokens: 6740
Percentage of tokenizer vocabulary: 5.25%

Sample tokens:
Token ID: 12055, Text: 'umps'
Token ID: 89502, Text: 'ambi'
Token ID: 33766, Text: ' reflex'
Token ID: 75696, Text: 'Fred'
Token ID: 2970, Text: '58'
Token ID: 12301, Text: 'hen'
Token ID: 1474, Text: '-m'
Token ID: 26235, Text: 'ieu'
Token ID: 20940, Text: ' posit'
Token ID: 70904, Text: ' bounced'
Token ID: 1645, Text: ' ac'
Token ID: 16820, Text: 'yt'
Token ID: 21933, Text: ' striking'
Token ID: 14797, Text: 'ovement'
Token ID: 15916, Text: 'erts'
Token ID: 43085, Text: '-ag'
Token ID: 27561, Text: 'icol'
Token ID: 21117, Text: '-foot'
Token ID: 3729, Text: ' contact'
Token ID: 23083, Text: 'David'

Tokens saved to unanonymized_soccer_words_llama3.pkl





In [1]:
import os
import json
import pickle as pkl
from transformers import AutoTokenizer
from tqdm import tqdm
import torch
    # 设置tokenizer
tokenizer_ckpt = "/home/jiayuanrao/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_ckpt)
tokenizer.decode([128000, 128001])

'<|begin_of_text|><|end_of_text|>'

## MatchTime inference

In [None]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import clip
from PIL import Image
import torch
import os
import cv2
import argparse
from models.matchvoice_model import matchvoice_model
%tb


class VideoDataset(Dataset):
    def __init__(self, video_path, size=224, fps=2):
        self.video_path = video_path
        self.size = size
        self.fps = fps
        self.transforms = transforms.Compose([
            transforms.Resize((self.size, self.size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[
                                 0.26862954, 0.26130258, 0.27577711]),
        ])
        # Load video using OpenCV
        self.cap = cv2.VideoCapture(self.video_path)
        self.length = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        # Calculate frames to capture based on FPS
        self.frame_indices = [int(x * self.cap.get(cv2.CAP_PROP_FPS) / self.fps)
                              for x in range(int(self.length / self.cap.get(cv2.CAP_PROP_FPS) * self.fps))]

    def __len__(self):
        return len(self.frame_indices)

    def __getitem__(self, idx):
        self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.frame_indices[idx])
        ret, frame = self.cap.read()
        if not ret:
            print("Error in reading frame")
            return None
        # Convert color from BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        # Apply transformations
        frame = self.transforms(Image.fromarray(frame))
        return frame.to(torch.float16)

    def close(self):
        self.cap.release()


def encode_features(data_loader, encoder, device):
    all_features = None  # 初始化为None，用于第一次赋值
    for frames in data_loader:
        features = encoder(frames.to(device))
        if all_features is None:
            all_features = features  # 第一次迭代，直接赋值
        else:
            all_features = torch.cat(
                (all_features, features), dim=0)  # 后续迭代，在第0维（行）上连接
    return all_features


def predict_single_video_CLIP(video_path, predict_model, visual_encoder, size, fps, device):
    # Loading features
    try:
        dataset = VideoDataset(video_path, size=size, fps=fps)
        data_loader = DataLoader(
            dataset, batch_size=40, shuffle=False, pin_memory=True, num_workers=0)
        # print("Start encoding!")
        features = encode_features(data_loader, visual_encoder, device)
        dataset.close()
        print("Features of this video loaded with shape of:", features.shape)
    except:
        print("Error with loading:", video_path)

    sample = {
        "features": features.unsqueeze(dim=0),
        "labels": None,
        "attention_mask": None,
        "input_ids": None
    }

    # Doing prediction:
    comment = predict_model(sample)
    print("The commentary is:", comment)


parser = argparse.ArgumentParser(
    description='Process video files for feature extraction.')
parser.add_argument('--video_path', type=str, default="/remote-home/jiayuanrao/haokai/UniSoccer/train_data/video_clips/europe_uefa-champions-league_2016-2017/2017-04-12 - 21-45 Bayern Munich 1 - 2 Real Madrid/2_27_58.mp4",
                    help='Path to the soccer game video clip.')
parser.add_argument('--device', type=str,
                    default="cuda:2", help='Device to extract.')
parser.add_argument('--size', type=int, default=224,
                    help='Size to which each video frame is resized.')
parser.add_argument('--fps', type=int, default=2,
                    help='Frames per second to sample from the video.')
parser.add_argument("--tokenizer_name", type=str, default="/home/jiayuanrao/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B-Instruct",
                    help="LLM checkpoints, use path in your computer is fine as well")
parser.add_argument("--restricted_token_list_path", type=str,
                    default="./Restricted_Token_List/unanonymized_llama3.pkl",
                    help="Restricted token list, use path in your computer is fine as well")
parser.add_argument("--model_ckpt", type=str, default="./results/unanonymized_matchtime_aligned_time/checkpoints/model_save_best_val_CIDEr.pth",
                    help="Model checkpoints, use path in your computer is fine as well")
parser.add_argument("--num_query_tokens", type=int, default=32)
parser.add_argument("--num_video_query_token", type=int, default=32)
parser.add_argument("--num_features", type=int, default=512)

args = parser.parse_args()

# 创建并配置模型
model, preprocess = clip.load("ViT-B/32", device=args.device)
model.eval()
# print(model.dtype)
clip_image_encoder = model.encode_image
predict_model = matchvoice_model(llm_ckpt=args.tokenizer_name,
                                    restricted_token_list_path=args.restricted_token_list_path,
                                    tokenizer_ckpt=args.tokenizer_name,
                                    num_video_query_token=args.num_video_query_token,
                                    num_features=args.num_features,
                                    device=args.device,
                                    inference=True)
# Load checkpoints
other_parts_state_dict = torch.load(args.model_ckpt)
new_model_state_dict = predict_model.state_dict()
for key, value in other_parts_state_dict.items():
    if key in new_model_state_dict:
        new_model_state_dict[key] = value
predict_model.load_state_dict(new_model_state_dict)
predict_model.eval()

In [None]:
video_path = ''

predict_single_video_CLIP(video_path=video_path, predict_model=predict_model,
                          visual_encoder=clip_image_encoder, device=args.device, size=args.size, fps=args.fps)

In [7]:
import numpy as np

feature_path = './features/features_CLIP/england_epl_2014-2015/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley/1_224p_CLIP.npy'
features = np.load(feature_path)
features.shape

(5400, 512)

## Unanonymized Metric

In [1]:
import json
import re


def extract_entities(reference, anonymized):
    """
    根据 comments_text_anonymized 中的占位符，提取 reference 中相应位置的实体名称
    支持提取所有 [PLAYER] 和 [TEAM] 标记对应的实际名称
    """
    # 如果缺少任一输入，返回空列表
    if not reference or not anonymized:
        return [], []

    # 处理特殊token的正则模式
    player_pattern = r'\[PLAYER\]'
    team_pattern = r'\[TEAM\]'
    referee_pattern = r'\[REFEREE\]'
    coach_pattern = r'\[COACH\]'

    # 创建占位符标记的位置映射
    placeholders = []
    for match in re.finditer(player_pattern, anonymized):
        placeholders.append(("PLAYER", match.start(), match.end()))
    for match in re.finditer(team_pattern, anonymized):
        placeholders.append(("TEAM", match.start(), match.end()))
    for match in re.finditer(referee_pattern, anonymized):
        placeholders.append(("REFEREE", match.start(), match.end()))
    for match in re.finditer(coach_pattern, anonymized):
        placeholders.append(("COACH", match.start(), match.end()))

    # 按位置排序占位符
    placeholders.sort(key=lambda x: x[1])

    # 记录参考文本中每个位置对应的原始文本和匿名文本的字符差异
    offset_map = {}
    ref_idx = 0
    anon_idx = 0

    while anon_idx < len(anonymized) and ref_idx < len(reference):
        # 检查是否匹配到占位符
        is_placeholder = False
        placeholder_type = None
        placeholder_length = 0

        if anon_idx + 8 <= len(anonymized) and anonymized[anon_idx:anon_idx + 8] == "[PLAYER]":
            placeholder_type = "PLAYER"
            placeholder_length = 8
            is_placeholder = True
        elif anon_idx + 6 <= len(anonymized) and anonymized[anon_idx:anon_idx + 6] == "[TEAM]":
            placeholder_type = "TEAM"
            placeholder_length = 6
            is_placeholder = True
        elif anon_idx + 9 <= len(anonymized) and anonymized[anon_idx:anon_idx + 9] == "[REFEREE]":
            placeholder_type = "REFEREE"
            placeholder_length = 9
            is_placeholder = True
        elif anon_idx + 7 <= len(anonymized) and anonymized[anon_idx:anon_idx + 7] == "[COACH]":
            placeholder_type = "COACH"
            placeholder_length = 7
            is_placeholder = True

        if is_placeholder:
            # 找到实体名
            start_ref_idx = ref_idx
            # 向前查找，直到遇到非名称字符
            # 名称可能包含字母、空格和连字符
            while ref_idx < len(reference) and (reference[ref_idx].isalpha() or reference[ref_idx] in [' ', '-']):
                ref_idx += 1
            # 记录偏移量和替换文本
            offset_map[anon_idx] = {
                "type": placeholder_type,
                "start": start_ref_idx,
                "end": ref_idx,
                "text": reference[start_ref_idx:ref_idx].strip()
            }
            anon_idx += placeholder_length
        else:
            # 对于相同的字符，同步前进
            if anon_idx < len(anonymized) and ref_idx < len(reference) and anonymized[anon_idx] == reference[ref_idx]:
                anon_idx += 1
                ref_idx += 1
            else:
                # 处理非占位符的不匹配字符
                anon_idx += 1
                ref_idx += 1

    # 提取所有实体名
    player_names = []
    team_names = []

    for placeholder_type, start, end in placeholders:
        if start in offset_map:
            entity_info = offset_map[start]
            entity_text = entity_info["text"]

            if entity_text:
                # 去除可能的标点和空格
                entity_text = entity_text.strip()

                # 处理不同类型的实体
                if entity_info["type"] == "PLAYER" and entity_text:
                    # 只保留看起来像名字的部分
                    name_parts = re.findall(
                        r'[A-Z][a-zA-Z\'.-]+(?:\s+[A-Z][a-zA-Z\'.-]+)*', entity_text)
                    if name_parts:
                        player_names.append(name_parts[0])
                elif entity_info["type"] == "TEAM" and entity_text:
                    # 球队名可能是多个单词组成
                    team_parts = re.findall(
                        r'[A-Z][a-zA-Z\'.-]+(?:\s+[a-zA-Z\'.-]+)*', entity_text)
                    if team_parts:
                        team_names.append(team_parts[0])

    # 兜底方法：查找标准格式的球员名和球队名 "Name (Team)"
    if not player_names or not team_names:
        names_with_team = re.findall(
            r'([A-Z][a-zA-Z\'.-]+(?:\s+[A-Z][a-zA-Z\'.-]+)*)\s*\(([^)]+)\)', reference)
        if names_with_team:
            if not player_names:
                player_names = [name for name, _ in names_with_team]
            if not team_names:
                # 合并所有球队名并去重
                all_teams = [team for _, team in names_with_team]
                team_names = list(set(all_teams))

    return player_names, team_names


def entity_recall(prediction, reference, anonymized):
    """
    计算在prediction中正确识别出的实体名字的比例
    分别返回球员和球队的召回率
    """
    gt_players, gt_teams = extract_entities(reference, anonymized)

    # 在prediction中查找每个实体名（不区分大小写）
    pred_lower = prediction.lower()

    # 球员识别
    player_hit = 0
    matched_players = []
    missed_players = []

    for name in gt_players:
        if name.lower() in pred_lower:
            player_hit += 1
            matched_players.append(name)
        else:
            missed_players.append(name)

    # 球队识别
    team_hit = 0
    matched_teams = []
    missed_teams = []

    for name in gt_teams:
        if name.lower() in pred_lower:
            team_hit += 1
            matched_teams.append(name)
        else:
            missed_teams.append(name)

    return {
        "player_hit": player_hit,  # 正确识别的球员数
        "player_total": len(gt_players),  # 总球员数
        "matched_players": matched_players,
        "missed_players": missed_players,
        "team_hit": team_hit,  # 正确识别的球队数
        "team_total": len(gt_teams),  # 总球队数
        "matched_teams": matched_teams,
        "missed_teams": missed_teams,
        "gt_players": gt_players,
        "gt_teams": gt_teams
    }


def unanonymous_metric(inference_result_path):
    """
    评估模型识别球员名字和球队名的能力
    """
    with open(inference_result_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    pairs = data.get('pairs', [])

    # 统计所有样本的总数
    total_player_hit = 0  # 所有样本中正确识别的球员总数
    total_player_count = 0  # 所有样本中球员总数
    total_team_hit = 0  # 所有样本中正确识别的球队总数
    total_team_count = 0  # 所有样本中球队总数

    sample_count = 0  # 有实体的样本数
    detailed_results = []

    for item in pairs:
        prediction = item.get('prediction', '')
        reference = item.get('reference', '')
        anonymized = item.get('anonymous_comment', '')

        # 跳过没有必要字段的条目
        if not (prediction and reference and anonymized):
            continue

        recall_info = entity_recall(prediction, reference, anonymized)

        # 只记录有球员名或球队名的样本
        has_entities = recall_info["player_total"] > 0 or recall_info["team_total"] > 0

        if has_entities:
            sample_count += 1

            # 累加各项指标
            total_player_hit += recall_info["player_hit"]
            total_player_count += recall_info["player_total"]
            total_team_hit += recall_info["team_hit"]
            total_team_count += recall_info["team_total"]

            # 计算单个样本的召回率（仅用于打印）
            player_recall = recall_info["player_hit"] / \
                recall_info["player_total"] if recall_info["player_total"] > 0 else 1.0
            team_recall = recall_info["team_hit"] / \
                recall_info["team_total"] if recall_info["team_total"] > 0 else 1.0

            # print(f"参考: {reference}")
            # print(f"预测: {prediction}")
            # print(f"GT球员: {recall_info['gt_players']}")
            # print(
            #     f"匹配球员: {recall_info['matched_players']} | Player Recall: {player_recall:.2f}")
            # print(f"GT球队: {recall_info['gt_teams']}")
            # print(
            #     f"匹配球队: {recall_info['matched_teams']} | Team Recall: {team_recall:.2f}\n")

            detailed_results.append({
                "reference": reference,
                "prediction": prediction,
                "ground_truth_players": recall_info["gt_players"],
                "matched_players": recall_info["matched_players"],
                "player_recall": player_recall,
                "ground_truth_teams": recall_info["gt_teams"],
                "matched_teams": recall_info["matched_teams"],
                "team_recall": team_recall
            })

    # 计算整体召回率
    overall_player_recall = total_player_hit / \
        total_player_count if total_player_count > 0 else 1.0
    overall_team_recall = total_team_hit / \
        total_team_count if total_team_count > 0 else 1.0
    overall_recall = (overall_player_recall + overall_team_recall) / 2

    if sample_count > 0:
        print(f"球员识别 - 总共: {total_player_count}, 正确识别: {total_player_hit}")
        print(f"球队识别 - 总共: {total_team_count}, 正确识别: {total_team_hit}")
        print(f"整体球员Recall: {overall_player_recall:.4f}")
        print(f"整体球队Recall: {overall_team_recall:.4f}")
        print(f"综合Recall: {overall_recall:.4f}")
        print(f"(共{sample_count}条有实体的样本)")

        # 保存详细结果供进一步分析
        # with open('entity_recognition_results.json', 'w', encoding='utf-8') as f:
        #     json.dump({
        #         "player_recall": overall_player_recall,
        #         "team_recall": overall_team_recall,
        #         "overall_recall": overall_recall,
        #         "total_samples": sample_count,
        #         "total_player_count": total_player_count,
        #         "total_player_hit": total_player_hit,
        #         "total_team_count": total_team_count,
        #         "total_team_hit": total_team_hit,
        #         "detailed_results": detailed_results
        #     }, f, indent=2, ensure_ascii=False)
    else:
        print("没有可评估的样本。")

In [3]:
# 执行评估
inference_result_path = './results/unanonymized_matchtime_1fps_aligned_time/final_validation_results.json'
unanonymous_metric(inference_result_path)

球员识别 - 总共: 3237, 正确识别: 227
球队识别 - 总共: 3193, 正确识别: 1229
整体球员Recall: 0.0701
整体球队Recall: 0.3849
综合Recall: 0.2275
(共3005条有实体的样本)
