# Visualization and Analysis for Pneumonia Detection

This notebook provides comprehensive visualization and analysis tools for understanding model behavior and performance in pneumonia detection. It includes:

1. **Data Exploration**: Visual analysis of the dataset characteristics
2. **Model Interpretability**: Grad-CAM and feature visualization
3. **Performance Analysis**: Detailed evaluation metrics and comparisons
4. **Error Analysis**: Understanding model failures and edge cases
5. **Clinical Insights**: Medical relevance and interpretation
6. **Interactive Visualizations**: Tools for exploring model behavior

These tools are essential for building trust in AI models for medical applications and understanding their clinical utility.

## 1. Setup and Imports

In [None]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

# Core libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json

# Image processing and visualization
import cv2
from PIL import Image, ImageEnhance
import matplotlib.patches as patches
from matplotlib.gridspec import GridSpec

# PyTorch and deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import timm

# Grad-CAM and interpretability
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, LayerCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# Scientific computing and metrics
from sklearn.metrics import (
    confusion_matrix, classification_report, roc_curve, auc,
    precision_recall_curve, average_precision_score,
    roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
)
from scipy import ndimage
from scipy.stats import chi2_contingency

# Interactive widgets (if available)
try:
    import ipywidgets as widgets
    from IPython.display import display, HTML
    WIDGETS_AVAILABLE = True
except ImportError:
    WIDGETS_AVAILABLE = False
    print("ipywidgets not available. Interactive features will be disabled.")

# Configuration
plt.style.use('default')
sns.set_palette("husl")
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Create output directories
RESULTS_DIR = Path("../results")
VISUALIZATIONS_DIR = RESULTS_DIR / "visualizations"
VISUALIZATIONS_DIR.mkdir(parents=True, exist_ok=True)

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

## 2. Data Exploration and Visualization

In [None]:
def analyze_dataset_statistics(data_dir):
    """
    Comprehensive analysis of dataset characteristics
    """
    stats = {}
    
    for split in ['train', 'test']:
        split_stats = {'normal': 0, 'pneumonia': 0, 'total': 0}
        
        for class_name in ['normal', 'pneumonia']:
            class_dir = Path(data_dir) / split / class_name
            if class_dir.exists():
                count = len(list(class_dir.glob('*.jpg')))
                split_stats[class_name] = count
                split_stats['total'] += count
        
        stats[split] = split_stats
    
    return stats

