In [1]:
# Enable auto-reloading of imports when they have been modified
from IPython import get_ipython
ipython = get_ipython(); assert ipython is not None
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

import os
import json
import torch
import torch.nn as nn

# Disable gradient computation - this notebook will only perform forward passes
torch.set_grad_enabled(False)

from pathlib import Path
import sys
import os

# Add the base (root) directory to the path so we can import the util modules
def get_base_folder(project_root = "Count_PIPNet"):
	# Find the project root dynamically
	current_dir = os.getcwd()
	while True:
		if os.path.basename(current_dir) == project_root:  # Adjust to match your project root folder name
			break
		parent = os.path.dirname(current_dir)
		if parent == current_dir:  # Stop if we reach the system root (failsafe)
			raise RuntimeError(f"Project root {project_root} not found. Check your folder structure.")
		current_dir = parent

	return Path(current_dir)

base_path = get_base_folder()
print(f"Base path: {base_path}")
sys.path.append(str(base_path))

Base path: /mnt/ssd-1/mechinterp/taras/Count_PIPNet


In [2]:
from util.vis_pipnet import visualize_topk
from pipnet.count_pipnet import get_count_network
from util.checkpoint_manager import CheckpointManager
from util.data import get_dataloaders
from util.args import get_args
from util.vis_pipnet import visualize_topk

In [3]:
# Device setup
GPU_TO_USE = 3

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = f"cuda:{GPU_TO_USE}" if torch.cuda.is_available() else "cpu"

print(f'>>> Using {device} device <<<')

>>> Using cuda:3 device <<<


In [4]:
multi_experiment_dir = base_path / 'runs/stage_3'
multi_experiment_dir

PosixPath('/mnt/ssd-1/mechinterp/taras/Count_PIPNet/runs/stage_3')

In [5]:
visualization_dir = base_path / 'visualizations'
os.makedirs(visualization_dir, exist_ok=True)

In [6]:
summary_path = os.path.join(multi_experiment_dir, 'summary.json')

# Load the summary file to get all run directories
with open(summary_path, 'r') as f:
    summary = json.load(f)

print(f"Found {len(summary)} trained models")

Found 2 trained models


In [25]:
def load_model(run_dir, checkpoint_name='net_trained_best', base_path=base_path, gpu_id=3):
	"""
	Load a model from a checkpoint directory for visualization purposes.

	Args:
		run_dir: Directory containing the run results
		checkpoint_name: Name of checkpoint to load (default: 'net_trained_best')
		base_path: Base path for dataset directories (default: None)
		gpu_id: GPU ID to use (default: 0)
		
	Returns:
		Tuple of (net, projectloader, classes, args, is_count_pipnet)
	"""
	# Step 1: Load the configuration used for this run
	metadata_dir = os.path.join(run_dir, 'metadata')
	args_path = os.path.join(metadata_dir, 'args.pickle')

	import pickle
	with open(args_path, 'rb') as f:
		args = pickle.load(f)
	print(f"Loaded configuration from {args_path}")

	# Explicitly set GPU ID to ensure device consistency
	if torch.cuda.is_available():
		args.gpu_ids = str(gpu_id)
		device = torch.device(f'cuda:{gpu_id}')
		torch.cuda.set_device(device)
	else:
		device = torch.device('cpu')

	print(f"Using device: {device}")

	# Step 2: Create dataloaders (needed for projectloader)
	args.log_dir = run_dir  # Use the run directory as log_dir
	trainloader, trainloader_pretraining, trainloader_normal, \
	trainloader_normal_augment, projectloader, testloader, test_projectloader, classes = get_dataloaders(args, device, base_path)

	# Step 3: Create a model with the same architecture
	if hasattr(args, 'model') and args.model == 'count_pipnet':
		is_count_pipnet = True
		net, num_prototypes = get_count_network(
			num_classes=len(classes), 
			args=args,
			max_count=getattr(args, 'max_count', 3),
			use_ste=getattr(args, 'use_ste', False))
	else:
		from pipnet.pipnet import get_pipnet
		is_count_pipnet = False
		net, num_prototypes = get_pipnet(len(classes), args)

	# Step 4: Move model to device (don't use DataParallel yet)
	net = net.to(device)

	# Step 5: Forward one batch through the backbone to get the latent output size
	# This needs to happen BEFORE loading the checkpoint
	with torch.no_grad():
		# Use a small batch to determine output shape
		xs1, _, _ = next(iter(trainloader))
		xs1 = xs1.to(device)

		# Single-forward pass without DataParallel
		features = net._net(xs1)
		proto_features = net._add_on(features)

		wshape = proto_features.shape[-1]
		args.wshape = wshape  # needed for calculating image patch size
		print(f"Output shape: {proto_features.shape}, setting wshape={wshape}")
            
	# Step 6: Now wrap with DataParallel
	device_ids = [gpu_id]
	print(f"Using device_ids: {device_ids}")
	net = nn.DataParallel(net, device_ids=device_ids)

	# Step 7: Direct checkpoint loading
	checkpoint_path = os.path.join(run_dir, 'checkpoints', checkpoint_name)
	if not os.path.exists(checkpoint_path):
		print(f"Checkpoint not found at {checkpoint_path}, trying alternative paths...")
		# Try with full path as fallback
		if os.path.exists(checkpoint_name):
			checkpoint_path = checkpoint_name
		else:
			# Try other common checkpoint names
			alternatives = [
				os.path.join(run_dir, 'checkpoints', 'net_trained_last'),
				os.path.join(run_dir, 'checkpoints', 'net_trained'),
				checkpoint_name # in case the direct path was passed
			]
			for alt_path in alternatives:
				if os.path.exists(alt_path):
					checkpoint_path = alt_path
					print(f"Found alternative checkpoint at {checkpoint_path}")
					break
			else:
				print("No checkpoint found")
				return None, None, None, None, None

	try:
		# Load just the model state dict, ignore optimizer states
		checkpoint = torch.load(checkpoint_path, map_location=device)
		
		if 'model_state_dict' in checkpoint:
			net.load_state_dict(checkpoint['model_state_dict'], strict=True)
			print(f"Successfully loaded model state from {checkpoint_path}")
			
			# Display additional information if available
			if 'epoch' in checkpoint:
				print(f"Checkpoint from epoch {checkpoint['epoch']}")
			if 'accuracy' in checkpoint:
				print(f"Model accuracy: {checkpoint['accuracy']:.4f}")
			
			return net, projectloader, classes, args, is_count_pipnet
		else:
			print(f"Checkpoint doesn't contain model_state_dict")
			return None, None, None, None, None
			
	except Exception as e:
		print(f"Error loading checkpoint: {str(e)}")
		import traceback
		traceback.print_exc()
		return None, None, None, None, None

