# VideoMAE GradCAM Visualization

This notebook implements GradCAM for VideoMAE models to visualize which spatial-temporal regions the model focuses on for classification.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import os
from pathlib import Path
from collections import OrderedDict
from typing import List, Tuple, Optional
import pandas as pd

# VideoMAE imports
import sys
sys.path.append('/home/tianze/Code/VideoMAE')
from timm.models import create_model
import modeling_finetune
from datasets import build_dataset

## Configuration

### Two-Model Comparison Setup

In [4]:
# Model configuration for comparison
CHECKPOINT_PATH_1 = '/home/tianze/Code/VideoMAE/checkpoints/final_experiments/K400-MRI-cov_loss_comb_diag-linear_probe/checkpoint-best.pth'
CHECKPOINT_PATH_2 = '/home/tianze/Code/VideoMAE/checkpoints/final_experiments/K400-MRI-MCI_CN-finetune-linear_probe/checkpoint-49.pth'
# CHECKPOINT_PATH_2 = '/home/tianze/Code/VideoMAE/checkpoints/final_experiments/K400-MRI-cov_loss_comb_trace-linear_probe-NIFD/checkpoint-best.pth'

MODEL_NAME = 'vit_base_patch16_224'
NUM_CLASSES = 2  # MCI_CN is binary classification
NUM_FRAMES = 16
SAMPLING_RATE = 4
INPUT_SIZE = 224
PATCH_SIZE = 16
TUBELET_SIZE = 2
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Dataset configuration
DATA_PATH = '/home/tianze/DATA_2T/MRI/csv_for_finetuning/Finalized_all/Downstreamtask_csv'
TEST_CSV = 'A_MCI_CN_allscantest_split.csv'
TEST_CSV_PATH = os.path.join(DATA_PATH, TEST_CSV)

# Output directory for saving results
OUTPUT_DIR = '/home/tianze/Code/VideoMAE/gradcam_comparison_results_bestvsbad'
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Using device: {DEVICE}")
print(f"Checkpoint 1: {CHECKPOINT_PATH_1}")
print(f"Checkpoint 2: {CHECKPOINT_PATH_2}")
print(f"Test CSV: {TEST_CSV_PATH}")
print(f"Output directory: {OUTPUT_DIR}")

Using device: cuda:0
Checkpoint 1: /home/tianze/Code/VideoMAE/checkpoints/final_experiments/K400-MRI-cov_loss_comb_diag-linear_probe/checkpoint-best.pth
Checkpoint 2: /home/tianze/Code/VideoMAE/checkpoints/final_experiments/K400-MRI-MCI_CN-finetune-linear_probe/checkpoint-49.pth
Test CSV: /home/tianze/DATA_2T/MRI/csv_for_finetuning/Finalized_all/Downstreamtask_csv/A_MCI_CN_allscantest_split.csv
Output directory: /home/tianze/Code/VideoMAE/gradcam_comparison_results_bestvsbad


## GradCAM Implementation for Vision Transformer

In [5]:
class ViTGradCAMHook:
    """Hook class for storing activations and gradients, similar to vision_transformer_gradcam.py"""
    def __init__(self):
        self.activations = None
        self.gradients = None

    def save_activation(self, module, input, output):
        """Save activations during forward pass"""
        self.activations = output.detach()

    def save_gradient(self, grad):
        """Save gradients during backward pass"""
        self.gradients = grad


