LIME with results save 第一帧

In [5]:
import os
import av
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig
from PIL import Image
import numpy as np
from tqdm import tqdm
from lime import lime_image
import matplotlib.pyplot as plt
import random
import json

# Configuration
config = {
    "model_config": "google/vit-base-patch16-224",
    "model_path": "finetuned_vit_model.pth",
    "feature_extractor_name": "google/vit-base-patch16-224",
    "video_directory": "archive/videos_val",
    "results_folder": "ResultsLIME",
    "num_classes": 400,
    "num_videos_to_process": 25  # Number of videos to process
}

# Ensure the results directory exists
os.makedirs(config["results_folder"], exist_ok=True)

# Load the model and feature extractor
model_config = ViTConfig.from_pretrained(config["model_config"], num_labels=config["num_classes"])
model = ViTForImageClassification(model_config)
model.load_state_dict(torch.load(config["model_path"]), strict=False)
model.eval()
feature_extractor = ViTFeatureExtractor.from_pretrained(config["feature_extractor_name"])

def process_and_explain(video_path, model, feature_extractor):
    container = av.open(video_path)
    frame = next(container.decode(video=0)).to_image()
    inputs = feature_extractor(images=frame, return_tensors="pt")
    outputs = model(**inputs)
    preds = torch.nn.functional.softmax(outputs.logits, dim=-1)
    top_pred = preds.argmax().item()

    def batch_predict(images):
        inputs = feature_extractor(images=[Image.fromarray(img.astype('uint8')) for img in images], return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        return probs.detach().cpu().numpy()

    explainer = lime_image.LimeImageExplainer()
    explanation = explainer.explain_instance(np.array(frame), batch_predict, top_labels=5, hide_color=0, num_samples=100)
    
    saliency_map = explanation.get_image_and_mask(top_pred, positive_only=True, num_features=10, hide_rest=False)[1]
    saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())

    # Create a directory for this video
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    video_result_dir = os.path.join(config["results_folder"], video_name)
    os.makedirs(video_result_dir, exist_ok=True)

    # Save original image
    plt.figure(figsize=(10, 10))
    plt.imshow(frame)
    plt.axis('off')
    plt.savefig(os.path.join(video_result_dir, "original_image.png"), bbox_inches='tight', pad_inches=0)
    plt.close()

    # Save saliency map overlay
    plt.figure(figsize=(10, 10))
    plt.imshow(frame)
    plt.imshow(saliency_map, cmap='jet', alpha=0.5)
    plt.axis('off')
    plt.savefig(os.path.join(video_result_dir, "saliency_overlay.png"), bbox_inches='tight', pad_inches=0)
    plt.close()

    # Save saliency map
    plt.figure(figsize=(10, 10))
    plt.imshow(saliency_map, cmap='jet')
    plt.axis('off')
    plt.savefig(os.path.join(video_result_dir, "saliency_map.png"), bbox_inches='tight', pad_inches=0)
    plt.close()

    # Save saliency map data as JSON
    saliency_data = {
        "saliency_map": saliency_map.tolist(),
        "top_prediction": top_pred,
        "prediction_score": preds[0, top_pred].item()
    }
    with open(os.path.join(video_result_dir, "saliency_data.json"), 'w') as f:
        json.dump(saliency_data, f)

    return top_pred, video_result_dir

# Select and process videos
all_video_files = os.listdir(config["video_directory"])
selected_video_files = random.sample(all_video_files, min(config["num_videos_to_process"], len(all_video_files)))

for video_file in tqdm(selected_video_files, desc="Processing videos"):
    video_path = os.path.join(config["video_directory"], video_file)
    prediction, result_dir = process_and_explain(video_path, model, feature_extractor)
    print(f"Processed {video_file}: Top prediction index = {prediction}, Results saved to {result_dir}")

Processing videos:   0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:   4%|▍         | 1/25 [00:09<03:55,  9.80s/it]