# Pre-trained prototypes visualization

In [31]:
checkpoint_to_load = 'runs/45_shapesGN_linear'
checkpoint_to_load_pretrain_dir = base_path / checkpoint_to_load
checkpoint_name = 'net_pretrained_e9aef2d9e9'

print(f'Loading a checkpoint {checkpoint_name} from {checkpoint_to_load_pretrain_dir}')

Loading a checkpoint net_pretrained_e9aef2d9e9 from /mnt/ssd-1/mechinterp/taras/Count_PIPNet/runs/45_shapesGN_linear


In [32]:
net, projectloader, classes, args, is_count_pipnet = load_model(checkpoint_to_load_pretrain_dir, gpu_id=GPU_TO_USE,
                                                                checkpoint_name=checkpoint_name)

Loaded configuration from /mnt/ssd-1/mechinterp/taras/Count_PIPNet/runs/45_shapesGN_linear/metadata/args.pickle
Using device: cuda:3
Num classes (k) =  9 ['class_1', 'class_2', 'class_3', 'class_4', 'class_5'] etc.


Detected 192 output channels from last conv layer
Number of prototypes set from 192 to 16. Extra 1x1 conv layer added.
Output shape: torch.Size([64, 16, 24, 24]), setting wshape=24
Using device_ids: [3]
Successfully loaded model state from /mnt/ssd-1/mechinterp/taras/Count_PIPNet/runs/45_shapesGN_linear/checkpoints/net_pretrained_e9aef2d9e9


In [33]:
run_vis_dir = visualization_dir / checkpoint_to_load.split('/')[-1]

print(f'Saving viz to {run_vis_dir}')

os.makedirs(run_vis_dir, exist_ok=True)

Saving viz to /mnt/ssd-1/mechinterp/taras/Count_PIPNet/visualizations/45_shapesGN_linear


In [34]:
topks = visualize_topk(net, projectloader, len(classes), device, run_vis_dir, args, k=10,
					   plot_histograms=True, visualize_prototype_maps=True, are_pretraining_prototypes=True)
print(f"Visualization saved to {run_vis_dir}")

