In [13]:
# %%
# =============================================================================
# 📋 DEEPLIFT IMPLEMENTATION FROM SCRATCH
# Visual Intelligence Project - Phase 3: Explainability
# =============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional, Union
import copy
from pathlib import Path
import json
from PIL import Image
import torchvision.transforms as transforms

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")



In [14]:
# =============================================================================
# 🧠 DEEPLIFT MATHEMATICAL FOUNDATION
# =============================================================================

class DeepLIFTFromScratch:
    """
    DeepLIFT (Deep Learning Important FeaTures) implementation from scratch.
    
    Mathematical Foundation:
    - DeepLIFT assigns attribution scores based on difference from reference
    - Uses chain rule with modified gradients for different layer types
    - Satisfies attribution conservation: sum(attributions) = output_diff
    
    Key Principles:
    1. Linear Rule: For linear layers, attribution flows proportionally
    2. Rescale Rule: For ReLU, rescale based on activation differences
    3. Reference Choice: Baseline input (e.g., zeros, mean, blurred image)
    """
    
    def __init__(self, model: nn.Module, reference_input: torch.Tensor = None):
        """
        Initialize DeepLIFT explainer
        
        Args:
            model: PyTorch model to explain
            reference_input: Reference/baseline input (if None, uses zeros)
        """
        self.model = model
        self.model.eval()
        
        # Store reference input
        self.reference_input = reference_input
        
        # Storage for forward pass activations
        self.activations = {}
        self.reference_activations = {}
        
        # Storage for attribution computation
        self.attribution_maps = {}
        
        self.input_shapes = []  # Track input shapes for flatten/unflatten
        self.pool_shapes = []   # Track input shapes for pooling layers
        
        print("🧠 DeepLIFT Explainer Initialized")
        print(f"   Model: {model.__class__.__name__}")
        print(f"   Reference: {'Custom' if reference_input is not None else 'Zero baseline'}")
    
    def set_reference(self, reference_input: torch.Tensor):
        """Set new reference input"""
        self.reference_input = reference_input
        print(f"🔄 Reference updated: shape {reference_input.shape}")
    
    def _register_hooks(self):
        """Register forward hooks to capture activations"""
        self.hooks = []
        self.input_shapes = []  # Reset input shapes
        self.pool_shapes = []   # Reset pool shapes
        
        def forward_hook(name):
            def hook(module, input, output):
                self.activations[name] = output.detach()
                # Track input shape for flatten
                if isinstance(module, nn.Flatten):
                    self.input_shapes.append(input[0].shape)
                # Track input shape for pooling
                if isinstance(module, (nn.AvgPool2d, nn.AdaptiveAvgPool2d, nn.MaxPool2d)):
                    self.pool_shapes.append(input[0].shape)
            return hook
        
        # Register hooks for key layers
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear, nn.ReLU, nn.Flatten, nn.AvgPool2d, nn.AdaptiveAvgPool2d, nn.MaxPool2d)):
                hook = module.register_forward_hook(forward_hook(name))
                self.hooks.append(hook)
    
    def _remove_hooks(self):
        """Remove all registered hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
    
    def _forward_with_reference(self, input_tensor: torch.Tensor):
        """
        Forward pass with both input and reference to collect activations
        """
        # Ensure reference is set
        if self.reference_input is None:
            self.reference_input = torch.zeros_like(input_tensor)
        
        # Register hooks
        self._register_hooks()
        
        # Forward pass with input
        self.activations = {}
        with torch.no_grad():
            output = self.model(input_tensor)
        
        # Forward pass with reference
        self.reference_activations = {}
        temp_activations = {}
        
        def ref_hook(name):
            def hook(module, input, output):
                temp_activations[name] = output.detach()
            return hook
        
        # Register reference hooks
        ref_hooks = []
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear, nn.ReLU, nn.Flatten, nn.AvgPool2d, nn.AdaptiveAvgPool2d, nn.MaxPool2d)):
                hook = module.register_forward_hook(ref_hook(name))
                ref_hooks.append(hook)
        
        with torch.no_grad():
            ref_output = self.model(self.reference_input)
        
        self.reference_activations = temp_activations
        
        # Clean up hooks
        for hook in ref_hooks:
            hook.remove()
        self._remove_hooks()
        
        return output, ref_output
    
    def _compute_conv2d_attribution(self, layer_name: str, input_attr: torch.Tensor, 
                                   layer: nn.Conv2d) -> torch.Tensor:
        """
        Compute attribution for Conv2D layer using linear rule
        
        Mathematical foundation:
        For linear operations: attr_input = W * attr_output
        where W are the weights and * represents convolution operation
        """
        # Get activations
        current_act = self.activations[layer_name]
        ref_act = self.reference_activations[layer_name]
        
        # Compute activation differences
        act_diff = current_act - ref_act
        
        # Apply rescale rule: attribution proportional to activation difference
        if torch.sum(torch.abs(act_diff)) > 1e-8:
            # Rescale input attribution based on activation differences
            rescale_factor = act_diff / (torch.sum(torch.abs(act_diff)) + 1e-8)
            output_attr = input_attr * rescale_factor
        else:
            output_attr = torch.zeros_like(input_attr)
        
        return output_attr
    
    def _compute_relu_attribution(self, layer_name: str, input_attr: torch.Tensor) -> torch.Tensor:
        """
        Compute attribution for ReLU layer using rescale rule
        
        Mathematical foundation:
        For ReLU: f(x) = max(0, x)
        Attribution rescaled based on which neurons are active
        """
        # Get activations before and after ReLU
        current_act = self.activations[layer_name]
        ref_act = self.reference_activations[layer_name]
        
        # Compute differences
        act_diff = current_act - ref_act
        
        # Rescale rule for ReLU
        mask = (act_diff != 0).float()
        output_attr = input_attr * mask
        
        return output_attr
    
    def _compute_linear_attribution(self, layer_name: str, input_attr: torch.Tensor, 
                                   layer: nn.Linear) -> torch.Tensor:
        """
        Compute attribution for Linear layer using linear rule
        
        Mathematical foundation:
        For linear layer: y = Wx + b
        Attribution: attr_x = W^T * attr_y
        """
        # Get weight matrix
        weight = layer.weight  # Shape: [out_features, in_features]
        
        # Compute attribution using transpose of weights
        # input_attr shape: [batch, out_features]
        # weight.T shape: [in_features, out_features]
        output_attr = torch.matmul(input_attr, weight)  # [batch, in_features]
        
        return output_attr
    
    def _compute_flatten_attribution(self, input_attr: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
        """Compute attribution for Flatten layer (noop in this case)"""
        # Unflatten the attribution to the original shape
        return input_attr.view(orig_shape)
    
    def _compute_pooling_attribution(self, input_attr: torch.Tensor, orig_shape: tuple, pool_type: str) -> torch.Tensor:
        """
        Compute attribution for pooling layers (e.g., AvgPool2d, MaxPool2d)
        
        Mathematical foundation:
        - For average pooling, distribute attribution equally to each pooled region
        - For max pooling, propagate attribution from the maximum location
        """
        # Upsample the attribution to the pre-pooled shape
        # For average pooling, distribute attribution equally to each pooled region
        if pool_type == 'adaptive_avg':
            # Use torch.nn.functional.interpolate for upsampling
            upsampled = F.interpolate(input_attr, size=orig_shape[2:], mode='nearest')
            return upsampled
        elif pool_type == 'avg' or pool_type == 'max':
            # Use repeat_interleave for upsampling (assume stride=kernel_size)
            # This is a simplification; for more general cases, use interpolate
            upsampled = F.interpolate(input_attr, size=orig_shape[2:], mode='nearest')
            return upsampled
        else:
            raise NotImplementedError(f"Pooling type {pool_type} not supported.")

    def compute_attributions(self, input_tensor: torch.Tensor, 
                           target_class: Optional[int] = None) -> torch.Tensor:
        """
        Compute DeepLIFT attributions for input
        
        Args:
            input_tensor: Input to explain [1, C, H, W]
            target_class: Target class index (if None, uses predicted class)
            
        Returns:
            attributions: Attribution map same shape as input
        """
        print(f"🔍 Computing DeepLIFT attributions...")
        print(f"   Input shape: {input_tensor.shape}")
        
        # Forward pass to get activations
        output, ref_output = self._forward_with_reference(input_tensor)
        
        # Determine target class
        if target_class is None:
            target_class = torch.argmax(output, dim=1).item()
        
        print(f"   Target class: {target_class}")
        print(f"   Output diff: {(output - ref_output)[0, target_class].item():.4f}")
        
        # Initialize attribution at output layer
        output_attr = torch.zeros_like(output)
        output_attr[0, target_class] = (output - ref_output)[0, target_class]
        
        # Backward attribution computation
        current_attr = output_attr
        
        # Get layer list in reverse order
        layer_list = list(self.model.named_modules())
        layer_list.reverse()
        
        flatten_idx = len(self.input_shapes) - 1  # For tracking flatten layers
        pool_idx = len(self.pool_shapes) - 1      # For tracking pooling layers
        
        for name, layer in layer_list:
            if isinstance(layer, nn.Linear):
                current_attr = self._compute_linear_attribution(name, current_attr, layer)
                print(f"   ↳ Linear {name}: attr shape {current_attr.shape}")
                
            elif isinstance(layer, nn.ReLU):
                current_attr = self._compute_relu_attribution(name, current_attr)
                print(f"   ↳ ReLU {name}: attr shape {current_attr.shape}")
                
            elif isinstance(layer, nn.Conv2d):
                current_attr = self._compute_conv2d_attribution(name, current_attr, layer)
                print(f"   ↳ Conv2D {name}: attr shape {current_attr.shape}")
                
            elif isinstance(layer, nn.Flatten):
                # Unflatten attribution to original shape
                if flatten_idx >= 0:
                    orig_shape = self.input_shapes[flatten_idx]
                    current_attr = self._compute_flatten_attribution(current_attr, orig_shape)
                    print(f"   ↳ Flatten {name}: unflattened to {orig_shape}")
                    flatten_idx -= 1
            
            elif isinstance(layer, nn.AdaptiveAvgPool2d):
                # Upsample attribution to match input shape
                if pool_idx >= 0:
                    orig_shape = self.pool_shapes[pool_idx]
                    current_attr = self._compute_pooling_attribution(current_attr, orig_shape, pool_type='adaptive_avg')
                    print(f"   ↳ AdaptiveAvgPool2d {name}: upsampled to {orig_shape}")
                    pool_idx -= 1
            
            elif isinstance(layer, nn.AvgPool2d):
                # Upsample attribution to match input shape
                if pool_idx >= 0:
                    orig_shape = self.pool_shapes[pool_idx]
                    current_attr = self._compute_pooling_attribution(current_attr, orig_shape, pool_type='avg')
                    print(f"   ↳ AvgPool2d {name}: upsampled to {orig_shape}")
                    pool_idx -= 1
            
            elif isinstance(layer, nn.MaxPool2d):
                # Upsample attribution to match input shape
                if pool_idx >= 0:
                    orig_shape = self.pool_shapes[pool_idx]
                    current_attr = self._compute_pooling_attribution(current_attr, orig_shape, pool_type='max')
                    print(f"   ↳ MaxPool2d {name}: upsampled to {orig_shape}")
                    pool_idx -= 1
        
        # Final attribution should match input shape
        if current_attr.shape != input_tensor.shape:
            # Reshape if needed (handle flattening between conv and linear layers)
            if len(current_attr.shape) == 2 and len(input_tensor.shape) == 4:
                # Reconstruct spatial dimensions
                batch, channels, height, width = input_tensor.shape
                current_attr = current_attr.view(batch, channels, height, width)
        
        print(f"✅ Attribution computation complete")
        print(f"   Final attribution shape: {current_attr.shape}")
        print(f"   Attribution sum: {torch.sum(current_attr).item():.4f}")
        
        return current_attr



In [15]:
# =============================================================================
# 🔧 DEEPLIFT TESTING AND VALIDATION
# =============================================================================

class DeepLIFTValidator:
    """Validate DeepLIFT implementation"""
    
    @staticmethod
    def test_attribution_conservation(explainer: DeepLIFTFromScratch, 
                                    input_tensor: torch.Tensor, 
                                    target_class: int = None) -> Dict[str, float]:
        """
        Test if attributions satisfy conservation property:
        sum(attributions) ≈ output_difference
        """
        print("\n🧪 Testing Attribution Conservation...")
        
        # Compute attributions
        attributions = explainer.compute_attributions(input_tensor, target_class)
        
        # Get output difference
        output, ref_output = explainer._forward_with_reference(input_tensor)
        if target_class is None:
            target_class = torch.argmax(output, dim=1).item()
        
        output_diff = (output - ref_output)[0, target_class].item()
        attr_sum = torch.sum(attributions).item()
        
        conservation_error = abs(output_diff - attr_sum)
        conservation_ratio = attr_sum / (output_diff + 1e-8)
        
        results = {
            'output_difference': output_diff,
            'attribution_sum': attr_sum,
            'conservation_error': conservation_error,
            'conservation_ratio': conservation_ratio,
            'passes_test': conservation_error < 0.01
        }
        
        print(f"   Output difference: {output_diff:.6f}")
        print(f"   Attribution sum:   {attr_sum:.6f}")
        print(f"   Conservation error: {conservation_error:.6f}")
        print(f"   Conservation ratio: {conservation_ratio:.6f}")
        print(f"   Test result: {'✅ PASS' if results['passes_test'] else '❌ FAIL'}")
        
        return results
    
    @staticmethod
    def test_sensitivity(explainer: DeepLIFTFromScratch, 
                        input_tensor: torch.Tensor) -> Dict[str, float]:
        """
        Test sensitivity: attributions should be zero for features that don't affect output
        """
        print("\n🧪 Testing Sensitivity...")
        
        # Compute attributions for original input
        original_attr = explainer.compute_attributions(input_tensor)
        
        # Create modified input (zero out a region)
        modified_input = input_tensor.clone()
        modified_input[:, :, 10:20, 10:20] = 0  # Zero out a patch
        
        # Compute attributions for modified input
        modified_attr = explainer.compute_attributions(modified_input)
        
        # Check if zeroed region has lower attribution
        original_patch_attr = torch.sum(torch.abs(original_attr[:, :, 10:20, 10:20]))
        modified_patch_attr = torch.sum(torch.abs(modified_attr[:, :, 10:20, 10:20]))
        
        sensitivity_score = float(modified_patch_attr / (original_patch_attr + 1e-8))
        
        results = {
            'original_patch_attribution': float(original_patch_attr),
            'modified_patch_attribution': float(modified_patch_attr),
            'sensitivity_score': sensitivity_score,
            'passes_test': sensitivity_score < 0.5
        }
        
        print(f"   Original patch attribution: {original_patch_attr:.6f}")
        print(f"   Modified patch attribution: {modified_patch_attr:.6f}")
        print(f"   Sensitivity score: {sensitivity_score:.6f}")
        print(f"   Test result: {'✅ PASS' if results['passes_test'] else '❌ FAIL'}")
        
        return results



In [16]:
# =============================================================================
# 📊 VISUALIZATION UTILITIES
# =============================================================================

class DeepLIFTVisualizer:
    """Visualization utilities for DeepLIFT results"""
    
    @staticmethod
    def plot_attribution_map(input_image: torch.Tensor, 
                           attributions: torch.Tensor,
                           title: str = "DeepLIFT Attribution Map",
                           save_path: Optional[Path] = None) -> plt.Figure:
        """
        Visualize attribution map alongside original image
        """
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Convert tensors to numpy
        if input_image.dim() == 4:
            input_image = input_image.squeeze(0)
        if attributions.dim() == 4:
            attributions = attributions.squeeze(0)
        
        # Original image
        if input_image.shape[0] == 3:  # RGB
            img_np = input_image.permute(1, 2, 0).cpu().numpy()
            img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
        else:  # Grayscale
            img_np = input_image.squeeze().cpu().numpy()
        
        axes[0].imshow(img_np, cmap='gray' if input_image.shape[0] == 1 else None)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        # Attribution map
        attr_np = torch.sum(torch.abs(attributions), dim=0).cpu().numpy()
        im1 = axes[1].imshow(attr_np, cmap='hot', alpha=0.8)
        axes[1].set_title('Attribution Magnitude')
        axes[1].axis('off')
        plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
        
        # Overlay
        axes[2].imshow(img_np, cmap='gray' if input_image.shape[0] == 1 else None, alpha=0.7)
        im2 = axes[2].imshow(attr_np, cmap='hot', alpha=0.5)
        axes[2].set_title('Attribution Overlay')
        axes[2].axis('off')
        
        plt.suptitle(title, fontsize=16, fontweight='bold')
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"💾 Attribution map saved: {save_path}")
        
        return fig
    
    @staticmethod
    def plot_attribution_statistics(attributions: torch.Tensor,
                                   title: str = "Attribution Statistics",
                                   save_path: Optional[Path] = None) -> plt.Figure:
        """
        Plot statistical analysis of attributions
        """
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Flatten attributions
        attr_flat = attributions.flatten().cpu().numpy()
        
        # Distribution histogram
        axes[0,0].hist(attr_flat, bins=50, alpha=0.7, edgecolor='black')
        axes[0,0].set_title('Attribution Distribution')
        axes[0,0].set_xlabel('Attribution Value')
        axes[0,0].set_ylabel('Frequency')
        axes[0,0].grid(True, alpha=0.3)
        
        # Cumulative distribution
        sorted_attr = np.sort(np.abs(attr_flat))[::-1]
        cumsum = np.cumsum(sorted_attr)
        axes[0,1].plot(cumsum / cumsum[-1])
        axes[0,1].set_title('Cumulative Attribution (Sorted)')
        axes[0,1].set_xlabel('Feature Rank')
        axes[0,1].set_ylabel('Cumulative Attribution')
        axes[0,1].grid(True, alpha=0.3)
        
        # Spatial attribution (sum across channels)
        if attributions.dim() == 4:
            spatial_attr = torch.sum(torch.abs(attributions), dim=(0,1)).cpu().numpy()
        else:
            spatial_attr = torch.sum(torch.abs(attributions), dim=0).cpu().numpy()
        
        im = axes[1,0].imshow(spatial_attr, cmap='viridis')
        axes[1,0].set_title('Spatial Attribution Intensity')
        axes[1,0].axis('off')
        plt.colorbar(im, ax=axes[1,0], fraction=0.046, pad=0.04)
        
        # Statistics table
        stats = {
            'Mean': np.mean(attr_flat),
            'Std': np.std(attr_flat),
            'Min': np.min(attr_flat),
            'Max': np.max(attr_flat),
            'Sum': np.sum(attr_flat),
            'L1 Norm': np.sum(np.abs(attr_flat)),
            'L2 Norm': np.sqrt(np.sum(attr_flat**2))
        }
        
        axes[1,1].axis('off')
        table_data = [[k, f"{v:.6f}"] for k, v in stats.items()]
        table = axes[1,1].table(cellText=table_data, 
                               colLabels=['Statistic', 'Value'],
                               cellLoc='center',
                               loc='center')
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 2)
        axes[1,1].set_title('Attribution Statistics')
        
        plt.suptitle(title, fontsize=16, fontweight='bold')
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"💾 Attribution statistics saved: {save_path}")
        
        return fig



In [17]:
# =============================================================================
# 🚀 MAIN IMPLEMENTATION AND TESTING
# =============================================================================

def test_deeplift_implementation():
    """Test DeepLIFT implementation with simple model"""
    
    print("🚀 TESTING DEEPLIFT IMPLEMENTATION")
    print("=" * 60)
    
    # Create simple test model
    class SimpleTestModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 8, 3, padding=1)
            self.relu1 = nn.ReLU()
            self.pool = nn.AdaptiveAvgPool2d((4, 4))
            self.flatten = nn.Flatten()
            self.fc1 = nn.Linear(8 * 4 * 4, 32)
            self.relu2 = nn.ReLU()
            self.fc2 = nn.Linear(32, 2)
        
        def forward(self, x):
            x = self.relu1(self.conv1(x))
            x = self.pool(x)
            x = self.flatten(x)
            x = self.relu2(self.fc1(x))
            x = self.fc2(x)
            return x
    
    # Initialize model and test data
    model = SimpleTestModel()
    model.eval()
    
    # Create test input
    test_input = torch.randn(1, 3, 32, 32)
    reference = torch.zeros_like(test_input)
    
    print(f"📋 Test Setup:")
    print(f"   Model: SimpleTestModel")
    print(f"   Input shape: {test_input.shape}")
    print(f"   Reference: Zero baseline")
    
    # Initialize DeepLIFT
    explainer = DeepLIFTFromScratch(model, reference)
    
    # Test attribution computation
    print(f"\n🧪 BASIC FUNCTIONALITY TEST:")
    try:
        attributions = explainer.compute_attributions(test_input)
        print(f"✅ Attribution computation successful")
        print(f"   Attribution shape: {attributions.shape}")
        print(f"   Attribution range: [{torch.min(attributions):.6f}, {torch.max(attributions):.6f}]")
        
    except Exception as e:
        print(f"❌ Attribution computation failed: {e}")
        return False, None
    
    # Test validation
    validator = DeepLIFTValidator()
    
    # Conservation test
    conservation_results = validator.test_attribution_conservation(explainer, test_input)
    
    # Sensitivity test
    sensitivity_results = validator.test_sensitivity(explainer, test_input)
    
    # Overall test result
    all_tests_pass = (conservation_results['passes_test'] and 
                     sensitivity_results['passes_test'])
    
    print(f"\n🎯 OVERALL TEST RESULT: {'✅ PASS' if all_tests_pass else '❌ FAIL'}")
    
    return all_tests_pass, attributions



In [18]:
# =============================================================================
# 📋 CONFIGURATION AND SETUP
# =============================================================================

# Run basic test
if __name__ == "__main__":
    print("🧠 DeepLIFT Implementation - Phase 3: Explainability")
    print("=" * 60)
    
    # Test implementation
    test_passed, test_attributions = test_deeplift_implementation()
    
    if test_passed:
        print(f"\n🎉 DeepLIFT implementation ready!")
        print(f"📝 Next steps:")
        print(f"   1. Load your trained CNN and ScatNet models")
        print(f"   2. Apply DeepLIFT to real lung cancer images")
        print(f"   3. Generate attribution maps for both architectures")
        print(f"   4. Compare explainability between models")
        
        # Save test results
        test_results = {
            'deeplift_implementation': {
                'status': 'completed',
                'basic_test': 'passed',
                'attribution_shape': list(test_attributions.shape),
                'ready_for_application': True
            }
        }
        
        print(f"\n💾 Test results saved - ready for real model application!")
    
    else:
        print(f"\n❌ Implementation needs debugging before proceeding")

print("\n" + "="*60)
print("📋 DEEPLIFT IMPLEMENTATION STATUS: COMPLETE")
print("🚀 Ready for application to CNN and ScatNet models!")

🧠 DeepLIFT Implementation - Phase 3: Explainability
🚀 TESTING DEEPLIFT IMPLEMENTATION
📋 Test Setup:
   Model: SimpleTestModel
   Input shape: torch.Size([1, 3, 32, 32])
   Reference: Zero baseline
🧠 DeepLIFT Explainer Initialized
   Model: SimpleTestModel
   Reference: Custom

🧪 BASIC FUNCTIONALITY TEST:
🔍 Computing DeepLIFT attributions...
   Input shape: torch.Size([1, 3, 32, 32])
   Target class: 1
   Output diff: 0.0127
   ↳ Linear fc2: attr shape torch.Size([1, 32])
   ↳ ReLU relu2: attr shape torch.Size([1, 32])
   ↳ Linear fc1: attr shape torch.Size([1, 128])
   ↳ Flatten flatten: unflattened to torch.Size([1, 8, 4, 4])
   ↳ AdaptiveAvgPool2d pool: upsampled to torch.Size([1, 8, 32, 32])
   ↳ ReLU relu1: attr shape torch.Size([1, 8, 32, 32])
   ↳ Conv2D conv1: attr shape torch.Size([1, 8, 32, 32])
✅ Attribution computation complete
   Final attribution shape: torch.Size([1, 8, 32, 32])
   Attribution sum: 0.0000
✅ Attribution computation successful
   Attribution shape: torch.Si