In [30]:
import os
import av
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig
from PIL import Image
import numpy as np
from tqdm import tqdm
from lime import lime_image
import matplotlib.pyplot as plt
import json
import logging
from torch.cuda.amp import autocast

# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Configuration
config = {
    "model_config": "google/vit-base-patch16-224",
    "model_path": "finetuned_vit_model_20.pth",
    "feature_extractor_name": "google/vit-base-patch16-224",
    "results_folder": "ResultsLIMEz",
    "num_frames": 8,  # 修改这里，使用 num_frames 替代 num_segments
    "lime_num_samples": 1000,
    "video_directory": "archive/videos_val",
    "num_labels": 400  
}

# 确保结果目录存在
os.makedirs(config["results_folder"], exist_ok=True)

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

# 加载模型和特征提取器
try:
    model_config = ViTConfig.from_pretrained(config["model_config"], num_labels=config["num_labels"])
    model = ViTForImageClassification(model_config)
    
    # 加载预训练权重
    state_dict = torch.load(config["model_path"], map_location=device)
    
    # 检查并调整权重字典中的键
    if "classifier.weight" in state_dict and state_dict["classifier.weight"].shape[0] != config["num_labels"]:
        logging.warning(f"Adjusting classifier weights from {state_dict['classifier.weight'].shape} to {config['num_labels']} classes")
        state_dict["classifier.weight"] = state_dict["classifier.weight"][:config["num_labels"], :]
        state_dict["classifier.bias"] = state_dict["classifier.bias"][:config["num_labels"]]
    
    # 加载调整后的权重
    model.load_state_dict(state_dict)
    
    model.to(device)
    model.eval()
    feature_extractor = ViTFeatureExtractor.from_pretrained(config["feature_extractor_name"])
    logging.info("Model and feature extractor loaded successfully")
except Exception as e:
    logging.error(f"Error loading model or feature extractor: {str(e)}")
    raise

def extract_frames(video_path, num_frames):
    frames = []
    try:
        with av.open(video_path) as container:
            stream = container.streams.video[0]
            total_frames = stream.frames
            interval = total_frames // num_frames

            for i in range(num_frames):
                target_frame = i * interval
                container.seek(target_frame, stream=stream)
                for frame in container.decode(video=0):
                    pil_image = frame.to_image()
                    frames.append(pil_image)
                    logging.info(f"Extracted frame {i+1} at position: {target_frame}/{total_frames}")
                    break

        logging.info(f"Extracted {len(frames)} frames from {video_path}")
    except Exception as e:
        logging.error(f"Error extracting frames from {video_path}: {e}")
    return frames

def process_and_explain(video_path, model, feature_extractor):
    frames = extract_frames(video_path, config["num_frames"])
    if not frames:
        logging.error(f"No frames extracted from {video_path}")
        return None, None

    results = []
    for i, frame in enumerate(frames):
        try:
            inputs = feature_extractor(images=frame, return_tensors="pt").to(device)
            with autocast():
                with torch.no_grad():
                    outputs = model(**inputs)
            preds = torch.nn.functional.softmax(outputs.logits, dim=-1)
            top_pred = preds.argmax().item()

            def batch_predict(images):
                batch_inputs = feature_extractor(images=[Image.fromarray(img.astype('uint8')) for img in images], return_tensors="pt").to(device)
                with autocast():
                    with torch.no_grad():
                        batch_outputs = model(**batch_inputs)
                return torch.nn.functional.softmax(batch_outputs.logits, dim=-1).cpu().numpy()

            explainer = lime_image.LimeImageExplainer()
            explanation = explainer.explain_instance(np.array(frame), 
                                                     batch_predict, 
                                                     top_labels=5, 
                                                     hide_color=0, 
                                                     num_samples=config["lime_num_samples"])
            
            saliency_map = explanation.get_image_and_mask(top_pred, positive_only=True, num_features=10, hide_rest=False)[1]
            saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())

            results.append({
                "frame_index": i,
                "top_prediction": top_pred,
                "prediction_score": preds[0, top_pred].item(),
                "saliency_map": saliency_map.tolist()
            })

            logging.info(f"Processed frame {i+1} for {video_path}")
        except Exception as e:
            logging.error(f"Error processing frame {i+1} of {video_path}: {e}")

    return results, frames

