In [2]:
import os
import av
import torch
import numpy as np
from transformers import AutoImageProcessor, TimesformerForVideoClassification
from sklearn.metrics import precision_score, recall_score, f1_score
from collections import defaultdict
import random

# 加载模型和图像处理器
model = TimesformerForVideoClassification.from_pretrained("facebook/timesformer-base-finetuned-k400")
image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")

# 读取视频列表和标签，并按类别组织
video_labels = defaultdict(list)
with open("archive/kinetics400_val_list_videos.txt", "r") as f:
    for line in f:
        name, label = line.strip().split()
        video_labels[int(label)].append(name)

def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    converted_len = int(clip_len * frame_sample_rate)
    end_idx = np.random.randint(converted_len, seg_len)
    start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    return np.clip(indices, start_idx, end_idx - 1).astype(np.int64)

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame.to_ndarray(format="rgb24"))
    return np.stack(frames)

def balanced_sample_videos(num_samples_per_class):
    sampled_files = []
    labels = []
    for label, files in video_labels.items():
        if len(files) >= num_samples_per_class:
            sampled = random.sample(files, num_samples_per_class)
        else:
            sampled = files
        sampled_files.extend(sampled)
        labels.extend([label] * len(sampled))
    return sampled_files, labels

def predict_labels(sampled_files, true_labels):
    predicted_labels = []
    for video_file in sampled_files:
        file_path = os.path.join("archive/videos_val", video_file)
        container = av.open(file_path)
        indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
        video = read_video_pyav(container, indices)
        inputs = image_processor(list(video), return_tensors="pt")

        with torch.no_grad():
            outputs = model(**inputs)
            predicted_label = outputs.logits.argmax(-1).item()
            predicted_labels.append(predicted_label)

    return predicted_labels, true_labels

# 指定每类样本数量，然后进行平衡采样
num_samples_per_class = 1  # 每类样本数
sampled_files, true_labels = balanced_sample_videos(num_samples_per_class)

# 使用采样结果进行预测
predicted_labels_index, true_labels_index = predict_labels(sampled_files, true_labels)

# 计算 precision, recall, 和 F1 分数
precision = precision_score(true_labels_index, predicted_labels_index, average='macro')
recall = recall_score(true_labels_index, predicted_labels_index, average='macro')
f1 = f1_score(true_labels_index, predicted_labels_index, average='macro')

print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)




KeyboardInterrupt: 