In [1]:
# Enable auto-reloading of imports when they have been modified
from IPython import get_ipython
ipython = get_ipython(); assert ipython is not None
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

import os
import json
import torch
import torch.nn as nn

# Disable gradient computation - this notebook will only perform forward passes
torch.set_grad_enabled(False)

from pathlib import Path
import sys
import os

# Add the base (root) directory to the path so we can import the util modules
def get_base_folder(project_root = "Count_PIPNet"):
	# Find the project root dynamically
	current_dir = os.getcwd()
	while True:
		if os.path.basename(current_dir) == project_root:  # Adjust to match your project root folder name
			break
		parent = os.path.dirname(current_dir)
		if parent == current_dir:  # Stop if we reach the system root (failsafe)
			raise RuntimeError(f"Project root {project_root} not found. Check your folder structure.")
		current_dir = parent

	return Path(current_dir)

base_path = get_base_folder()
print(f"Base path: {base_path}")
sys.path.append(str(base_path))

Base path: /mnt/ssd-1/mechinterp/taras/Count_PIPNet


In [2]:
from util.vis_pipnet import visualize_topk
from pipnet.count_pipnet import get_count_network
from util.checkpoint_manager import CheckpointManager
from util.data import get_dataloaders
from util.args import get_args
from util.vis_pipnet import visualize_topk

In [3]:
# Device setup
GPU_TO_USE = 3

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = f"cuda:{GPU_TO_USE}" if torch.cuda.is_available() else "cpu"

print(f'>>> Using {device} device <<<')

>>> Using cuda:3 device <<<


In [4]:
multi_experiment_dir = base_path / 'runs/stage_3'
multi_experiment_dir

PosixPath('/mnt/ssd-1/mechinterp/taras/Count_PIPNet/runs/stage_3')

In [5]:
visualization_dir = base_path / 'visualizations'
os.makedirs(visualization_dir, exist_ok=True)

In [6]:
summary_path = os.path.join(multi_experiment_dir, 'summary.json')

# Load the summary file to get all run directories
with open(summary_path, 'r') as f:
    summary = json.load(f)

print(f"Found {len(summary)} trained models")

Found 2 trained models


In [7]:
def load_model(run_dir, checkpoint_name='net_trained_best', base_path=base_path, gpu_id=3):
	"""
	Load a model from a checkpoint directory for visualization purposes.

	Args:
		run_dir: Directory containing the run results
		checkpoint_name: Name of checkpoint to load (default: 'net_trained_best')
		base_path: Base path for dataset directories (default: None)
		gpu_id: GPU ID to use (default: 0)
		
	Returns:
		Tuple of (net, projectloader, classes, args, is_count_pipnet)
	"""
	# Step 1: Load the configuration used for this run
	metadata_dir = os.path.join(run_dir, 'metadata')
	args_path = os.path.join(metadata_dir, 'args.pickle')

	import pickle
	with open(args_path, 'rb') as f:
		args = pickle.load(f)
	print(f"Loaded configuration from {args_path}")

	# Explicitly set GPU ID to ensure device consistency
	if torch.cuda.is_available():
		args.gpu_ids = str(gpu_id)
		device = torch.device(f'cuda:{gpu_id}')
		torch.cuda.set_device(device)
	else:
		device = torch.device('cpu')

	print(f"Using device: {device}")

	# Step 2: Create dataloaders (needed for projectloader)
	args.log_dir = run_dir  # Use the run directory as log_dir
	trainloader, trainloader_pretraining, trainloader_normal, \
	trainloader_normal_augment, projectloader, testloader, test_projectloader, classes = get_dataloaders(args, device, base_path)

	# Step 3: Create a model with the same architecture
	if hasattr(args, 'model') and args.model == 'count_pipnet':
		is_count_pipnet = True
		net, num_prototypes = get_count_network(
			num_classes=len(classes), 
			args=args,
			max_count=getattr(args, 'max_count', 3),
			use_ste=getattr(args, 'use_ste', False))
	else:
		from pipnet.pipnet import get_pipnet
		is_count_pipnet = False
		net, num_prototypes = get_pipnet(len(classes), args)

	# Step 4: Move model to device (don't use DataParallel yet)
	net = net.to(device)

	# Step 5: Forward one batch through the backbone to get the latent output size
	# This needs to happen BEFORE loading the checkpoint
	with torch.no_grad():
		# Use a small batch to determine output shape
		xs1, _, _ = next(iter(trainloader))
		xs1 = xs1.to(device)

		# Single-forward pass without DataParallel
		features = net._net(xs1)
		proto_features = net._add_on(features)

		wshape = proto_features.shape[-1]
		args.wshape = wshape  # needed for calculating image patch size
		print(f"Output shape: {proto_features.shape}, setting wshape={wshape}")
            
	# Step 6: Now wrap with DataParallel
	device_ids = [gpu_id]
	print(f"Using device_ids: {device_ids}")
	net = nn.DataParallel(net, device_ids=device_ids)

	# Step 7: Direct checkpoint loading
	checkpoint_path = os.path.join(run_dir, 'checkpoints', checkpoint_name)
	if not os.path.exists(checkpoint_path):
		print(f"Checkpoint not found at {checkpoint_path}, trying alternative paths...")
		# Try with full path as fallback
		if os.path.exists(checkpoint_name):
			checkpoint_path = checkpoint_name
		else:
			# Try other common checkpoint names
			alternatives = [
				os.path.join(run_dir, 'checkpoints', 'net_trained_last'),
				os.path.join(run_dir, 'checkpoints', 'net_trained')
			]
			for alt_path in alternatives:
				if os.path.exists(alt_path):
					checkpoint_path = alt_path
					print(f"Found alternative checkpoint at {checkpoint_path}")
					break
			else:
				print("No checkpoint found")
				return None, None, None, None, None

	try:
		# Load just the model state dict, ignore optimizer states
		checkpoint = torch.load(checkpoint_path, map_location=device)
		
		if 'model_state_dict' in checkpoint:
			net.load_state_dict(checkpoint['model_state_dict'])
			print(f"Successfully loaded model state from {checkpoint_path}")
			
			# Display additional information if available
			if 'epoch' in checkpoint:
				print(f"Checkpoint from epoch {checkpoint['epoch']}")
			if 'accuracy' in checkpoint:
				print(f"Model accuracy: {checkpoint['accuracy']:.4f}")
			
			return net, projectloader, classes, args, is_count_pipnet
		else:
			print(f"Checkpoint doesn't contain model_state_dict")
			return None, None, None, None, None
			
	except Exception as e:
		print(f"Error loading checkpoint: {str(e)}")
		import traceback
		traceback.print_exc()
		return None, None, None, None, None