def plot_dataset_distribution(stats, save_path=None):
    """
    Visualize dataset distribution across splits and classes
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Split distribution
    splits = list(stats.keys())
    split_totals = [stats[split]['total'] for split in splits]
    
    axes[0].pie(split_totals, labels=splits, autopct='%1.1f%%', startangle=90)
    axes[0].set_title('Dataset Split Distribution')
    
    # Class distribution per split
    splits = list(stats.keys())
    normal_counts = [stats[split]['normal'] for split in splits]
    pneumonia_counts = [stats[split]['pneumonia'] for split in splits]
    
    x = np.arange(len(splits))
    width = 0.35
    
    axes[1].bar(x - width/2, normal_counts, width, label='Normal', alpha=0.8)
    axes[1].bar(x + width/2, pneumonia_counts, width, label='Pneumonia', alpha=0.8)
    axes[1].set_xlabel('Dataset Split')
    axes[1].set_ylabel('Number of Images')
    axes[1].set_title('Class Distribution by Split')
    axes[1].set_xticks(x)
    axes[1].set_xticklabels(splits)
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Overall class balance
    total_normal = sum(stats[split]['normal'] for split in stats)
    total_pneumonia = sum(stats[split]['pneumonia'] for split in stats)
    
    axes[2].pie([total_normal, total_pneumonia], 
               labels=['Normal', 'Pneumonia'], 
               autopct='%1.1f%%', 
               startangle=90,
               colors=['lightblue', 'lightcoral'])
    axes[2].set_title('Overall Class Distribution')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
    # Print statistics
    print("Dataset Statistics:")
    print("=" * 50)
    for split, split_stats in stats.items():
        print(f"\n{split.upper()} SET:")
        print(f"  Normal: {split_stats['normal']:,}")
        print(f"  Pneumonia: {split_stats['pneumonia']:,}")
        print(f"  Total: {split_stats['total']:,}")
        
        if split_stats['total'] > 0:
            balance_ratio = split_stats['pneumonia'] / split_stats['normal']
            print(f"  Class Balance Ratio: {balance_ratio:.2f}")

def visualize_sample_images(data_dir, num_samples=8, save_path=None):
    """
    Display sample images from each class
    """
    fig, axes = plt.subplots(2, num_samples, figsize=(20, 8))
    
    for class_idx, class_name in enumerate(['normal', 'pneumonia']):
        class_dir = Path(data_dir) / 'train' / class_name
        image_files = list(class_dir.glob('*.jpg'))[:num_samples]
        
        for idx, img_path in enumerate(image_files):
            if idx < num_samples:
                img = Image.open(img_path).convert('RGB')
                axes[class_idx, idx].imshow(img, cmap='gray')
                axes[class_idx, idx].set_title(f'{class_name.title()}\n{img_path.name}')
                axes[class_idx, idx].axis('off')
    
    plt.suptitle('Sample Images from Dataset', fontsize=16)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

## 3. Model Interpretability with Grad-CAM

In [None]:
class GradCAMVisualizer:
    """
    Grad-CAM visualization for model interpretability
    """
    
    def __init__(self, model, target_layers=None):
        self.model = model
        self.model.eval()
        
        # Auto-detect target layers if not provided
        if target_layers is None:
            self.target_layers = self._get_target_layers()
        else:
            self.target_layers = target_layers
        
        # Initialize Grad-CAM
        self.cam = GradCAM(model=model, target_layers=self.target_layers)
    
    def _get_target_layers(self):
        """
        Automatically detect appropriate target layers for different model types
        """
        # Look for common layer patterns
        for name, module in self.model.named_modules():
            if 'features' in name and isinstance(module, nn.Conv2d):
                return [module]
            elif 'layer4' in name and isinstance(module, nn.Conv2d):
                return [module]
            elif 'block' in name and isinstance(module, nn.Conv2d):
                return [module]
        
        # Fallback: use the last convolutional layer
        conv_layers = []
        for module in self.model.modules():
            if isinstance(module, nn.Conv2d):
                conv_layers.append(module)
        
        return [conv_layers[-1]] if conv_layers else []
    
    def generate_cam(self, input_tensor, target_class=None):
        """
        Generate Grad-CAM visualization
        """
        if target_class is None:
            # Use the predicted class
            with torch.no_grad():
                output = self.model(input_tensor)
                predicted_class = (torch.sigmoid(output) > 0.5).int().item()
            target_class = predicted_class
        
        targets = [ClassifierOutputTarget(target_class)]
        grayscale_cam = self.cam(input_tensor=input_tensor, targets=targets)
        
        return grayscale_cam[0]  # Return first batch item
    
    def visualize_prediction(self, image_path, save_path=None):
        """
        Complete visualization pipeline for a single image
        """
        # Load and preprocess image
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        # Original image for visualization
        original_img = Image.open(image_path).convert('RGB')
        original_img = original_img.resize((224, 224))
        original_array = np.array(original_img) / 255.0
        
        # Preprocessed image for model
        input_tensor = transform(original_img).unsqueeze(0).to(DEVICE)
        
        # Get prediction
        with torch.no_grad():
            output = self.model(input_tensor)
            probability = torch.sigmoid(output).item()
            prediction = "Pneumonia" if probability > 0.5 else "Normal"
            confidence = probability if probability > 0.5 else (1 - probability)
        
        # Generate Grad-CAM
        cam = self.generate_cam(input_tensor)
        
        # Create visualization
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original image
        axes[0].imshow(original_img)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        # Grad-CAM heatmap
        axes[1].imshow(cam, cmap='jet', alpha=0.8)
        axes[1].set_title('Grad-CAM Heatmap')
        axes[1].axis('off')
        
        # Overlay
        cam_image = show_cam_on_image(original_array, cam, use_rgb=True)
        axes[2].imshow(cam_image)
        axes[2].set_title(f'Grad-CAM Overlay\nPrediction: {prediction}\nConfidence: {confidence:.3f}')
        axes[2].axis('off')
        
        plt.suptitle(f'Model Interpretation: {Path(image_path).name}', fontsize=14)
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()
        
        return {
            'prediction': prediction,
            'probability': probability,
            'confidence': confidence,
            'cam': cam
        }
    
    def batch_visualize(self, image_paths, save_dir=None):
        """
        Generate Grad-CAM visualizations for multiple images
        """
        results = []
        
        for i, image_path in enumerate(image_paths):
            print(f"Processing {i+1}/{len(image_paths)}: {Path(image_path).name}")
            
            save_path = None
            if save_dir:
                save_path = Path(save_dir) / f"gradcam_{Path(image_path).stem}.png"
            
            result = self.visualize_prediction(image_path, save_path)
            result['image_path'] = image_path
            results.append(result)
        
        return results

def compare_gradcam_methods(model, image_path, save_path=None):
    """
    Compare different Grad-CAM methods on the same image
    """
    # Get target layers
    target_layers = []
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            target_layers.append(module)
    target_layers = [target_layers[-1]] if target_layers else []
    
    # Different CAM methods
    cam_methods = {
        'GradCAM': GradCAM,
        'GradCAM++': GradCAMPlusPlus,
        'ScoreCAM': ScoreCAM,
        'XGradCAM': XGradCAM
    }
    
    # Load image
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    original_img = Image.open(image_path).convert('RGB').resize((224, 224))
    original_array = np.array(original_img) / 255.0
    input_tensor = transform(original_img).unsqueeze(0).to(DEVICE)
    
    # Generate CAMs
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    
    # Original image
    axes[0].imshow(original_img)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    for idx, (method_name, cam_class) in enumerate(cam_methods.items(), 1):
        try:
            cam = cam_class(model=model, target_layers=target_layers)
            targets = [ClassifierOutputTarget(1)]  # Pneumonia class
            grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
            
            cam_image = show_cam_on_image(original_array, grayscale_cam, use_rgb=True)
            axes[idx].imshow(cam_image)
            axes[idx].set_title(f'{method_name}')
            axes[idx].axis('off')
        except Exception as e:
            axes[idx].text(0.5, 0.5, f'Error: {method_name}\n{str(e)[:50]}...', 
                          ha='center', va='center', transform=axes[idx].transAxes)
            axes[idx].axis('off')
    
    # Remove extra subplot
    axes[5].remove()
    
    plt.suptitle(f'Grad-CAM Method Comparison: {Path(image_path).name}', fontsize=16)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

## 4. Performance Analysis and Error Analysis

In [None]:
def detailed_performance_analysis(y_true, y_pred, y_prob, class_names=None, save_path=None):
    """
    Comprehensive performance analysis with multiple visualizations
    """
    if class_names is None:
        class_names = ['Normal', 'Pneumonia']
    
    fig = plt.figure(figsize=(20, 15))
    gs = GridSpec(3, 4, figure=fig)
    
    # 1. Confusion Matrix
    ax1 = fig.add_subplot(gs[0, 0])
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names, ax=ax1)
    ax1.set_title('Confusion Matrix')
    ax1.set_ylabel('True Label')
    ax1.set_xlabel('Predicted Label')
    
    # 2. ROC Curve
    ax2 = fig.add_subplot(gs[0, 1])
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)
    ax2.plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.3f})', linewidth=2)
    ax2.plot([0, 1], [0, 1], 'k--', label='Random')
    ax2.set_xlabel('False Positive Rate')
    ax2.set_ylabel('True Positive Rate')
    ax2.set_title('ROC Curve')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # 3. Precision-Recall Curve
    ax3 = fig.add_subplot(gs[0, 2])
    precision, recall, _ = precision_recall_curve(y_true, y_prob)
    pr_auc = average_precision_score(y_true, y_prob)
    ax3.plot(recall, precision, label=f'PR Curve (AUC = {pr_auc:.3f})', linewidth=2)
    ax3.set_xlabel('Recall')
    ax3.set_ylabel('Precision')
    ax3.set_title('Precision-Recall Curve')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # 4. Prediction Distribution
    ax4 = fig.add_subplot(gs[0, 3])
    ax4.hist(y_prob[y_true == 0], bins=30, alpha=0.5, label='Normal', density=True)
    ax4.hist(y_prob[y_true == 1], bins=30, alpha=0.5, label='Pneumonia', density=True)
    ax4.axvline(x=0.5, color='red', linestyle='--', label='Threshold')
    ax4.set_xlabel('Predicted Probability')
    ax4.set_ylabel('Density')
    ax4.set_title('Prediction Distribution')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    # 5. Metrics by Threshold
    ax5 = fig.add_subplot(gs[1, 0:2])
    thresholds = np.linspace(0, 1, 100)
    precisions, recalls, f1s, accuracies = [], [], [], []
    
    for threshold in thresholds:
        y_pred_thresh = (y_prob >= threshold).astype(int)
        if len(np.unique(y_pred_thresh)) > 1:  # Avoid division by zero
            precisions.append(precision_score(y_true, y_pred_thresh, zero_division=0))
            recalls.append(recall_score(y_true, y_pred_thresh, zero_division=0))
            f1s.append(f1_score(y_true, y_pred_thresh, zero_division=0))
            accuracies.append(accuracy_score(y_true, y_pred_thresh))
        else:
            precisions.append(0)
            recalls.append(0)
            f1s.append(0)
            accuracies.append(0)
    
    ax5.plot(thresholds, precisions, label='Precision', linewidth=2)
    ax5.plot(thresholds, recalls, label='Recall', linewidth=2)
    ax5.plot(thresholds, f1s, label='F1-Score', linewidth=2)
    ax5.plot(thresholds, accuracies, label='Accuracy', linewidth=2)
    ax5.axvline(x=0.5, color='red', linestyle='--', alpha=0.7, label='Default Threshold')
    ax5.set_xlabel('Threshold')
    ax5.set_ylabel('Metric Value')
    ax5.set_title('Metrics vs Threshold')
    ax5.legend()
    ax5.grid(True, alpha=0.3)
    
    # 6. Class-wise Performance
    ax6 = fig.add_subplot(gs[1, 2:4])
    report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
    
    metrics = ['precision', 'recall', 'f1-score']
    normal_scores = [report['Normal'][metric] for metric in metrics]
    pneumonia_scores = [report['Pneumonia'][metric] for metric in metrics]
    
    x = np.arange(len(metrics))
    width = 0.35
    
    ax6.bar(x - width/2, normal_scores, width, label='Normal', alpha=0.8)
    ax6.bar(x + width/2, pneumonia_scores, width, label='Pneumonia', alpha=0.8)
    
    ax6.set_xlabel('Metrics')
    ax6.set_ylabel('Score')
    ax6.set_title('Class-wise Performance')
    ax6.set_xticks(x)
    ax6.set_xticklabels([m.title() for m in metrics])
    ax6.legend()
    ax6.grid(True, alpha=0.3)
    ax6.set_ylim(0, 1)
    
    # Add value labels on bars
    for i, (normal, pneumonia) in enumerate(zip(normal_scores, pneumonia_scores)):
        ax6.text(i - width/2, normal + 0.01, f'{normal:.3f}', ha='center', va='bottom')
        ax6.text(i + width/2, pneumonia + 0.01, f'{pneumonia:.3f}', ha='center', va='bottom')
    
    # 7. Error Analysis Summary
    ax7 = fig.add_subplot(gs[2, :])
    ax7.axis('off')
    
    # Calculate detailed metrics
    tn, fp, fn, tp = cm.ravel()
    
    metrics_text = f"""
    DETAILED PERFORMANCE METRICS
    
    Basic Metrics:
    • Accuracy: {accuracy_score(y_true, y_pred):.4f}
    • Precision: {precision_score(y_true, y_pred):.4f}
    • Recall (Sensitivity): {recall_score(y_true, y_pred):.4f}
    • Specificity: {tn/(tn+fp):.4f}
    • F1-Score: {f1_score(y_true, y_pred):.4f}
    
    AUC Scores:
    • ROC AUC: {roc_auc:.4f}
    • PR AUC: {pr_auc:.4f}
    
    Confusion Matrix:
    • True Positives (Pneumonia correctly identified): {tp}
    • False Positives (Normal misclassified as Pneumonia): {fp}
    • True Negatives (Normal correctly identified): {tn}
    • False Negatives (Pneumonia misclassified as Normal): {fn}
    
    Clinical Relevance:
    • Missed Pneumonia Cases: {fn} ({fn/(tp+fn)*100:.1f}% of actual pneumonia cases)
    • False Alarms: {fp} ({fp/(tn+fp)*100:.1f}% of actual normal cases)
    • Positive Predictive Value: {tp/(tp+fp):.4f}
    • Negative Predictive Value: {tn/(tn+fn):.4f}
    """
    
    ax7.text(0.05, 0.95, metrics_text, transform=ax7.transAxes, fontsize=11,
             verticalalignment='top', fontfamily='monospace',
             bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
    
    plt.suptitle('Comprehensive Performance Analysis', fontsize=16, y=0.98)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

def analyze_model_errors(model, test_loader, save_dir=None, max_errors=10):
    """
    Analyze and visualize model errors to understand failure cases
    """
    model.eval()
    
    errors = {'false_positives': [], 'false_negatives': []}
    correct_predictions = {'true_positives': [], 'true_negatives': []}
    
    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_loader):
            inputs = inputs.to(DEVICE)
            outputs = model(inputs)
            probabilities = torch.sigmoid(outputs.squeeze())
            predictions = (probabilities > 0.5).float()
            
            for i in range(len(inputs)):
                true_label = labels[i].item()
                pred_label = predictions[i].item()
                prob = probabilities[i].item()
                
                sample_info = {
                    'image': inputs[i].cpu(),
                    'true_label': true_label,
                    'pred_label': pred_label,
                    'probability': prob,
                    'batch_idx': batch_idx,
                    'sample_idx': i
                }
                
                if true_label == 0 and pred_label == 1:  # False positive
                    errors['false_positives'].append(sample_info)
                elif true_label == 1 and pred_label == 0:  # False negative
                    errors['false_negatives'].append(sample_info)
                elif true_label == 1 and pred_label == 1:  # True positive
                    correct_predictions['true_positives'].append(sample_info)
                elif true_label == 0 and pred_label == 0:  # True negative
                    correct_predictions['true_negatives'].append(sample_info)
    
    # Visualize errors
    for error_type, error_samples in errors.items():
        if error_samples:
            # Sort by confidence (probability distance from 0.5)
            error_samples.sort(key=lambda x: abs(x['probability'] - 0.5), reverse=True)
            
            # Select top errors
            top_errors = error_samples[:max_errors]
            
            # Create visualization
            n_cols = min(5, len(top_errors))
            n_rows = (len(top_errors) + n_cols - 1) // n_cols
            
            fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3 * n_rows))
            if n_rows == 1:
                axes = axes.reshape(1, -1)
            
            for idx, error in enumerate(top_errors):
                row, col = idx // n_cols, idx % n_cols
                
                # Denormalize image for visualization
                img = error['image']
                img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
                img = torch.clamp(img, 0, 1)
                img = img.permute(1, 2, 0).numpy()
                
                axes[row, col].imshow(img)
                
                true_class = "Pneumonia" if error['true_label'] == 1 else "Normal"
                pred_class = "Pneumonia" if error['pred_label'] == 1 else "Normal"
                
                axes[row, col].set_title(f'True: {true_class}\nPred: {pred_class}\nConf: {error["probability"]:.3f}')
                axes[row, col].axis('off')
            
            # Remove empty subplots
            for idx in range(len(top_errors), n_rows * n_cols):
                row, col = idx // n_cols, idx % n_cols
                axes[row, col].remove()
            
            error_type_title = error_type.replace('_', ' ').title()
            plt.suptitle(f'{error_type_title} Examples', fontsize=16)
            plt.tight_layout()
            
            if save_dir:
                save_path = Path(save_dir) / f'{error_type}_analysis.png'
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
            
            plt.show()
    
    return errors, correct_predictions

## 5. Interactive Visualization Tools

In [None]:
if WIDGETS_AVAILABLE:
    def create_interactive_threshold_widget(y_true, y_prob):
        """
        Interactive widget for exploring different classification thresholds
        """
        def update_metrics(threshold):
            y_pred = (y_prob >= threshold).astype(int)
            
            # Calculate metrics
            accuracy = accuracy_score(y_true, y_pred)
            precision = precision_score(y_true, y_pred, zero_division=0)
            recall = recall_score(y_true, y_pred, zero_division=0)
            f1 = f1_score(y_true, y_pred, zero_division=0)
            
            cm = confusion_matrix(y_true, y_pred)
            tn, fp, fn, tp = cm.ravel()
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
            
            # Create visualization
            fig, axes = plt.subplots(1, 2, figsize=(12, 5))
            
            # Metrics plot
            metrics = ['Accuracy', 'Precision', 'Recall', 'Specificity', 'F1-Score']
            values = [accuracy, precision, recall, specificity, f1]
            
            bars = axes[0].bar(metrics, values, color=['skyblue', 'lightgreen', 'lightcoral', 'lightyellow', 'lightpink'])
            axes[0].set_ylim(0, 1)
            axes[0].set_title(f'Metrics at Threshold = {threshold:.2f}')
            axes[0].set_ylabel('Score')
            
            # Add value labels
            for bar, value in zip(bars, values):
                axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                           f'{value:.3f}', ha='center', va='bottom')
            
            # Confusion matrix
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                       xticklabels=['Normal', 'Pneumonia'], 
                       yticklabels=['Normal', 'Pneumonia'], ax=axes[1])
            axes[1].set_title('Confusion Matrix')
            axes[1].set_ylabel('True Label')
            axes[1].set_xlabel('Predicted Label')
            
            plt.tight_layout()
            plt.show()
            
            # Print detailed info
            print(f"Threshold: {threshold:.2f}")
            print(f"True Positives: {tp}, False Positives: {fp}")
            print(f"True Negatives: {tn}, False Negatives: {fn}")
            print(f"Missed Pneumonia Cases: {fn} ({fn/(tp+fn)*100:.1f}% of pneumonia cases)")
            print(f"False Alarms: {fp} ({fp/(tn+fp)*100:.1f}% of normal cases)")
        
        threshold_slider = widgets.FloatSlider(
            value=0.5,
            min=0.0,
            max=1.0,
            step=0.01,
            description='Threshold:',
            style={'description_width': 'initial'}
        )
        
        return widgets.interact(update_metrics, threshold=threshold_slider)
    
    def create_model_comparison_widget(models_results):
        """
        Interactive widget for comparing different models
        """
        def compare_models(metric):
            model_names = list(models_results.keys())
            metric_values = [models_results[model][metric] for model in model_names]
            
            plt.figure(figsize=(10, 6))
            bars = plt.bar(model_names, metric_values, 
                          color=['skyblue', 'lightgreen', 'lightcoral', 'lightyellow'][:len(model_names)])
            
            plt.title(f'Model Comparison: {metric.replace("_", " ").title()}')
            plt.ylabel(metric.replace('_', ' ').title())
            plt.xlabel('Models')
            plt.xticks(rotation=45)
            
            # Add value labels
            for bar, value in zip(bars, metric_values):
                plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001,
                        f'{value:.3f}', ha='center', va='bottom')
            
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.show()
        
        # Get available metrics
        available_metrics = list(list(models_results.values())[0].keys())
        metric_dropdown = widgets.Dropdown(
            options=available_metrics,
            value='accuracy',
            description='Metric:',
            style={'description_width': 'initial'}
        )
        
        return widgets.interact(compare_models, metric=metric_dropdown)

else:
    def create_interactive_threshold_widget(y_true, y_prob):
        print("Interactive widgets not available. Install ipywidgets for interactive features.")
    
    def create_model_comparison_widget(models_results):
        print("Interactive widgets not available. Install ipywidgets for interactive features.")

## 6. Clinical Relevance Analysis

In [None]:
def clinical_decision_analysis(y_true, y_prob, cost_matrix=None, save_path=None):
    """
    Analyze model performance from a clinical decision-making perspective
    
    Args:
        y_true: True labels
        y_prob: Predicted probabilities
        cost_matrix: 2x2 matrix of costs [TN, FP; FN, TP]
    """
    if cost_matrix is None:
        # Default cost matrix (clinical perspective)
        # Missing pneumonia (FN) is much more costly than false alarm (FP)
        cost_matrix = np.array([[0, 1],    # [TN, FP]
                               [10, 0]])   # [FN, TP]
    
    thresholds = np.linspace(0, 1, 101)
    costs = []
    metrics = {'sensitivity': [], 'specificity': [], 'ppv': [], 'npv': []}
    
    for threshold in thresholds:
        y_pred = (y_prob >= threshold).astype(int)
        cm = confusion_matrix(y_true, y_pred)
        
        if cm.shape == (2, 2):
            tn, fp, fn, tp = cm.ravel()
        else:
            # Handle edge cases where only one class is predicted
            if len(np.unique(y_pred)) == 1:
                if y_pred[0] == 0:  # All predicted as negative
                    tn, fp, fn, tp = len(y_true[y_true == 0]), 0, len(y_true[y_true == 1]), 0
                else:  # All predicted as positive
                    tn, fp, fn, tp = 0, len(y_true[y_true == 0]), 0, len(y_true[y_true == 1])
            else:
                continue
        
        # Calculate cost
        total_cost = (tn * cost_matrix[0, 0] + fp * cost_matrix[0, 1] +
                     fn * cost_matrix[1, 0] + tp * cost_matrix[1, 1])
        costs.append(total_cost)
        
        # Calculate clinical metrics
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0  # Recall
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        ppv = tp / (tp + fp) if (tp + fp) > 0 else 0  # Positive Predictive Value
        npv = tn / (tn + fn) if (tn + fn) > 0 else 0  # Negative Predictive Value
        
        metrics['sensitivity'].append(sensitivity)
        metrics['specificity'].append(specificity)
        metrics['ppv'].append(ppv)
        metrics['npv'].append(npv)
    
    # Find optimal threshold
    optimal_idx = np.argmin(costs)
    optimal_threshold = thresholds[optimal_idx]
    optimal_cost = costs[optimal_idx]
    
    # Create comprehensive visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Cost vs Threshold
    axes[0, 0].plot(thresholds, costs, linewidth=2)
    axes[0, 0].axvline(x=optimal_threshold, color='red', linestyle='--', 
                      label=f'Optimal Threshold: {optimal_threshold:.3f}')
    axes[0, 0].axvline(x=0.5, color='gray', linestyle=':', alpha=0.7, label='Default (0.5)')
    axes[0, 0].set_xlabel('Threshold')
    axes[0, 0].set_ylabel('Total Cost')
    axes[0, 0].set_title('Cost vs Classification Threshold')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. Sensitivity and Specificity
    axes[0, 1].plot(thresholds, metrics['sensitivity'], label='Sensitivity (Recall)', linewidth=2)
    axes[0, 1].plot(thresholds, metrics['specificity'], label='Specificity', linewidth=2)
    axes[0, 1].axvline(x=optimal_threshold, color='red', linestyle='--', alpha=0.7)
    axes[0, 1].axvline(x=0.5, color='gray', linestyle=':', alpha=0.7)
    axes[0, 1].set_xlabel('Threshold')
    axes[0, 1].set_ylabel('Rate')
    axes[0, 1].set_title('Sensitivity vs Specificity')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. Predictive Values
    axes[0, 2].plot(thresholds, metrics['ppv'], label='PPV (Precision)', linewidth=2)
    axes[0, 2].plot(thresholds, metrics['npv'], label='NPV', linewidth=2)
    axes[0, 2].axvline(x=optimal_threshold, color='red', linestyle='--', alpha=0.7)
    axes[0, 2].axvline(x=0.5, color='gray', linestyle=':', alpha=0.7)
    axes[0, 2].set_xlabel('Threshold')
    axes[0, 2].set_ylabel('Predictive Value')
    axes[0, 2].set_title('Positive vs Negative Predictive Value')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # 4. ROC Space with Operating Points
    fpr_points = 1 - np.array(metrics['specificity'])
    tpr_points = metrics['sensitivity']
    
    # Standard ROC curve
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)
    
    axes[1, 0].plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.3f})', linewidth=2)
    axes[1, 0].plot([0, 1], [0, 1], 'k--', alpha=0.7)
    
    # Mark optimal point
    opt_fpr = fpr_points[optimal_idx]
    opt_tpr = tpr_points[optimal_idx]
    axes[1, 0].scatter([opt_fpr], [opt_tpr], color='red', s=100, zorder=5,
                      label=f'Optimal Point (FPR={opt_fpr:.3f}, TPR={opt_tpr:.3f})')
    
    # Mark default threshold point
    default_idx = 50  # Index for threshold 0.5
    default_fpr = fpr_points[default_idx]
    default_tpr = tpr_points[default_idx]
    axes[1, 0].scatter([default_fpr], [default_tpr], color='gray', s=100, zorder=5,
                      label=f'Default (0.5) Point')
    
    axes[1, 0].set_xlabel('False Positive Rate (1 - Specificity)')
    axes[1, 0].set_ylabel('True Positive Rate (Sensitivity)')
    axes[1, 0].set_title('ROC Space with Operating Points')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # 5. Comparison of Thresholds
    thresholds_to_compare = [0.3, 0.5, 0.7, optimal_threshold]
    comparison_data = []
    
    for thresh in thresholds_to_compare:
        y_pred_thresh = (y_prob >= thresh).astype(int)
        cm = confusion_matrix(y_true, y_pred_thresh)
        tn, fp, fn, tp = cm.ravel()
        
        sensitivity = tp / (tp + fn)
        specificity = tn / (tn + fp)
        ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
        npv = tn / (tn + fn) if (tn + fn) > 0 else 0
        
        comparison_data.append({
            'Threshold': f'{thresh:.3f}',
            'Sensitivity': sensitivity,
            'Specificity': specificity,
            'PPV': ppv,
            'NPV': npv,
            'Missed Cases': fn,
            'False Alarms': fp
        })
    
    comparison_df = pd.DataFrame(comparison_data)
    
    # Plot threshold comparison
    x_pos = np.arange(len(thresholds_to_compare))
    width = 0.15
    
    axes[1, 1].bar(x_pos - 1.5*width, comparison_df['Sensitivity'], width, label='Sensitivity', alpha=0.8)
    axes[1, 1].bar(x_pos - 0.5*width, comparison_df['Specificity'], width, label='Specificity', alpha=0.8)
    axes[1, 1].bar(x_pos + 0.5*width, comparison_df['PPV'], width, label='PPV', alpha=0.8)
    axes[1, 1].bar(x_pos + 1.5*width, comparison_df['NPV'], width, label='NPV', alpha=0.8)
    
    axes[1, 1].set_xlabel('Threshold')
    axes[1, 1].set_ylabel('Metric Value')
    axes[1, 1].set_title('Threshold Comparison')
    axes[1, 1].set_xticks(x_pos)
    axes[1, 1].set_xticklabels(comparison_df['Threshold'])
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # 6. Clinical Impact Summary
    axes[1, 2].axis('off')
    
    # Calculate metrics for optimal threshold
    y_pred_opt = (y_prob >= optimal_threshold).astype(int)
    cm_opt = confusion_matrix(y_true, y_pred_opt)
    tn_opt, fp_opt, fn_opt, tp_opt = cm_opt.ravel()
    
    clinical_summary = f"""
    CLINICAL DECISION ANALYSIS SUMMARY
    
    Cost Matrix Used:
    • True Negative (correct normal): {cost_matrix[0,0]}
    • False Positive (false alarm): {cost_matrix[0,1]}
    • False Negative (missed pneumonia): {cost_matrix[1,0]}
    • True Positive (correct pneumonia): {cost_matrix[1,1]}
    
    Optimal Threshold: {optimal_threshold:.3f}
    Total Cost at Optimal: {optimal_cost:.0f}
    
    Clinical Performance:
    • Sensitivity: {tp_opt/(tp_opt+fn_opt):.1%} ({tp_opt}/{tp_opt+fn_opt} pneumonia cases detected)
    • Specificity: {tn_opt/(tn_opt+fp_opt):.1%} ({tn_opt}/{tn_opt+fp_opt} normal cases correct)
    • Missed Pneumonia: {fn_opt} cases ({fn_opt/(tp_opt+fn_opt)*100:.1f}%)
    • False Alarms: {fp_opt} cases ({fp_opt/(tn_opt+fp_opt)*100:.1f}%)
    
    Clinical Interpretation:
    • PPV: {tp_opt/(tp_opt+fp_opt)*100:.1f}% of positive predictions are correct
    • NPV: {tn_opt/(tn_opt+fn_opt)*100:.1f}% of negative predictions are correct
    
    Recommendation:
    Use threshold {optimal_threshold:.3f} for clinical deployment
    to minimize cost-weighted classification errors.
    """
    
    axes[1, 2].text(0.05, 0.95, clinical_summary, transform=axes[1, 2].transAxes,
                   fontsize=10, verticalalignment='top', fontfamily='monospace',
                   bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
    
    plt.suptitle('Clinical Decision Analysis', fontsize=16, y=0.98)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
    return {
        'optimal_threshold': optimal_threshold,
        'optimal_cost': optimal_cost,
        'threshold_comparison': comparison_df,
        'cost_curve': costs
    }

## 7. Example Usage and Demonstrations

Uncomment the cells below to run visualizations when models and data are available:

In [None]:
# Example: Dataset Analysis
# DATA_DIR = "../data"  # Adjust path to your data directory

# # Analyze dataset statistics
# stats = analyze_dataset_statistics(DATA_DIR)
# plot_dataset_distribution(stats, save_path=VISUALIZATIONS_DIR / "dataset_distribution.png")

# # Visualize sample images
# visualize_sample_images(DATA_DIR, num_samples=8, 
#                        save_path=VISUALIZATIONS_DIR / "sample_images.png")

print("Dataset analysis functions are ready.")
print("Uncomment the above code and set DATA_DIR to analyze your dataset.")

In [None]:
# Example: Model Interpretability with Grad-CAM
# MODEL_PATH = "../models/xception_weights.pth"  # Adjust path to your model
# SAMPLE_IMAGE = "../data/test/pneumonia/sample.jpg"  # Adjust to sample image

# # Load model (example for Xception)
# from notebooks.models import XceptionModel  # Import your model class
# model = XceptionModel()
# model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
# model.to(DEVICE)

# # Create Grad-CAM visualizer
# gradcam_viz = GradCAMVisualizer(model)

# # Visualize single prediction
# result = gradcam_viz.visualize_prediction(SAMPLE_IMAGE, 
#                                          save_path=VISUALIZATIONS_DIR / "gradcam_example.png")

# # Compare different Grad-CAM methods
# compare_gradcam_methods(model, SAMPLE_IMAGE, 
#                        save_path=VISUALIZATIONS_DIR / "gradcam_comparison.png")

print("Grad-CAM visualization functions are ready.")
print("Uncomment the above code and set paths to visualize model interpretability.")

In [None]:
# Example: Performance Analysis
# # Assuming you have model predictions
# y_true = np.array([0, 1, 1, 0, 1, 0, 0, 1])  # Example true labels
# y_pred = np.array([0, 1, 0, 0, 1, 1, 0, 1])  # Example predictions
# y_prob = np.array([0.2, 0.8, 0.4, 0.3, 0.9, 0.6, 0.1, 0.7])  # Example probabilities

# # Detailed performance analysis
# detailed_performance_analysis(y_true, y_pred, y_prob, 
#                              save_path=VISUALIZATIONS_DIR / "performance_analysis.png")

# # Clinical decision analysis
# clinical_results = clinical_decision_analysis(y_true, y_prob,
#                                              save_path=VISUALIZATIONS_DIR / "clinical_analysis.png")

print("Performance analysis functions are ready.")
print("Uncomment the above code with your model predictions to analyze performance.")

In [None]:
# Example: Interactive Widgets (if available)
# if WIDGETS_AVAILABLE:
#     # Interactive threshold exploration
#     print("Interactive Threshold Explorer:")
#     create_interactive_threshold_widget(y_true, y_prob)
    
#     # Model comparison widget
#     models_results = {
#         'xception': {'accuracy': 0.92, 'precision': 0.89, 'recall': 0.94, 'f1_score': 0.91},
#         'fusion': {'accuracy': 0.94, 'precision': 0.91, 'recall': 0.96, 'f1_score': 0.93},
#         'ensemble': {'accuracy': 0.95, 'precision': 0.93, 'recall': 0.97, 'f1_score': 0.95}
#     }
#     print("\nInteractive Model Comparison:")
#     create_model_comparison_widget(models_results)
# else:
#     print("Install ipywidgets for interactive features")

print("Interactive visualization functions are ready.")
print("Uncomment the above code to create interactive widgets.")

## 8. Visualization Summary and Best Practices

This comprehensive visualization and analysis framework provides essential tools for understanding and validating pneumonia detection models:

### Key Visualization Categories:

1. **Data Exploration**:
   - Dataset distribution analysis
   - Class balance visualization
   - Sample image inspection

2. **Model Interpretability**:
   - Grad-CAM heatmaps showing model attention
   - Comparison of different CAM methods
   - Feature importance visualization

3. **Performance Assessment**:
   - Comprehensive metric analysis
   - ROC and Precision-Recall curves
   - Confusion matrix analysis
   - Threshold optimization

4. **Error Analysis**:
   - False positive/negative case studies
   - Failure pattern identification
   - Edge case analysis

5. **Clinical Relevance**:
   - Cost-sensitive analysis
   - Clinical decision thresholds
   - Sensitivity vs specificity trade-offs

### Best Practices for Medical AI Visualization:

1. **Interpretability First**: Always provide explanations for model decisions
2. **Clinical Context**: Frame results in terms of clinical impact
3. **Uncertainty Quantification**: Show confidence levels and uncertainty
4. **Comprehensive Metrics**: Use multiple evaluation metrics
5. **Error Analysis**: Understand and communicate failure modes
6. **Threshold Optimization**: Consider clinical costs in threshold selection
7. **Interactive Exploration**: Enable dynamic analysis when possible

### Clinical Considerations:

- **Sensitivity vs Specificity**: Balance based on clinical requirements
- **False Negative Cost**: Missing pneumonia cases is typically more costly
- **Interpretability**: Clinicians need to understand model reasoning
- **Robustness**: Models must work across diverse patient populations
- **Integration**: Visualizations should fit clinical workflows

This framework enables thorough analysis and validation of pneumonia detection models, ensuring they meet both technical and clinical requirements for safe deployment in medical settings.