In [None]:
# Spatial Contribution Maps and B-cos Explanations

def to_numpy_img(tensor):
    """Convert tensor to numpy array for matplotlib display"""
    if len(tensor.shape) == 4:
        tensor = tensor.squeeze(0)
    if len(tensor.shape) == 3:
        if tensor.shape[0] <= 3:  # CHW format
            tensor = tensor.permute(1, 2, 0)
    # Handle 6-channel input by taking RGB channels
    if tensor.shape[-1] == 6:
        tensor = tensor[:, :, :3]
    # Ensure valid range for display
    tensor = torch.clamp(tensor, 0, 1)
    return tensor.detach().cpu().numpy()

def extract_prototype_patches(model, dataloader, prototype_idx, num_patches=9):
    """Extract top-activating patches for a prototype"""
    model.eval()
    patches = []
    
    with torch.no_grad():
        for imgs, _ in dataloader:
            imgs = imgs.to(device)
            out = model(imgs)
            sim_maps = out["sim_maps"]
            
            # Get similarities for specific prototype
            proto_sims = sim_maps[:, prototype_idx]  # (B, H, W)
            
            for i in range(proto_sims.shape[0]):
                if len(proto_sims.shape) == 3:
                    max_sim = proto_sims[i].max().item()
                else:
                    max_sim = proto_sims[i].item() if proto_sims[i].numel() == 1 else proto_sims[i].max().item()
                    
                if max_sim > 0.1:  # Threshold
                    patches.append({
                        'image': imgs[i],
                        'similarity': max_sim,
                        'sim_map': proto_sims[i] if len(proto_sims.shape) == 3 else proto_sims[i].unsqueeze(0)
                    })
                    
            if len(patches) >= num_patches:
                break
    
    # Sort by similarity and return top patches
    patches.sort(key=lambda x: x['similarity'], reverse=True)
    return patches[:num_patches]

def visualize_prototype_analysis(model, cfg, test_loader, proto_idx):
    """Analyze and visualize a specific prototype"""
    print(f"Analyzing prototype {proto_idx}...")
    patches = extract_prototype_patches(model, test_loader, proto_idx, 6)
    
    if patches:
        print(f"Found {len(patches)} activating patches")
        print(f"Top similarity: {patches[0]['similarity']:.3f}")
        print(f"Prototype class: {proto_idx // cfg.num_prototypes_per_class}")
    else:
        print("No significant activations found")