Visualizing prototypes for topk (CountPIPNet with class-based count selection)...
Using class-to-count mapping: {(1, 3): 1, (4, 6): 2, (7, 9): 3}
Available counts: [1, 2, 3]
classifier_input_weights shape: torch.Size([48])
classifier_input_weights:
tensor([0.3333, 0.6667, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000], device='cuda:3')
Detected model type: CountPIPNet with max_count=3
Visualizing pretraining prototypes
Examples per count group: 10
Pooled shape: torch.Size([16]), min: 0.0, max: 561.0
Feature maps shape: torch.Size([16, 24, 24])
Non-zero pooled values: 10 out of 16
Pooled values (counts): [3.0, 0.0, 2.0, 0.0, 0.0, 2.

Collecting activations:  28%|██▊       | 1000/3600 [00:03<00:08, 289.88it/s]
Single-pass processing for class-based count selection: 100% 3600/3600 [00:19<00:00, 183.06it/s]

Abstained: 0





0 prototypes do not have any examples. Will be ignored in visualization.
Creating prototype feature map visualizations with count information...


Traceback (most recent call last):
  File "/home/taras/.conda/envs/taras_ami/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/home/taras/.conda/envs/taras_ami/lib/python3.10/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/home/taras/.conda/envs/taras_ami/lib/python3.10/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor


Visualization saved to /mnt/ssd-1/mechinterp/taras/Count_PIPNet/visualizations/45_shapesGN_linear


# Best trained model visualization

In [35]:
checkpoint_to_load = '20250401_064200_0_linear'
path_to_load = multi_experiment_dir / checkpoint_to_load

print(f'Loading a checkpoint from {path_to_load}...')

Loading a checkpoint from /mnt/ssd-1/mechinterp/taras/Count_PIPNet/runs/stage_3/20250401_064200_0_linear...


In [36]:
net, projectloader, classes, args, is_count_pipnet = load_model(path_to_load, gpu_id=GPU_TO_USE)

Loaded configuration from /mnt/ssd-1/mechinterp/taras/Count_PIPNet/runs/stage_3/20250401_064200_0_linear/metadata/args.pickle
Using device: cuda:3
Num classes (k) =  9 ['class_1', 'class_2', 'class_3', 'class_4', 'class_5'] etc.


Detected 192 output channels from last conv layer
Number of prototypes set from 192 to 16. Extra 1x1 conv layer added.
Output shape: torch.Size([64, 16, 24, 24]), setting wshape=24
Using device_ids: [3]
Successfully loaded model state from /mnt/ssd-1/mechinterp/taras/Count_PIPNet/runs/stage_3/20250401_064200_0_linear/checkpoints/net_trained_best
Checkpoint from epoch 93
Model accuracy: 0.9922


In [37]:
run_vis_dir = visualization_dir /checkpoint_to_load

print(f'Saving viz to {run_vis_dir}')

os.makedirs(run_vis_dir, exist_ok=True)

Saving viz to /mnt/ssd-1/mechinterp/taras/Count_PIPNet/visualizations/20250401_064200_0_linear


In [38]:
topks = visualize_topk(net, projectloader, len(classes), device, run_vis_dir, args, k=10,
					   plot_histograms=True, visualize_prototype_maps=True)
print(f"Visualization saved to {run_vis_dir}")

Visualizing prototypes for topk (CountPIPNet with class-based count selection)...
Using class-to-count mapping: {(1, 3): 1, (4, 6): 2, (7, 9): 3}
Available counts: [1, 2, 3]
classifier_input_weights shape: torch.Size([48])
classifier_input_weights:
tensor([0.3333, 0.6667, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000], device='cuda:3')
Detected model type: CountPIPNet with max_count=3
Visualizing trained prototypes
Examples per count group: 10


Pooled shape: torch.Size([16]), min: 0.0, max: 554.0
Feature maps shape: torch.Size([16, 24, 24])
Non-zero pooled values: 4 out of 16
Pooled values (counts): [9.0, 12.0, 0.0, 0.0, 0.0, 0.0, 554.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
Classification weights shape: torch.Size([9, 48]) (expanded)
Max classification weight per prototype: [2.0725, 0.0, 0.2801, 5.7195, 0.0, 1.96, 0.0, 6.4921, 0.043, 0.0, 5.0559, 2.8195, 1.4803, 1.3423, 1.5896, 2.6441]
Prototypes with weight > 1e-1: 11
Prototype indices with weight > 1e-1: [0, 2, 3, 5, 7, 10, 11, 12, 13, 14, 15]


Collecting activations:  28%|██▊       | 1000/3600 [00:03<00:09, 286.51it/s]
Single-pass processing for class-based count selection: 100% 3600/3600 [00:17<00:00, 211.10it/s]

Abstained: 0
0 prototypes do not have any examples. Will be ignored in visualization.
Creating prototype feature map visualizations with count information...





Visualization saved to /mnt/ssd-1/mechinterp/taras/Count_PIPNet/visualizations/20250401_064200_0_linear


## Global explanation

In [72]:
def calculate_global_explanation(net, classes):
    """
    Calculate the importance of each prototype for each class in the network.
    
    Args:
        net: The trained CountPIPNet model
        classes: List of class names
        
    Returns:
        Dictionary mapping class indices to tensors of prototype importances
    """
    num_prototypes = net.module._num_prototypes

    # Dictionary to store the importance of each prototype for each class
    class_prototype_importances = {}

    # Iterate through all prototypes
    for i in range(num_prototypes):
        # Get importance of this prototype for each class
        prototype_importance_per_class = net.module.get_prototype_importance_per_class(i)
        
        # Distribute the importance values to their respective classes
        for class_idx, class_importance in enumerate(prototype_importance_per_class):
            if class_idx not in class_prototype_importances.keys():
                class_prototype_importances[class_idx] = torch.zeros([num_prototypes], device=class_importance.device)

            class_prototype_importances[class_idx][i] += class_importance
    
    return class_prototype_importances

def show_global_explanation(net, classes, global_explanation=None, top_k_prototypes=None, output_path=None):
    """
    Visualize the global explanation as a heatmap using Plotly.
    
    Args:
        net: The trained CountPIPNet model
        classes: List of class names
        global_explanation: Pre-computed global explanation (optional)
        top_k_prototypes: Number of top prototypes to display per class (optional)
        output_path: Path to save the visualization (optional)
        
    Returns:
        Plotly figure object
    """
    import plotly.graph_objects as go
    import numpy as np
    
    # Calculate global explanation if not provided
    if global_explanation is None:
        global_explanation = calculate_global_explanation(net, classes)
    
    # Convert dictionary to numpy array for heatmap
    num_classes = len(global_explanation)
    num_prototypes = net.module._num_prototypes
    
    # Initialize the data matrix
    data_matrix = np.zeros((num_classes, num_prototypes))
    
    # Fill in the data matrix with importance values
    for class_idx, prototype_importances in global_explanation.items():
        data_matrix[class_idx] = prototype_importances.cpu().numpy()
    
    # Optionally filter to show only the top-k most important prototypes per class
    if top_k_prototypes is not None:
        # Create a mask of top-k prototypes per class
        top_k_mask = np.zeros_like(data_matrix, dtype=bool)
        for class_idx in range(num_classes):
            # Get indices of top-k prototypes for this class
            top_indices = np.argsort(data_matrix[class_idx])[-top_k_prototypes:]
            top_k_mask[class_idx, top_indices] = True
        
        # Apply mask (set non-top-k values to NaN for better visibility)
        filtered_data = np.where(top_k_mask, data_matrix, np.nan)
    else:
        filtered_data = data_matrix
    
    # Create x and y labels
    y_labels = [f"{classes[i]}" if i < len(classes) else f"Class {i}" 
               for i in range(num_classes)]
    x_labels = [f"Prototype {i}" for i in range(num_prototypes)]
    
    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=filtered_data,
        x=x_labels,
        y=y_labels,
        colorscale='Plasma',
        hoverongaps=False,
        colorbar=dict(
            title="Importance",
            titleside="right"
        )
    ))
    
    # Update layout for better readability
    fig.update_layout(
        title="Global Explanation: Prototype Importance per Class",
        xaxis=dict(
            title="Prototypes",
            tickangle=-45,
        ),
        yaxis=dict(
            title="Classes",
        ),
        width=max(800, num_prototypes * 25),  # Adjust width based on number of prototypes
        height=max(600, num_classes * 30),    # Adjust height based on number of classes
        margin=dict(l=150, r=50, t=100, b=150)
    )
    
    # Save the figure if output path is provided
    if output_path:
        fig.write_image(output_path)
        print(f"Saved global explanation visualization to {output_path}")
    
    # Display the figure
    fig.show()

In [73]:
show_global_explanation(net, classes)

In [1]:
(8 + 2) / 9

1.1111111111111112

In [1]:
3.488888888888889 / 4.68

0.7454890788224122

In [2]:
1.2777777777777777 / 1.8555555555555556

0.6886227544910178