Processed sZ8JiPfAoWc.mp4: Top prediction index = 43, Results saved to ResultsLIME/sZ8JiPfAoWc


  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:   8%|▊         | 2/25 [00:19<03:43,  9.72s/it]

Processed JhfrG0Yy8B8.mp4: Top prediction index = 94, Results saved to ResultsLIME/JhfrG0Yy8B8


  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  12%|█▏        | 3/25 [00:31<03:55, 10.72s/it]

Processed CMoSvXnu2r8.mp4: Top prediction index = 159, Results saved to ResultsLIME/CMoSvXnu2r8


  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  16%|█▌        | 4/25 [00:43<03:58, 11.35s/it]

Processed KDwnc2_a1UI.mp4: Top prediction index = 233, Results saved to ResultsLIME/KDwnc2_a1UI


  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  20%|██        | 5/25 [00:56<03:55, 11.75s/it]

Processed Ecul2WP41mI.mp4: Top prediction index = 193, Results saved to ResultsLIME/Ecul2WP41mI


  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  24%|██▍       | 6/25 [01:08<03:47, 11.95s/it]

Processed h-JWFxsDg7Y.mp4: Top prediction index = 55, Results saved to ResultsLIME/h-JWFxsDg7Y


  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  28%|██▊       | 7/25 [01:21<03:40, 12.25s/it]

Processed 3SFFYaz5czo.mp4: Top prediction index = 75, Results saved to ResultsLIME/3SFFYaz5czo


  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  32%|███▏      | 8/25 [01:33<03:25, 12.11s/it]

Processed 8NStNQyjIXI.mp4: Top prediction index = 26, Results saved to ResultsLIME/8NStNQyjIXI


  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  32%|███▏      | 8/25 [01:36<03:25, 12.10s/it]


KeyboardInterrupt: 

平均5帧LIME

In [1]:
import os
import av
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig
from PIL import Image
import numpy as np
from tqdm import tqdm
from lime import lime_image
import matplotlib.pyplot as plt
import random
import json

# Configuration
config = {
    "model_config": "google/vit-base-patch16-224",
    "model_path": "finetuned_vit_model_20.pth",
    "feature_extractor_name": "google/vit-base-patch16-224",
    "video_directory": "archive/videos_val",
    "results_folder": "ResultsLIME",
    "num_classes": 400,
    "num_videos_to_process": 25,  # Number of videos to process
    "num_frames_per_video": 5     # Number of frames to analyze per video
}

# Ensure the results directory exists
os.makedirs(config["results_folder"], exist_ok=True)

# Load the model and feature extractor
model_config = ViTConfig.from_pretrained(config["model_config"], num_labels=config["num_classes"])
model = ViTForImageClassification(model_config)
model.load_state_dict(torch.load(config["model_path"]), strict=False)
model.eval()
feature_extractor = ViTFeatureExtractor.from_pretrained(config["feature_extractor_name"])

