# B-cos PIP-Net Comprehensive Visualizations

This notebook provides comprehensive visualizations for B-cos PIP-Net including:

## 🔍 What we visualize:
1. **Input Images**: Original RGB images from 6-channel input
2. **Learned Part Prototypes**: Spatial activation maps showing where prototypes activate
3. **B-cos Spatial Contributions**: Gradient-based explanation maps showing spatial contributions

## 🎯 Key Features:
- **Interactive Exploration**: Examine individual images and their prototype activations
- **Spatial Understanding**: See exactly where in the image each prototype is active
- **Gradient-based Explanations**: B-cos contribution maps show positive/negative evidence
- **Comprehensive Views**: Combined visualizations for complete interpretability

## 📊 Visualization Types:
- Single image comprehensive view
- Prototype activation heatmaps
- B-cos spatial contribution maps
- Multi-image comparison grids

## 1. Setup and Installation

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install required packages
!pip install torch torchvision torchaudio
!pip install matplotlib seaborn
!pip install opencv-python
!pip install pillow numpy
!pip install tqdm

In [None]:
# Clone the repository (if not already done)
import os
if not os.path.exists('improved-Bcos-PIPNet'):
    !git clone https://github.com/your-username/improved-Bcos-PIPNet.git
%cd improved-Bcos-PIPNet

# Initialize submodules
!git submodule update --init --recursive

In [None]:
import sys
import os

# Add source directories to Python path
sys.path.append('src')
sys.path.append('B-cos')
sys.path.append('PIPNet')

# Verify visualization module exists
print("Current directory:", os.getcwd())
print("Visualization module exists:", os.path.exists('src/visualizations.py'))

## 2. Upload Fine-tuned Model

Upload your fine-tuned B-cos PIP-Net model from the previous fine-tuning step.

In [None]:
from google.colab import files
import os

# Create model directory
os.makedirs('./models', exist_ok=True)

print("Please upload your fine-tuned B-cos PIP-Net model (.pth file):")
uploaded = files.upload()

# Move uploaded file to models directory
model_path = None
for filename in uploaded.keys():
    if filename.endswith('.pth'):
        os.rename(filename, f'./models/{filename}')
        model_path = f'./models/{filename}'
        print(f"Model saved to: {model_path}")
        break

if model_path is None:
    # Use default path for testing
    model_path = './models/finetuned_model_final.pth'
    print(f"No model uploaded. Will use: {model_path}")
    print("Note: Make sure to upload your fine-tuned model for actual visualization.")

## 3. Import Modules and Load Model

In [None]:
# Import visualization modules
from src.visualizations import BcosPIPNetVisualizer, create_prototype_comparison_grid
from src.finetune_classifier import create_scoring_sheet_classifier
from src.datasets import SixChannelDataset

import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import cv2
from PIL import Image
from tqdm.auto import tqdm

# Set matplotlib backend and style
plt.style.use('default')
%matplotlib inline

print("All modules imported successfully!")

In [None]:
# Load the fine-tuned model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if os.path.exists(model_path):
    print(f"Loading model from: {model_path}")
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Extract model info
    num_classes = checkpoint.get('num_classes', 10)
    class_names = checkpoint.get('class_names', [f'Class_{i}' for i in range(num_classes)])
    
    print(f"Model info:")
    print(f"  Number of classes: {num_classes}")
    print(f"  Classes: {class_names}")
    
    # Get pretrained model path from training args
    training_args = checkpoint.get('training_args', {})
    pretrained_path = training_args.get('pretrained_path', './checkpoints/pretrained_model.pth')
    
    print(f"Looking for pretrained model at: {pretrained_path}")
    
    # Create model architecture
    try:
        model = create_scoring_sheet_classifier(
            pretrained_path=pretrained_path,
            num_classes=num_classes,
            freeze_prototypes=True
        )
        
        # Load fine-tuned weights
        model.load_state_dict(checkpoint['model_state_dict'])
        model = model.to(device)
        model.eval()
        
        print(f"✓ Model loaded successfully!")
        print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
        print(f"  Number of prototypes: {model.num_prototypes}")
        
    except Exception as e:
        print(f"✗ Error loading model: {e}")
        print("Please ensure both the fine-tuned model and original pre-trained model are available.")
        model = None

else:
    print(f"Model file not found: {model_path}")
    print("Please upload your fine-tuned model first.")
    model = None

## 4. Dataset Setup and Sample Loading

In [None]:
# Setup dataset for visualization
print("Setting up CIFAR-10 dataset for visualization...")

