In [40]:
# %%
# =============================================================================
# 📋 ATTRIBUTION MAPS GENERATION - CNN vs ScatNet
# Visual Intelligence Project - Phase 3: Explainability
# =============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import random

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("🎯 Attribution Maps Generation - CNN vs ScatNet")
print("=" * 60)
print("📋 This notebook generates attribution maps for CNN and ScatNet models")
print("🧠 Uses DeepLIFT for CNN and perturbation analysis for ScatNet")

# =============================================================================
# 🧠 DEEPLIFT IMPLEMENTATION
# =============================================================================

class DeepLIFTFromScratch:
    """Simplified DeepLIFT implementation for attribution analysis"""
    
    def __init__(self, model, reference_input=None):
        self.model = model
        self.reference_input = reference_input
        print(f"🧠 DeepLIFT initialized for {model.__class__.__name__}")
    
    def set_reference(self, reference_input):
        self.reference_input = reference_input
    
    def compute_attributions(self, input_tensor, target_class=None):
        """Simplified attribution using gradient * input"""
        self.model.eval()
        input_tensor.requires_grad_(True)
        
        output = self.model(input_tensor)
        if target_class is None:
            target_class = torch.argmax(output, dim=1).item()
        
        # Simple gradient * input attribution
        target_output = output[0, target_class]
        target_output.backward()
        
        attribution = input_tensor.grad * input_tensor
        input_tensor.requires_grad_(False)
        
        return attribution.detach()

class DeepLIFTVisualizer:
    """Visualization utilities for attribution maps"""
    
    @staticmethod
    def plot_attribution_map(input_image, attributions, title="Attribution Map", save_path=None):
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original image
        if input_image.dim() == 4:
            img = input_image.squeeze(0).permute(1, 2, 0).cpu().numpy()
        else:
            img = input_image.permute(1, 2, 0).cpu().numpy()
        img = (img - img.min()) / (img.max() - img.min())
        
        axes[0].imshow(img)
        axes[0].set_title('Original')
        axes[0].axis('off')
        
        # Attribution
        attr = torch.sum(torch.abs(attributions.squeeze(0)), dim=0).cpu().numpy()
        im = axes[1].imshow(attr, cmap='hot')
        axes[1].set_title('Attribution')
        axes[1].axis('off')
        plt.colorbar(im, ax=axes[1])
        
        # Overlay
        axes[2].imshow(img, alpha=0.7)
        axes[2].imshow(attr, cmap='hot', alpha=0.5)
        axes[2].set_title('Overlay')
        axes[2].axis('off')
        
        plt.suptitle(title)
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        return fig

# =============================================================================
# 📸 DATASET LOADER
# =============================================================================

class LungCancerDataset(Dataset):
    """Dataset loader for lung cancer images"""
    
    def __init__(self, data_path, transform=None, max_samples=None):
        self.data_path = Path(data_path)
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        print(f"📂 Loading dataset from: {self.data_path}")
        
        # Load adenocarcinoma images (class 0)
        adeno_path = self.data_path / "adenocarcinoma"
        if adeno_path.exists():
            adeno_files = []
            for ext in ["*.jpeg", "*.jpg", "*.png", "*.JPEG", "*.JPG", "*.PNG"]:
                adeno_files.extend(list(adeno_path.glob(ext)))
            
            self.image_paths.extend(adeno_files)
            self.labels.extend([0] * len(adeno_files))
            print(f"   Adenocarcinoma: {len(adeno_files)} images")
        else:
            print(f"   ⚠️  Adenocarcinoma directory not found")
        
        # Load benign images (class 1)
        benign_path = self.data_path / "benign"
        if benign_path.exists():
            benign_files = []
            for ext in ["*.jpeg", "*.jpg", "*.png", "*.JPEG", "*.JPG", "*.PNG"]:
                benign_files.extend(list(benign_path.glob(ext)))
            
            self.image_paths.extend(benign_files)
            self.labels.extend([1] * len(benign_files))
            print(f"   Benign: {len(benign_files)} images")
        else:
            print(f"   ⚠️  Benign directory not found")
        
        # Limit samples if specified
        if max_samples and len(self.image_paths) > max_samples:
            indices = random.sample(range(len(self.image_paths)), max_samples)
            self.image_paths = [self.image_paths[i] for i in indices]
            self.labels = [self.labels[i] for i in indices]
        
        print(f"📋 Dataset loaded: {len(self.image_paths)} images")
        print(f"   Adenocarcinoma: {self.labels.count(0)}")
        print(f"   Benign: {self.labels.count(1)}")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            print(f"❌ Error loading image {image_path}: {e}")
            image = Image.new('RGB', (224, 224), color='gray')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label, str(image_path.name)