def process_and_explain(video_path, model, feature_extractor, num_frames=5):
    container = av.open(video_path)
    video = container.streams.video[0]
    duration = video.duration
    frame_indices = [int(i * duration / (num_frames + 1)) for i in range(1, num_frames + 1)]
    
    results = []
    for frame_index in frame_indices:
        container.seek(frame_index)
        frame = next(container.decode(video=0)).to_image()
        
        inputs = feature_extractor(images=frame, return_tensors="pt")
        outputs = model(**inputs)
        preds = torch.nn.functional.softmax(outputs.logits, dim=-1)
        top_pred = preds.argmax().item()

        def batch_predict(images):
            inputs = feature_extractor(images=[Image.fromarray(img.astype('uint8')) for img in images], return_tensors="pt")
            with torch.no_grad():
                outputs = model(**inputs)
                probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
            return probs.detach().cpu().numpy()

        explainer = lime_image.LimeImageExplainer()
        explanation = explainer.explain_instance(np.array(frame), batch_predict, top_labels=5, hide_color=0, num_samples=100)
        
        saliency_map = explanation.get_image_and_mask(top_pred, positive_only=True, num_features=10, hide_rest=False)[1]
        saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())

        results.append({
            "frame_index": frame_index,
            "top_prediction": top_pred,
            "prediction_score": preds[0, top_pred].item(),
            "saliency_map": saliency_map.tolist(),
            "frame": np.array(frame).tolist()
        })

    # Create a directory for this video
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    video_result_dir = os.path.join(config["results_folder"], video_name)
    os.makedirs(video_result_dir, exist_ok=True)

    # Save results for each frame
    for i, result in enumerate(results):
        frame = np.array(result['frame']).astype(np.uint8)
        saliency_map = np.array(result['saliency_map'])

        # Save original image
        plt.figure(figsize=(10, 10))
        plt.imshow(frame)
        plt.axis('off')
        plt.savefig(os.path.join(video_result_dir, f"original_image_{i}.png"), bbox_inches='tight', pad_inches=0)
        plt.close()

        # Save saliency map overlay
        plt.figure(figsize=(10, 10))
        plt.imshow(frame)
        plt.imshow(saliency_map, cmap='jet', alpha=0.5)
        plt.axis('off')
        plt.savefig(os.path.join(video_result_dir, f"saliency_overlay_{i}.png"), bbox_inches='tight', pad_inches=0)
        plt.close()

        # Save saliency map
        plt.figure(figsize=(10, 10))
        plt.imshow(saliency_map, cmap='jet')
        plt.axis('off')
        plt.savefig(os.path.join(video_result_dir, f"saliency_map_{i}.png"), bbox_inches='tight', pad_inches=0)
        plt.close()

    # Save all results as JSON
    with open(os.path.join(video_result_dir, "saliency_data.json"), 'w') as f:
        json.dump(results, f)

    return results, video_result_dir

# Select and process videos
all_video_files = os.listdir(config["video_directory"])
selected_video_files = random.sample(all_video_files, min(config["num_videos_to_process"], len(all_video_files)))

for video_file in tqdm(selected_video_files, desc="Processing videos"):
    video_path = os.path.join(config["video_directory"], video_file)
    results, result_dir = process_and_explain(video_path, model, feature_extractor, num_frames=config["num_frames_per_video"])
    print(f"Processed {video_file}: Results saved to {result_dir}")
    print(f"Top predictions: {[result['top_prediction'] for result in results]}")

Processing videos:   0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:   4%|▍         | 1/25 [01:01<24:31, 61.32s/it]

Processed Ny8YzIrC7JI.mp4: Results saved to ResultsLIME/Ny8YzIrC7JI
Top predictions: [183, 183, 183, 183, 183]


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:   8%|▊         | 2/25 [01:53<21:25, 55.87s/it]

Processed C5yTd8hS6AY.mp4: Results saved to ResultsLIME/C5yTd8hS6AY
Top predictions: [392, 392, 392, 392, 392]


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  12%|█▏        | 3/25 [02:43<19:28, 53.10s/it]

Processed wcPTE-oE4U0.mp4: Results saved to ResultsLIME/wcPTE-oE4U0
Top predictions: [282, 282, 282, 282, 282]


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  16%|█▌        | 4/25 [03:37<18:46, 53.63s/it]

Processed zqwKGGvPWZc.mp4: Results saved to ResultsLIME/zqwKGGvPWZc
Top predictions: [150, 150, 150, 150, 150]


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  20%|██        | 5/25 [04:40<19:02, 57.12s/it]

Processed UytdOEGAZbs.mp4: Results saved to ResultsLIME/UytdOEGAZbs
Top predictions: [86, 86, 86, 86, 86]


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  24%|██▍       | 6/25 [05:35<17:49, 56.27s/it]

Processed gJFNNbxsCN8.mp4: Results saved to ResultsLIME/gJFNNbxsCN8
Top predictions: [221, 221, 221, 221, 221]


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  28%|██▊       | 7/25 [06:33<17:05, 56.95s/it]