In [8]:
summary

[{'run_index': 1,
  'config_path': 'configs/full_linear.yaml',
  'status': 'completed',
  'duration': 1631.151186466217,
  'log_dir': './runs/stage_3/20250401_071214_1_full_linear'},
 {'run_index': 2,
  'config_path': 'configs/identity.yaml',
  'status': 'completed',
  'duration': 1676.0133264064789,
  'log_dir': './runs/stage_3/20250401_073925_2_identity'}]

In [9]:
checkpoint_to_load = '20250401_064200_0_linear'
path_to_load = multi_experiment_dir / checkpoint_to_load

print(f'Loading a checkpoint from {path_to_load}...')

Loading a checkpoint from /mnt/ssd-1/mechinterp/taras/Count_PIPNet/runs/stage_3/20250401_064200_0_linear...


In [10]:
net, projectloader, classes, args, is_count_pipnet = load_model(path_to_load, gpu_id=GPU_TO_USE)

Loaded configuration from /mnt/ssd-1/mechinterp/taras/Count_PIPNet/runs/stage_3/20250401_064200_0_linear/metadata/args.pickle
Using device: cuda:3
Num classes (k) =  9 ['class_1', 'class_2', 'class_3', 'class_4', 'class_5'] etc.


Detected 192 output channels from last conv layer
Number of prototypes set from 192 to 16. Extra 1x1 conv layer added.
Output shape: torch.Size([64, 16, 24, 24]), setting wshape=24
Using device_ids: [3]
Successfully loaded model state from /mnt/ssd-1/mechinterp/taras/Count_PIPNet/runs/stage_3/20250401_064200_0_linear/checkpoints/net_trained_best
Checkpoint from epoch 93
Model accuracy: 0.9922


In [17]:
run_vis_dir = visualization_dir /checkpoint_to_load

print(f'Saving viz to {run_vis_dir}')

os.makedirs(run_vis_dir, exist_ok=True)

Saving viz to /mnt/ssd-1/mechinterp/taras/Count_PIPNet/visualizations/20250401_064200_0_linear


In [18]:
topks = visualize_topk(net, projectloader, len(classes), device, run_vis_dir, args, k=10,
					   plot_histograms=True, visualize_prototype_maps=True)
print(f"Visualization saved to {run_vis_dir}")

Visualizing prototypes for topk (CountPIPNet with class-based count selection)...
Using class-to-count mapping: {(1, 3): 1, (4, 6): 2, (7, 9): 3}
Available counts: [1, 2, 3]


Collecting activations:   0%|          | 0/3600 [00:00<?, ?it/s]

Collecting activations:  28%|██▊       | 1000/3600 [00:03<00:08, 295.18it/s]


Detected model type: CountPIPNet with max_count=3
Visualizing trained prototypes
Examples per count group: 10
Pooled shape: torch.Size([16]), min: 0.0, max: 555.0
Feature maps shape: torch.Size([16, 24, 24])
Non-zero pooled values: 4 out of 16
Pooled values (counts): [8.0, 12.0, 0.0, 0.0, 0.0, 0.0, 555.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
Classification weights shape: torch.Size([9, 48]) (expanded)
Max classification weight per prototype: [3.101, 0.0, 0.3846, 8.6222, 0.0, 2.9144, 0.0, 9.6965, 0.1289, 0.0, 7.7373, 4.1335, 2.0769, 2.0048, 2.435, 3.9396]
Prototypes with weight > 1e-3: 12
Prototype indices with weight > 1e-3: [0, 2, 3, 5, 7, 8, 10, 11, 12, 13, 14, 15]


Single-pass processing for class-based count selection: 100% 3600/3600 [00:16<00:00, 223.77it/s]

Abstained: 0
0 prototypes do not have any examples. Will be ignored in visualization.
Creating prototype feature map visualizations with count information...





Visualization saved to /mnt/ssd-1/mechinterp/taras/Count_PIPNet/visualizations/20250401_064200_0_linear


In [None]:
# Count how many prototypes are used by each class
class_prototypes = {}
for c in range(net.module._classification.weight.shape[0]):
	relevant_ps = []
	proto_weights = net.module._classification.weight[c,:]
	for p in range(net.module._classification.weight.shape[1]):
		if proto_weights[p] > 1e-3:
			relevant_ps.append((p, proto_weights[p].item()))
	
	class_name = classes[c] if c < len(classes) else f"Class {c}"
	class_prototypes[class_name] = relevant_ps
	print(f"Class {class_name} has {len(relevant_ps)} relevant prototypes")
	
# Save class-prototype information
class_info_path = os.path.join(run_vis_dir, 'class_prototypes.json')
with open(class_info_path, 'w') as f:
	json.dump(class_prototypes, f, indent=2)