class GradCAM:
    """GradCAM implementation for Vision Transformers, following vision_transformer_gradcam.py pattern"""
    
    def __init__(self, model: nn.Module, target_layer: nn.Module):
        self.model = model
        self.target_layer = target_layer
        self.hook = ViTGradCAMHook()
        self.hook_handle = None
        
        # Register forward hook to save activations
        def forward_hook(module, input, output):
            # Save activations (detached, similar to vision_transformer_gradcam.py)
            self.hook.save_activation(module, input, output)
            
            # Register backward hook on the output tensor to capture gradients
            # This is done during forward pass, similar to vision_transformer_gradcam.py
            output.register_hook(self.hook.save_gradient)
        
        self.hook_handle = self.target_layer.register_forward_hook(forward_hook)
    
    def __del__(self):
        """Remove hooks when object is deleted"""
        if self.hook_handle is not None:
            self.hook_handle.remove()
    
    def generate_cam(self, input_tensor: torch.Tensor, target_class: Optional[int] = None) -> Tuple[np.ndarray, int, torch.Tensor]:
        """
        Generate GradCAM heatmap
        
        Args:
            input_tensor: Input video tensor [B, C, T, H, W]
            target_class: Target class index. If None, uses the predicted class.
        
        Returns:
            Tuple of (cam, target_class, output)
            cam: Heatmap as numpy array [num_patches]
            target_class: Target class index used
            output: Model output logits
        """
        # Reset gradients and activations
        self.hook.activations = None
        self.hook.gradients = None
        
        # Set model to eval mode but enable gradients
        self.model.eval()
        # Ensure all parameters require grad for gradient computation
        for param in self.model.parameters():
            param.requires_grad = True
        
        # Ensure input requires grad
        if not input_tensor.requires_grad:
            input_tensor = input_tensor.requires_grad_(True)
        
        # Zero gradients before forward pass
        self.model.zero_grad()
        if input_tensor.grad is not None:
            input_tensor.grad.zero_()
        
        # Forward pass
        output = self.model(input_tensor)
        
        # Get target class
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Backward pass using same logic as eval.py
        # Create one-hot vector for target class
        one_hot = torch.zeros_like(output)
        one_hot[0, target_class] = 1  # For single sample, sample_idx=0
        
        # Clear previous gradients
        if self.hook.gradients is not None:
            self.hook.gradients = None
        
        # Backward pass
        output.backward(gradient=one_hot, retain_graph=False)
        
        # Get gradients and activations from hook
        gradients = self.hook.gradients  # [B, N, C] where N is num_patches (including CLS token if present)
        activations = self.hook.activations  # [B, N, C] (already detached)
        
        if gradients is None or activations is None:
            raise ValueError(f"Gradients or activations not captured. Gradients: {gradients is not None}, Activations: {activations is not None}")
        
        # Detach gradients from graph for CAM computation
        gradients = gradients.detach()

        # Remove batch dimension for single sample
        # activations, gradients: [B, N, C] -> [N, C]
        activations = activations[0]
        gradients = gradients[0]

        # Note: VideoMAE typically doesn't use CLS tokens, so we use all tokens
        # If your model uses CLS tokens, you would drop the first token here:
        # activations = activations[1:, :]
        # gradients = gradients[1:, :]

        # Compute channel-wise weights by averaging gradients over "spatial" tokens
        # This mirrors the CNN Grad-CAM pattern: mean over HxW -> weight per channel.
        # gradients: [N_patches, C] -> weights: [C]
        weights = gradients.mean(dim=0)  # [C]

        # Weighted sum of features over channels to get importance per patch
        # activations: [N_patches, C], weights: [C]
        cam = (activations * weights.unsqueeze(0)).sum(dim=1)  # [N_patches]
        cam = F.relu(cam)  # Apply ReLU to get positive contributions only

        # Normalize
        cam_np = cam.cpu().numpy()
        if cam_np.max() > cam_np.min():
            cam_np = (cam_np - cam_np.min()) / (cam_np.max() - cam_np.min() + 1e-8)
        else:
            cam_np = np.zeros_like(cam_np)

        return cam_np, target_class, output
    
    def get_activations_and_gradients(self, input_tensor: torch.Tensor, target_class: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]:
        """
        Get raw activations and gradients for saving (similar to eval.py)
        
        Args:
            input_tensor: Input video tensor [B, C, T, H, W]
            target_class: Target class index. If None, uses the predicted class.
        
        Returns:
            Tuple of (activations, gradients, target_class, output)
            activations: [N, C] tensor (already detached)
            gradients: [N, C] tensor (detached)
            target_class: Target class index used
            output: Model output logits
        """
        # Reset gradients and activations
        self.hook.activations = None
        self.hook.gradients = None
        
        # Set model to eval mode but enable gradients
        self.model.eval()
        # Ensure all parameters require grad for gradient computation
        for param in self.model.parameters():
            param.requires_grad = True
        
        # Ensure input requires grad
        if not input_tensor.requires_grad:
            input_tensor = input_tensor.requires_grad_(True)
        
        # Zero gradients before forward pass
        self.model.zero_grad()
        if input_tensor.grad is not None:
            input_tensor.grad.zero_()
        
        # Forward pass
        output = self.model(input_tensor)
        
        # Get target class
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Backward pass using same logic as eval.py
        # Create one-hot vector for target class
        one_hot = torch.zeros_like(output)
        one_hot[0, target_class] = 1  # For single sample, sample_idx=0
        
        # Clear previous gradients
        if self.hook.gradients is not None:
            self.hook.gradients = None
        
        # Backward pass
        output.backward(gradient=one_hot, retain_graph=False)
        
        # Get gradients and activations from hook
        gradients = self.hook.gradients  # [B, N, C]
        activations = self.hook.activations  # [B, N, C] (already detached)
        
        if gradients is None or activations is None:
            raise ValueError(f"Gradients or activations not captured. Gradients: {gradients is not None}, Activations: {activations is not None}")
        
        # Detach gradients and remove batch dimension
        gradients = gradients[0].detach()  # [N, C]
        activations = activations[0]  # [N, C] (already detached)
        
        return activations, gradients, target_class, output

