In [None]:
"""
Table 3 Generation: Quantitative Explainability Metrics (Gini & Entropy)

Description:
    This script computes the Attention Concentration (Gini Coefficient) and 
    Dispersion (Entropy) metrics reported in Table 3 of the manuscript.
    
    It processes a stratified sample of images from external datasets to quantify 
    the "Stability Gap" between CNNs and Transformers.
    
    Methodology:
    - ViT: Uses Score-CAM (computationally expensive but accurate for Global Attention).
    - CNNs: Uses Grad-CAM (Standard approach).
"""

import os
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import pydicom
import cv2
from tqdm.notebook import tqdm
import warnings

# Suppress warnings
warnings.filterwarnings('ignore')

# --- Dependency Check ---
try:
    from pytorch_grad_cam import GradCAM, ScoreCAM
    from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
except ImportError:
    raise ImportError("Missing 'grad-cam'. Please install via requirements.txt")

# ===================================================================
# 1. Configuration
# ===================================================================

class Config:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    BATCH_SIZE = 1   # Keep 1 for XAI to avoid OOM with Score-CAM
    SAMPLE_SIZE = 50 # Number of images per dataset to analyze (Stratified)
    SEED = 42
    
    # --- Path Configuration (Auto-detect) ---
    if os.path.exists('/kaggle/input'):
        DATA_ROOT = '/kaggle/input'
        WEIGHTS_DIR = './' 
        
        # Adjust these to match your Kaggle dataset structure
        KAGGLE_DIR = os.path.join(DATA_ROOT, 'chest-xray-pneumonia/chest_xray/test')
        # VinDr Path (Example)
        VINDR_ROOT = os.path.join(DATA_ROOT, 'vixdr/vindr-pcxr')
    else:
        DATA_ROOT = './data'
        WEIGHTS_DIR = './weights'
        KAGGLE_DIR = os.path.join(DATA_ROOT, 'kaggle/test')
        VINDR_ROOT = os.path.join(DATA_ROOT, 'vindr')

config = Config()
print(f"‚úÖ Environment Ready: {config.DEVICE}")

# ===================================================================
# 2. Metrics & Utilities
# ===================================================================

def calculate_gini(heatmap):
    """Calculates Gini Coefficient (0=Uniform/Diffuse, 1=Focused)."""
    heatmap = heatmap.flatten() + 1e-7
    heatmap = np.sort(heatmap)
    n = len(heatmap)
    index = np.arange(1, n + 1)
    gini = (2 * np.sum(index * flat)) / (n * np.sum(flat)) - (n + 1) / n
    # Re-implementation for numpy array
    return (np.sum((2 * index - n - 1) * heatmap)) / (n * np.sum(heatmap))

def calculate_entropy(heatmap):
    """Calculates Entropy (Higher = More Disordered/Diffuse)."""
    heatmap = heatmap.flatten()
    heatmap = heatmap / (np.sum(heatmap) + 1e-7)
    heatmap = heatmap[heatmap > 0]
    return -np.sum(heatmap * np.log2(heatmap))

def read_image(path):
    """Universal image reader (DICOM/JPG/PNG) -> RGB 224x224."""
    try:
        if path.endswith('.dicom') or not os.path.splitext(path)[1]:
            dcm = pydicom.dcmread(path)
            img = dcm.pixel_array
            # Handle MONOCHROME1 (Invert)
            if hasattr(dcm, "PhotometricInterpretation") and dcm.PhotometricInterpretation == "MONOCHROME1":
                img = np.max(img) - img
            # Normalize
            img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-6) * 255.0
            img = img.astype(np.uint8)
            if len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        else:
            img = np.array(Image.open(path).convert('RGB'))
        
        img = cv2.resize(img, (224, 224))
        return img
    except: return None

def reshape_transform_vit(tensor):
    """Reshapes ViT embeddings for CAM."""
    result = tensor[:, 1:, :].reshape(tensor.size(0), 14, 14, tensor.size(2))
    result = result.transpose(2, 3).transpose(1, 2)
    return result

# ===================================================================
# 3. Model Loader
# ===================================================================