Processed WZuGgvcl_PY.mp4: Results saved to ResultsLIME/WZuGgvcl_PY
Top predictions: [81, 81, 81, 81, 81]


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Processing videos:  28%|██▊       | 7/25 [07:25<19:06, 63.71s/it]


KeyboardInterrupt: 

完善


In [1]:
import os
import av
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig
from PIL import Image
import numpy as np
from tqdm import tqdm
from lime import lime_image
import matplotlib.pyplot as plt
import random
import json
import logging
from torch.cuda.amp import autocast
from collections import defaultdict

# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Configuration
config = {
    "model_config": "google/vit-base-patch16-224",
    "model_path": "finetuned_vit_model_20.pth",
    "feature_extractor_name": "google/vit-base-patch16-224",
    "video_directory": "archive/videos_val",
    "results_folder": "ResultsLIME",
    "num_classes": 400,
    "num_samples_per_class": 25,
    "num_frames_per_video": 3,
    "lime_num_samples": 1000,
    "video_list_path": "archive/kinetics400_val_list_videos.txt"
}

# 确保结果目录存在
os.makedirs(config["results_folder"], exist_ok=True)

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

# 加载模型和特征提取器
try:
    model_config = ViTConfig.from_pretrained(config["model_config"], num_labels=config["num_classes"])
    model = ViTForImageClassification(model_config)
    model.load_state_dict(torch.load(config["model_path"], map_location=device))
    model.to(device)
    model.eval()
    feature_extractor = ViTFeatureExtractor.from_pretrained(config["feature_extractor_name"])
    logging.info("Model and feature extractor loaded successfully")
except Exception as e:
    logging.error(f"Error loading model or feature extractor: {e}")
    raise

def load_video_list(video_list_path):
    video_labels = defaultdict(list)
    with open(video_list_path, "r") as f:
        for line in f:
            name, label = line.strip().split()
            video_labels[int(label)].append(name)
    return video_labels

def select_balanced_samples(video_labels, num_samples_per_class):
    selected_videos = []
    selected_labels = []
    for label, videos in video_labels.items():
        selected = random.sample(videos, min(num_samples_per_class, len(videos)))
        selected_videos.extend(selected)
        selected_labels.extend([label] * len(selected))
    return selected_videos, selected_labels

def extract_frames(video_path, num_frames):
    frames = []
    try:
        with av.open(video_path) as container:
            stream = container.streams.video[0]
            duration = stream.duration * stream.time_base
            for i in range(num_frames):
                target_ts = duration * (i + 1) / (num_frames + 1)
                container.seek(int(target_ts / stream.time_base))
                for frame in container.decode(video=0):
                    frames.append(frame.to_image())
                    logging.info(f"Frame {i+1} extracted at timestamp: {frame.pts * frame.time_base:.2f}s")
                    break
    except Exception as e:
        logging.error(f"Error extracting frames from {video_path}: {e}")
    return frames

def process_and_explain(video_path, model, feature_extractor, num_frames):
    frames = extract_frames(video_path, num_frames)
    if not frames:
        logging.error(f"No frames extracted from {video_path}")
        return None, None

    results = []
    for i, frame in enumerate(frames):
        try:
            inputs = feature_extractor(images=frame, return_tensors="pt").to(device)
            with autocast():
                with torch.no_grad():
                    outputs = model(**inputs)
            preds = torch.nn.functional.softmax(outputs.logits, dim=-1)
            top_pred = preds.argmax().item()

            def batch_predict(images):
                batch_inputs = feature_extractor(images=[Image.fromarray(img.astype('uint8')) for img in images], return_tensors="pt").to(device)
                with autocast():
                    with torch.no_grad():
                        batch_outputs = model(**batch_inputs)
                return torch.nn.functional.softmax(batch_outputs.logits, dim=-1).cpu().numpy()

            explainer = lime_image.LimeImageExplainer()
            explanation = explainer.explain_instance(np.array(frame), 
                                                     batch_predict, 
                                                     top_labels=5, 
                                                     hide_color=0, 
                                                     num_samples=config["lime_num_samples"])
            
            saliency_map = explanation.get_image_and_mask(top_pred, positive_only=True, num_features=10, hide_rest=False)[1]
            saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())

            results.append({
                "frame_index": i,
                "top_prediction": top_pred,
                "prediction_score": preds[0, top_pred].item(),
                "saliency_map": saliency_map.tolist()
            })

            logging.info(f"Processed frame {i+1} for {video_path}")
        except Exception as e:
            logging.error(f"Error processing frame {i+1} of {video_path}: {e}")

    video_name = os.path.splitext(os.path.basename(video_path))[0]
    video_result_dir = os.path.join(config["results_folder"], video_name)
    os.makedirs(video_result_dir, exist_ok=True)

    save_results(video_result_dir, frames, results)

    return results, video_result_dir