def save_results(video_path, frames, results):
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    video_result_dir = os.path.join(config["results_folder"], video_name)
    os.makedirs(video_result_dir, exist_ok=True)

    plt.figure(figsize=(20, 5 * len(frames)))

    for i, (frame, result) in enumerate(zip(frames, results)):
        frame_array = np.array(frame)
        saliency_map = np.array(result['saliency_map'])

        # Save and plot original frame
        plt.subplot(len(frames), 3, i*3 + 1)
        plt.imshow(frame_array)
        plt.title(f"Frame {i+1}")
        plt.axis('off')
        plt.imsave(os.path.join(video_result_dir, f"original_frame_{i}.png"), frame_array)

        # Save and plot saliency map overlay
        plt.subplot(len(frames), 3, i*3 + 2)
        plt.imshow(frame_array)
        plt.imshow(saliency_map, cmap='jet', alpha=0.5)
        plt.title(f"Overlay {i+1}")
        plt.axis('off')
        overlay = plt.gcf()
        overlay.savefig(os.path.join(video_result_dir, f"saliency_overlay_{i}.png"))

        # Save and plot saliency map
        plt.subplot(len(frames), 3, i*3 + 3)
        plt.imshow(saliency_map, cmap='jet')
        plt.title(f"Saliency {i+1}")
        plt.axis('off')
        plt.imsave(os.path.join(video_result_dir, f"saliency_map_{i}.png"), saliency_map, cmap='jet')

    # Save overall visualization
    plt.tight_layout()
    plt.savefig(os.path.join(video_result_dir, "all_frames_visualization.png"))
    plt.close()

    # Save all results as JSON
    with open(os.path.join(video_result_dir, "saliency_data.json"), 'w') as f:
        json.dump(results, f, indent=4)

    logging.info(f"Results saved to {video_result_dir}")

def process_single_video(video_filename):
    video_path = os.path.join(config["video_directory"], video_filename)
    logging.info(f"Processing video: {video_path}")
    
    results, frames = process_and_explain(video_path, model, feature_extractor)
    
    if results and frames:
        save_results(video_path, frames, results)
        logging.info(f"Processed {video_filename}")
        logging.info(f"Top predictions: {[result['top_prediction'] for result in results]}")
    else:
        logging.warning(f"Failed to process {video_filename}")

# 在notebook中使用时，可以直接调用这个函数
process_single_video("zZr_mmwaeKc.mp4")

2024-09-10 22:23:24,783 - INFO - Using device: cuda
2024-09-10 22:23:26,195 - INFO - Model and feature extractor loaded successfully
2024-09-10 22:23:26,197 - INFO - Processing video: archive/videos_val/zZr_mmwaeKc.mp4
2024-09-10 22:23:26,201 - INFO - Extracted frame 1 at position: 0/300
2024-09-10 22:23:26,204 - INFO - Extracted frame 2 at position: 37/300
2024-09-10 22:23:26,206 - INFO - Extracted frame 3 at position: 74/300
2024-09-10 22:23:26,209 - INFO - Extracted frame 4 at position: 111/300
2024-09-10 22:23:26,211 - INFO - Extracted frame 5 at position: 148/300
2024-09-10 22:23:26,214 - INFO - Extracted frame 6 at position: 185/300
2024-09-10 22:23:26,216 - INFO - Extracted frame 7 at position: 222/300
2024-09-10 22:23:26,219 - INFO - Extracted frame 8 at position: 259/300
2024-09-10 22:23:26,219 - INFO - Extracted 8 frames from archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:23:36,504 - INFO - Processed frame 1 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:23:46,376 - INFO - Processed frame 2 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:23:56,741 - INFO - Processed frame 3 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:24:06,933 - INFO - Processed frame 4 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:24:17,613 - INFO - Processed frame 5 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:24:27,961 - INFO - Processed frame 6 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:24:38,470 - INFO - Processed frame 7 for archive/videos_val/zZr_mmwaeKc.mp4


  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-10 22:24:48,989 - INFO - Processed frame 8 for archive/videos_val/zZr_mmwaeKc.mp4
2024-09-10 22:24:57,835 - INFO - Results saved to ResultsLIMEz/zZr_mmwaeKc
2024-09-10 22:24:57,835 - INFO - Processed zZr_mmwaeKc.mp4
2024-09-10 22:24:57,836 - INFO - Top predictions: [188, 188, 188, 188, 188, 188, 188, 188]