def get_model_and_cam(arch, weights_dir):
    """Initializes model and selects appropriate CAM method."""
    
    # 1. Find Weights
    search_path = os.path.join(weights_dir, f"*{arch}*best.pth")
    # Recursive search to support Kaggle directory structure
    files = glob.glob(search_path) + glob.glob(f"/kaggle/input/**/*{arch}*best.pth", recursive=True)
    
    if not files:
        print(f"‚ö†Ô∏è No weights found for {arch}. Skipping.")
        return None, None

    weights_path = files[0] # Pick the first match
    
    try:
        # 2. Define Architecture
        if 'efficientnet' in arch.lower():
            model = models.efficientnet_b0(weights=None)
            model.classifier[1] = nn.Linear(model.classifier[1].in_features, 2)
            target_layers = [model.features[-1]]
            cam_cls = GradCAM
            
        elif 'convnext' in arch.lower():
            model = models.convnext_tiny(weights=None)
            model.classifier[2] = nn.Linear(model.classifier[2].in_features, 2)
            target_layers = [model.features[-1]]
            cam_cls = GradCAM
            
        elif 'vit' in arch.lower():
            model = models.vit_b_16(weights=None)
            model.heads.head = nn.Linear(model.heads.head.in_features, 2)
            target_layers = [model.encoder.layers[-1].ln_1]
            cam_cls = ScoreCAM # Critical: GradCAM fails on ViT for this task
            
        # 3. Load State Dict
        state = torch.load(weights_path, map_location=config.DEVICE)
        state = {k.replace('module.', ''): v for k, v in state.items()}
        model.load_state_dict(state, strict=False)
        model.to(config.DEVICE).eval()
        
        # 4. Initialize CAM
        if cam_cls == ScoreCAM:
            cam = cam_cls(model=model, target_layers=target_layers, reshape_transform=reshape_transform_vit)
            cam.batch_size = 16 # Batch processing for Score-CAM speedup
        else:
            cam = cam_cls(model=model, target_layers=target_layers)
            
        return model, cam
        
    except Exception as e:
        print(f"Error loading {arch}: {e}")
        return None, None

# ===================================================================
# 4. Main Execution
# ===================================================================

def run_analysis():
    # --- A. Prepare Samples (Stratified) ---
    samples = {}
    
    # 1. Kaggle (Pneumonia class only for attention analysis)
    if os.path.exists(config.KAGGLE_DIR):
        files = glob.glob(os.path.join(config.KAGGLE_DIR, 'PNEUMONIA', '*'))
        if files:
            np.random.seed(config.SEED)
            # Take random sample to ensure diversity
            samples['Kaggle'] = np.random.choice(files, min(len(files), config.SAMPLE_SIZE), replace=False)
    
    # 2. VinDr (Scan for images if CSV logic is complex)
    # Simplified logic: scan for files in VinDr test directory
    vin_files = glob.glob(os.path.join(config.VINDR_ROOT, '**', '*.dicom'), recursive=True)
    if not vin_files: # Try jpg/png
        vin_files = glob.glob(os.path.join(config.VINDR_ROOT, '**', '*.jp*g'), recursive=True)
        
    if vin_files:
        np.random.seed(config.SEED)
        samples['VinDr'] = np.random.choice(vin_files, min(len(vin_files), config.SAMPLE_SIZE), replace=False)

    print(f"üìä Samples loaded: { {k:len(v) for k,v in samples.items()} }")
    
    if not any(samples.values()):
        print("‚ùå No images found. Check paths in Config.")
        return

    # --- B. Processing Loop ---
    architectures = ['EfficientNet_B0', 'ConvNeXt_Tiny', 'ViT_Base_16']
    results = []
    
    preprocess = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    for arch in architectures:
        print(f"\nü§ñ Analyzing {arch}...")
        model, cam = get_model_and_cam(arch, config.WEIGHTS_DIR)
        if not model: continue
        
        for dataset_name, file_paths in samples.items():
            ginis, entropies = [], []
            
            # Progress bar for images
            for path in tqdm(file_paths, desc=f"   {dataset_name}", leave=False):
                img = read_image(path)
                if img is None: continue
                
                input_tensor = preprocess(Image.fromarray(img)).unsqueeze(0).to(config.DEVICE)
                
                try:
                    # Generate Heatmap for 'Pneumonia' class (Target=1)
                    heatmap = cam(input_tensor=input_tensor, targets=[ClassifierOutputTarget(1)])[0, :]
                    
                    # Compute Metrics
                    ginis.append(calculate_gini(heatmap))
                    entropies.append(calculate_entropy(heatmap))
                except Exception:
                    continue # Skip failed images
            
            if ginis:
                results.append({
                    'Dataset': dataset_name,
                    'Model': arch,
                    'Gini Mean': np.mean(ginis),
                    'Gini SD': np.std(ginis),
                    'Entropy Me