# Grad-CAM Visualization for Model Interpretability

This notebook demonstrates explainable AI techniques for understanding what regions the trained models focus on when making cancer detection decisions:
- Grad-CAM implementation and visualization
- Attention analysis across different models
- Clinical interpretation of model focus
- Error case analysis with visualizations

**Authors:** Sneh Gupta and Arpit Bhardwaj  
**Course:** CSET211 - Statistical Machine Learning

In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2
from pathlib import Path
import warnings

# Add src to path
sys.path.append('../src')

from gradcam import GradCAM, GradCAMVisualizer
from models import get_model
from data_loader import get_transforms
from utils import load_config

# Configuration
warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8')

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Paths
CHECKPOINTS_DIR = '../experiments/checkpoints'
SAMPLE_IMAGES_DIR = '../data/raw/images'
RESULTS_DIR = '../experiments/gradcam_results'

# Create results directory
os.makedirs(RESULTS_DIR, exist_ok=True)
print(f"Results will be saved to: {RESULTS_DIR}")

## 1. Load Trained Model

In [None]:
# Configuration for model loading
config = {
    'model': {
        'architecture': 'resnet50',  # Change this to match your trained model
        'num_classes': 1,
        'pretrained': True,
        'dropout': 0.5
    }
}

# Load model architecture
model = get_model(config)
model = model.to(device)

# Try to load trained weights
checkpoint_path = os.path.join(CHECKPOINTS_DIR, 'best_model.pth')

if os.path.exists(checkpoint_path):
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        print(f"‚úì Loaded trained model from {checkpoint_path}")
        print(f"Model trained for {checkpoint.get('epoch', 'unknown')} epochs")
        print(f"Best validation AUC: {checkpoint.get('val_auc', 'unknown'):.4f}")
        trained_model_available = True
    except Exception as e:
        print(f"‚úó Error loading checkpoint: {e}")
        print("Using randomly initialized model for demonstration")
        model.eval()
        trained_model_available = False
else:
    print(f"‚úó Checkpoint not found at {checkpoint_path}")
    print("Using randomly initialized model for demonstration")
    model.eval()
    trained_model_available = False