def save_results(video_result_dir, frames, results):
    for i, (frame, result) in enumerate(zip(frames, results)):
        frame_array = np.array(frame)
        saliency_map = np.array(result['saliency_map'])

        # Save original frame
        plt.figure(figsize=(10, 10))
        plt.imshow(frame_array)
        plt.axis('off')
        plt.savefig(os.path.join(video_result_dir, f"original_frame_{i}.png"), bbox_inches='tight', pad_inches=0)
        plt.close()

        # Save saliency map overlay
        plt.figure(figsize=(10, 10))
        plt.imshow(frame_array)
        plt.imshow(saliency_map, cmap='jet', alpha=0.5)
        plt.axis('off')
        plt.savefig(os.path.join(video_result_dir, f"saliency_overlay_{i}.png"), bbox_inches='tight', pad_inches=0)
        plt.close()

        # Save saliency map
        plt.figure(figsize=(10, 10))
        plt.imshow(saliency_map, cmap='jet')
        plt.axis('off')
        plt.savefig(os.path.join(video_result_dir, f"saliency_map_{i}.png"), bbox_inches='tight', pad_inches=0)
        plt.close()

    # Save all results as JSON
    with open(os.path.join(video_result_dir, "saliency_data.json"), 'w') as f:
        json.dump(results, f)

def main():
    # 加载视频列表
    video_labels = load_video_list(config["video_list_path"])
    
    # 选择平衡的样本
    selected_video_files, selected_labels = select_balanced_samples(video_labels, config["num_samples_per_class"])
    
    for video_file, label in tqdm(zip(selected_video_files, selected_labels), desc="Processing videos", total=len(selected_video_files)):
        video_path = os.path.join(config["video_directory"], video_file)
        logging.info(f"Processing video: {video_path} (Label: {label})")
        
        results, result_dir = process_and_explain(video_path, model, feature_extractor, config["num_frames_per_video"])
        
        if results:
            logging.info(f"Processed {video_file}: Results saved to {result_dir}")
            logging.info(f"Top predictions: {[result['top_prediction'] for result in results]}")
        else:
            logging.warning(f"Failed to process {video_file}")

if __name__ == "__main__":
    main()

2024-10-02 16:16:03,060 - INFO - Using device: cuda
2024-10-02 16:16:04,914 - INFO - Model and feature extractor loaded successfully
Processing videos:   0%|          | 0/10000 [00:00<?, ?it/s]2024-10-02 16:16:04,936 - INFO - Processing video: archive/videos_val/qi5Jtwx066I.mp4 (Label: 325)
2024-10-02 16:16:04,953 - INFO - Frame 1 extracted at timestamp: 0.00s
2024-10-02 16:16:04,957 - INFO - Frame 2 extracted at timestamp: 0.00s
2024-10-02 16:16:04,960 - INFO - Frame 3 extracted at timestamp: 0.00s


  0%|          | 0/1000 [00:00<?, ?it/s]

Processing videos:   0%|          | 0/10000 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [2]:
import os
import av
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig
from PIL import Image
import numpy as np
from lime import lime_image
import matplotlib.pyplot as plt
import json
import logging
from torch.cuda.amp import autocast

# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Configuration
config = {
    "model_config": "google/vit-base-patch16-224",
    "model_path": "finetuned_vit_model_20.pth",
    "feature_extractor_name": "google/vit-base-patch16-224",
    "video_directory": "archive/videos_val",
    "results_folder": "ResultsLIME3",
    "num_classes": 400,
    "num_frames_per_video": 8,
    "lime_num_samples": 1000,
}

# 确保结果目录存在
os.makedirs(config["results_folder"], exist_ok=True)

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

# 加载模型和特征提取器
try:
    model_config = ViTConfig.from_pretrained(config["model_config"], num_labels=config["num_classes"])
    model = ViTForImageClassification(model_config)
    model.load_state_dict(torch.load(config["model_path"], map_location=device))
    model.to(device)
    model.eval()
    feature_extractor = ViTFeatureExtractor.from_pretrained(config["feature_extractor_name"])
    logging.info("Model and feature extractor loaded successfully")
except Exception as e:
    logging.error(f"Error loading model or feature extractor: {e}")
    raise

def extract_frames(video_path, num_frames):
    frames = []
    try:
        with av.open(video_path) as container:
            stream = container.streams.video[0]
            duration = stream.duration * stream.time_base
            for i in range(num_frames):
                target_ts = duration * (i + 1) / (num_frames + 1)
                container.seek(int(target_ts / stream.time_base))
                for frame in container.decode(video=0):
                    frames.append(frame.to_image())
                    logging.info(f"Frame {i+1} extracted at timestamp: {frame.pts * frame.time_base:.2f}s")
                    break
    except Exception as e:
        logging.error(f"Error extracting frames from {video_path}: {e}")
    return frames

def process_and_explain(video_path, model, feature_extractor, num_frames):
    frames = extract_frames(video_path, num_frames)
    if not frames:
        logging.error(f"No frames extracted from {video_path}")
        return None, None

    results = []
    for i, frame in enumerate(frames):
        try:
            inputs = feature_extractor(images=frame, return_tensors="pt").to(device)
            with autocast():
                with torch.no_grad():
                    outputs = model(**inputs)
            preds = torch.nn.functional.softmax(outputs.logits, dim=-1)
            top_pred = preds.argmax().item()

            def batch_predict(images):
                batch_inputs = feature_extractor(images=[Image.fromarray(img.astype('uint8')) for img in images], return_tensors="pt").to(device)
                with autocast():
                    with torch.no_grad():
                        batch_outputs = model(**batch_inputs)
                return torch.nn.functional.softmax(batch_outputs.logits, dim=-1).cpu().numpy()

            explainer = lime_image.LimeImageExplainer()
            explanation = explainer.explain_instance(np.array(frame), 
                                                     batch_predict, 
                                                     top_labels=5, 
                                                     hide_color=0, 
                                                     num_samples=config["lime_num_samples"])
            
            saliency_map = explanation.get_image_and_mask(top_pred, positive_only=True, num_features=10, hide_rest=False)[1]
            saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())

            results.append({
                "frame_index": i,
                "top_prediction": top_pred,
                "prediction_score": preds[0, top_pred].item(),
                "saliency_map": saliency_map.tolist()
            })

            logging.info(f"Processed frame {i+1} for {video_path}")
        except Exception as e:
            logging.error(f"Error processing frame {i+1} of {video_path}: {e}")

    return results, frames