class BcosSpatialContributionAnalyzer:
    """
    Analyzer for B-cos spatial contributions integrated with PIP-Net
    """
    
    def __init__(self, model, device='cuda'):
        self.model = model
        self.device = device
        self.model.eval()
        
        # Store intermediate activations
        self.activations = {}
        self.gradients = {}
        self.hooks = []
        
        # Register hooks for B-cos explanations
        self._register_hooks()
    
    def _register_hooks(self):
        """Register forward and backward hooks for B-cos layers"""
        def save_activation(name):
            def hook(module, input, output):
                self.activations[name] = output.detach()
            return hook
        
        def save_gradient(name):
            def hook(module, grad_input, grad_output):
                if grad_output[0] is not None:
                    self.gradients[name] = grad_output[0].detach()
            return hook
        
        # Register hooks for B-cos layers
        for name, module in self.model.named_modules():
            if hasattr(module, 'explanation_mode'):  # B-cos layers
                handle_forward = module.register_forward_hook(save_activation(name))
                handle_backward = module.register_backward_hook(save_gradient(name))
                self.hooks.extend([handle_forward, handle_backward])
    
    def cleanup_hooks(self):
        """Remove all registered hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
    
    def get_bcos_contributions(self, image_tensor, target_class=None):
        """
        Get B-cos spatial contribution maps using gradients
        
        Args:
            image_tensor: Input tensor (1, 3, H, W)
            target_class: Target class for explanation (if None, use predicted class)
        
        Returns:
            contribution_maps: Dictionary of contribution maps from B-cos layers
            predicted_class: Predicted class index
            class_score: Confidence score for predicted class
        """
        # Clear previous activations
        self.activations.clear()
        self.gradients.clear()
        
        # Enable explanation mode for B-cos layers
        self.model.explanation_mode(detach=True)
        
        # Enable gradients
        image_tensor.requires_grad_(True)
        
        # Forward pass
        out = self.model(image_tensor.to(self.device))
        logits = out["logits"]
        
        # Get target class
        if target_class is None:
            target_class = torch.argmax(logits, dim=1)
        else:
            target_class = torch.tensor([target_class]).to(self.device)
        
        predicted_class = target_class.item()
        class_score = torch.softmax(logits, dim=1)[0, predicted_class].item()
        
        # Backward pass
        class_output = logits[0, target_class]
        class_output.backward(retain_graph=True)
        
        # Generate contribution maps
        contribution_maps = {}
        
        for name, activation in self.activations.items():
            if name in self.gradients:
                gradient = self.gradients[name]
                
                # Compute contribution as activation * gradient
                contribution = activation * gradient
                
                # Sum over channels to get spatial contribution
                spatial_contribution = contribution.sum(dim=1, keepdim=True)  # (1, 1, H, W)
                
                contribution_maps[name] = spatial_contribution.squeeze(0).squeeze(0).cpu().numpy()
        
        return contribution_maps, predicted_class, class_score
    
    def get_input_gradient_contribution(self, image_tensor, target_class=None):
        """
        Get input-level gradient contribution map (Input x Gradient)
        """
        # Clear previous activations
        self.activations.clear()
        self.gradients.clear()
        
        # Enable explanation mode
        self.model.explanation_mode(detach=True)
        
        # Enable gradients for input
        image_tensor.requires_grad_(True)
        
        # Forward pass
        out = self.model(image_tensor.to(self.device))
        logits = out["logits"]
        encoded_input = out["encoded_input"]  # 6-channel encoded input
        
        # Get target class
        if target_class is None:
            target_class = torch.argmax(logits, dim=1)
        else:
            target_class = torch.tensor([target_class]).to(self.device)
        
        predicted_class = target_class.item()
        class_score = torch.softmax(logits, dim=1)[0, predicted_class].item()
        
        # Backward pass to get input gradients
        class_output = logits[0, target_class]
        class_output.backward(retain_graph=True)
        
        # Get gradients w.r.t. the 6-channel encoded input
        input_grad = image_tensor.grad  # This will be (1, 3, H, W) for original input
        
        # We need gradients w.r.t. encoded input for B-cos explanation
        encoded_input.requires_grad_(True)
        encoded_input.retain_grad()
        
        # Re-run forward with encoded input requiring gradients
        image_tensor.grad = None  # Clear previous gradients
        out2 = self.model.features(encoded_input)  # Skip encoder, use encoded input directly
        sims_max, locs, sim_maps = self.model.prototype_layer(out2)
        scores = self.model.classifier(sims_max)
        logits2 = torch.log(scores.pow(2) + 1.0)
        
        class_output2 = logits2[0, target_class]
        class_output2.backward()
        
        # Get gradients w.r.t. encoded input
        encoded_grad = encoded_input.grad  # (1, 6, H, W)
        
        if BCOS_UTILS_AVAILABLE:
            # Use official B-cos grad_to_img function
            bcos_explanation = grad_to_img(encoded_input[0], encoded_grad[0])
        else:
            # Fallback implementation
            bcos_explanation = self._grad_to_img_fallback(encoded_input[0], encoded_grad[0])
        
        # Input * Gradient contribution
        input_contrib = (image_tensor[0] * input_grad[0]).sum(0).cpu().numpy()  # Sum over RGB channels
        
        return {
            'input_contribution': input_contrib,
            'bcos_explanation': bcos_explanation,
            'predicted_class': predicted_class,
            'class_score': class_score,
            'input_gradient': input_grad[0].cpu().numpy(),
            'encoded_gradient': encoded_grad[0].cpu().numpy()
        }
    
    def _grad_to_img_fallback(self, img_6ch, linear_mapping, smooth=3, alpha_percentile=99.5):
        """Fallback B-cos grad_to_img implementation"""
        # Ensure tensors are on CPU and detached
        if hasattr(img_6ch, 'detach'):
            img_6ch = img_6ch.detach().cpu()
        if hasattr(linear_mapping, 'detach'):
            linear_mapping = linear_mapping.detach().cpu()
        
        # Compute contributions
        contribs = (img_6ch * linear_mapping).sum(0, keepdim=True)[0]
        
        # Normalize gradient
        rgb_grad = (linear_mapping / (linear_mapping.abs().max(0, keepdim=True)[0] + 1e-12))
        rgb_grad = rgb_grad.clamp(0)
        rgb_grad = rgb_grad[:3] / (rgb_grad[:3] + rgb_grad[3:] + 1e-12)
        
        # Set alpha
        alpha = linear_mapping.norm(p=2, dim=0, keepdim=True)
        alpha = torch.where(contribs[None] < 0, torch.zeros_like(alpha) + 1e-12, alpha)
        
        if smooth > 1:
            alpha = F.avg_pool2d(alpha, smooth, stride=1, padding=(smooth-1)//2)
        
        alpha = alpha.numpy()
        alpha = (alpha / np.percentile(alpha, alpha_percentile)).clip(0, 1)
        rgb_grad = rgb_grad.numpy()
        
        rgba_grad = np.concatenate([rgb_grad, alpha], axis=0)
        return rgba_grad.transpose((1, 2, 0))
    
    def visualize_comprehensive_explanation(self, image_tensor, target_class=None, 
                                          figsize=(20, 12), class_names=None):
        """
        Create comprehensive visualization with:
        1. Input image
        2. Prototype activations
        3. B-cos spatial contributions
        4. Input-level explanations
        """
        
        fig = plt.figure(figsize=figsize)
        gs = GridSpec(3, 4, figure=fig, hspace=0.3, wspace=0.3)
        
        # Original image
        ax_orig = fig.add_subplot(gs[0, 0])
        orig_img = to_numpy_img(image_tensor)
        ax_orig.imshow(orig_img)
        ax_orig.set_title('Original Image', fontweight='bold')
        ax_orig.axis('off')
        
        # Get B-cos contributions
        contributions, pred_class, class_score = self.get_bcos_contributions(image_tensor, target_class)
        
        # Get input-level explanation
        input_explanation = self.get_input_gradient_contribution(image_tensor, target_class)
        
        # Prediction info
        ax_pred = fig.add_subplot(gs[0, 1])
        class_name = class_names[pred_class] if class_names and pred_class < len(class_names) else f"Class {pred_class}"
        ax_pred.text(0.1, 0.7, f'Prediction:', fontsize=12, fontweight='bold')
        ax_pred.text(0.1, 0.5, class_name, fontsize=14, color='red')
        ax_pred.text(0.1, 0.3, f'Confidence: {class_score:.3f}', fontsize=12)
        ax_pred.axis('off')
        
        # Input * Gradient contribution
        ax_input = fig.add_subplot(gs[0, 2])
        input_contrib = input_explanation['input_contribution']
        ax_input.imshow(orig_img, alpha=0.7)
        vmax = np.abs(input_contrib).max()
        im1 = ax_input.imshow(input_contrib, cmap='RdBu_r', alpha=0.8, vmin=-vmax, vmax=vmax)
        ax_input.set_title('Input × Gradient')
        ax_input.axis('off')
        plt.colorbar(im1, ax=ax_input, fraction=0.046, pad=0.04)
        
        # B-cos explanation (RGBA)
        ax_bcos = fig.add_subplot(gs[0, 3])
        bcos_rgba = input_explanation['bcos_explanation']
        ax_bcos.imshow(bcos_rgba)
        ax_bcos.set_title('B-cos Explanation (RGBA)')
        ax_bcos.axis('off')
        
        # Layer-wise B-cos contributions
        layer_names = list(contributions.keys())[:6]  # Show top 6 layers
        for idx, layer_name in enumerate(layer_names):
            row = 1 + idx // 3
            col = idx % 3
            
            if row < 3:  # Ensure we don't exceed grid
                ax = fig.add_subplot(gs[row, col])
                
                contrib_map = contributions[layer_name]
                # Resize to match image
                contrib_resized = cv2.resize(contrib_map, (orig_img.shape[1], orig_img.shape[0]))
                
                ax.imshow(orig_img, alpha=0.6)
                vmax = np.abs(contrib_resized).max()
                im = ax.imshow(contrib_resized, cmap='RdBu_r', alpha=0.8, vmin=-vmax, vmax=vmax)
                
                display_name = layer_name.split('.')[-1] if '.' in layer_name else layer_name
                ax.set_title(f'{display_name}', fontsize=10)
                ax.axis('off')
                plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        
        plt.suptitle('B-cos PIP-Net: Comprehensive Spatial Contribution Analysis', 
                     fontsize=16, fontweight='bold')
        
        return {
            'contributions': contributions,
            'input_explanation': input_explanation,
            'predicted_class': pred_class,
            'class_score': class_score
        }

# Usage example function
def analyze_single_image_comprehensive(model, image_tensor, cfg, class_names=None):
    """
    Comprehensive analysis of a single image showing all interpretability aspects
    """
    print("🔍 Comprehensive B-cos PIP-Net Analysis")
    print("=" * 50)
    
    # Initialize analyzer
    analyzer = BcosSpatialContributionAnalyzer(model, device)
    
    try:
        # Run comprehensive visualization
        results = analyzer.visualize_comprehensive_explanation(
            image_tensor, 
            figsize=(20, 15),
            class_names=class_names
        )
        
        # Print summary
        pred_class = results['predicted_class']
        class_score = results['class_score']
        
        print(f"Prediction: Class {pred_class} (confidence: {class_score:.3f})")
        print(f"Number of B-cos layers analyzed: {len(results['contributions'])}")
        print(f"B-cos explanation shape: {results['input_explanation']['bcos_explanation'].shape}")
        
        plt.show()
        
        return results
        
    finally:
        # Always cleanup hooks
        analyzer.cleanup_hooks()

print("✅ Spatial contribution maps and B-cos explanations ready!")

B-cos PiP-Net (improved, no ReLU/BN/MaxPool): LA+LT pretraining + supervised.
Downsampling ONLY via strided B-cos convs. Explanations stay compact via. Non-negative scoring-sheet.



In [None]:
import os, time, math, json, random
from datetime import timedelta
from collections import OrderedDict
from typing import Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader, random_split, Dataset
from tqdm import tqdm

# Visualization imports
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.gridspec import GridSpec
import cv2
from PIL import Image

# Add PIPNet and visualization modules to path
import sys
sys.path.append('./PIPNet')
sys.path.append('./src')
from pipnet.pipnet import NonNegLinear

# B-cos interpretability utilities
sys.path.append('./B-cos')
try:
    from interpretability.utils import grad_to_img, plot_contribution_map, explanation_mode
    from project_utils import to_numpy
    BCOS_UTILS_AVAILABLE = True
except ImportError:
    print("B-cos interpretability utils not available, using fallback implementations")
    BCOS_UTILS_AVAILABLE = False

#Utils

In [47]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
seed_everything(42)
torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"CUDA: {torch.cuda.is_available()} | Device: {device}")
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

CUDA: True | Device: cuda
NVIDIA A100-SXM4-40GB


#B-cos modules

In [None]:
# Import proper B-cos modules
import sys
import numpy as np
sys.path.append('./B-cos')

from modules.bcosconv2d import BcosConv2d as OriginalBcosConv2d
from data.data_transforms import AddInverse

class BcosInputEncoder(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.max() > 1.1: x = x / 255.0
        x = x.clamp(0.0, 1.0)
        return torch.cat([x, 1.0 - x], dim=1)  # [r,g,b,1-r,1-g,1-b]

class BcosConv2d(OriginalBcosConv2d):
    """
    Proper B-cos conv using original implementation from /B-cos
    Wrapper to match our interface while using correct B-cos math
    """
    def __init__(self, inc, outc, k=3, s=1, p=1, bias=False, B=2, max_out=1, **kwargs):
        # Convert our interface to original B-cos interface
        super().__init__(
            inc=inc, 
            outc=outc, 
            kernel_size=k, 
            stride=s, 
            padding=p, 
            b=B,
            max_out=max_out,
            **kwargs
        )
        # Note: Original B-cos has built-in max_out, bias is always False

class BcosLinear(nn.Module):
    """
    B-cos linear layer following the same principles as BcosConv2d
    """
    def __init__(self, in_f, out_f, bias=True, B=2):
        super().__init__()
        from modules.bcosconv2d import NormedConv2d
        self.weight = nn.Parameter(torch.randn(out_f, in_f) * 0.1)
        self.bias = nn.Parameter(torch.zeros(out_f)) if bias else None
        self.B = float(B)
        self.detach = False
        
        # Proper scaling following B-cos principles
        self.scale = (np.sqrt(in_f)) / 100  # Similar to BcosConv2d scale calculation
    
    def explanation_mode(self, detach=True):
        """Enable explanation mode like in BcosConv2d"""
        self.detach = detach
    
    def forward(self, x):
        # Proper B-cos linear computation following the same pattern as conv
        
        # Normalize weights
        w_hat = self.weight / (self.weight.norm(p=2, dim=1, keepdim=True) + 1e-8)
        z_lin = F.linear(x, w_hat, self.bias)
        
        if self.B == 1:
            return z_lin / self.scale
        
        # Compute input norm
        x_norm = x.norm(p=2, dim=1, keepdim=True) + 1e-6
        
        # Cosine computation
        cos_sim = z_lin / x_norm
        abs_cos = cos_sim.abs() + 1e-6
        
        # Apply detaching for explanation mode
        if self.detach:
            abs_cos = abs_cos.detach()
        
        # B-cos transformation: multiply by |cos|^(B-1) to get cos^B effect
        out = z_lin * abs_cos.pow(self.B - 1)
        return out / self.scale

#Prototype layer

In [49]:
class PrototypeLayer(nn.Module):
    def __init__(self, num_prototypes: int, prototype_shape: Tuple[int,int,int], eps=1e-6):
        super().__init__()
        C,kH,kW = prototype_shape
        self.prototypes = nn.Parameter(torch.randn(num_prototypes, C, kH, kW) * 0.1)
        self.eps = eps
    @property
    def P(self): return self.prototypes.size(0)
    def _norm_protos(self):
        p = self.prototypes.view(self.P, -1)
        p = p / (p.norm(p=2, dim=1, keepdim=True) + self.eps)
        return p.view_as(self.prototypes)
    def forward(self, x):
        x_norm = F.normalize(x, p=2, dim=1, eps=self.eps)
        p_norm = self._norm_protos()
        sim_maps = F.conv2d(x_norm, p_norm, bias=None, stride=1, padding=0)  # (B,P,Hout,Wout)
        Bsz,P,H,W = sim_maps.shape
        maxv, idx = sim_maps.view(Bsz,P,-1).max(dim=2)
        row, col = (idx // W).float(), (idx % W).float()
        loc = torch.stack([row, col], dim=2)
        return maxv, loc, sim_maps

#Model

In [None]:
class BcosPiPNet(nn.Module):
    """
    No ReLU/BN/MaxPool. Downsampling via strided B-cos convs only.
    Uses proper B-cos implementation with built-in max_out.
    """
    def __init__(self, num_classes, num_prototypes_per_class=15,
                 prototype_shape=(128,1,1), base_channels=64, B=2,
                 patchify_stem=True):
        super().__init__()
        self.num_classes = num_classes
        self.num_prototypes_per_class = num_prototypes_per_class
        self.num_prototypes = num_classes * num_prototypes_per_class
        self.prototype_shape = prototype_shape

        self.input_encoder = BcosInputEncoder()
        in_ch = 6
        Cproto = prototype_shape[0]

        feats = []
        if patchify_stem:
            feats += [("stem", BcosConv2d(in_ch, base_channels, k=4, s=4, p=0, B=B))]   # /4
        else:
            feats += [("conv0", BcosConv2d(in_ch, base_channels, k=3, s=1, p=1, B=B))]

        # Use proper B-cos with built-in max_out (max_out=2 means 2x output channels, then max)
        feats += [
            ("conv1", BcosConv2d(base_channels, base_channels, k=3, s=1, p=1, B=B, max_out=2)),  # Built-in MaxOut
            ("down1", BcosConv2d(base_channels, base_channels*2, k=3, s=2, p=1, B=B, max_out=2)),  # /2, built-in MaxOut
            
            ("conv2", BcosConv2d(base_channels, base_channels*2, k=3, s=1, p=1, B=B, max_out=2)),  # Built-in MaxOut
            ("down2", BcosConv2d(base_channels, base_channels*2, k=3, s=2, p=1, B=B, max_out=2)),  # /2, built-in MaxOut
            
            ("conv3", BcosConv2d(base_channels, Cproto, k=3, s=1, p=1, B=B, max_out=2)),  # Built-in MaxOut
            ("conv4", BcosConv2d(Cproto//2, Cproto, k=3, s=1, p=1, B=B)),  # No MaxOut for final layer
        ]

        self.features = nn.Sequential(OrderedDict(feats))

        self.prototype_layer = PrototypeLayer(self.num_prototypes, prototype_shape)
        self.classifier = NonNegLinear(self.num_prototypes, num_classes)
        self._init_classifier_bias()

    def _init_classifier_bias(self):
        with torch.no_grad():
            W = torch.zeros(self.num_prototypes, self.num_classes)
            for c in range(self.num_classes):
                s, e = c*self.num_prototypes_per_class, (c+1)*self.num_prototypes_per_class
                W[s:e, c] = 1.0
            W += 0.05 * torch.randn_like(W)
            # inverse softplus approx to seed non-negative weights
            self.classifier._w.copy_(torch.log(torch.expm1(W.clamp_min(1e-4))))

    def explanation_mode(self, detach=True):
        """Enable explanation mode for all B-cos layers"""
        for module in self.modules():
            if hasattr(module, 'explanation_mode'):
                module.explanation_mode(detach)

    def forward(self, x):
        enc = self.input_encoder(x)
        feats = self.features(enc)
        sims_max, locs, sim_maps = self.prototype_layer(feats)       # sims_max: (B,P), sim_maps: (B,P,H,W)

        scores = self.classifier(sims_max)                           # (B,C)
        logits = torch.log(scores.pow(2) + 1.0)                      # training logits (keep abstention & compactness)
        return {"logits": logits, "similarities": sims_max, "sim_maps": sim_maps,
                "locations": locs, "features": feats, "encoded_input": enc}

    def get_prototype_activations(self, x, return_features=False):
        """
        Get prototype activations for contrastive learning.
        Returns similarity maps and pooled similarities.
        """
        enc = self.input_encoder(x)
        feats = self.features(enc)
        sims_max, locs, sim_maps = self.prototype_layer(feats)
        
        if return_features:
            return sim_maps, sims_max, feats
        return sim_maps, sims_max

#Losses

In [None]:
class ContrastiveLoss(nn.Module):
    """
    Contrastive learning loss combining Alignment Loss (LA) and Tanh Loss (LT)
    """
    def __init__(self, align_weight=1.0, tanh_weight=1.0):
        super().__init__()
        self.align_weight = align_weight
        self.tanh_weight = tanh_weight
    
    def forward(self, sim_maps1, sim_maps2):
        """
        Args:
            sim_maps1, sim_maps2: (B,P,H,W) similarity maps from two augmented views
        """
        # Convert similarity maps to prototype presence
        z1, p1 = proto_presence_from_sim_maps(sim_maps1)
        z2, p2 = proto_presence_from_sim_maps(sim_maps2)
        
        # Alignment loss: encourage consistency between views
        LA = loss_LA(z1, z2)
        
        # Tanh diversity loss: encourage diverse prototype usage
        LT = loss_LT(torch.cat([p1, p2], dim=0))
        
        total_loss = self.align_weight * LA + self.tanh_weight * LT
        
        return {
            'total_loss': total_loss,
            'align_loss': LA,
            'tanh_loss': LT
        }

class BcosPiPNetLoss(nn.Module):
    """
    B-cos compatible loss with BCE, temperature scaling, and logit bias
    Following official B-cos implementation patterns
    """
    def __init__(self, num_classes, class_weight=1.0, cluster_weight=0.6, separation_weight=0.05, l1_weight=2e-4, 
                 use_bce=True, logit_bias=None, logit_temperature=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.class_weight, self.cluster_weight = class_weight, cluster_weight
        self.separation_weight, self.l1_weight = separation_weight, l1_weight
        self.use_bce = use_bce
        self.logit_bias = logit_bias if logit_bias is not None else np.log(0.1/0.9)  # B-cos default
        self.logit_temperature = logit_temperature
        
        # Loss functions following B-cos implementation
        if use_bce:
            self.classification_loss = nn.BCEWithLogitsLoss(reduction='mean')
        else:
            self.classification_loss = nn.CrossEntropyLoss(label_smoothing=0.05)

    def forward(self, out, targets, model):
        logits, sims = out["logits"], out["similarities"]
        
        # Apply B-cos temperature scaling and bias (following B-cos FinalLayer)
        processed_logits = logits / self.logit_temperature + self.logit_bias
        
        # Classification loss (B-cos style)
        if self.use_bce:
            # Convert targets to one-hot for BCE (B-cos requirement)
            targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float()
            classification_loss = self.classification_loss(processed_logits, targets_one_hot)
        else:
            classification_loss = self.classification_loss(processed_logits, targets)

        # cluster: raise sims for target-class prototypes
        cluster = 0.0
        C = model.num_classes; Ppc = model.num_prototypes_per_class
        for c in range(C):
            mask = (targets == c)
            if mask.any():
                s, e = c*Ppc, (c+1)*Ppc
                cluster += (1.0 - sims[mask, s:e]).mean()
        cluster = cluster / C

        # separation: lower sims for other-class prototypes
        separation, count = 0.0, 0
        for c in range(C):
            mask = (targets != c)
            if mask.any():
                s, e = c*Ppc, (c+1)*Ppc
                separation += F.relu(sims[mask, s:e] - 0.1).mean()
                count += 1
        separation = separation / max(1, count)

        l1 = model.classifier.weight.sum()  # non-negative
        total = (self.class_weight*classification_loss + self.cluster_weight*cluster +
                 self.separation_weight*separation + self.l1_weight*l1)
        return {"total_loss": total, "classification_loss": classification_loss, "cluster_loss": cluster,
                "separation_loss": separation, "l1_loss": l1, "processed_logits": processed_logits}

#PIP-Net

In [52]:
def proto_presence_from_sim_maps(sim_maps: torch.Tensor) -> torch.Tensor:
    """
    sim_maps: (B,P,H,W). Convert to per-location soft assignment z (B,H,W,P),
    then pool presence vector p in [0,1]^P via max over spatial.
    """
    Bsz,P,H,W = sim_maps.shape
    z = F.softmax(sim_maps.permute(0,2,3,1), dim=-1)   # (B,H,W,P)
    p = z.amax(dim=(1,2))                              # (B,P)
    return z, p

def loss_LA(z1: torch.Tensor, z2: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """Patch alignment: -mean log( dot(z1,z2) ). z*: (B,H,W,P) with sum_P=1."""
    dot = (z1 * z2).sum(dim=-1).clamp_min(eps)         # (B,H,W)
    return -torch.log(dot).mean()

def loss_LT(p_batch: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """tanh diversity over batch presence p: (B,P)."""
    sum_b = p_batch.sum(dim=0)                         # (P,)
    return -(torch.log(torch.tanh(sum_b) + eps)).mean()

#Metric utils

In [53]:
!pip -q install scikit-learn tqdm

import numpy as np, sklearn.metrics as skm
from tqdm import tqdm

def evaluate_metrics(net,
                     dl_in,
                     cfg,
                     dl_ood=None,
                     thr_sim=0.05,
                     device=device):
    """
    net      : BcosPiPNet
    dl_in    : DataLoader (ID validation/test split)
    cfg      : your CFG object (only Ppc is needed)
    dl_ood   : DataLoader with *unlabelled* OoD images (optional)
    thr_sim  : similarity threshold for sparsity
    returns  : dict with metrics
    """
    net.eval()
    correct = total = 0
    purity_ok = purity_tot = 0
    sparsity_list = []
    conf_in = []

    with torch.no_grad():
        for imgs, labels in tqdm(dl_in, desc="Eval ID", leave=False):
            imgs, labels = imgs.to(device), labels.to(device)
            out   = net(imgs)
            logits, sims = out["logits"], out["similarities"]

            # ---------- accuracy ----------
            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total   += labels.size(0)

            # ---------- prototype purity ----------
            top_p   = sims.argmax(1)                      # (B,)
            proto_cls = (top_p // cfg.num_prototypes_per_class).cpu()
            purity_ok += (proto_cls == labels.cpu()).sum().item()
            purity_tot += labels.size(0)

            # ---------- sparsity ----------
            sparsity_list.extend((sims > thr_sim).sum(1).cpu().tolist())

            # ---------- confidence (for FPR95) ----------
            conf_in.extend(logits.softmax(1).max(1).values.cpu().tolist())

    acc      = 100.0 * correct / total
    purity   = 100.0 * purity_ok / purity_tot
    sparsity = float(np.mean(sparsity_list))

    # ---------- OoD FPR95 ----------
    fpr95 = None
    if dl_ood is not None:
        conf_ood = []
        with torch.no_grad():
            for imgs,_ in tqdm(dl_ood, desc="Eval OoD", leave=False):
                imgs = imgs.to(device)
                logits = net(imgs)["logits"]
                conf_ood.extend(logits.softmax(1).max(1).values.cpu().tolist())

        # higher confidence ⇒ more ID-like; we invert sign for ROC so
        # lower score = ID, higher = OoD
        scores = -np.concatenate([conf_in, conf_ood])
        labels = np.concatenate([np.ones(len(conf_in)), np.zeros(len(conf_ood))])
        fpr, tpr, _ = skm.roc_curve(labels, scores)
        try:
            idx = np.where(tpr >= 0.95)[0][0]
            fpr95 = 100.0 * fpr[idx]
        except IndexError:
            fpr95 = 100.0       # tpr never reached 95 %

    metrics = {"Accuracy (%)"     : acc,
               "Purity (%)"       : purity,
               "Sparsity (avg #)" : sparsity}
    if fpr95 is not None:
        metrics["FPR95 (%)"] = fpr95
    return metrics

#Data

In [None]:
class Config:
    dataset = "CIFAR10"   # or 'CIFAR100'
    data_root = "./data"
    batch_size = 1024
    num_workers = max(2, (os.cpu_count() or 4)//2)
    image_size = 160

    # Model
    num_classes = 10
    num_prototypes_per_class = 10
    prototype_shape = (128,1,1)
    base_channels = 64
    B = 2.5
    patchify_stem = True

    # Training
    epochs_pretrain = 16    # Contrastive learning only (classifier frozen)
    epochs_supervised = 100  # CE + regs + light contrastive regularization
    lr = 1e-3
    wd = 1e-4

    # B-cos specific parameters (following official implementation)
    use_bce = True                      # Use BCE instead of CrossEntropy
    logit_bias = np.log(0.1/0.9)        # B-cos default bias
    logit_temperature = 1.0             # Temperature scaling
    
    # Loss weights
    class_weight = 1.0
    cluster_weight = 0.6
    separation_weight = 0.05
    l1_weight = 2e-4
    lambda_A_pre = 1.0      # Align loss weight for pretraining
    lambda_T_pre = 0.5      # Tanh loss weight for pretraining  
    lambda_A_sup = 0.2      # Align loss weight for supervised (regularization)
    lambda_T_sup = 0.1      # Tanh loss weight for supervised (regularization)

    # Saving
    save_dir = "./checkpoints"
    save_frequency = 10

#Training

In [None]:
def pretrain_prototypes(model, pre_loader, contrastive_crit, cfg: Config):
    """
    Freeze classifier; optimize features + prototypes with contrastive learning (LA + LT).
    """
    for p in model.classifier.parameters(): p.requires_grad = False
    opt = optim.Adam(list(model.features.parameters()) + list(model.prototype_layer.parameters()),
                     lr=cfg.lr, weight_decay=cfg.wd)

    print(f"pretraining prototypes: {cfg.epochs_pretrain} epochs (Contrastive Learning: LA+LT)")
    for ep in range(cfg.epochs_pretrain):
        model.train()
        ep_loss = ep_LA = ep_LT = 0.0
        pbar = tqdm(pre_loader, desc=f"Pretrain {ep+1}/{cfg.epochs_pretrain}", ncols=120, leave=False)
        for (x1, x2), _ in pbar:
            x1, x2 = x1.to(device), x2.to(device)
            opt.zero_grad()
            
            # Get prototype activations for both views
            sim_maps1, _ = model.get_prototype_activations(x1)
            sim_maps2, _ = model.get_prototype_activations(x2)
            
            # Compute contrastive loss
            loss_dict = contrastive_crit(sim_maps1, sim_maps2)
            loss = loss_dict['total_loss']
            
            loss.backward()
            opt.step()

            ep_loss += loss.item()
            ep_LA += loss_dict['align_loss'].item()
            ep_LT += loss_dict['tanh_loss'].item()
            pbar.set_postfix({"Loss": f"{loss.item():.3f}", 
                              "LA": f"{loss_dict['align_loss'].item():.3f}", 
                              "LT": f"{loss_dict['tanh_loss'].item():.3f}"})
        
        print(f"pretrain epoch {ep+1}: Loss {ep_loss/len(pre_loader):.3f} | "
              f"LA {ep_LA/len(pre_loader):.3f} | LT {ep_LT/len(pre_loader):.3f}")
    
    # unfreeze classifier for supervised training
    for p in model.classifier.parameters(): p.requires_grad = True

def train_supervised(model,
                     train_loader,
                     val_loader,
                     test_loader,
                     supervised_crit,
                     opt,
                     sch,
                     cfg,
                     ood_loader=None):
    """
    Supervised phase with metric tracking.
    Uses contrastive learning regularization instead of prototype pushing.
    Updated for B-cos BCE loss.
    """
    os.makedirs(cfg.save_dir, exist_ok=True)

    history = {
        "train_loss": [], "train_acc": [],
        "val_loss":   [], "val_acc":   [],
        "purity":     [], "sparsity":  [], "fpr95": [],
        "lr": [], "epoch_times": []
    }
    best_val = 0.0

    print(f"supervised training for {cfg.epochs_supervised} epochs...")
    for ep in range(cfg.epochs_supervised):
        t0 = time.time()
        # -------------------- TRAIN --------------------
        model.train()
        tr_loss = tr_correct = tr_total = 0

        pbar = tqdm(train_loader,
                    desc=f"train {ep+1}/{cfg.epochs_supervised}",
                    ncols=120, leave=False)
        for imgs, targets in pbar:
            imgs, targets = imgs.to(device), targets.to(device)
            opt.zero_grad()
            out = model(imgs)

            # core PiP-Net losses (updated for B-cos BCE)
            losses = supervised_crit(out, targets, model)

            # lightweight contrastive regularization to keep prototype learning stable
            sm  = out["sim_maps"]
            sm2 = torch.roll(sm, shifts=1, dims=-1)               # shifted view for regularization
            z1, _ = proto_presence_from_sim_maps(sm)
            z2, _ = proto_presence_from_sim_maps(sm2)
            LA_s  = loss_LA(z1, z2)
            _, p  = proto_presence_from_sim_maps(sm)
            LT_s  = loss_LT(p)

            extra = cfg.lambda_A_sup*LA_s + cfg.lambda_T_sup*LT_s
            loss  = losses["total_loss"] + extra

            loss.backward()
            opt.step()

            tr_loss   += loss.item()
            tr_correct += (out["logits"].argmax(1) == targets).sum().item()
            tr_total  += targets.size(0)
            pbar.set_postfix({"Loss": f"{loss.item():.4f}",
                              "Acc":  f"{100.0*tr_correct/tr_total:.2f}%"})

        # -------------------- VALIDATION --------------------
        model.eval()
        val_loss = val_correct = val_total = 0
        with torch.no_grad():
            for imgs, targets in val_loader:
                imgs, targets = imgs.to(device), targets.to(device)
                out   = model(imgs)
                loss_dict = supervised_crit(out, targets, model)
                val_loss += loss_dict["total_loss"].item()
                val_correct += (out["logits"].argmax(1) == targets).sum().item()
                val_total   += targets.size(0)

        # -------------------- PER-EPOCH METRICS --------------------
        metrics_val = evaluate_metrics(model, val_loader, cfg, dl_ood=ood_loader)
        print(f"val metrics: {metrics_val}")

        epoch_time = time.time() - t0
        train_loss_avg = tr_loss / len(train_loader)
        train_acc      = 100. * tr_correct / tr_total
        val_loss_avg   = val_loss / len(val_loader)
        val_acc        = 100. * val_correct / val_total

        # ---------- history ----------
        history["train_loss"].append(train_loss_avg)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss_avg)
        history["val_acc"].append(val_acc)
        history["purity"].append(metrics_val["Purity (%)"])
        history["sparsity"].append(metrics_val["Sparsity (avg #)"])
        history["fpr95"].append(metrics_val.get("FPR95 (%)", None))
        history["lr"].append(opt.param_groups[0]["lr"])
        history["epoch_times"].append(epoch_time)

        print(f"train {train_loss_avg:.4f}/{train_acc:.2f}% | "
              f"val {val_loss_avg:.4f}/{val_acc:.2f}% | "
              f"LR {opt.param_groups[0]['lr']:.3g} | "
              f"time {timedelta(seconds=int(epoch_time))}")

        # -------------------- CHECKPOINTS --------------------
        if val_acc > best_val:
            best_val = val_acc
            torch.save({
                "epoch": ep + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": opt.state_dict(),
                "val_acc": val_acc,
                "config": cfg.__dict__,
                "history": history
            }, os.path.join(cfg.save_dir, "best_model.pth"))
            print(f"new best model saved! Val Acc: {val_acc:.2f}%")

        if (ep + 1) % cfg.save_frequency == 0:
            torch.save({
                "epoch": ep + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": opt.state_dict(),
                "history": history,
                "config": cfg.__dict__,
            }, os.path.join(cfg.save_dir, f"checkpoint_epoch_{ep+1}.pth"))
            print(f"checkpoint saved at epoch {ep+1}")

        sch.step()

    # -------------------- TEST / OoD METRICS --------------------
    metrics_test = evaluate_metrics(model, test_loader, cfg, dl_ood=ood_loader)
    print("test / OoD metrics:", metrics_test)

    print(f"done. Best Val Acc: {best_val:.2f}%")
    return model, history

# Two-view transform with shared geometry, different color jitter
class TwoViewSharedGeom:
    def __init__(self, image_size, max_rot=10, flip_p=0.5,
                 cj_s1=0.2, cj_s2=0.4):
        self.image_size = image_size
        self.max_rot = max_rot
        self.flip_p = flip_p
        self.cj1 = T.ColorJitter(brightness=cj_s1, contrast=cj_s1, saturation=cj_s1, hue=0.05)
        self.cj2 = T.ColorJitter(brightness=cj_s2, contrast=cj_s2, saturation=cj_s2, hue=0.1)
    def _geom(self, img, flip, angle):
        img = TF.resize(img, [self.image_size, self.image_size], antialias=True)
        if flip: img = TF.hflip(img)
        if abs(angle) > 1e-3: img = TF.rotate(img, angle)
        return img
    def __call__(self, img):
        flip = random.random() < self.flip_p
        angle = random.uniform(-self.max_rot, self.max_rot)
        # shared geometry
        i1 = self._geom(img, flip, angle)
        i2 = self._geom(img, flip, angle)
        # distinct colors
        i1 = self.cj1(i1); i2 = self.cj2(i2)
        return TF.to_tensor(i1), TF.to_tensor(i2)

def get_dataloaders(cfg: Config):
    # Supervised train/val/test transforms
    train_tf = T.Compose([
        T.Resize((cfg.image_size, cfg.image_size)),
        T.RandomHorizontalFlip(0.5),
        T.RandomRotation(10),
        T.ColorJitter(0.2,0.2,0.2,0.1),
        T.ToTensor(),
    ])
    test_tf = T.Compose([T.Resize((cfg.image_size, cfg.image_size)), T.ToTensor()])

    if cfg.dataset.upper() == "CIFAR10":
        train_set = torchvision.datasets.CIFAR10(cfg.data_root, train=True, download=True, transform=train_tf)
        test_set  = torchvision.datasets.CIFAR10(cfg.data_root, train=False, download=True, transform=test_tf)
        cfg.num_classes = 10
    else:
        train_set = torchvision.datasets.CIFAR100(cfg.data_root, train=True, download=True, transform=train_tf)
        test_set  = torchvision.datasets.CIFAR100(cfg.data_root, train=False, download=True, transform=test_tf)
        cfg.num_classes = 100

    # Split
    n_train = int(0.8 * len(train_set))
    n_val = len(train_set) - n_train
    train_sub, val_sub = random_split(train_set, [n_train, n_val], generator=torch.Generator().manual_seed(42))

    train_loader = DataLoader(train_sub, batch_size=cfg.batch_size, shuffle=True,
                              num_workers=cfg.num_workers, pin_memory=True)
    val_loader   = DataLoader(val_sub,   batch_size=cfg.batch_size, shuffle=False,
                              num_workers=cfg.num_workers, pin_memory=True)
    test_loader  = DataLoader(test_set,  batch_size=cfg.batch_size, shuffle=False,
                              num_workers=cfg.num_workers, pin_memory=True)

    # Two-view dataloader for contrastive pretraining (same images, different transform)
    two_view = TwoViewSharedGeom(cfg.image_size)
    if cfg.dataset.upper() == "CIFAR10":
        pre_set = torchvision.datasets.CIFAR10(cfg.data_root, train=True, download=False, transform=two_view)
    else:
        pre_set = torchvision.datasets.CIFAR100(cfg.data_root, train=True, download=False, transform=two_view)

    pre_loader = DataLoader(pre_set, batch_size=cfg.batch_size, shuffle=True,
                            num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
    return train_loader, val_loader, test_loader, pre_loader

#Main

In [None]:
def main():
    cfg = Config()
    train_loader, val_loader, test_loader, pre_loader = get_dataloaders(cfg)
    print(f"data ready. train:{len(train_loader.dataset)} Val:{len(val_loader.dataset)} Test:{len(test_loader.dataset)}")

    model, supervised_crit, contrastive_crit, opt, sch = setup(cfg)
    print("phase 1/2: contrastive pretraining (LA+LT)")
    pretrain_prototypes(model, pre_loader, contrastive_crit, cfg)

    print("phase 2/2: supervised training with contrastive regularization")
    model, history = train_supervised(model, train_loader, val_loader, test_loader, supervised_crit, opt, sch, cfg)

    return model, history, cfg

if __name__ == "__main__":
    model, history, config = main()

#load trained model

In [None]:
# Run Comprehensive Prototype Analysis

# If you have a trained model, run this comprehensive analysis
try:
    if 'model' in locals() and model is not None:
        print("🚀 Starting comprehensive prototype analysis...")
        run_comprehensive_prototype_analysis(model, cfg, test_loader)
        
        # Also demonstrate individual functions
        print(f"\n\n🎯 QUICK PROTOTYPE EXAMPLES")
        print("="*50)
        
        # Show how to analyze a specific prototype
        print("Example 1: Analyzing prototype 5...")
        visualize_prototype_analysis(model, cfg, test_loader, 5)
        
        print("Example 2: Showing patch grid for prototype 15...")
        visualize_prototype_patches_grid(model, cfg, test_loader, 15, num_patches=6)
        
    else:
        print("⚠️  No trained model loaded. Please run the training first or load a checkpoint.")
        print("💡 You can still examine the visualization functions!")
        
except NameError:
    print("⚠️  Model not found. Please load a trained model first.")
    print("💡 Example usage:")
    print("""
    # Load your model
    model = BcosPiPNet(...)
    checkpoint = torch.load('path/to/checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Run analysis
    run_comprehensive_prototype_analysis(model, cfg, test_loader)
    """)

# Additional Prototype Visualization Functions

def visualize_prototype_patches_grid(model, cfg, test_loader, prototype_idx, num_patches=9):
    """
    Show a grid of top-activating patches for a specific prototype
    """
    print(f"Finding top {num_patches} patches for prototype {prototype_idx}...")
    
    # Extract top patches for this prototype
    patches = extract_prototype_patches(model, test_loader, prototype_idx, num_patches)
    
    if not patches:
        print(f"No activating patches found for prototype {prototype_idx}")
        return None
    
    # Create grid visualization
    cols = 3
    rows = (len(patches) + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(cols*4, rows*4))
    if rows == 1:
        axes = axes.reshape(1, -1)
    
    # Get prototype info
    prototype = model.prototype_layer.prototypes[prototype_idx]
    proto_class = prototype_idx // cfg.num_prototypes_per_class
    
    fig.suptitle(f'Prototype {prototype_idx} (Class {proto_class}) - Top Activating Patches', 
                 fontsize=16, fontweight='bold')
    
    for idx, patch_info in enumerate(patches):
        row = idx // cols
        col = idx % cols
        ax = axes[row, col]
        
        # Show the image patch
        img = patch_info['image']
        ax.imshow(to_numpy_img(img))
        ax.set_title(f'Similarity: {patch_info["similarity"]:.3f}')
        ax.axis('off')
    
    # Hide unused subplots
    for idx in range(len(patches), rows * cols):
        row = idx // cols
        col = idx % cols
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return fig

def visualize_class_prototypes_summary(model, cfg, test_loader, class_idx, max_prototypes=5):
    """
    Show summary of prototypes for a specific class
    """
    print(f"\n" + "="*50)
    print(f"CLASS {class_idx} PROTOTYPE SUMMARY")
    print("="*50)
    
    # Get prototypes for this class
    start_idx = class_idx * cfg.num_prototypes_per_class
    end_idx = start_idx + cfg.num_prototypes_per_class
    num_to_show = min(max_prototypes, cfg.num_prototypes_per_class)
    
    for i in range(num_to_show):
        proto_idx = start_idx + i
        print(f"\n--- Prototype {proto_idx} (Class {class_idx}, Proto {i+1}/{cfg.num_prototypes_per_class}) ---")
        
        # Show single prototype analysis
        visualize_prototype_analysis(model, cfg, test_loader, proto_idx)

def analyze_prototype_diversity(model, cfg):
    """
    Analyze the diversity and statistics of learned prototypes
    """
    prototypes = model.prototype_layer.prototypes.data  # [P, C, H, W]
    P, C, H, W = prototypes.shape
    
    print(f"\n" + "="*50)
    print("PROTOTYPE DIVERSITY ANALYSIS")
    print("="*50)
    print(f"Total prototypes: {P}")
    print(f"Prototype shape: [{C}, {H}, {W}]")
    print(f"Classes: {cfg.num_classes}")
    print(f"Prototypes per class: {cfg.num_prototypes_per_class}")
    
    # Flatten prototypes for analysis
    proto_flat = prototypes.view(P, -1)  # [P, C*H*W]
    
    # Compute statistics
    proto_means = proto_flat.mean(dim=1)  # [P]
    proto_stds = proto_flat.std(dim=1)    # [P]
    proto_norms = proto_flat.norm(dim=1)  # [P]
    
    # Compute pairwise similarities between prototypes
    proto_normalized = F.normalize(proto_flat, dim=1)
    similarity_matrix = torch.mm(proto_normalized, proto_normalized.t())  # [P, P]
    
    # Remove diagonal (self-similarity)
    mask = torch.eye(P).bool()
    off_diagonal_sims = similarity_matrix[~mask]
    
    print(f"\nPrototype Statistics:")
    print(f"  Mean activation: {proto_means.mean():.4f} ± {proto_means.std():.4f}")
    print(f"  Mean std: {proto_stds.mean():.4f} ± {proto_stds.std():.4f}")
    print(f"  Mean norm: {proto_norms.mean():.4f} ± {proto_norms.std():.4f}")
    
    print(f"\nPrototype Diversity:")
    print(f"  Mean pairwise similarity: {off_diagonal_sims.mean():.4f}")
    print(f"  Max pairwise similarity: {off_diagonal_sims.max():.4f}")
    print(f"  Min pairwise similarity: {off_diagonal_sims.min():.4f}")
    
    # Plot similarity distribution
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Histogram of pairwise similarities
    ax1.hist(off_diagonal_sims.cpu().numpy(), bins=50, alpha=0.7, edgecolor='black')
    ax1.set_xlabel('Pairwise Cosine Similarity')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Distribution of Prototype Similarities')
    ax1.axvline(off_diagonal_sims.mean().item(), color='red', linestyle='--', 
                label=f'Mean: {off_diagonal_sims.mean():.3f}')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Prototype norms
    ax2.hist(proto_norms.cpu().numpy(), bins=50, alpha=0.7, edgecolor='black', color='orange')
    ax2.set_xlabel('Prototype L2 Norm')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Distribution of Prototype Norms')
    ax2.axvline(proto_norms.mean().item(), color='red', linestyle='--',
                label=f'Mean: {proto_norms.mean():.3f}')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return {
        'similarity_matrix': similarity_matrix,
        'mean_similarity': off_diagonal_sims.mean().item(),
        'mean_norm': proto_norms.mean().item(),
        'diversity_score': 1 - off_diagonal_sims.mean().item()  # Higher = more diverse
    }

# Enhanced demonstration
def run_comprehensive_prototype_analysis(model, cfg, test_loader):
    """Run a comprehensive analysis of learned prototypes"""
    
    print("🔍 COMPREHENSIVE PROTOTYPE ANALYSIS")
    print("="*80)
    
    # 1. Overall prototype diversity analysis
    diversity_stats = analyze_prototype_diversity(model, cfg)
    print(f"\n📊 Diversity Score: {diversity_stats['diversity_score']:.4f} (higher = more diverse)")
    
    # 2. Show prototypes from first few classes
    for class_idx in range(min(2, cfg.num_classes)):
        visualize_class_prototypes_summary(model, cfg, test_loader, class_idx, max_prototypes=2)
    
    # 3. Show patch grids for a few interesting prototypes
    print(f"\n🔍 DETAILED PATCH ANALYSIS")
    print("="*50)
    
    # Pick some prototypes to analyze in detail
    interesting_prototypes = []
    
    # Pick one prototype per class for detailed analysis
    for class_idx in range(min(3, cfg.num_classes)):
        proto_idx = class_idx * cfg.num_prototypes_per_class
        interesting_prototypes.append(proto_idx)
    
    for proto_idx in interesting_prototypes[:2]:  # Limit to avoid too much output
        visualize_prototype_patches_grid(model, cfg, test_loader, proto_idx, num_patches=6)
    
    print("\n✅ Analysis complete!")

In [None]:
def grad_to_img_bcos(img_6ch, linear_mapping, smooth=3, alpha_percentile=99.5):
    """
    Computing color image from dynamic linear mapping of B-cos models.
    Uses the proper B-cos grad_to_img function from /B-cos/interpretability/utils.py
    
    Args:
        img_6ch: Original 6-channel input image [r,g,b,1-r,1-g,1-b] - shape (6,H,W)
        linear_mapping: linear mapping from B-cos model - shape (6,H,W)
        smooth: kernel size for smoothing the alpha values
        alpha_percentile: cut-off percentile for the alpha value
    
    Returns:
        RGBA image explanation of the B-cos model - shape (H,W,4)
    """
    # Import the proper B-cos explanation utilities
    try:
        sys.path.append('./B-cos')
        from interpretability.utils import grad_to_img
        
        # Use the official B-cos grad_to_img function
        return grad_to_img(img_6ch, linear_mapping, smooth=smooth, alpha_percentile=alpha_percentile)
        
    except ImportError:
        # Fallback implementation if B-cos utils not available
        print("Warning: Using fallback B-cos explanation (B-cos utils not found)")
        
        # Ensure tensors are on CPU and detached
        if hasattr(img_6ch, 'detach'):
            img_6ch = img_6ch.detach().cpu()
        if hasattr(linear_mapping, 'detach'):
            linear_mapping = linear_mapping.detach().cpu()
        
        # Shape should be [6, H, W]
        if len(img_6ch.shape) == 4:
            img_6ch = img_6ch[0]
        if len(linear_mapping.shape) == 4:
            linear_mapping = linear_mapping[0]
        
        # Compute contributions: summing over channel dimension gives contribution map per location
        contribs = (img_6ch * linear_mapping).sum(0, keepdim=True)  # [1, H, W]
        contribs = contribs[0]  # [H, W]
        
        # Normalize each pixel vector s.t. max entry is 1, maintaining direction
        rgb_grad = (linear_mapping / (linear_mapping.abs().max(0, keepdim=True)[0] + 1e-12))
        
        # Clip off values below 0 (set negatively weighted channels to 0)
        rgb_grad = rgb_grad.clamp(0)
        
        # Normalize s.t. each pair (e.g., r and 1-r) sums to 1, use only RGB values
        rgb_grad = rgb_grad[:3] / (rgb_grad[:3] + rgb_grad[3:] + 1e-12)
        
        # Set alpha value to the strength (L2 norm) of each location's gradient
        alpha = linear_mapping.norm(p=2, dim=0, keepdim=True)  # [1, H, W]
        
        # Only show positive contributions
        alpha = torch.where(contribs[None] < 0, torch.zeros_like(alpha) + 1e-12, alpha)
        
        # Apply smoothing if requested
        if smooth > 1:
            alpha = F.avg_pool2d(alpha, smooth, stride=1, padding=(smooth-1)//2)
        
        # Convert to numpy and normalize alpha
        alpha = alpha.numpy()
        alpha = (alpha / np.percentile(alpha, alpha_percentile)).clip(0, 1)
        
        # Convert RGB gradients to numpy
        rgb_grad = rgb_grad.numpy()
        
        # Combine RGB + Alpha: shape [4, H, W]
        rgba_grad = np.concatenate([rgb_grad, alpha], axis=0)
        
        # Reshape to [H, W, 4] for display
        grad_image = rgba_grad.transpose((1, 2, 0))
        
        return grad_image