In [11]:
import os
import av
import torch
import numpy as np
import torch.nn.functional as F
from transformers import AutoImageProcessor, TimesformerForVideoClassification
from sklearn.metrics import precision_score, recall_score, f1_score
from collections import defaultdict
import random
from tqdm.notebook import tqdm

# 加载模型和图像处理器
model = TimesformerForVideoClassification.from_pretrained("facebook/timesformer-base-finetuned-k400")
image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")

# 读取视频列表和标签，并按类别组织
video_labels = defaultdict(list)
with open("archive/kinetics400_val_list_videos.txt", "r") as f:
    for line in f:
        name, label = line.strip().split()
        video_labels[int(label)].append(name)

def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    converted_len = int(clip_len * frame_sample_rate)
    end_idx = np.random.randint(converted_len, seg_len)
    start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    return np.clip(indices, start_idx, end_idx - 1).astype(np.int64)

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame.to_ndarray(format="rgb24"))
    return np.stack(frames)

def balanced_sample_videos(num_samples_per_class, num_classes=None):
    sampled_files = []
    labels = []

    available_classes = list(video_labels.keys())
    if num_classes is not None and num_classes < len(available_classes):
        chosen_classes = random.sample(available_classes, num_classes)
    else:
        chosen_classes = available_classes

    for label in chosen_classes:
        files = video_labels[label]
        if len(files) >= num_samples_per_class:
            sampled = random.sample(files, num_samples_per_class)
        else:
            sampled = files
        sampled_files.extend(sampled)
        labels.extend([label] * len(sampled))
    
    return sampled_files, labels

def predict_labels(sampled_files, true_labels):
    predicted_labels = []
    prediction_scores = []
    model.to('cuda')
    for video_file in tqdm(sampled_files, desc="Processing videos", unit="video"):
        file_path = os.path.join("archive/videos_val", video_file)
        container = av.open(file_path)
        indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
        video = read_video_pyav(container, indices)
        inputs = image_processor(list(video), return_tensors="pt")
        inputs = {k: v.to('cuda') for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            probabilities = F.softmax(logits, dim=-1)
            predicted_label = logits.argmax(-1).item()
            predicted_labels.append(predicted_label)
            prediction_score = probabilities[0, predicted_label].item()
            prediction_scores.append(prediction_score)

    return sampled_files, predicted_labels, prediction_scores, true_labels

# 获取数据和预测
num_samples_per_class = 8
num_classes = 5
sampled_files, true_labels = balanced_sample_videos(num_samples_per_class, num_classes)
sampled_files, predicted_labels, prediction_scores, true_labels = predict_labels(sampled_files, true_labels)

# 输出结果
for file, pred_label, true_label, score in zip(sampled_files, predicted_labels, true_labels, prediction_scores):
    print(f"Video: {file}, Predicted Label: {pred_label}, True Label: {true_label}, Prediction Score: {score:.4f}")




Processing videos:   0%|          | 0/40 [00:00<?, ?video/s]

Video: 6yP9Rl1cyH0.mp4, Predicted Label: 390, True Label: 187, Prediction Score: 0.6702
Video: bSDl9KS9JPs.mp4, Predicted Label: 187, True Label: 187, Prediction Score: 0.9041
Video: m3I2xiLM8CA.mp4, Predicted Label: 187, True Label: 187, Prediction Score: 0.7866
Video: 1-Gu8XdbVl8.mp4, Predicted Label: 211, True Label: 187, Prediction Score: 0.4155
Video: R6sPhtiikWQ.mp4, Predicted Label: 352, True Label: 187, Prediction Score: 0.4307
Video: 9N87vt-heao.mp4, Predicted Label: 187, True Label: 187, Prediction Score: 0.9867
Video: YAvhEr0J0K4.mp4, Predicted Label: 232, True Label: 187, Prediction Score: 0.4719
Video: 64Y7-9j3rzU.mp4, Predicted Label: 187, True Label: 187, Prediction Score: 0.7809
Video: aOVJb1yzRNk.mp4, Predicted Label: 101, True Label: 102, Prediction Score: 0.1958
Video: ykfWdDjL5UA.mp4, Predicted Label: 100, True Label: 102, Prediction Score: 0.4959
Video: cFEG5VGBWzg.mp4, Predicted Label: 392, True Label: 102, Prediction Score: 0.5831
Video: 0fmNdKx4cdI.mp4, Predicte