def save_results(video_path, frames, results):
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    video_result_dir = os.path.join(config["results_folder"], video_name)
    os.makedirs(video_result_dir, exist_ok=True)

    plt.figure(figsize=(20, 5 * len(frames)))

    for i, (frame, result) in enumerate(zip(frames, results)):
        frame_array = np.array(frame)
        saliency_map = np.array(result['saliency_map'])

        # Save and plot original frame
        plt.subplot(len(frames), 3, i*3 + 1)
        plt.imshow(frame_array)
        plt.title(f"Frame {i+1}")
        plt.axis('off')
        plt.imsave(os.path.join(video_result_dir, f"original_frame_{i}.png"), frame_array)

        # Save and plot saliency map overlay
        plt.subplot(len(frames), 3, i*3 + 2)
        plt.imshow(frame_array)
        plt.imshow(saliency_map, cmap='jet', alpha=0.5)
        plt.title(f"Overlay {i+1}")
        plt.axis('off')
        overlay = plt.gcf()
        overlay.savefig(os.path.join(video_result_dir, f"saliency_overlay_{i}.png"))

        # Save and plot saliency map
        plt.subplot(len(frames), 3, i*3 + 3)
        plt.imshow(saliency_map, cmap='jet')
        plt.title(f"Saliency {i+1}")
        plt.axis('off')
        plt.imsave(os.path.join(video_result_dir, f"saliency_map_{i}.png"), saliency_map, cmap='jet')

    # Save overall visualization
    plt.tight_layout()
    plt.savefig(os.path.join(video_result_dir, "all_frames_visualization.png"))
    plt.close()

    # Save all results as JSON
    with open(os.path.join(video_result_dir, "saliency_data.json"), 'w') as f:
        json.dump(results, f, indent=4)

    logging.info(f"Results saved to {video_result_dir}")

def process_single_video(video_filename):
    video_path = os.path.join(config["video_directory"], video_filename)
    logging.info(f"Processing video: {video_path}")
    
    results, frames = process_and_explain(video_path, model, feature_extractor, config["num_frames_per_video"])
    
    if results and frames:
        save_results(video_path, frames, results)
        logging.info(f"Processed {video_filename}")
        logging.info(f"Top predictions: {[result['top_prediction'] for result in results]}")
    else:
        logging.warning(f"Failed to process {video_filename}")

# 主函数
if __name__ == "__main__":
    video_filename = "zZr_mmwaeKc.mp4"  # 指定要处理的视频文件名
    process_single_video(video_filename)

2024-09-10 22:38:24,365 - INFO - Using device: cuda
2024-09-10 22:38:25,776 - INFO - Model and feature extractor loaded successfully
2024-09-10 22:38:25,777 - INFO - Processing video: archive/videos_val/zZr_mmwaeKc.mp4
2024-09-10 22:38:25,782 - INFO - Frame 1 extracted at timestamp: 0.00s
2024-09-10 22:38:25,784 - INFO - Frame 2 extracted at timestamp: 0.00s
2024-09-10 22:38:25,787 - INFO - Frame 3 extracted at timestamp: 0.00s
2024-09-10 22:38:25,789 - INFO - Frame 4 extracted at timestamp: 0.00s
2024-09-10 22:38:25,792 - INFO - Frame 5 extracted at timestamp: 0.00s
2024-09-10 22:38:25,794 - INFO - Frame 6 extracted at timestamp: 0.00s
2024-09-10 22:38:25,797 - INFO - Frame 7 extracted at timestamp: 0.00s
2024-09-10 22:38:25,799 - INFO - Frame 8 extracted at timestamp: 0.00s


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:38:36,660 - INFO - Processed frame 1 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:38:47,637 - INFO - Processed frame 2 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:38:58,393 - INFO - Processed frame 3 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:39:08,680 - INFO - Processed frame 4 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:39:19,622 - INFO - Processed frame 5 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:39:30,545 - INFO - Processed frame 6 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:39:41,036 - INFO - Processed frame 7 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:39:51,189 - INFO - Processed frame 8 for archive/videos_val/zZr_mmwaeKc.mp4
2024-09-10 22:39:59,500 - INFO - Results saved to ResultsLIME3/zZr_mmwaeKc
2024-09-10 22:39:59,500 - INFO - Processed zZr_mmwaeKc.mp4
2024-09-10 22:39:59,500 - INFO - Top predictions: [188, 188, 188, 188, 188, 188, 188, 188]