## Load Model

In [6]:
def load_model(checkpoint_path: str, model_name: str, num_classes: int, 
               num_frames: int, tubelet_size: int) -> nn.Module:
    """Load VideoMAE model from checkpoint"""
    
    # Create model
    model = create_model(
        model_name,
        pretrained=False,
        num_classes=num_classes,
        all_frames=num_frames,
        tubelet_size=tubelet_size,
        fc_drop_rate=0.0,
        drop_rate=0.0,
        drop_path_rate=0.1,
    )
    
    # Load checkpoint
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        
        # Handle different checkpoint formats
        if 'model' in checkpoint:
            checkpoint_model = checkpoint['model']
        elif 'module' in checkpoint:
            checkpoint_model = checkpoint['module']
        else:
            checkpoint_model = checkpoint
        
        # Remove 'head' if shape mismatch
        state_dict = model.state_dict()
        for k in ['head.weight', 'head.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]
        
        # Remove 'backbone.' or 'encoder.' prefix if present
        all_keys = list(checkpoint_model.keys())
        new_dict = OrderedDict()
        for key in all_keys:
            if key.startswith('backbone.'):
                new_dict[key[9:]] = checkpoint_model[key]
            elif key.startswith('encoder.'):
                new_dict[key[8:]] = checkpoint_model[key]
            else:
                new_dict[key] = checkpoint_model[key]
        checkpoint_model = new_dict
        
        # Load state dict
        model.load_state_dict(checkpoint_model, strict=False)
        print("Checkpoint loaded successfully")
    else:
        print(f"Warning: Checkpoint not found at {checkpoint_path}")
    
    model.to(DEVICE)
    model.eval()
    
    return model

# Note: Models will be loaded one by one to save GPU memory
# Model 1 will be loaded first, then Model 2 after processing Model 1

## Load and Preprocess Video Data

In [7]:
# Load MCI_CN test dataset using build_dataset (same as finetune notebook)
class Args:
    def __init__(self):
        self.data_set = 'MCI_CN'
        self.data_path = DATA_PATH
        self.nb_classes = NUM_CLASSES
        self.num_frames = NUM_FRAMES
        self.sampling_rate = SAMPLING_RATE
        self.input_size = INPUT_SIZE
        self.short_side_size = INPUT_SIZE
        self.test_num_segment = 1
        self.test_num_crop = 1
        self.reprob = 0.0
        self.aa = None
        self.smoothing = 0.0

args = Args()

# Build test dataset
print('Loading test dataset...')
test_dataset, _ = build_dataset(is_train=False, test_mode=True, args=args)
print(f'Test dataset loaded. Total samples: {len(test_dataset)}')

Loading test dataset...
Dataset initialized with 292 samples for mode: test
Number of the class = 2
Test dataset loaded. Total samples: 292


## Combine Results and Visualize Comparisons

In [8]:
## Save Class-Specific Gradients/Activations for Later Interpretation

# This cell processes all test samples and saves class-specific averaged gradients/activations
# similar to eval.py, which can then be interpreted using read_tensor.py

def save_class_specific_gradcam(test_dataset, gradcam, model_name, output_dir, num_samples=None):
    """
    Process test samples and save class-specific averaged gradients/activations
    
    Args:
        test_dataset: Test dataset
        gradcam: GradCAM instance
        model_name: Name identifier for the model (e.g., 'model1', 'model2')
        output_dir: Directory to save the results
        num_samples: Number of samples to process (None for all)
    """
    import torch
    import os
    
    num_samples = num_samples or len(test_dataset)
    
    # Initialize accumulators per class for class-conditional Grad-CAM
    class_gradients = {}  # {class_id: [list of gradients]}
    class_activations = {}  # {class_id: [list of activations]}
    class_counts = {}
    class_inputs = {}  # Store representative inputs per class
    
    print(f"Processing {num_samples} samples to collect class-specific gradients/activations...")
    
    for idx in range(min(num_samples, len(test_dataset))):
        try:
            # Load sample
            sample = test_dataset[idx]
            video_tensor = sample[0]  # Already in [C, T, H, W] format
            if len(video_tensor.shape) == 4:
                video_tensor = video_tensor.unsqueeze(0)  # Add batch dimension [1, C, T, H, W]
            video_tensor = video_tensor.to(DEVICE)
            label = sample[1] if len(sample) > 1 else None
            
            # Get activations and gradients for the predicted class
            activations, gradients, pred_class, output = gradcam.get_activations_and_gradients(
                video_tensor, target_class=None
            )
            
            # Use predicted class for grouping (you can change to label if you want true class)
            target_class = pred_class.item() if isinstance(pred_class, torch.Tensor) else pred_class
            
            # Initialize class storage if needed
            if target_class not in class_gradients:
                class_gradients[target_class] = []
                class_activations[target_class] = []
                class_counts[target_class] = 0
                class_inputs[target_class] = None
            
            # Store sample data (move to CPU)
            class_activations[target_class].append(activations.cpu())
            class_gradients[target_class].append(gradients.cpu())
            class_counts[target_class] += 1
            
            # Store a representative input for this class (first occurrence)
            if class_inputs[target_class] is None:
                class_inputs[target_class] = video_tensor[0].detach().cpu()
            
            if (idx + 1) % 50 == 0:
                print(f"Processed {idx + 1}/{num_samples} samples")
                
        except Exception as e:
            print(f"Error processing sample {idx}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # Compute class-conditional average Grad-CAM after processing all samples
    print(f"\nComputing class-conditional average Grad-CAM for {len(class_gradients)} classes...")
    
    for class_id in sorted(class_gradients.keys()):
        if len(class_gradients[class_id]) > 0:
            print(f"Processing class {class_id} with {class_counts[class_id]} samples")
            
            # Stack all activations and gradients for this class
            stacked_activations = torch.stack(class_activations[class_id])  # [num_samples, N, C]
            stacked_gradients = torch.stack(class_gradients[class_id])  # [num_samples, N, C]
            
            # Compute average for this class
            avg_activations = torch.mean(stacked_activations, dim=0)  # [N, C]
            avg_gradients = torch.mean(stacked_gradients, dim=0)  # [N, C]
            
            # Save class-specific Grad-CAM (compatible with read_tensor.py format)
            class_save_path = os.path.join(output_dir, f'gradcam_{model_name}_class_{class_id}.pt')
            torch.save({
                'avg_activations': avg_activations,  # [N, C] - compatible with read_tensor.py
                'avg_gradients': avg_gradients,      # [N, C] - compatible with read_tensor.py
                'input': class_inputs[class_id],      # Representative input [C, T, H, W]
                'class_id': class_id,
                'num_samples': class_counts[class_id],
                'individual_activations': stacked_activations,  # [num_samples, N, C]
                'individual_gradients': stacked_gradients       # [num_samples, N, C]
            }, class_save_path)
            
            print(f"Saved class {class_id} Grad-CAM to {class_save_path}")
            print(f"  - avg_activations shape: {avg_activations.shape}")
            print(f"  - avg_gradients shape: {avg_gradients.shape}")
            print(f"  - num_samples: {class_counts[class_id]}\n")

# Process Model 1
print("=" * 60)
print("SAVING CLASS-SPECIFIC GRAD-CAM FOR MODEL 1")
print("=" * 60)

# Load Model 1 if needed
if 'model1' not in locals() or model1 is None:
    print("Loading Model 1...")
    model1 = load_model(CHECKPOINT_PATH_1, MODEL_NAME, NUM_CLASSES, NUM_FRAMES, TUBELET_SIZE)
    target_layer1 = model1.blocks[-1]
    gradcam1 = GradCAM(model1, target_layer1)
    print("Model 1 loaded and GradCAM initialized")

save_class_specific_gradcam(
    test_dataset, gradcam1, 'model1', OUTPUT_DIR, num_samples=None
)

# Clear Model 1 from GPU memory
del model1, gradcam1
torch.cuda.empty_cache()
print("Model 1 cleared from GPU memory")

# Process Model 2
print("\n" + "=" * 60)
print("SAVING CLASS-SPECIFIC GRAD-CAM FOR MODEL 2")
print("=" * 60)

# Load Model 2
print("Loading Model 2...")
model2 = load_model(CHECKPOINT_PATH_2, MODEL_NAME, NUM_CLASSES, NUM_FRAMES, TUBELET_SIZE)
target_layer2 = model2.blocks[-1]
gradcam2 = GradCAM(model2, target_layer2)
print("Model 2 loaded and GradCAM initialized")

save_class_specific_gradcam(
    test_dataset, gradcam2, 'model2', OUTPUT_DIR, num_samples=None
)

# Clear Model 2 from GPU memory
del model2, gradcam2
torch.cuda.empty_cache()
print("Model 2 cleared from GPU memory")

print("\n" + "=" * 60)
print("All class-specific Grad-CAM data saved!")
print(f"Saved files can be interpreted using read_tensor.py")
print(f"Output directory: {OUTPUT_DIR}")
print("=" * 60)

SAVING CLASS-SPECIFIC GRAD-CAM FOR MODEL 1
Loading Model 1...
Loading checkpoint from /home/tianze/Code/VideoMAE/checkpoints/final_experiments/K400-MRI-cov_loss_comb_diag-linear_probe/checkpoint-best.pth
Checkpoint loaded successfully
Model 1 loaded and GradCAM initialized
Processing 292 samples to collect class-specific gradients/activations...
Processed 50/292 samples
Processed 100/292 samples
Processed 150/292 samples
Processed 200/292 samples
Processed 250/292 samples

Computing class-conditional average Grad-CAM for 2 classes...
Processing class 0 with 213 samples
Saved class 0 Grad-CAM to /home/tianze/Code/VideoMAE/gradcam_comparison_results_bestvsbad/gradcam_model1_class_0.pt
  - avg_activations shape: torch.Size([1568, 768])
  - avg_gradients shape: torch.Size([1568, 768])
  - num_samples: 213

Processing class 1 with 79 samples
Saved class 1 Grad-CAM to /home/tianze/Code/VideoMAE/gradcam_comparison_results_bestvsbad/gradcam_model1_class_1.pt
  - avg_activations shape: torch.Si

In [9]:
def read_tensor(tensor_path,class_id=None,channels=768):
    if class_id is None:
        class_id = 'not_specified_class'
    parentdir = os.path.dirname(tensor_path)
    # Load the saved Grad-CAM tensors
    data = torch.load(tensor_path)
    # print(data.keys())
    # print(data['input'].shape)

    #%%
    activations = data['avg_activations']  # shape: [1568, 768]
    gradients = data['avg_gradients']      # shape: [1568, 768]
    print('activations shape:', activations.shape)
    print('gradients shape:', gradients.shape)

    # Set your parameters (adjust as needed)
    num_frames = 8
    height = 14
    width = 14

    # Reshape to [num_frames, height, width, channels]
    activations = activations.view(num_frames, height, width, channels)
    gradients = gradients.view(num_frames, height, width, channels)

    # Average gradients over spatial dimensions (frames, height, width)
    weights = gradients.mean(dim=(0, 1, 2))  # shape: [channels]

    # Weighted sum of activations
    cam = (weights * activations).sum(dim=3)  # shape: [num_frames, height, width]

    # Take mean over frames if you want a single heatmap
    cam = cam.cpu().numpy() if hasattr(cam, 'cpu') else np.array(cam)
    print('min cam:', cam.min(), 'max cam:', cam.max())
    # cam = np.maximum(cam, 0)  # ReLU

    # Visualize Grad-CAM overlayed on input patches
    input_tensor = data['input']  # shape: [3, 16, 224, 224]
    print('input shape:', input_tensor.shape)

    # Convert input to numpy
    input_np = input_tensor.cpu().numpy() if hasattr(input_tensor, 'cpu') else np.array(input_tensor)

    frames_per_patch = input_np.shape[1] // num_frames  # 16 // 8 = 2
    # Compute frame importance scores
    frame_importance = []

    for i in range(num_frames):
        ###### Grad-CAM magnitude per frame
        # Method 1: Sum of all Grad-CAM values in the frame
        frame_importance_sum = cam[i].sum()
        # Store the importance score (choose one method)
        frame_importance.append(frame_importance_sum)  # or use any other method

        ##### end Grad-CAM magnitude per frame
        # Grad-CAM heatmap
        frame_cam = cam[i].astype(np.float32)
        frame_cam -= frame_cam.min()
        if frame_cam.max() > 0:
            frame_cam /= frame_cam.max()
        frame_cam_resized = cv2.resize(frame_cam, (224, 224))

        # Corresponding input frames (mean over the patch)
        start = i * frames_per_patch
        end = (i + 1) * frames_per_patch
        input_patch = input_np[:, start:end, :, :]  # shape: [3, 2, 224, 224]
        input_patch_mean = input_patch.mean(axis=1)  # shape: [3, 224, 224]
        input_patch_img = np.transpose(input_patch_mean, (1, 2, 0))  # [224, 224, 3]
        input_patch_img = (input_patch_img - input_patch_img.min()) / (input_patch_img.max() - input_patch_img.min() + 1e-8)

        # Rotate image and heatmap 180 degrees
        input_patch_img = np.rot90(input_patch_img, 2)  # 180-degree rotation
        frame_cam_resized = np.rot90(frame_cam_resized, 2)  # 180-degree rotation
        
        # Plot input and Grad-CAM side by side
        plt.figure(figsize=(8, 4))
        plt.subplot(1, 2, 1)
        plt.imshow(input_patch_img)
        plt.title(f'Input Patch Mean {i} ({start}-{end-1})')
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(input_patch_img)
        plt.imshow(frame_cam_resized, cmap='jet', alpha=0.5)
        plt.title(f'Grad-CAM Patch {i}')
        plt.axis('off')
        plt.colorbar(fraction=0.046, pad=0.04)
        plt.tight_layout()
        plt.savefig(f'{parentdir}/class_{class_id}_gradcam_patch_{i}.png')
        plt.close()


def read_tensor_comparison(tensor_path_model1, tensor_path_model2, class_id=None, channels=768):
    """
    Compare Grad-CAM from two models side by side
    
    Args:
        tensor_path_model1: Path to model1 Grad-CAM .pt file
        tensor_path_model2: Path to model2 Grad-CAM .pt file
        class_id: Class identifier for naming
        channels: Number of channels (default 768)
    """
    if class_id is None:
        class_id = 'comparison'
    parentdir = os.path.dirname(tensor_path_model1)
    
    # Load both models' data
    data1 = torch.load(tensor_path_model1)
    data2 = torch.load(tensor_path_model2)
    
    # Process Model 1
    activations1 = data1['avg_activations']  # shape: [1568, 768]
    gradients1 = data1['avg_gradients']      # shape: [1568, 768]
    
    # Process Model 2
    activations2 = data2['avg_activations']  # shape: [1568, 768]
    gradients2 = data2['avg_gradients']      # shape: [1568, 768]
    
    print('Model 1 - activations shape:', activations1.shape, 'gradients shape:', gradients1.shape)
    print('Model 2 - activations shape:', activations2.shape, 'gradients shape:', gradients2.shape)
    
    # Set your parameters (adjust as needed)
    num_frames = 8
    height = 14
    width = 14
    
    # Reshape Model 1
    activations1 = activations1.view(num_frames, height, width, channels)
    gradients1 = gradients1.view(num_frames, height, width, channels)
    weights1 = gradients1.mean(dim=(0, 1, 2))  # shape: [channels]
    cam1 = (weights1 * activations1).sum(dim=3)  # shape: [num_frames, height, width]
    cam1 = cam1.cpu().numpy() if hasattr(cam1, 'cpu') else np.array(cam1)
    
    # Reshape Model 2
    activations2 = activations2.view(num_frames, height, width, channels)
    gradients2 = gradients2.view(num_frames, height, width, channels)
    weights2 = gradients2.mean(dim=(0, 1, 2))  # shape: [channels]
    cam2 = (weights2 * activations2).sum(dim=3)  # shape: [num_frames, height, width]
    cam2 = cam2.cpu().numpy() if hasattr(cam2, 'cpu') else np.array(cam2)
    
    print('Model 1 CAM - min:', cam1.min(), 'max:', cam1.max())
    print('Model 2 CAM - min:', cam2.min(), 'max:', cam2.max())
    
    # Use input from model1 (should be same for both models)
    input_tensor = data1['input']  # shape: [3, 16, 224, 224]
    print('input shape:', input_tensor.shape)
    
    # Convert input to numpy
    input_np = input_tensor.cpu().numpy() if hasattr(input_tensor, 'cpu') else np.array(input_tensor)
    
    frames_per_patch = input_np.shape[1] // num_frames  # 16 // 8 = 2
    
    for i in range(num_frames):
        # Process Model 1 CAM
        frame_cam1 = cam1[i].astype(np.float32)
        frame_cam1 -= frame_cam1.min()
        if frame_cam1.max() > 0:
            frame_cam1 /= frame_cam1.max()
        frame_cam1_resized = cv2.resize(frame_cam1, (224, 224))
        
        # Process Model 2 CAM
        frame_cam2 = cam2[i].astype(np.float32)
        frame_cam2 -= frame_cam2.min()
        if frame_cam2.max() > 0:
            frame_cam2 /= frame_cam2.max()
        frame_cam2_resized = cv2.resize(frame_cam2, (224, 224))
        
        # Corresponding input frames (mean over the patch)
        start = i * frames_per_patch
        end = (i + 1) * frames_per_patch
        input_patch = input_np[:, start:end, :, :]  # shape: [3, 2, 224, 224]
        input_patch_mean = input_patch.mean(axis=1)  # shape: [3, 224, 224]
        input_patch_img = np.transpose(input_patch_mean, (1, 2, 0))  # [224, 224, 3]
        input_patch_img = (input_patch_img - input_patch_img.min()) / (input_patch_img.max() - input_patch_img.min() + 1e-8)
        
        # Rotate image and heatmaps 180 degrees
        input_patch_img = np.rot90(input_patch_img, 2)  # 180-degree rotation
        frame_cam1_resized = np.rot90(frame_cam1_resized, 2)  # 180-degree rotation
        frame_cam2_resized = np.rot90(frame_cam2_resized, 2)  # 180-degree rotation
        
        # Plot: Input | Input + Model1 GradCAM | Input + Model2 GradCAM
        plt.figure(figsize=(15, 5))
        
        # Column 1: Original input
        plt.subplot(1, 3, 1)
        plt.imshow(input_patch_img)
        plt.title(f'Input Patch {i}\n({start}-{end-1})')
        plt.axis('off')
        
        # Column 2: Input + Model1 GradCAM
        plt.subplot(1, 3, 2)
        plt.imshow(input_patch_img)
        plt.imshow(frame_cam1_resized, cmap='jet', alpha=0.5)
        plt.title(f'Model 1 GradCAM\nPatch {i}')
        plt.axis('off')
        
        # Column 3: Input + Model2 GradCAM
        plt.subplot(1, 3, 3)
        plt.imshow(input_patch_img)
        plt.imshow(frame_cam2_resized, cmap='jet', alpha=0.5)
        plt.title(f'Model 2 GradCAM\nPatch {i}')
        plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(f'{parentdir}/class_{class_id}_comparison_patch_{i}.png', dpi=150, bbox_inches='tight')
        plt.close()
    
    print(f"Saved comparison visualizations to {parentdir}")

In [10]:
# Run read_tensor_comparison to visualize both models side by side
# This will generate comparison visualizations showing: Input | Input + Model1 GradCAM | Input + Model2 GradCAM

# Compare Class 0
print("=" * 60)
print("Comparing Model 1 vs Model 2 - Class 0")
print("=" * 60)
tensor_path_model1_class0 = os.path.join(OUTPUT_DIR, 'gradcam_model1_class_0.pt')
tensor_path_model2_class0 = os.path.join(OUTPUT_DIR, 'gradcam_model2_class_0.pt')
if os.path.exists(tensor_path_model1_class0) and os.path.exists(tensor_path_model2_class0):
    read_tensor_comparison(
        tensor_path_model1_class0, 
        tensor_path_model2_class0, 
        class_id='class0', 
        channels=768
    )
else:
    print(f"Files not found:")
    if not os.path.exists(tensor_path_model1_class0):
        print(f"  - {tensor_path_model1_class0}")
    if not os.path.exists(tensor_path_model2_class0):
        print(f"  - {tensor_path_model2_class0}")

# Compare Class 1
print("\n" + "=" * 60)
print("Comparing Model 1 vs Model 2 - Class 1")
print("=" * 60)
tensor_path_model1_class1 = os.path.join(OUTPUT_DIR, 'gradcam_model1_class_1.pt')
tensor_path_model2_class1 = os.path.join(OUTPUT_DIR, 'gradcam_model2_class_1.pt')
if os.path.exists(tensor_path_model1_class1) and os.path.exists(tensor_path_model2_class1):
    read_tensor_comparison(
        tensor_path_model1_class1, 
        tensor_path_model2_class1, 
        class_id='class1', 
        channels=768
    )
else:
    print(f"Files not found:")
    if not os.path.exists(tensor_path_model1_class1):
        print(f"  - {tensor_path_model1_class1}")
    if not os.path.exists(tensor_path_model2_class1):
        print(f"  - {tensor_path_model2_class1}")

print("\n" + "=" * 60)
print("All comparison visualizations complete!")
print(f"Check {OUTPUT_DIR} for saved images")
print("Files saved as: class_class0_comparison_patch_X.png and class_class1_comparison_patch_X.png")
print("=" * 60)

Comparing Model 1 vs Model 2 - Class 0
Model 1 - activations shape: torch.Size([1568, 768]) gradients shape: torch.Size([1568, 768])
Model 2 - activations shape: torch.Size([1568, 768]) gradients shape: torch.Size([1568, 768])
Model 1 CAM - min: -0.0020994954 max: 0.0028355212
Model 2 CAM - min: -0.0016992074 max: 0.002305457
input shape: torch.Size([3, 16, 224, 224])
Saved comparison visualizations to /home/tianze/Code/VideoMAE/gradcam_comparison_results_bestvsbad

Comparing Model 1 vs Model 2 - Class 1
Model 1 - activations shape: torch.Size([1568, 768]) gradients shape: torch.Size([1568, 768])
Model 2 - activations shape: torch.Size([1568, 768]) gradients shape: torch.Size([1568, 768])
Model 1 CAM - min: -0.0031615943 max: 0.002051248
Model 2 CAM - min: -0.002597673 max: 0.0016390154
input shape: torch.Size([3, 16, 224, 224])
Saved comparison visualizations to /home/tianze/Code/VideoMAE/gradcam_comparison_results_bestvsbad

All comparison visualizations complete!
Check /home/tianze/