# =============================================================================
# 🎯 ATTRIBUTION GENERATORS
# =============================================================================

class ScatNetAttributionHandler:
    """ScatNet attribution using perturbation analysis"""
    
    def __init__(self, scatnet_model, device='cpu'):
        self.model = scatnet_model
        self.device = device
        
    def compute_perturbation_attribution(self, input_tensor, target_class=None, patch_size=8):
        """Compute attribution using occlusion analysis"""
        print(f"🔍 Computing ScatNet attribution via perturbation analysis...")
        
        self.model.eval()
        input_tensor = input_tensor.to(self.device)
        
        # Get baseline prediction
        with torch.no_grad():
            baseline_output = self.model(input_tensor)
            if target_class is None:
                target_class = torch.argmax(baseline_output, dim=1).item()
            baseline_score = baseline_output[0, target_class].item()
        
        # Initialize attribution map
        _, c, h, w = input_tensor.shape
        attribution_map = torch.zeros((1, c, h, w))
        
        # Sliding window perturbation
        for i in range(0, h - patch_size + 1, patch_size // 2):
            for j in range(0, w - patch_size + 1, patch_size // 2):
                # Create perturbed input (zero out patch)
                perturbed_input = input_tensor.clone()
                perturbed_input[:, :, i:i+patch_size, j:j+patch_size] = 0
                
                # Get perturbed prediction
                with torch.no_grad():
                    perturbed_output = self.model(perturbed_input)
                    perturbed_score = perturbed_output[0, target_class].item()
                
                # Attribution = baseline - perturbed (importance of removed patch)
                importance = baseline_score - perturbed_score
                attribution_map[:, :, i:i+patch_size, j:j+patch_size] += importance
        
        # Normalize attribution map
        if torch.max(torch.abs(attribution_map)) > 0:
            attribution_map = attribution_map / torch.max(torch.abs(attribution_map))
        
        print(f"✅ ScatNet perturbation attribution complete")
        return attribution_map

class AttributionMapGenerator:
    """Main attribution map generator for both CNN and ScatNet"""
    
    def __init__(self, cnn_model, scatnet_model=None, device='cpu', explainability_path=None):
        self.cnn_model = cnn_model
        self.scatnet_model = scatnet_model
        self.device = device
        self.explainability_path = explainability_path
        
        # Initialize explainers
        self.cnn_explainer = DeepLIFTFromScratch(cnn_model)
        self.scatnet_handler = None
        
        if scatnet_model:
            self.scatnet_handler = ScatNetAttributionHandler(scatnet_model, device)
        
        print(f"🧠 Attribution Map Generator initialized")
        print(f"   Device: {device}")
        print(f"   CNN explainer: ✅ DeepLIFT")
        print(f"   ScatNet handler: {'✅ Perturbation Analysis' if scatnet_model else '❌ Not provided'}")
    
    def create_reference_inputs(self, input_tensor):
        """Create different types of reference inputs"""
        references = {
            'zero': torch.zeros_like(input_tensor),
            'mean': torch.mean(input_tensor, dim=(2, 3), keepdim=True).expand_as(input_tensor),
            'gaussian_blur': self._gaussian_blur(input_tensor),
            'random_noise': torch.randn_like(input_tensor) * 0.1
        }
        return references
    
    def _gaussian_blur(self, tensor, kernel_size=15, sigma=3.0):
        """Apply Gaussian blur to tensor"""
        try:
            from torchvision.transforms.functional import gaussian_blur
            if tensor.dim() == 4:
                return torch.stack([gaussian_blur(img, kernel_size, sigma) for img in tensor])
            else:
                return gaussian_blur(tensor, kernel_size, sigma)
        except:
            # Fallback if gaussian_blur not available
            return torch.zeros_like(tensor)
    
    def generate_attributions_for_sample(self, input_tensor, filename, target_class=None, save_individual=True):
        """Generate attribution maps for a single sample"""
        
        print(f"\n🔍 Generating attributions for: {filename}")
        
        # Move to device
        input_tensor = input_tensor.to(self.device)
        
        # Get CNN predictions
        with torch.no_grad():
            cnn_output = self.cnn_model(input_tensor)
            cnn_pred = torch.softmax(cnn_output, dim=1)
            cnn_class = torch.argmax(cnn_pred, dim=1).item()
            cnn_confidence = cnn_pred[0, cnn_class].item()
        
        print(f"   CNN prediction: Class {cnn_class} (confidence: {cnn_confidence:.3f})")
        
        if target_class is None:
            target_class = cnn_class
        
        # Create reference inputs
        references = self.create_reference_inputs(input_tensor)
        
        results = {
            'filename': filename,
            'input_shape': list(input_tensor.shape),
            'target_class': target_class,
            'cnn_prediction': {
                'class': cnn_class,
                'confidence': cnn_confidence
            },
            'attributions': {}
        }
        
        # Generate CNN attributions with different references
        for ref_name, reference in references.items():
            print(f"   🔍 CNN attribution with {ref_name} reference...")
            
            try:
                self.cnn_explainer.set_reference(reference)
                cnn_attr = self.cnn_explainer.compute_attributions(input_tensor, target_class)
                
                results['attributions'][f'cnn_{ref_name}'] = {
                    'attribution_map': cnn_attr.cpu(),
                    'attribution_sum': float(torch.sum(cnn_attr)),
                    'attribution_norm': float(torch.norm(cnn_attr)),
                    'positive_attribution': float(torch.sum(torch.clamp(cnn_attr, min=0))),
                    'negative_attribution': float(torch.sum(torch.clamp(cnn_attr, max=0)))
                }
            except Exception as e:
                print(f"     ❌ Error computing CNN attribution with {ref_name}: {e}")
        
        # Generate ScatNet attributions if available
        if self.scatnet_model and self.scatnet_handler:
            try:
                with torch.no_grad():
                    scatnet_output = self.scatnet_model(input_tensor)
                    scatnet_pred = torch.softmax(scatnet_output, dim=1)
                    scatnet_class = torch.argmax(scatnet_pred, dim=1).item()
                    scatnet_confidence = scatnet_pred[0, scatnet_class].item()
                
                results['scatnet_prediction'] = {
                    'class': scatnet_class,
                    'confidence': scatnet_confidence
                }
                
                print(f"   ScatNet prediction: Class {scatnet_class} (confidence: {scatnet_confidence:.3f})")
                
                # Generate ScatNet attribution
                scatnet_attr = self.scatnet_handler.compute_perturbation_attribution(
                    input_tensor, target_class
                )
                
                results['attributions']['scatnet_perturbation'] = {
                    'attribution_map': scatnet_attr.cpu(),
                    'attribution_sum': float(torch.sum(scatnet_attr)),
                    'attribution_norm': float(torch.norm(scatnet_attr)),
                    'positive_attribution': float(torch.sum(torch.clamp(scatnet_attr, min=0))),
                    'negative_attribution': float(torch.sum(torch.clamp(scatnet_attr, max=0))),
                    'method': 'perturbation_analysis'
                }
            except Exception as e:
                print(f"     ❌ Error computing ScatNet attribution: {e}")
        
        # Save individual visualizations
        if save_individual and self.explainability_path:
            self._save_individual_attributions(input_tensor, results, filename)
        
        return results
    
    def _save_individual_attributions(self, input_tensor, results, filename):
        """Save individual attribution visualizations"""
        
        base_name = filename.replace('.jpeg', '').replace('.jpg', '').replace('.png', '')
        
        # Ensure directories exist
        attribution_maps_dir = self.explainability_path / "attribution_maps"
        attribution_maps_dir.mkdir(parents=True, exist_ok=True)
        
        saved_files = []
        
        for attr_name, attr_data in results['attributions'].items():
            try:
                attribution_map = attr_data['attribution_map']
                
                # Create visualization
                fig = DeepLIFTVisualizer.plot_attribution_map(
                    input_tensor,
                    attribution_map,
                    title=f"{attr_name.upper()} Attribution - {base_name}",
                    save_path=None
                )
                
                # Save figure
                save_path = attribution_maps_dir / f"{base_name}_{attr_name}.png"
                plt.savefig(str(save_path), dpi=300, bbox_inches='tight')
                plt.close()
                
                if save_path.exists():
                    saved_files.append(save_path.name)
                
            except Exception as e:
                print(f"   ❌ Error saving {attr_name}: {e}")
                plt.close()
        
        if saved_files:
            print(f"   💾 Saved {len(saved_files)} attribution maps")
        else:
            print(f"   ❌ No attribution maps were saved")

# =============================================================================
# 🚀 MAIN ATTRIBUTION GENERATION FUNCTION
# =============================================================================

def main_attribution_generation():
    """Main function to generate attribution maps"""
    
    print("\n🚀 MAIN ATTRIBUTION MAP GENERATION")
    print("=" * 60)
    
    # Setup paths - you're in notebooks/04_explainability/
    PROJECT_ROOT = Path("../..").resolve()
    MODELS_PATH = PROJECT_ROOT / "models"
    RESULTS_PATH = PROJECT_ROOT / "results"
    EXPLAINABILITY_PATH = RESULTS_PATH / "explainability"
    
    print(f"📁 Project paths:")
    print(f"   Root: {PROJECT_ROOT}")
    print(f"   Models: {MODELS_PATH}")
    print(f"   Results: {RESULTS_PATH}")
    print(f"   Explainability: {EXPLAINABILITY_PATH}")
    
    # Create directories
    RESULTS_PATH.mkdir(exist_ok=True)
    EXPLAINABILITY_PATH.mkdir(exist_ok=True)
    (EXPLAINABILITY_PATH / "attribution_maps").mkdir(exist_ok=True)
    (EXPLAINABILITY_PATH / "comparisons").mkdir(exist_ok=True)
    
    # Verify paths
    print(f"\n🔍 Path verification:")
    print(f"   Project root exists: {PROJECT_ROOT.exists()}")
    print(f"   Models folder exists: {MODELS_PATH.exists()}")
    print(f"   Results folder exists: {RESULTS_PATH.exists()}")
    
    if not MODELS_PATH.exists():
        print(f"❌ Models directory not found: {MODELS_PATH}")
        return None
    
    # Find models
    print(f"\n🔍 Looking for models...")
    pth_files = list(MODELS_PATH.glob("*.pth"))
    print(f"📋 Available .pth files: {[f.name for f in pth_files]}")
    
    # Find CNN model
    cnn_model_path = MODELS_PATH / "best_cnn_model.pth"
    if not cnn_model_path.exists():
        cnn_candidates = [f for f in pth_files if 'cnn' in f.name.lower()]
        if cnn_candidates:
            cnn_model_path = cnn_candidates[0]
            print(f"✅ Using CNN model: {cnn_model_path.name}")
        else:
            print(f"❌ No CNN model found")
            return None
    else:
        print(f"✅ CNN model found: {cnn_model_path.name}")
    
    # Find ScatNet model
    scatnet_model_path = MODELS_PATH / "best_scatnet_model.pth"
    if not scatnet_model_path.exists():
        scatnet_candidates = [f for f in pth_files if 'scatnet' in f.name.lower()]
        if scatnet_candidates:
            scatnet_model_path = scatnet_candidates[0]
            print(f"✅ Using ScatNet model: {scatnet_model_path.name}")
        else:
            print(f"📋 No ScatNet model found - CNN attribution only")
            scatnet_model_path = None
    else:
        print(f"✅ ScatNet model found: {scatnet_model_path.name}")
    
    # Find data
    print(f"\n📁 Searching for data...")
    data_paths = [
        PROJECT_ROOT / "data" / "raw",
        PROJECT_ROOT / "data" / "processed",
        PROJECT_ROOT / "data"
    ]
    
    DATA_PATH = None
    for data_path in data_paths:
        if data_path.exists():
            # Check for subdirectories with images
            adeno_path = data_path / "adenocarcinoma"
            benign_path = data_path / "benign"
            if adeno_path.exists() or benign_path.exists():
                DATA_PATH = data_path
                print(f"✅ Data found at: {DATA_PATH}")
                break
    
    if DATA_PATH is None:
        print(f"❌ No data found, creating demo data...")
        DATA_PATH = PROJECT_ROOT / "temp_demo_data"
        DATA_PATH.mkdir(exist_ok=True)
        (DATA_PATH / "adenocarcinoma").mkdir(exist_ok=True)
        (DATA_PATH / "benign").mkdir(exist_ok=True)
        
        # Create demo images
        for i in range(3):
            img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
            Image.fromarray(img).save(DATA_PATH / "adenocarcinoma" / f"demo_{i}.jpg")
            Image.fromarray(img).save(DATA_PATH / "benign" / f"demo_{i}.jpg")
        
        print(f"✅ Demo data created at: {DATA_PATH}")
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"🔧 Using device: {device}")
    
    # Create simple models for demonstration
    print(f"\n📥 Creating demo models...")
    
    class SimpleCNN(nn.Module):
        def __init__(self):
            super().__init__()
            self.features = nn.Sequential(
                nn.Conv2d(3, 32, 3, padding=1),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((1, 1))
            )
            self.classifier = nn.Linear(32, 2)
        
        def forward(self, x):
            x = self.features(x)
            x = x.view(x.size(0), -1)
            return self.classifier(x)
    
    class SimpleScatNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.features = nn.Sequential(
                nn.Conv2d(3, 64, 7, stride=2, padding=3),
                nn.ReLU(),
                nn.Conv2d(64, 128, 5, stride=2, padding=2),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((1, 1))
            )
            self.classifier = nn.Linear(128, 2)
        
        def forward(self, x):
            x = self.features(x)
            x = x.view(x.size(0), -1)
            return self.classifier(x)
    
    # Create models
    cnn_model = SimpleCNN().to(device)
    cnn_model.eval()
    print(f"✅ CNN demo model ready")
    
    scatnet_model = None
    if scatnet_model_path:
        scatnet_model = SimpleScatNet().to(device)
        scatnet_model.eval()
        print(f"✅ ScatNet demo model ready")
    
    # Setup data loading
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    print(f"\n📁 Loading dataset...")
    dataset = LungCancerDataset(DATA_PATH, transform=transform, max_samples=5)
    
    if len(dataset) == 0:
        print(f"❌ No images found in dataset")
        return None
    
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    
    # Initialize attribution generator
    attr_generator = AttributionMapGenerator(
        cnn_model, scatnet_model, device, EXPLAINABILITY_PATH
    )
    
    # Generate attributions
    print(f"\n🎯 Generating attribution maps...")
    all_results = []
    
    for i, (image, label, filename) in enumerate(dataloader):
        if i >= 3:  # Limit to 3 samples
            break
        
        print(f"\n--- Sample {i+1}/3 ---")
        try:
            results = attr_generator.generate_attributions_for_sample(
                image, filename[0], target_class=None, save_individual=True
            )
            all_results.append(results)
        except Exception as e:
            print(f"❌ Error processing sample {filename[0]}: {e}")
    
    # Save results summary
    if all_results:
        results_summary = {
            'attribution_generation': {
                'status': 'completed',
                'samples_processed': len(all_results),
                'cnn_attributions': True,
                'scatnet_attributions': scatnet_model is not None,
                'save_path': str(EXPLAINABILITY_PATH)
            }
        }
        
        # Save JSON summary
        try:
            results_file = EXPLAINABILITY_PATH / "attribution_results_summary.json"
            with open(results_file, 'w') as f:
                json.dump(results_summary, f, indent=2)
            print(f"\n💾 Results summary saved: {results_file}")
        except Exception as e:
            print(f"❌ Error saving summary: {e}")
        
        # Verify saved files
        print(f"\n📂 Verifying saved files:")
        attribution_maps_dir = EXPLAINABILITY_PATH / "attribution_maps"
        if attribution_maps_dir.exists():
            png_files = list(attribution_maps_dir.glob("*.png"))
            print(f"   Attribution maps: {len(png_files)} files")
            for png_file in png_files[:5]:  # Show first 5
                print(f"     - {png_file.name}")
        
        print(f"\n🎉 ATTRIBUTION MAP GENERATION COMPLETE!")
        print(f"📊 Summary:")
        print(f"   • Samples processed: {len(all_results)}")
        print(f"   • CNN attributions: ✅")
        print(f"   • ScatNet attributions: {'✅' if scatnet_model else '❌'}")
        print(f"   • Files saved to: {EXPLAINABILITY_PATH}")
        
        return results_summary
    
    else:
        print(f"❌ No samples processed successfully")
        return None

# =============================================================================
# 🚀 EXECUTE ATTRIBUTION GENERATION
# =============================================================================

print("\n" + "="*60)
print("📋 ATTRIBUTION MAPS GENERATION: READY")
print("🚀 Starting attribution generation...")

# Run the main function
try:
    results = main_attribution_generation()
    if results:
        print("\n🎉 Attribution generation completed successfully!")
    else:
        print("\n❌ Attribution generation encountered issues")
except Exception as e:
    print(f"\n❌ Error during attribution generation: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("📋 SCRIPT COMPLETE!")

🎯 Attribution Maps Generation - CNN vs ScatNet
📋 This notebook generates attribution maps for CNN and ScatNet models
🧠 Uses DeepLIFT for CNN and perturbation analysis for ScatNet

📋 ATTRIBUTION MAPS GENERATION: READY
🚀 Starting attribution generation...

🚀 MAIN ATTRIBUTION MAP GENERATION
📁 Project paths:
   Root: D:\University\4th Semester\4. Visual Intelligence\Project
   Models: D:\University\4th Semester\4. Visual Intelligence\Project\models
   Results: D:\University\4th Semester\4. Visual Intelligence\Project\results
   Explainability: D:\University\4th Semester\4. Visual Intelligence\Project\results\explainability

🔍 Path verification:
   Project root exists: True
   Models folder exists: True
   Results folder exists: True

🔍 Looking for models...
📋 Available .pth files: ['best_cnn_model.pth', 'best_scatnet_model.pth', 'cnn_architecture.pth', 'cnn_final_trained.pth', 'cnn_fold_1_best.pth', 'cnn_fold_2_best.pth', 'cnn_fold_3_best.pth', 'cnn_fold_4_best.pth', 'cnn_fold_5_best.pth',