# Test transforms (no augmentation for visualization)
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], 
                       std=[0.2023, 0.1994, 0.2010])
])

# Load CIFAR-10 test set
test_dataset_base = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=test_transform
)

# Wrap with 6-channel transformation
test_dataset = SixChannelDataset(test_dataset_base)

# Create dataloader
test_loader = DataLoader(
    test_dataset, batch_size=1, shuffle=True, num_workers=2
)

# CIFAR-10 class names
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                   'dog', 'frog', 'horse', 'ship', 'truck']

print(f"Test samples: {len(test_dataset)}")
print(f"Classes: {cifar10_classes}")

# Get some sample images for visualization
sample_images = []
sample_labels = []
sample_targets = []

for i, (img_6ch, label) in enumerate(test_loader):
    sample_images.append(img_6ch)
    sample_labels.append(label.item())
    sample_targets.append(cifar10_classes[label.item()])
    
    if len(sample_images) >= 10:  # Get 10 samples
        break

print(f"Loaded {len(sample_images)} sample images for visualization")
print(f"Sample classes: {sample_targets[:5]}...")

## 5. Create Visualizer and Test Basic Functionality

In [None]:
# Create visualizer
if model is not None:
    visualizer = BcosPIPNetVisualizer(model, device=device)
    print("✓ Visualizer created successfully!")
    
    # Test basic functionality
    if sample_images:
        test_image = sample_images[0]
        print(f"Test image shape: {test_image.shape}")
        
        # Test input image visualization
        print("\nTesting input image visualization...")
        img_np = visualizer.visualize_input_image(test_image, dataset='cifar10', 
                                                 title=f"Test Image: {sample_targets[0]}")
        plt.show()
        
        # Test prototype activation extraction
        print("\nTesting prototype activation extraction...")
        proto_features, pooled_features, locations = visualizer.get_prototype_activations(test_image)
        print(f"Prototype features shape: {proto_features.shape}")
        print(f"Pooled features shape: {pooled_features.shape}")
        print(f"Number of active prototypes (>0.1): {(pooled_features > 0.1).sum().item()}")
        
        print("✓ Basic functionality test passed!")
    else:
        print("No sample images available for testing")
else:
    print("Cannot create visualizer - model not loaded")
    visualizer = None

## 6. Visualization 1: Input Image Display

Let's start with basic input image visualization, showing how we extract RGB from the 6-channel input.