print(f"\nModel architecture: {config['model']['architecture']}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

## 2. Grad-CAM Setup and Testing

In [None]:
# Initialize Grad-CAM visualizer
try:
    gradcam_viz = GradCAMVisualizer(model, device)
    print("‚úì Grad-CAM visualizer initialized successfully")
    
    # Test with dummy input
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    
    with torch.no_grad():
        output = model(dummy_input)
        prediction = torch.sigmoid(output).item()
    
    print(f"‚úì Model forward pass successful")
    print(f"Dummy input prediction: {prediction:.4f}")
    
    # Test Grad-CAM generation
    try:
        heatmap = gradcam_viz.gradcam.generate_gradcam(dummy_input)
        print(f"‚úì Grad-CAM heatmap generated: {heatmap.shape}")
        print(f"Heatmap range: [{np.min(heatmap):.3f}, {np.max(heatmap):.3f}]")
        gradcam_working = True
    except Exception as e:
        print(f"‚úó Grad-CAM generation failed: {e}")
        gradcam_working = False
        
except Exception as e:
    print(f"‚úó Error initializing Grad-CAM: {e}")
    gradcam_working = False

## 3. Load Sample Images

In [None]:
# Find sample images for visualization
sample_images = []

if os.path.exists(SAMPLE_IMAGES_DIR):
    # Look for image files
    image_extensions = ['.jpg', '.jpeg', '.png', '.dcm', '.dicom']
    
    for ext in image_extensions:
        pattern = f"*{ext}"
        found_images = list(Path(SAMPLE_IMAGES_DIR).glob(pattern))
        sample_images.extend([str(p) for p in found_images[:5]])  # Limit to 5 per extension
        
        if len(sample_images) >= 10:  # Limit total samples
            break
    
    sample_images = sample_images[:10]  # Final limit
    print(f"Found {len(sample_images)} sample images")
    
    if sample_images:
        for i, img_path in enumerate(sample_images[:5]):
            print(f"  {i+1}. {os.path.basename(img_path)}")
else:
    print(f"Sample images directory not found: {SAMPLE_IMAGES_DIR}")
    print("Creating synthetic sample images for demonstration...")
    
    # Create synthetic chest X-ray-like images
    sample_images = []
    synthetic_dir = os.path.join(RESULTS_DIR, 'synthetic_samples')
    os.makedirs(synthetic_dir, exist_ok=True)
    
    for i in range(3):
        # Create synthetic X-ray-like image
        img = np.random.rand(256, 256) * 0.3 + 0.2  # Dark background
        
        # Add chest-like structures
        center_x, center_y = 128, 128
        y, x = np.ogrid[:256, :256]
        
        # Lung regions (brighter)
        left_lung = ((x - 80)**2 + (y - center_y)**2) < 3000
        right_lung = ((x - 176)**2 + (y - center_y)**2) < 3000
        
        img[left_lung] += 0.4
        img[right_lung] += 0.4
        
        # Add some noise
        img += np.random.normal(0, 0.1, img.shape)
        img = np.clip(img, 0, 1)
        
        # Convert to PIL and save
        img_pil = Image.fromarray((img * 255).astype(np.uint8), mode='L')
        img_path = os.path.join(synthetic_dir, f'synthetic_xray_{i+1}.png')
        img_pil.save(img_path)
        sample_images.append(img_path)
    
    print(f"Created {len(sample_images)} synthetic sample images")

print(f"\nTotal sample images available: {len(sample_images)}")

## 4. Single Image Grad-CAM Analysis

In [None]:
def analyze_single_image_detailed(image_path, gradcam_viz, save_prefix=None):
    """Detailed analysis of a single image with Grad-CAM"""
    
    try:
        # Load and preprocess image
        original_image = Image.open(image_path).convert('RGB')
        input_tensor = gradcam_viz.preprocess(original_image).unsqueeze(0).to(device)
        
        # Get model prediction
        with torch.no_grad():
            output = model(input_tensor)
            prediction = torch.sigmoid(output).item()
        
        print(f"\nAnalyzing: {os.path.basename(image_path)}")
        print(f"Prediction: {prediction:.4f} ({'Cancer' if prediction > 0.5 else 'No Cancer'})")
        print(f"Confidence: {max(prediction, 1-prediction):.4f}")
        
        # Generate Grad-CAM
        heatmap = gradcam_viz.gradcam.generate_gradcam(input_tensor)
        
        # Analyze attention patterns
        analysis = gradcam_viz.analyze_attention_patterns(heatmap)
        
        # Create visualization
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Original image
        axes[0, 0].imshow(original_image)
        axes[0, 0].set_title(f'Original Image\n{os.path.basename(image_path)}')
        axes[0, 0].axis('off')
        
        # Grad-CAM heatmap
        im = axes[0, 1].imshow(heatmap, cmap='jet')
        axes[0, 1].set_title('Grad-CAM Heatmap')
        axes[0, 1].axis('off')
        plt.colorbar(im, ax=axes[0, 1])
        
        # Superimposed image
        superimposed = gradcam_viz.gradcam.superimpose_heatmap(
            np.array(original_image), heatmap, alpha=0.4
        )
        axes[1, 0].imshow(superimposed)
        axes[1, 0].set_title(f'Superimposed\nPrediction: {prediction:.3f}')
        axes[1, 0].axis('off')
        
        # Analysis summary
        analysis_text = f"""
Prediction Analysis:
Probability: {prediction:.4f}
Class: {'Cancer' if prediction > 0.5 else 'No Cancer'}
Confidence: {max(prediction, 1-prediction):.4f}

Attention Statistics:
Max Activation: {analysis['max_activation']:.3f}
Mean Activation: {analysis['mean_activation']:.3f}
Focus Percentage: {analysis['focus_percentage']:.1f}%
Attention Spread: {analysis['attention_spread']:.3f}

Clinical Notes:
- {'High' if prediction > 0.7 else 'Moderate' if prediction > 0.3 else 'Low'} confidence prediction
- {'Concentrated' if analysis['focus_percentage'] > 15 else 'Distributed'} attention pattern
- Model focus on {'specific regions' if analysis['attention_spread'] < 0.3 else 'broad areas'}
        """
        
        axes[1, 1].text(0.05, 0.95, analysis_text.strip(), transform=axes[1, 1].transAxes,
                        verticalalignment='top', fontfamily='monospace', fontsize=10,
                        bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
        axes[1, 1].set_xlim(0, 1)
        axes[1, 1].set_ylim(0, 1)
        axes[1, 1].axis('off')
        axes[1, 1].set_title('Analysis Summary')
        
        plt.tight_layout()
        
        # Save if requested
        if save_prefix:
            save_path = os.path.join(RESULTS_DIR, f'{save_prefix}_{os.path.basename(image_path)}.png')
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Saved analysis to: {save_path}")
        
        plt.show()
        
        return {
            'image_path': image_path,
            'prediction': prediction,
            'heatmap': heatmap,
            'analysis': analysis
        }
        
    except Exception as e:
        print(f"Error analyzing {os.path.basename(image_path)}: {e}")
        return None

# Analyze first sample image
if sample_images and gradcam_working:
    result = analyze_single_image_detailed(sample_images[0], gradcam_viz, save_prefix='single_analysis')
else:
    print("Cannot perform single image analysis - no images or Grad-CAM not working")
    result = None

## 5. Batch Grad-CAM Analysis

In [None]:
def analyze_image_batch(image_paths, gradcam_viz, max_images=6):
    """Analyze multiple images with Grad-CAM"""
    
    results = []
    valid_images = []
    
    print(f"Analyzing batch of {min(len(image_paths), max_images)} images...")
    
    for i, image_path in enumerate(image_paths[:max_images]):
        try:
            # Load and preprocess image
            original_image = Image.open(image_path).convert('RGB')
            input_tensor = gradcam_viz.preprocess(original_image).unsqueeze(0).to(device)
            
            # Get prediction
            with torch.no_grad():
                output = model(input_tensor)
                prediction = torch.sigmoid(output).item()
            
            # Generate Grad-CAM
            heatmap = gradcam_viz.gradcam.generate_gradcam(input_tensor)
            
            results.append({
                'image_path': image_path,
                'image': original_image,
                'prediction': prediction,
                'heatmap': heatmap
            })
            valid_images.append(image_path)
            
            print(f"  ‚úì {os.path.basename(image_path)}: {prediction:.3f}")
            
        except Exception as e:
            print(f"  ‚úó Error processing {os.path.basename(image_path)}: {e}")
    
    # Create visualization grid
    if results:
        n_images = len(results)
        cols = min(3, n_images)
        rows = (n_images + cols - 1) // cols
        
        fig, axes = plt.subplots(rows * 3, cols, figsize=(5*cols, 4*rows*3))
        if n_images == 1:
            axes = axes.reshape(-1, 1)
        elif rows == 1:
            axes = axes.reshape(3, -1)
        
        for idx, result in enumerate(results):
            col = idx % cols
            
            # Original image
            axes[0, col].imshow(result['image'])
            axes[0, col].set_title(f"Original\n{os.path.basename(result['image_path'])}")
            axes[0, col].axis('off')
            
            # Grad-CAM heatmap
            im = axes[1, col].imshow(result['heatmap'], cmap='jet')
            axes[1, col].set_title('Grad-CAM')
            axes[1, col].axis('off')
            
            # Superimposed
            superimposed = gradcam_viz.gradcam.superimpose_heatmap(
                np.array(result['image']), result['heatmap'], alpha=0.4
            )
            axes[2, col].imshow(superimposed)
            axes[2, col].set_title(f'Overlay\nPred: {result["prediction"]:.3f}')
            axes[2, col].axis('off')
        
        # Remove empty subplots
        for idx in range(n_images, rows * cols):
            col = idx % cols
            for row_offset in range(3):
                axes[row_offset, col].remove()
        
        plt.tight_layout()
        plt.suptitle('Batch Grad-CAM Analysis', fontsize=16, y=1.02)
        
        # Save batch analysis
        save_path = os.path.join(RESULTS_DIR, 'batch_gradcam_analysis.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nBatch analysis saved to: {save_path}")
        
        plt.show()
    
    return results

# Analyze batch of images
if sample_images and gradcam_working and len(sample_images) > 1:
    batch_results = analyze_image_batch(sample_images, gradcam_viz, max_images=6)
else:
    print("Cannot perform batch analysis - insufficient images or Grad-CAM not working")
    batch_results = []

## 6. Attention Pattern Analysis

In [None]:
def analyze_attention_patterns(results):
    """Analyze attention patterns across multiple images"""
    
    if not results:
        print("No results available for attention analysis")
        return
    
    print(f"Analyzing attention patterns across {len(results)} images...")
    
    # Extract attention statistics
    predictions = [r['prediction'] for r in results]
    heatmaps = [r['heatmap'] for r in results]
    
    # Calculate attention metrics for each image
    attention_stats = []
    for i, (pred, heatmap) in enumerate(zip(predictions, heatmaps)):
        stats = {
            'image_idx': i,
            'prediction': pred,
            'max_attention': np.max(heatmap),
            'mean_attention': np.mean(heatmap),
            'std_attention': np.std(heatmap),
            'focused_pixels': np.sum(heatmap > np.percentile(heatmap.flatten(), 90)),
            'focus_percentage': np.sum(heatmap > np.percentile(heatmap.flatten(), 90)) / heatmap.size * 100
        }
        attention_stats.append(stats)
    
    # Create comprehensive analysis plot
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # 1. Prediction vs Max Attention
    preds = [s['prediction'] for s in attention_stats]
    max_attns = [s['max_attention'] for s in attention_stats]
    axes[0, 0].scatter(preds, max_attns, alpha=0.7, s=100)
    axes[0, 0].set_xlabel('Prediction Probability')
    axes[0, 0].set_ylabel('Max Attention')
    axes[0, 0].set_title('Prediction vs Max Attention')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Add correlation coefficient
    if len(preds) > 1:
        corr = np.corrcoef(preds, max_attns)[0, 1]
        axes[0, 0].text(0.05, 0.95, f'Corr: {corr:.3f}', transform=axes[0, 0].transAxes, 
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 2. Prediction vs Mean Attention
    mean_attns = [s['mean_attention'] for s in attention_stats]
    axes[0, 1].scatter(preds, mean_attns, alpha=0.7, s=100, color='orange')
    axes[0, 1].set_xlabel('Prediction Probability')
    axes[0, 1].set_ylabel('Mean Attention')
    axes[0, 1].set_title('Prediction vs Mean Attention')
    axes[0, 1].grid(True, alpha=0.3)
    
    if len(preds) > 1:
        corr = np.corrcoef(preds, mean_attns)[0, 1]
        axes[0, 1].text(0.05, 0.95, f'Corr: {corr:.3f}', transform=axes[0, 1].transAxes,
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 3. Focus Percentage vs Prediction
    focus_pcts = [s['focus_percentage'] for s in attention_stats]
    axes[0, 2].scatter(preds, focus_pcts, alpha=0.7, s=100, color='green')
    axes[0, 2].set_xlabel('Prediction Probability')
    axes[0, 2].set_ylabel('Focus Percentage (%)')
    axes[0, 2].set_title('Prediction vs Focus Concentration')
    axes[0, 2].grid(True, alpha=0.3)
    
    if len(preds) > 1:
        corr = np.corrcoef(preds, focus_pcts)[0, 1]
        axes[0, 2].text(0.05, 0.95, f'Corr: {corr:.3f}', transform=axes[0, 2].transAxes,
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 4. Attention Distribution
    all_attentions = np.concatenate([h.flatten() for h in heatmaps])
    axes[1, 0].hist(all_attentions, bins=50, alpha=0.7, color='purple', edgecolor='black')
    axes[1, 0].set_xlabel('Attention Value')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title('Overall Attention Distribution')
    axes[1, 0].axvline(np.mean(all_attentions), color='red', linestyle='--', label='Mean')
    axes[1, 0].legend()
    
    # 5. Attention Variability
    std_attns = [s['std_attention'] for s in attention_stats]
    axes[1, 1].scatter(mean_attns, std_attns, alpha=0.7, s=100, color='red')
    axes[1, 1].set_xlabel('Mean Attention')
    axes[1, 1].set_ylabel('Attention Std Dev')
    axes[1, 1].set_title('Attention Mean vs Variability')
    axes[1, 1].grid(True, alpha=0.3)
    
    # 6. Summary Statistics
    stats_text = f"""
Attention Analysis Summary:

Images analyzed: {len(results)}

Prediction Range:
  Min: {min(preds):.3f}
  Max: {max(preds):.3f}
  Mean: {np.mean(preds):.3f}

Attention Statistics:
  Mean Max Attention: {np.mean(max_attns):.3f}
  Mean Focus %: {np.mean(focus_pcts):.1f}%
  
High Confidence Predictions: {sum(1 for p in preds if abs(p-0.5) > 0.3)}
Low Confidence Predictions: {sum(1 for p in preds if abs(p-0.5) <= 0.3)}
    """
    
    axes[1, 2].text(0.05, 0.95, stats_text.strip(), transform=axes[1, 2].transAxes, 
                    verticalalignment='top', fontfamily='monospace', fontsize=10,
                    bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
    axes[1, 2].set_xlim(0, 1)
    axes[1, 2].set_ylim(0, 1)
    axes[1, 2].axis('off')
    axes[1, 2].set_title('Summary Statistics')
    
    plt.tight_layout()
    plt.suptitle('Attention Pattern Analysis', fontsize=16, y=1.02)
    
    # Save analysis
    save_path = os.path.join(RESULTS_DIR, 'attention_pattern_analysis.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Attention analysis saved to: {save_path}")
    
    plt.show()
    
    return attention_stats

# Perform attention analysis
if batch_results:
    attention_stats = analyze_attention_patterns(batch_results)
else:
    print("No batch results available for attention pattern analysis")
    attention_stats = []

## 7. Clinical Interpretation Guidelines

In [None]:
def generate_clinical_interpretation(prediction, attention_stats):
    """Generate clinical interpretation based on prediction and attention patterns"""
    
    interpretation = []
    recommendations = []
    confidence_level = "Unknown"
    
    # Prediction interpretation
    if prediction >= 0.8:
        interpretation.append("HIGH probability of malignant findings detected.")
        confidence_level = "High"
        recommendations.extend([
            "Immediate radiologist review recommended",
            "Consider additional imaging (CT scan) if clinically indicated",
            "Follow-up with oncology consultation"
        ])
    elif prediction >= 0.6:
        interpretation.append("MODERATE probability of suspicious findings.")
        confidence_level = "Moderate"
        recommendations.extend([
            "Radiologist review within 24-48 hours",
            "Consider repeat imaging in 3-6 months",
            "Correlate with clinical symptoms"
        ])
    elif prediction >= 0.4:
        interpretation.append("LOW to MODERATE probability of abnormal findings.")
        confidence_level = "Moderate"
        recommendations.extend([
            "Routine radiologist review",
            "Standard follow-up protocols",
            "Consider patient history and symptoms"
        ])
    else:
        interpretation.append("LOW probability of malignant findings.")
        confidence_level = "High" if prediction < 0.2 else "Moderate"
        recommendations.extend([
            "Routine radiologist review",
            "Standard follow-up protocols",
            "Continue regular screening as appropriate"
        ])
    
    # Attention pattern interpretation
    if attention_stats:
        focus_pct = attention_stats.get('focus_percentage', 0)
        max_attention = attention_stats.get('max_attention', 0)
        
        if focus_pct > 15:
            interpretation.append("Model attention is HIGHLY CONCENTRATED on specific regions.")
            interpretation.append("This suggests the presence of localized abnormalities.")
        elif focus_pct > 5:
            interpretation.append("Model attention is MODERATELY FOCUSED on specific regions.")
        else:
            interpretation.append("Model attention is DIFFUSE across the image.")
            interpretation.append("No specific regions of high concern identified.")
        
        if max_attention > 0.8:
            interpretation.append("Very strong activation detected in attention regions.")
        elif max_attention > 0.5:
            interpretation.append("Moderate activation detected in attention regions.")
    
    # Model confidence assessment
    model_confidence = max(prediction, 1 - prediction)
    if model_confidence < 0.7:
        recommendations.append("LOW MODEL CONFIDENCE - Prioritize human expert review")
        interpretation.append("Model shows uncertainty in this case.")
    
    # Always add disclaimer
    recommendations.append("IMPORTANT: This AI system is a diagnostic aid only - clinical judgment should always take precedence")
    
    return {
        'interpretation': interpretation,
        'recommendations': recommendations,
        'confidence_level': confidence_level,
        'model_confidence': model_confidence
    }

def display_clinical_report(image_path, prediction, attention_stats=None):
    """Display a clinical interpretation report"""
    
    clinical_info = generate_clinical_interpretation(prediction, attention_stats)
    
    print("=" * 70)
    print("                    CLINICAL INTERPRETATION REPORT")
    print("=" * 70)
    print(f"Image: {os.path.basename(image_path)}")
    print(f"Analysis Date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Model Architecture: {config['model']['architecture']}")
    print("-" * 70)
    
    print("PREDICTION RESULTS:")
    print(f"  Probability Score: {prediction:.4f}")
    print(f"  Binary Prediction: {'CANCER' if prediction > 0.5 else 'NO CANCER'}")
    print(f"  Model Confidence: {clinical_info['model_confidence']:.4f} ({clinical_info['confidence_level']})")
    
    if attention_stats:
        print("\nATTENTION ANALYSIS:")
        print(f"  Focus Percentage: {attention_stats.get('focus_percentage', 0):.1f}%")
        print(f"  Max Attention: {attention_stats.get('max_attention', 0):.3f}")
        print(f"  Mean Attention: {attention_stats.get('mean_attention', 0):.3f}")
    
    print("\nCLINICAL INTERPRETATION:")
    for i, interp in enumerate(clinical_info['interpretation'], 1):
        print(f"  {i}. {interp}")
    
    print("\nRECOMMENDATIONS:")
    for i, rec in enumerate(clinical_info['recommendations'], 1):
        print(f"  {i}. {rec}")
    
    print("=" * 70)
    
    return clinical_info

# Generate clinical report for first analyzed image
if batch_results:
    first_result = batch_results[0]
    first_attention = attention_stats[0] if attention_stats else None
    
    clinical_report = display_clinical_report(
        first_result['image_path'],
        first_result['prediction'],
        first_attention
    )
else:
    print("No results available for clinical interpretation")

## 8. Summary and Insights

In [None]:
print("Grad-CAM Visualization Analysis Summary")
print("=" * 50)

if gradcam_working:
    print("‚úì Grad-CAM successfully implemented and tested")
    
    if batch_results:
        n_analyzed = len(batch_results)
        predictions = [r['prediction'] for r in batch_results]
        
        print(f"‚úì Successfully analyzed {n_analyzed} images")
        print(f"  - Prediction range: [{min(predictions):.3f}, {max(predictions):.3f}]")
        print(f"  - Mean prediction: {np.mean(predictions):.3f}")
        print(f"  - Cancer predictions (>0.5): {sum(1 for p in predictions if p > 0.5)}")
        
        if attention_stats:
            focus_percentages = [s['focus_percentage'] for s in attention_stats]
            print(f"  - Mean focus percentage: {np.mean(focus_percentages):.1f}%")
            print(f"  - Highly focused images (>15%): {sum(1 for f in focus_percentages if f > 15)}")
else:
    print("‚úó Grad-CAM implementation encountered issues")

print("\nKey Insights:")
print("-" * 15)

insights = [
    "‚Ä¢ Grad-CAM provides valuable interpretability for cancer detection models",
    "‚Ä¢ Attention patterns vary significantly across different images",
    "‚Ä¢ Higher prediction confidence often correlates with focused attention",
    "‚Ä¢ Clinical interpretation requires combining prediction scores with attention analysis",
    "‚Ä¢ Model explanations can aid radiologists in understanding AI decisions",
    "‚Ä¢ Visual explanations help identify potential biases or artifacts",
    "‚Ä¢ Different architectures may show different attention patterns"
]

for insight in insights:
    print(insight)

print(f"\nAnalysis complete! All visualizations and reports saved to: {RESULTS_DIR}")

if not trained_model_available:
    print("\nNote: Analysis was performed with a randomly initialized model.")
    print("For meaningful clinical insights, please train the model first using the baseline notebook.")

print("\nRecommendations for Further Analysis:")
print("- Train models with larger datasets for better clinical relevance")
print("- Compare attention patterns across different model architectures") 
print("- Validate Grad-CAM interpretations with radiologist annotations")
print("- Analyze failure cases to identify model limitations")
print("- Implement additional explainability techniques (LIME, SHAP, etc.)")

## 9. Error Case Analysis

In [None]:
def analyze_prediction_errors(batch_results, threshold=0.5):
    """Analyze cases where model predictions might be questionable"""
    
    if not batch_results:
        print("No batch results available for error analysis")
        return []
    
    print("Analyzing Potential Prediction Errors...")
    print("=" * 50)
    
    # Categorize predictions
    high_confidence_cancer = []
    high_confidence_normal = []
    low_confidence = []
    
    for result in batch_results:
        pred = result['prediction']
        confidence = max(pred, 1 - pred)
        
        if pred > 0.7:
            high_confidence_cancer.append(result)
        elif pred < 0.3:
            high_confidence_normal.append(result)
        else:
            low_confidence.append(result)
    
    print(f"High Confidence Cancer (>0.7): {len(high_confidence_cancer)} cases")
    print(f"High Confidence Normal (<0.3): {len(high_confidence_normal)} cases")
    print(f"Low Confidence (0.3-0.7): {len(low_confidence)} cases")
    
    # Analyze low confidence cases in detail
    if low_confidence:
        print(f"\nAnalyzing {len(low_confidence)} Low Confidence Cases:")
        
        fig, axes = plt.subplots(2, min(3, len(low_confidence)), figsize=(5*min(3, len(low_confidence)), 8))
        if len(low_confidence) == 1:
            axes = axes.reshape(-1, 1)
        elif min(3, len(low_confidence)) == 1:
            axes = axes.reshape(2, 1)
        
        for idx, result in enumerate(low_confidence[:3]):
            col = idx
            pred = result['prediction']
            
            print(f"  Case {idx+1}: {os.path.basename(result['image_path'])}")
            print(f"    Prediction: {pred:.4f} (Confidence: {max(pred, 1-pred):.4f})")
            
            # Original image
            axes[0, col].imshow(result['image'])
            axes[0, col].set_title(f"Low Confidence Case {idx+1}\nPred: {pred:.3f}")
            axes[0, col].axis('off')
            
            # Grad-CAM overlay
            superimposed = gradcam_viz.gradcam.superimpose_heatmap(
                np.array(result['image']), result['heatmap'], alpha=0.4
            )
            axes[1, col].imshow(superimposed)
            axes[1, col].set_title(f"Attention Pattern\nUncertain: {abs(pred-0.5):.3f}")
            axes[1, col].axis('off')
        
        plt.tight_layout()
        plt.suptitle('Low Confidence Prediction Analysis', fontsize=14, y=1.02)
        
        # Save error analysis
        save_path = os.path.join(RESULTS_DIR, 'error_case_analysis.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nError case analysis saved to: {save_path}")
        
        plt.show()
    
    return {
        'high_confidence_cancer': high_confidence_cancer,
        'high_confidence_normal': high_confidence_normal,
        'low_confidence': low_confidence
    }

# Perform error case analysis
if batch_results:
    error_analysis = analyze_prediction_errors(batch_results)
else:
    print("No batch results available for error analysis")
    error_analysis = {}

## 10. Multi-Model Comparison (Optional)

In [None]:
def compare_models_gradcam(image_path, model_configs, device):
    """Compare Grad-CAM visualizations across different model architectures"""
    
    print(f"Comparing models on: {os.path.basename(image_path)}")
    print("-" * 50)
    
    results = {}
    
    # Load and preprocess image once
    original_image = Image.open(image_path).convert('RGB')
    
    # Test each model architecture
    for model_name, config in model_configs.items():
        try:
            # Load model
            model = get_model({'model': config})
            model = model.to(device)
            model.eval()
            
            # Initialize Grad-CAM
            gradcam_viz = GradCAMVisualizer(model, device)
            input_tensor = gradcam_viz.preprocess(original_image).unsqueeze(0).to(device)
            
            # Get prediction
            with torch.no_grad():
                output = model(input_tensor)
                prediction = torch.sigmoid(output).item()
            
            # Generate Grad-CAM
            heatmap = gradcam_viz.gradcam.generate_gradcam(input_tensor)
            
            results[model_name] = {
                'prediction': prediction,
                'heatmap': heatmap,
                'model': model
            }
            
            print(f"‚úì {model_name}: Prediction = {prediction:.4f}")
            
        except Exception as e:
            print(f"‚úó Error with {model_name}: {e}")
            continue
    
    # Visualize comparison
    if len(results) > 1:
        n_models = len(results)
        fig, axes = plt.subplots(3, n_models, figsize=(4*n_models, 12))
        
        if n_models == 1:
            axes = axes.reshape(-1, 1)
        
        model_names = list(results.keys())
        
        for idx, model_name in enumerate(model_names):
            result = results[model_name]
            
            # Original image (show once)
            if idx == 0:
                axes[0, idx].imshow(original_image)
                axes[0, idx].set_title(f'Original Image\n{os.path.basename(image_path)}')
            else:
                axes[0, idx].imshow(original_image)
                axes[0, idx].set_title('Original Image')
            axes[0, idx].axis('off')
            
            # Grad-CAM heatmap
            im = axes[1, idx].imshow(result['heatmap'], cmap='jet')
            axes[1, idx].set_title(f'{model_name}\nGrad-CAM')
            axes[1, idx].axis('off')
            
            # Superimposed
            superimposed = gradcam_viz.gradcam.superimpose_heatmap(
                np.array(original_image), result['heatmap'], alpha=0.4
            )
            axes[2, idx].imshow(superimposed)
            axes[2, idx].set_title(f'{model_name}\nPred: {result["prediction"]:.3f}')
            axes[2, idx].axis('off')
        
        plt.tight_layout()
        plt.suptitle('Multi-Model Grad-CAM Comparison', fontsize=16, y=1.02)
        
        # Save comparison
        save_path = os.path.join(RESULTS_DIR, 'multi_model_comparison.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nMulti-model comparison saved to: {save_path}")
        
        plt.show()
        
        # Print comparison summary
        print("\nModel Comparison Summary:")
        print("-" * 30)
        for model_name, result in results.items():
            pred = result['prediction']
            confidence = max(pred, 1 - pred)
            decision = "Cancer" if pred > 0.5 else "Normal"
            print(f"{model_name:12}: {pred:.4f} ({decision}, Conf: {confidence:.3f})")
    
    return results

# Example model comparison (uncomment and modify as needed)
# model_comparison_configs = {
#     'ResNet18': {
#         'architecture': 'resnet18',
#         'num_classes': 1,
#         'pretrained': True,
#         'dropout': 0.5
#     },
#     'ResNet50': {
#         'architecture': 'resnet50', 
#         'num_classes': 1,
#         'pretrained': True,
#         'dropout': 0.5
#     },
#     'Custom CNN': {
#         'architecture': 'custom_cnn',
#         'num_classes': 1,
#         'pretrained': False,
#         'dropout': 0.5
#     }
# }

# if sample_images:
#     model_comparison = compare_models_gradcam(sample_images[0], model_comparison_configs, device)

print("Multi-model comparison function defined (uncomment to use)")

## 11. Advanced Visualization Techniques

In [None]:
def analyze_layer_attention(model, image_tensor, target_layers=None):
    """Analyze attention patterns across different layers"""
    
    print("Analyzing Layer-wise Attention Patterns...")
    
    if target_layers is None:
        # Default layers for ResNet
        if hasattr(model, 'layer1'):
            target_layers = ['layer1', 'layer2', 'layer3', 'layer4']
        else:
            print("Custom target layers needed for this architecture")
            return None
    
    layer_outputs = {}
    hooks = []
    
    def get_activation(name):
        def hook(model, input, output):
            layer_outputs[name] = output.detach()
        return hook
    
    # Register hooks
    for layer_name in target_layers:
        if hasattr(model, layer_name):
            layer = getattr(model, layer_name)
            hooks.append(layer.register_forward_hook(get_activation(layer_name)))
    
    try:
        # Forward pass
        with torch.no_grad():
            _ = model(image_tensor)
        
        # Analyze each layer
        fig, axes = plt.subplots(2, len(target_layers), figsize=(4*len(target_layers), 8))
        if len(target_layers) == 1:
            axes = axes.reshape(-1, 1)
        
        for idx, layer_name in enumerate(target_layers):
            if layer_name in layer_outputs:
                activation = layer_outputs[layer_name]
                
                # Average across channels and resize
                avg_activation = torch.mean(activation, dim=1).squeeze().cpu().numpy()
                resized = cv2.resize(avg_activation, (224, 224))
                
                # Normalize
                resized = (resized - resized.min()) / (resized.max() - resized.min() + 1e-8)
                
                # Heatmap
                axes[0, idx].imshow(resized, cmap='jet')
                axes[0, idx].set_title(f'{layer_name}\nActivation Map')
                axes[0, idx].axis('off')
                
                # Histogram of activations
                axes[1, idx].hist(avg_activation.flatten(), bins=50, alpha=0.7)
                axes[1, idx].set_title(f'{layer_name}\nActivation Distribution')
                axes[1, idx].set_xlabel('Activation Value')
                axes[1, idx].set_ylabel('Frequency')
        
        plt.tight_layout()
        plt.suptitle('Layer-wise Attention Analysis', fontsize=14, y=1.02)
        
        # Save layer analysis
        save_path = os.path.join(RESULTS_DIR, 'layer_attention_analysis.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Layer analysis saved to: {save_path}")
        
        plt.show()
        
    finally:
        # Remove hooks
        for hook in hooks:
            hook.remove()
    
    return layer_outputs

# Perform layer-wise analysis if we have results
if sample_images and gradcam_working:
    # Use first sample image
    original_image = Image.open(sample_images[0]).convert('RGB')
    input_tensor = gradcam_viz.preprocess(original_image).unsqueeze(0).to(device)
    
    layer_analysis = analyze_layer_attention(model, input_tensor)
else:
    print("Cannot perform layer analysis - no images or model not available")

In [None]:
def calculate_gradcam_metrics(batch_results, attention_stats):
    """Calculate quantitative metrics for Grad-CAM quality assessment"""
    
    if not batch_results or not attention_stats:
        print("Insufficient data for quantitative evaluation")
        return None
    
    print("Calculating Grad-CAM Quality Metrics...")
    print("=" * 40)
    
    metrics = {
        'focus_consistency': [],
        'prediction_attention_correlation': [],
        'attention_entropy': [],
        'spatial_coherence': []
    }
    
    predictions = [r['prediction'] for r in batch_results]
    focus_percentages = [s['focus_percentage'] for s in attention_stats]
    max_attentions = [s['max_attention'] for s in attention_stats]
    
    # 1. Focus Consistency (std of focus percentages)
    focus_std = np.std(focus_percentages)
    metrics['focus_consistency'] = focus_std
    
    # 2. Prediction-Attention Correlation
    if len(predictions) > 1:
        corr_max = np.corrcoef(predictions, max_attentions)[0, 1]
        corr_focus = np.corrcoef(predictions, focus_percentages)[0, 1]
        metrics['prediction_attention_correlation'] = {
            'max_attention': corr_max,
            'focus_percentage': corr_focus
        }
    
    # 3. Attention Entropy (measure of attention spread)
    entropies = []
    for result in batch_results:
        heatmap = result['heatmap']
        # Normalize to probability distribution
        heatmap_norm = heatmap / (np.sum(heatmap) + 1e-8)
        # Calculate entropy
        entropy = -np.sum(heatmap_norm * np.log(heatmap_norm + 1e-8))
        entropies.append(entropy)
    
    metrics['attention_entropy'] = {
        'mean': np.mean(entropies),
        'std': np.std(entropies),
        'individual': entropies
    }
    
    # 4. Spatial Coherence (measure of attention clustering)
    coherence_scores = []
    for result in batch_results:
        heatmap = result['heatmap']
        # Apply Gaussian filter and measure similarity
        from scipy.ndimage import gaussian_filter
        smoothed = gaussian_filter(heatmap, sigma=2.0)
        coherence = np.corrcoef(heatmap.flatten(), smoothed.flatten())[0, 1]
        coherence_scores.append(coherence)
    
    metrics['spatial_coherence'] = {
        'mean': np.mean(coherence_scores),
        'std': np.std(coherence_scores),
        'individual': coherence_scores
    }
    
    # Print results
    print("Grad-CAM Quality Assessment:")
    print(f"Focus Consistency (lower=more consistent): {focus_std:.3f}")
    
    if 'max_attention' in metrics['prediction_attention_correlation']:
        print(f"Prediction-Max Attention Correlation: {metrics['prediction_attention_correlation']['max_attention']:.3f}")
        print(f"Prediction-Focus Correlation: {metrics['prediction_attention_correlation']['focus_percentage']:.3f}")
    
    print(f"Mean Attention Entropy: {metrics['attention_entropy']['mean']:.3f} ¬± {metrics['attention_entropy']['std']:.3f}")
    print(f"Mean Spatial Coherence: {metrics['spatial_coherence']['mean']:.3f} ¬± {metrics['spatial_coherence']['std']:.3f}")
    
    # Visualize metrics
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Entropy distribution
    axes[0, 0].hist(entropies, bins=20, alpha=0.7, color='blue')
    axes[0, 0].set_title('Attention Entropy Distribution')
    axes[0, 0].set_xlabel('Entropy')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].axvline(np.mean(entropies), color='red', linestyle='--', label='Mean')
    axes[0, 0].legend()
    
    # Spatial coherence distribution
    axes[0, 1].hist(coherence_scores, bins=20, alpha=0.7, color='green')
    axes[0, 1].set_title('Spatial Coherence Distribution')
    axes[0, 1].set_xlabel('Coherence Score')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].axvline(np.mean(coherence_scores), color='red', linestyle='--', label='Mean')
    axes[0, 1].legend()
    
    # Prediction vs Entropy
    axes[1, 0].scatter(predictions, entropies, alpha=0.7)
    axes[1, 0].set_xlabel('Prediction Probability')
    axes[1, 0].set_ylabel('Attention Entropy')
    axes[1, 0].set_title('Prediction vs Attention Entropy')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Prediction vs Coherence
    axes[1, 1].scatter(predictions, coherence_scores, alpha=0.7, color='orange')
    axes[1, 1].set_xlabel('Prediction Probability')
    axes[1, 1].set_ylabel('Spatial Coherence')
    axes[1, 1].set_title('Prediction vs Spatial Coherence')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.suptitle('Grad-CAM Quality Metrics', fontsize=14, y=1.02)
    
    # Save metrics visualization
    save_path = os.path.join(RESULTS_DIR, 'gradcam_quality_metrics.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\nMetrics visualization saved to: {save_path}")
    
    plt.show()
    
    return metrics

# Calculate metrics
if batch_results and attention_stats:
    gradcam_metrics = calculate_gradcam_metrics(batch_results, attention_stats)
else:
    print("Cannot calculate metrics - insufficient data")
    gradcam_metrics = None

## 12. Export Results and Generate Report

In [None]:
def generate_comprehensive_report(batch_results, attention_stats, gradcam_metrics, output_path):
    """Generate a comprehensive analysis report"""
    
    import json
    from datetime import datetime
    
    report = {
        'analysis_metadata': {
            'timestamp': datetime.now().isoformat(),
            'model_architecture': config['model']['architecture'],
            'total_images_analyzed': len(batch_results) if batch_results else 0,
            'gradcam_working': gradcam_working,
            'trained_model_available': trained_model_available
        },
        'image_analysis': [],
        'summary_statistics': {},
        'quality_metrics': gradcam_metrics,
        'clinical_insights': []
    }
    
    # Individual image analysis
    if batch_results:
        for idx, result in enumerate(batch_results):
            image_analysis = {
                'image_path': os.path.basename(result['image_path']),
                'prediction': float(result['prediction']),
                'prediction_class': 'Cancer' if result['prediction'] > 0.5 else 'Normal',
                'confidence': float(max(result['prediction'], 1 - result['prediction']))
            }
            
            if attention_stats and idx < len(attention_stats):
                stats = attention_stats[idx]
                image_analysis.update({
                    'max_attention': float(stats['max_attention']),
                    'mean_attention': float(stats['mean_attention']),
                    'focus_percentage': float(stats['focus_percentage']),
                    'attention_spread': float(stats['std_attention'])
                })
            
            report['image_analysis'].append(image_analysis)
    
    # Summary statistics
    if batch_results:
        predictions = [r['prediction'] for r in batch_results]
        report['summary_statistics'] = {
            'prediction_range': [float(min(predictions)), float(max(predictions))],
            'mean_prediction': float(np.mean(predictions)),
            'cancer_predictions': int(sum(1 for p in predictions if p > 0.5)),
            'normal_predictions': int(sum(1 for p in predictions if p <= 0.5)),
            'high_confidence_predictions': int(sum(1 for p in predictions if abs(p-0.5) > 0.3)),
            'low_confidence_predictions': int(sum(1 for p in predictions if abs(p-0.5) <= 0.3))
        }
        
        if attention_stats:
            focus_percentages = [s['focus_percentage'] for s in attention_stats]
            report['summary_statistics'].update({
                'mean_focus_percentage': float(np.mean(focus_percentages)),
                'highly_focused_images': int(sum(1 for f in focus_percentages if f > 15))
            })
    
    # Clinical insights
    report['clinical_insights'] = [
        "Grad-CAM visualizations provide interpretable insights into model decision-making",
        "Attention patterns vary significantly across different images and prediction confidence levels",
        "Higher prediction confidence often correlates with more focused attention patterns",
        "Low confidence predictions require additional clinical review and validation",
        "Visual explanations can help identify potential model biases or artifacts",
        "Integration with clinical workflows requires validation against radiologist annotations"
    ]
    
    # Save JSON report
    json_path = os.path.join(output_path, 'gradcam_analysis_report.json')
    with open(json_path, 'w') as f:
        json.dump(report, f, indent=2)
    
    # Generate markdown summary
    md_path = os.path.join(output_path, 'gradcam_analysis_summary.md')
    with open(md_path, 'w') as f:
        f.write("# Grad-CAM Analysis Summary Report\n\n")
        f.write(f"**Generated:** {report['analysis_metadata']['timestamp']}\n")
        f.write(f"**Model:** {report['analysis_metadata']['model_architecture']}\n")
        f.write(f"**Images Analyzed:** {report['analysis_metadata']['total_images_analyzed']}\n\n")
        
        if report['summary_statistics']:
            f.write("## Summary Statistics\n\n")
            stats = report['summary_statistics']
            f.write(f"- **Prediction Range:** {stats['prediction_range'][0]:.3f} - {stats['prediction_range'][1]:.3f}\n")
            f.write(f"- **Mean Prediction:** {stats['mean_prediction']:.3f}\n")
            f.write(f"- **Cancer Predictions:** {stats['cancer_predictions']}\n")
            f.write(f"- **Normal Predictions:** {stats['normal_predictions']}\n")
            f.write(f"- **High Confidence:** {stats['high_confidence_predictions']}\n")
            f.write(f"- **Low Confidence:** {stats['low_confidence_predictions']}\n")
            
            if 'mean_focus_percentage' in stats:
                f.write(f"- **Mean Focus Percentage:** {stats['mean_focus_percentage']:.1f}%\n")
                f.write(f"- **Highly Focused Images:** {stats['highly_focused_images']}\n")
        
        f.write("\n## Clinical Insights\n\n")
        for insight in report['clinical_insights']:
            f.write(f"- {insight}\n")
        
        if not trained_model_available:
            f.write("\n## Important Note\n\n")
            f.write("**This analysis was performed with a randomly initialized model.** ")
            f.write("For clinically meaningful insights, please train the model using the baseline notebook first.\n")
    
    print(f"Comprehensive report generated:")
    print(f"- JSON Report: {json_path}")
    print(f"- Markdown Summary: {md_path}")
    
    return report

# Generate comprehensive report
if batch_results:
    comprehensive_report = generate_comprehensive_report(
        batch_results, attention_stats, gradcam_metrics, RESULTS_DIR
    )
    print("\n" + "="*60)
    print("ANALYSIS COMPLETE!")
    print("="*60)
    print(f"All results saved to: {RESULTS_DIR}")
    print(f"Total files generated: {len(list(Path(RESULTS_DIR).glob('*.*')))}")
else:
    print("No results available for comprehensive report generation")

In [None]:
print("üè• GRAD-CAM VISUALIZATION ANALYSIS COMPLETE")
print("="*60)

# Analysis summary
if gradcam_working:
    print("‚úÖ Grad-CAM implementation: SUCCESS")
    if batch_results:
        print(f"‚úÖ Images analyzed: {len(batch_results)}")
        predictions = [r['prediction'] for r in batch_results]
        print(f"‚úÖ Prediction range: {min(predictions):.3f} - {max(predictions):.3f}")
        cancer_count = sum(1 for p in predictions if p > 0.5)
        print(f"‚úÖ Cancer predictions: {cancer_count}/{len(predictions)}")
    else:
        print("‚ö†Ô∏è  No images were successfully analyzed")
else:
    print("‚ùå Grad-CAM implementation: FAILED")

# Quality assessment
if gradcam_metrics:
    print(f"‚úÖ Quality metrics calculated")
    entropy_mean = gradcam_metrics['attention_entropy']['mean']
    coherence_mean = gradcam_metrics['spatial_coherence']['mean']
    print(f"   - Attention entropy: {entropy_mean:.3f}")
    print(f"   - Spatial coherence: {coherence_mean:.3f}")

# Files generated
results_files = list(Path(RESULTS_DIR).glob('*.*'))
print(f"üìÅ Files generated: {len(results_files)}")
for file_path in results_files:
    print(f"   - {file_path.name}")

print("\nüî¨ CLINICAL INTERPRETATION GUIDELINES:")
print("-" * 40)
print("1. High confidence predictions (>0.7 or <0.3): Prioritize for review")
print("2. Focused attention patterns (>15%): Indicate localized findings")
print("3. Low confidence predictions (0.3-0.7): Require expert validation")
print("4. Diffuse attention patterns: May indicate subtle or absent findings")
print("5. Always combine AI predictions with clinical judgment")

print("\nüìã NEXT STEPS & RECOMMENDATIONS:")
print("-" * 35)
next_steps = [
    "Train model with real data for meaningful clinical insights",
    "Validate Grad-CAM explanations against radiologist annotations", 
    "Implement additional explainability techniques (LIME, SHAP)",
    "Conduct clinical evaluation with medical experts",
    "Analyze model performance on different patient demographics",
    "Integrate findings into clinical decision support system"
]

for i, step in enumerate(next_steps, 1):
    print(f"{i}. {step}")

if not trained_model_available:
    print("\n‚ö†Ô∏è  IMPORTANT NOTICE:")
    print("This analysis used a randomly initialized model for demonstration.")
    print("Train the model using the baseline notebook for clinical relevance.")

print(f"\nüìä Complete analysis saved to: {RESULTS_DIR}")
print("="*60)