In [1]:
import os
import av
import torch
import numpy as np
import torch.nn.functional as F
import random
import json
from transformers import AutoImageProcessor, TimesformerForVideoClassification
from collections import defaultdict
from tqdm.notebook import tqdm
from itertools import combinations

class VideoProcessor:
    def __init__(self, model_name, image_processor_name, device='cuda'):
        self.model = self.load_model(model_name)
        self.image_processor = AutoImageProcessor.from_pretrained(image_processor_name)
        self.device = device
        self.model.to(device)

    def load_model(self, model_name):
        if "timesformer" in model_name.lower():
            return TimesformerForVideoClassification.from_pretrained(model_name)
        else:
            raise ValueError(f"Unsupported model name: {model_name}")

    def split_video_into_segments(self, container, n_segments=8, frames_per_segment=16):
        frame_list = [frame.to_image() for frame in container.decode(video=0)]
        total_frames = len(frame_list)
        segment_length = total_frames // n_segments
        segments = []
        for i in range(n_segments):
            start = i * segment_length
            end = min(start + segment_length, total_frames)
            segment_frames = frame_list[start:end] if end - start == segment_length else frame_list[start:] + [frame_list[-1]] * (segment_length - (end - start))
            segments.append(segment_frames[:frames_per_segment])
        return segments

    def predict_video_and_segments(self, container, true_label):
        video_segments = self.split_video_into_segments(container)
        segment_outputs = []
        with torch.no_grad():
            for segment in video_segments:
                inputs = self.image_processor(list(segment), return_tensors="pt")
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                outputs = self.model(**inputs)
                logits = outputs.logits
                probabilities = F.softmax(logits, dim=-1)
                pred_label = logits.argmax(-1).item()
                pred_score = probabilities[0, pred_label].item()
                segment_outputs.append((pred_label, pred_score, probabilities))
        return segment_outputs

class TemporalShap:
    def __init__(self, num_samples=100):
        self.num_samples = num_samples
    
    def approximate_shapley_values(self, segment_outputs, label_index):
        n = len(segment_outputs)
        shapley_values = [0] * n
        for _ in range(self.num_samples):
            random_subset = sorted(range(n), key=lambda _: random.random())
            subset_prob = torch.zeros_like(segment_outputs[0][2])
            for i, index in enumerate(random_subset):
                old_contribution = subset_prob[0, label_index].item()
                subset_prob += segment_outputs[index][2]
                subset_prob /= (i + 1)
                new_contribution = subset_prob[0, label_index].item()
                shapley_values[index] += new_contribution - old_contribution
        return [val / self.num_samples for val in shapley_values]

    def exact_shapley_values(self, segment_outputs, label_index):
        n = len(segment_outputs)
        shapley_values = [0] * n
        all_indices = list(range(n))
        for i in all_indices:
            marginal_contributions = []
            for subset_size in range(n):
                subsets = list(combinations([x for x in all_indices if x != i], subset_size))
                for subset in subsets:
                    subset_prob = torch.zeros_like(segment_outputs[0][2])
                    if subset:
                        subset_prob = torch.mean(torch.stack([segment_outputs[j][2] for j in subset]), dim=0)
                    with_i_prob = (subset_prob * len(subset) + segment_outputs[i][2]) / (len(subset) + 1)
                    marginal_contributions.append(with_i_prob[0, label_index].item() - subset_prob[0, label_index].item())
            shapley_values[i] = np.mean(marginal_contributions)
        return shapley_values

def process_videos(video_processor, shap_calculator, sampled_files, true_labels, use_exact=False):
    predictions = []
    for video_file, true_label in tqdm(zip(sampled_files, true_labels), desc="Processing videos", total=len(sampled_files), unit="video"):
        file_path = os.path.join(config["video_directory"], video_file)
        container = av.open(file_path)
        segment_outputs = video_processor.predict_video_and_segments(container, true_label)
        video_probs = torch.mean(torch.stack([output[2] for output in segment_outputs]), dim=0)
        video_pred_label = video_probs.argmax().item()
        video_pred_score = video_probs[0, video_pred_label].item()
        video_true_score = video_probs[0, true_label].item()
        
        if use_exact:
            sv_true_label = shap_calculator.exact_shapley_values(segment_outputs, true_label)
            sv_video_pred = shap_calculator.exact_shapley_values(segment_outputs, video_pred_label)
        else:
            sv_true_label = shap_calculator.approximate_shapley_values(segment_outputs, true_label)
            sv_video_pred = shap_calculator.approximate_shapley_values(segment_outputs, video_pred_label)
        
        predictions.append((video_file, video_pred_label, video_pred_score, video_true_score, segment_outputs, sv_true_label, sv_video_pred))
    return predictions