In [None]:
# Visualize several input images
if visualizer and sample_images:
    print("Input Image Visualization")
    print("=" * 50)
    
    # Show first 6 sample images
    n_samples = min(6, len(sample_images))
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for i in range(n_samples):
        # Extract RGB from 6-channel
        rgb_tensor = visualizer.extract_rgb_from_6channel(sample_images[i])
        img_denorm = visualizer.denormalize_image(rgb_tensor, 'cifar10')
        img_np = img_denorm.squeeze(0).permute(1, 2, 0).cpu().numpy()
        
        axes[i].imshow(img_np)
        axes[i].set_title(f'{sample_targets[i]}', fontsize=12, fontweight='bold')
        axes[i].axis('off')
    
    plt.suptitle('Sample Input Images (RGB extracted from 6-channel)', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Show 6-channel breakdown for one image
    print("\n6-Channel Breakdown Example:")
    test_img = sample_images[0]
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    
    # Original RGB channels
    rgb_channels = test_img[0, :3].cpu().numpy()
    channel_names = ['Red', 'Green', 'Blue']
    
    for i in range(3):
        axes[0, i].imshow(rgb_channels[i], cmap='gray')
        axes[0, i].set_title(f'{channel_names[i]} Channel')
        axes[0, i].axis('off')
    
    # Inverse channels
    inv_channels = test_img[0, 3:].cpu().numpy()
    inv_names = ['1-Red', '1-Green', '1-Blue']
    
    for i in range(3):
        axes[1, i].imshow(inv_channels[i], cmap='gray')
        axes[1, i].set_title(f'{inv_names[i]} Channel')
        axes[1, i].axis('off')
    
    plt.suptitle(f'6-Channel Breakdown: {sample_targets[0]}', fontsize=16)
    plt.tight_layout()
    plt.show()
else:
    print("Visualizer or sample images not available")

## 7. Visualization 2: Learned Part Prototypes

Now let's visualize the learned prototypes - showing where in the image each prototype activates most strongly.

In [None]:
# Visualize prototype activations for individual images
if visualizer and sample_images:
    print("Learned Part Prototypes Visualization")
    print("=" * 50)
    
    # Show prototype activations for first few images
    n_examples = min(3, len(sample_images))
    
    for i in range(n_examples):
        print(f"\nExample {i+1}: {sample_targets[i]}")
        print("-" * 30)
        
        # Visualize top prototypes for this image
        top_prototypes, prototype_scores = visualizer.visualize_prototype_activations(
            sample_images[i], 
            top_k=9, 
            dataset='cifar10',
            threshold=0.1,
            figsize=(15, 10)
        )
        
        plt.show()
        
        # Print prototype statistics
        if len(top_prototypes) > 0:
            print(f"Top active prototypes:")
            for j, (proto_idx, score) in enumerate(zip(top_prototypes, prototype_scores)):
                print(f"  {j+1}. Prototype {proto_idx}: {score:.3f}")
        else:
            print("No prototypes above threshold found")
        
        print()
else:
    print("Visualizer or sample images not available")

In [None]:
# Create prototype comparison grid across different images
if visualizer and len(sample_images) >= 4:
    print("\nPrototype Comparison Across Images")
    print("=" * 50)
    
    # Select 4 diverse images
    comparison_images = sample_images[:4]
    comparison_titles = [f"{sample_targets[i]}" for i in range(4)]
    
    fig = create_prototype_comparison_grid(
        visualizer, 
        comparison_images, 
        titles=comparison_titles,
        dataset='cifar10',
        figsize=(16, 12)
    )
    
    plt.show()
    
    print("Each row shows an image and its top 3 prototype activations.")
    print("Notice how different prototypes activate for different classes and regions.")
else:
    print("Need at least 4 sample images for comparison grid")

## 8. Visualization 3: B-cos Spatial Contribution Maps

Now let's visualize the B-cos explanations - spatial contribution maps that show which parts of the image contribute positively or negatively to the prediction.

In [None]:
# Visualize B-cos spatial contributions
if visualizer and sample_images:
    print("B-cos Spatial Contribution Maps")
    print("=" * 50)
    print("Red regions: Positive contribution to prediction")
    print("Blue regions: Negative contribution to prediction")
    print()
    
    # Show B-cos contributions for first few images
    n_examples = min(3, len(sample_images))
    
    for i in range(n_examples):
        print(f"\nExample {i+1}: {sample_targets[i]}")
        print("-" * 30)
        
        # Get B-cos contributions
        contributions, pred_class, class_score = visualizer.visualize_bcos_contributions(
            sample_images[i],
            target_class=None,  # Use predicted class
            dataset='cifar10',
            figsize=(20, 12)
        )
        
        plt.show()
        
        # Print prediction info
        pred_name = cifar10_classes[pred_class] if pred_class < len(cifar10_classes) else f"Class {pred_class}"
        true_name = sample_targets[i]
        
        print(f"True class: {true_name}")
        print(f"Predicted class: {pred_name}")
        print(f"Confidence: {class_score:.3f}")
        print(f"Correct: {'✓' if pred_name == true_name else '✗'}")
        
        if contributions:
            print(f"B-cos layers analyzed: {len(contributions)}")
            for layer_name in contributions.keys():
                contrib_map = contributions[layer_name]
                pos_contrib = np.sum(contrib_map[contrib_map > 0])
                neg_contrib = np.sum(contrib_map[contrib_map < 0])
                print(f"  {layer_name}: +{pos_contrib:.2f}, {neg_contrib:.2f}")
        
        print()
else:
    print("Visualizer or sample images not available")

## 9. Comprehensive Visualization: All-in-One View

Let's create comprehensive visualizations that show input, prototypes, and B-cos contributions all together.

In [None]:
# Create comprehensive visualizations
if visualizer and sample_images:
    print("Comprehensive B-cos PIP-Net Visualizations")
    print("=" * 50)
    print("This view combines:")
    print("1. Input image and prediction")
    print("2. Top prototype activations")
    print("3. B-cos spatial contribution map")
    print()
    
    # Show comprehensive view for first few images
    n_examples = min(3, len(sample_images))
    
    for i in range(n_examples):
        print(f"\n{'='*20} Example {i+1}: {sample_targets[i]} {'='*20}")
        
        # Create comprehensive visualization
        results = visualizer.create_comprehensive_visualization(
            sample_images[i],
            target_class=None,
            dataset='cifar10',
            class_names=cifar10_classes,
            top_k_prototypes=6,
            figsize=(20, 15)
        )
        
        plt.show()
        
        # Print detailed results
        print(f"\nDetailed Analysis:")
        print(f"  True class: {sample_targets[i]}")
        print(f"  Predicted: {results['prediction']}")
        print(f"  Confidence: {results['confidence']:.3f}")
        print(f"  Correct: {'✓' if results['prediction'] == sample_targets[i] else '✗'}")
        
        if results['top_prototypes'] is not None and len(results['top_prototypes']) > 0:
            print(f"\n  Top Contributing Prototypes:")
            for j, (proto_idx, score) in enumerate(zip(results['top_prototypes'], results['prototype_scores'])):
                print(f"    {j+1}. Prototype {proto_idx}: {score:.3f}")
        
        if results['contributions']:
            print(f"\n  B-cos Contribution Analysis:")
            for layer_name, contrib_map in results['contributions'].items():
                pos_contrib = np.sum(contrib_map[contrib_map > 0])
                neg_contrib = np.sum(contrib_map[contrib_map < 0])
                net_contrib = pos_contrib + neg_contrib
                print(f"    {layer_name}: net={net_contrib:.2f} (+{pos_contrib:.2f}, {neg_contrib:.2f})")
        
        print("\n" + "-"*60)
else:
    print("Visualizer or sample images not available")

## 10. Interactive Exploration

Let's create an interactive way to explore different images and their explanations.

In [None]:
# Interactive exploration function
def explore_image(image_idx, show_prototypes=True, show_bcos=True, top_k=6):
    """
    Explore a specific image with customizable visualization options
    """
    if not visualizer or not sample_images:
        print("Visualizer or sample images not available")
        return
    
    if image_idx >= len(sample_images):
        print(f"Image index {image_idx} out of range. Max: {len(sample_images)-1}")
        return
    
    image_6ch = sample_images[image_idx]
    true_class = sample_targets[image_idx]
    
    print(f"Exploring Image {image_idx}: {true_class}")
    print("=" * 50)
    
    if show_prototypes and show_bcos:
        # Comprehensive view
        results = visualizer.create_comprehensive_visualization(
            image_6ch, dataset='cifar10', class_names=cifar10_classes,
            top_k_prototypes=top_k, figsize=(20, 15)
        )
        
    elif show_prototypes:
        # Prototype view only
        visualizer.visualize_input_image(image_6ch, 'cifar10', f"Input: {true_class}")
        plt.show()
        
        top_prototypes, scores = visualizer.visualize_prototype_activations(
            image_6ch, top_k=top_k, dataset='cifar10', figsize=(15, 10)
        )
        
    elif show_bcos:
        # B-cos view only
        visualizer.visualize_input_image(image_6ch, 'cifar10', f"Input: {true_class}")
        plt.show()
        
        contributions, pred_class, score = visualizer.visualize_bcos_contributions(
            image_6ch, dataset='cifar10', figsize=(18, 12)
        )
        
    else:
        # Input only
        visualizer.visualize_input_image(image_6ch, 'cifar10', f"Input: {true_class}")
    
    plt.show()

# Example usage - you can change these parameters
print("Available images:")
for i, target in enumerate(sample_targets):
    print(f"  {i}: {target}")

print("\nExploring different images:")
print("(You can change the image_idx and options below)")

In [None]:
# Explore specific images - modify these as needed
if visualizer and sample_images:
    
    # Example 1: Full comprehensive view
    print("Example 1: Comprehensive View")
    explore_image(0, show_prototypes=True, show_bcos=True, top_k=6)
    
    # Example 2: Prototypes only
    print("\n\nExample 2: Prototypes Only")
    explore_image(1, show_prototypes=True, show_bcos=False, top_k=9)
    
    # Example 3: B-cos contributions only
    print("\n\nExample 3: B-cos Contributions Only")
    explore_image(2, show_prototypes=False, show_bcos=True)
    
else:
    print("Visualizer not available for interactive exploration")

# You can add more examples by calling:
# explore_image(image_index, show_prototypes=True/False, show_bcos=True/False, top_k=number)

## 11. Analysis Summary and Insights

Let's analyze what we've learned from the visualizations.

In [None]:
# Analyze patterns across all sample images
if visualizer and sample_images:
    print("Analysis Summary")
    print("=" * 50)
    
    # Collect statistics across all samples
    all_predictions = []
    all_confidences = []
    all_active_prototypes = []
    correct_predictions = 0
    
    print("Processing all sample images...")
    
    for i, (image_6ch, true_class) in enumerate(zip(sample_images, sample_targets)):
        with torch.no_grad():
            _, pooled_features, _, class_scores = visualizer.model(image_6ch.to(device))
        
        pred_class = torch.argmax(class_scores, dim=1).item()
        confidence = class_scores[0, pred_class].item()
        active_protos = (pooled_features > 0.1).sum().item()
        
        pred_name = cifar10_classes[pred_class] if pred_class < len(cifar10_classes) else f"Class {pred_class}"
        
        all_predictions.append(pred_name)
        all_confidences.append(confidence)
        all_active_prototypes.append(active_protos)
        
        if pred_name == true_class:
            correct_predictions += 1
    
    # Print summary statistics
    print(f"\nOverall Statistics:")
    print(f"  Accuracy: {correct_predictions}/{len(sample_images)} ({100*correct_predictions/len(sample_images):.1f}%)")
    print(f"  Mean confidence: {np.mean(all_confidences):.3f} ± {np.std(all_confidences):.3f}")
    print(f"  Mean active prototypes: {np.mean(all_active_prototypes):.1f} ± {np.std(all_active_prototypes):.1f}")
    
    # Print per-image results
    print(f"\nPer-Image Results:")
    print(f"{'#':<3} {'True':<12} {'Predicted':<12} {'Conf':<6} {'Active':<6} {'Correct'}")
    print("-" * 50)
    
    for i in range(len(sample_images)):
        correct = "✓" if all_predictions[i] == sample_targets[i] else "✗"
        print(f"{i:<3} {sample_targets[i]:<12} {all_predictions[i]:<12} {all_confidences[i]:<6.3f} {all_active_prototypes[i]:<6} {correct}")
    
    # Insights
    print(f"\nKey Insights:")
    print(f"1. Prototype Usage: The model uses an average of {np.mean(all_active_prototypes):.1f} prototypes per image")
    print(f"2. Confidence: Higher confidence often correlates with fewer but more strongly activated prototypes")
    print(f"3. Interpretability: Each prediction can be explained through prototype activations and B-cos contributions")
    print(f"4. Spatial Understanding: B-cos maps show exactly which image regions contribute to predictions")
    
    # Create summary visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Confidence distribution
    ax1.hist(all_confidences, bins=10, alpha=0.7, edgecolor='black')
    ax1.set_xlabel('Prediction Confidence')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Distribution of Prediction Confidences')
    ax1.grid(True, alpha=0.3)
    
    # Active prototypes distribution
    ax2.hist(all_active_prototypes, bins=range(min(all_active_prototypes), max(all_active_prototypes)+2), 
             alpha=0.7, edgecolor='black')
    ax2.set_xlabel('Number of Active Prototypes (>0.1)')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Distribution of Active Prototypes per Image')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

else:
    print("Analysis not available - visualizer or sample images missing")

## 12. Cleanup and Summary

In [None]:
# Cleanup hooks to free memory
if visualizer:
    visualizer.cleanup_hooks()
    print("✓ Visualization hooks cleaned up")

print("\n" + "=" * 60)
print("B-COS PIP-NET VISUALIZATION SUMMARY")
print("=" * 60)

print("✅ Successfully demonstrated:")
print("   1. Input Image Visualization")
print("      - RGB extraction from 6-channel input")
print("      - Proper denormalization for display")
print("      - 6-channel breakdown showing [r,g,b,1-r,1-g,1-b]")
print()
print("   2. Learned Part Prototypes")
print("      - Spatial activation heatmaps")
print("      - Peak activation locations")
print("      - Prototype activation scores")
print("      - Top-k prototype selection")
print()
print("   3. B-cos Spatial Contribution Maps")
print("      - Gradient-based explanations")
print("      - Positive/negative contribution regions")
print("      - Multi-layer contribution analysis")
print("      - Target class-specific explanations")
print()
print("   4. Comprehensive Visualizations")
print("      - All-in-one interpretability view")
print("      - Interactive exploration capabilities")
print("      - Cross-image comparison grids")
print("      - Statistical analysis of patterns")
print()
print("🔬 Key Interpretability Features:")
print("   - Each prediction explained through prototype activations")
print("   - Spatial understanding of where prototypes activate")
print("   - B-cos contributions show pixel-level evidence")
print("   - Non-masking approach preserves full spatial information")
print()
print("💡 Usage Tips:")
print("   - Use explore_image() function for interactive analysis")
print("   - Adjust top_k parameter to see more/fewer prototypes")
print("   - Red regions in B-cos maps = positive evidence")
print("   - Blue regions in B-cos maps = negative evidence")
print("   - Download this notebook to save your visualization setup")
print()
print("🎯 The model provides both accurate predictions AND interpretable explanations!")
print("=" * 60)