def save_results(predictions, filename="results.json"):
    results = []
    for video_file, video_pred_label, video_pred_score, video_true_score, segment_outputs, sv_true_label, sv_video_pred in predictions:
        video_result = {
            "video_file": video_file,
            "video_pred_label": video_pred_label,
            "video_pred_score": video_pred_score,
            "video_true_score": video_true_score,
            "segments": []
        }
        for i, (segment_label, segment_score, probabilities) in enumerate(segment_outputs):
            segment_video_label_score = probabilities[0, video_pred_label].item()
            segment_true_label_score = probabilities[0, true_labels[0]].item()
            video_result["segments"].append({
                "segment_index": i + 1,
                "segment_label": segment_label,
                "segment_score": segment_score,
                "segment_video_label_score": segment_video_label_score,
                "segment_true_label_score": segment_true_label_score,
                "sv_true_label": sv_true_label[i],
                "sv_video_pred": sv_video_pred[i]
            })
        results.append(video_result)
    
    with open(filename, "w") as f:
        json.dump(results, f, indent=4)

# 配置
config = {
    "model_name": "facebook/timesformer-base-finetuned-k400",  # 用户可以在这里更换模型名称，例如 "huggingface/vivit"
    "image_processor_name": "MCG-NJU/videomae-base-finetuned-kinetics",
    "num_samples": 100,
    "num_classes": 5,
    "num_samples_per_class": 2,
    "video_list_path": "archive/kinetics400_val_list_videos.txt",
    "video_directory": "archive/videos_val",
    "use_exact": True  # 设置为True以使用精确Shapley Value计算
}

# 初始化处理器
video_processor = VideoProcessor(config["model_name"], config["image_processor_name"])
shap_calculator = TemporalShap(num_samples=config["num_samples"])

# 读取视频列表和标签，并按类别组织
video_labels = defaultdict(list)
with open(config["video_list_path"], "r") as f:
    for line in f:
        name, label = line.strip().split()
        video_labels[int(label)].append(name)

# 准备视频样本
sampled_files = []
true_labels = []
selected_classes = list(video_labels.keys())[:config["num_classes"]]
for cls in selected_classes:
    sampled_files.extend(video_labels[cls][:config["num_samples_per_class"]])
    true_labels.extend([cls] * config["num_samples_per_class"])

# 获取数据和预测
video_data = process_videos(video_processor, shap_calculator, sampled_files, true_labels, use_exact=config["use_exact"])

# 打印并保存结果
save_results(video_data)

# 打印结果以便查看
for video_file, video_pred_label, video_pred_score, video_true_score, segment_outputs, sv_true_label, sv_video_pred in video_data:
    print(f"Video: {video_file}, Overall Predicted Label = {video_pred_label}, Overall Prediction Score = {video_pred_score:.4f}, True Label = {true_labels[0]}, True Label Score = {video_true_score:.4f}")
    for i, (segment_label, segment_score, probabilities) in enumerate(segment_outputs):
        segment_video_label_score = probabilities[0, video_pred_label].item()
        segment_true_label_score = probabilities[0, true_labels[0]].item()
        print(f"  Segment {i+1}: Predicted Label = {segment_label}, Prediction Score = {segment_score:.4f}, Segment Video Label Score = {segment_video_label_score:.4f}, Segment True Label Score = {segment_true_label_score:.4f}, SV True Label = {sv_true_label[i]:.4f}, SV Predicted Label = {sv_video_pred[i]:.4f}")




Processing videos:   0%|          | 0/10 [00:00<?, ?video/s]

  return torch.tensor(value)


Video: jf7RDuUTrsQ.mp4, Overall Predicted Label = 1, Overall Prediction Score = 0.2028, True Label = 325, True Label Score = 0.0010
  Segment 1: Predicted Label = 81, Prediction Score = 0.4557, Segment Video Label Score = 0.0373, Segment True Label Score = 0.0000, SV True Label = -0.0003, SV Predicted Label = -0.0453
  Segment 2: Predicted Label = 171, Prediction Score = 0.6322, Segment Video Label Score = 0.0035, Segment True Label Score = 0.0004, SV True Label = -0.0002, SV Predicted Label = -0.0549
  Segment 3: Predicted Label = 1, Prediction Score = 0.5045, Segment Video Label Score = 0.5045, Segment True Label Score = 0.0000, SV True Label = -0.0003, SV Predicted Label = 0.0871
  Segment 4: Predicted Label = 289, Prediction Score = 0.4299, Segment Video Label Score = 0.1295, Segment True Label Score = 0.0000, SV True Label = -0.0003, SV Predicted Label = -0.0192
  Segment 5: Predicted Label = 1, Prediction Score = 0.4968, Segment Video Label Score = 0.4968, Segment True Label Scor