# Table 3 Generation: Quantitative XAI Metrics (Gini & Entropy)

**Description:** 
This notebook calculates the Attention Concentration (Gini Coefficient) and Dispersion (Shannon Entropy) for the evaluated architectures.

It validates the **"Stability Gap"** hypothesis:
* **ViT** is expected to have lower Gini (more diffuse) and higher Entropy.
* **CNNs (ConvNeXt)** are expected to have higher Gini (more focal).

**Methodology:**
* **CNNs:** Grad-CAM
* **ViT:** Score-CAM (to address shattered gradients)

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import warnings

# Suppress warnings
warnings.filterwarnings("ignore")

# --- Dependency Check ---
try:
    from pytorch_grad_cam import GradCAM, ScoreCAM 
    from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
except ImportError:
    raise ImportError("Missing grad-cam. Please run: pip install -r requirements.txt")

In [None]:
# ===================================================================
# 1. Metrics Calculation (The Core Logic)
# ===================================================================
def calculate_gini(heatmap):
    """
    Calculates Gini Coefficient for a heatmap.
    Range: [0, 1]. Higher = More focused/sparse attention.
    """
    if np.sum(heatmap) < 1e-9: return 0.0
    
    # Flatten and sort
    flat = np.sort(heatmap.flatten())
    n = len(flat)
    index = np.arange(1, n + 1)
    
    # Gini formula
    gini = (2 * np.sum(index * flat)) / (n * np.sum(flat)) - (n + 1) / n
    return gini

def calculate_entropy(heatmap):
    """
    Calculates Shannon Entropy for a heatmap.
    Higher = More diffuse/uncertain attention.
    """
    if np.sum(heatmap) < 1e-9: return 0.0
    
    # Normalize to probability distribution
    probs = heatmap.flatten() / (heatmap.sum() + 1e-9)
    probs = probs[probs > 0] # Remove zeros to avoid log(0)
    
    return -np.sum(probs * np.log(probs))

In [None]:
# ===================================================================
# 2. Model & Data Setup
# ===================================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_model(arch, weights_dir):
    if arch == "EfficientNet_B0":
        model = models.efficientnet_b0(weights=None)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, 2)
        target_layers = [model.features[-1]]
        is_vit = False
    elif arch == "ConvNeXt_Tiny":
        model = models.convnext_tiny(weights=None)
        model.classifier[2] = nn.Linear(model.classifier[2].in_features, 2)
        target_layers = [model.features[-1][-1]]
        is_vit = False
    elif arch == "ViT_Base_16":
        model = models.vit_b_16(weights=None)
        model.heads.head = nn.Linear(model.heads.head.in_features, 2)
        target_layers = [model.encoder.layers[-1].ln_1]
        is_vit = True
    else:
        raise ValueError(f"Unknown architecture: {arch}")
        
    # Load weights (Auto-search)
    search_pattern = os.path.join(weights_dir, f"*{arch}*best.pth")
    files = glob.glob(search_pattern) + glob.glob(f"{weights_dir}/**/*{arch}*best.pth", recursive=True)
    if files:
        try:
            state = torch.load(files[0], map_location=DEVICE)
            # Handle DataParallel prefix if present
            state = {k.replace('module.', ''): v for k, v in state.items()}
            model.load_state_dict(state, strict=False)
            print(f"Loaded {arch}")
        except:
            print(f"Warning: Could not load weights for {arch}, using random init.")
    
    model.to(DEVICE).eval()
    return model, target_layers, is_vit

def get_transforms():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# ===================================================================
# 3. Analysis Loop
# ===================================================================
def run_table3_analysis():
    # Detect environment
    if os.path.exists('/kaggle/input'):
        DATA_ROOT = '/kaggle/input'
        WEIGHTS_DIR = './'
    else:
        DATA_ROOT = './data'
        WEIGHTS_DIR = './weights'

    # Define Datasets (Using specific subfolders)
    datasets = {
        'Kaggle': os.path.join(DATA_ROOT, 'chest-xray-pneumonia/chest_xray/test/PNEUMONIA'),
        # Add VinDr path if available in your structure
        # 'VinDr': os.path.join(DATA_ROOT, 'vindr-pcxr/test/Pneumonia') 
    }
    
    models_to_test = ["EfficientNet_B0", "ConvNeXt_Tiny", "ViT_Base_16"]
    results = []

    print(f"Starting XAI Metric Analysis on {DEVICE}...")

    for ds_name, ds_path in datasets.items():
        if not os.path.exists(ds_path):
            print(f"Skipping {ds_name} (Path not found: {ds_path})")
            continue
            
        # Select random subset of 50 images for statistics (as per paper)
        all_imgs = glob.glob(os.path.join(ds_path, "*.jpeg")) + glob.glob(os.path.join(ds_path, "*.jpg"))
        np.random.seed(2025) # Fixed seed for reproducibility
        if len(all_imgs) > 50:
            subset_imgs = np.random.choice(all_imgs, 50, replace=False)
        else:
            subset_imgs = all_imgs

        for arch in models_to_test:
            model, layers, is_vit = get_model(arch, WEIGHTS_DIR)
            
            # Helper for ViT CAM
            def reshape_transform_vit(tensor):
                result = tensor[:, 1:, :].reshape(tensor.size(0), 14, 14, tensor.size(2))
                result = result.transpose(2, 3).transpose(1, 2)
                return result

            # Select Algorithm
            cam_algo = ScoreCAM if is_vit else GradCAM
            
            gini_scores = []
            entropy_scores = []
            
            # Process Batch
            transform = get_transforms()
            targets = [ClassifierOutputTarget(1)] # Target: Pneumonia
            
            print(f"Processing {ds_name} with {arch}...")
            
            try:
                with cam_algo(model=model, target_layers=layers, 
                              reshape_transform=reshape_transform_vit if is_vit else None) as cam:
                    # Optimize ViT with batch size limitation
                    if is_vit: cam.batch_size = 16 
                    
                    for img_path in subset_imgs:
                        try:
                            img = Image.open(img_path).convert('RGB')
                            input_tensor = transform(img).unsqueeze(0).to(DEVICE)
                            
                            # Generate CAM
                            grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
                            
                            # Calculate Metrics
                            gini_scores.append(calculate_gini(grayscale_cam))
                            entropy_scores.append(calculate_entropy(grayscale_cam))
                        except Exception as e:
                            continue # Skip bad images
                            
            except Exception as e:
                print(f"CAM generation failed for {arch}: {e}")
                continue

            # Aggregate Results
            if gini_scores:
                results.append({
                    "Dataset": ds_name,
                    "Architecture": arch,
                    "Gini_Mean": np.mean(gini_scores),
                    "Gini_SD": np.std(gini_scores),
                    "Entropy_Mean": np.mean(entropy_scores),
                    "Entropy_SD": np.std(entropy_scores)
                })

    # --- Print Final Table ---
    print("\n" + "="*60)
    print("        TABLE 3: QUANTITATIVE EXPLAINABILITY METRICS")
    print("="*60)
    df = pd.DataFrame(results)
    if not df.empty:
        # Format for display similar to manuscript
        df['Gini'] = df.apply(lambda x: f"{x['Gini_Mean']:.3f} ± {x['Gini_SD']:.3f}", axis=1)
        df['Entropy'] = df.apply(lambda x: f"{x['Entropy_Mean']:.2f} ± {x['Entropy_SD']:.2f}", axis=1)
        # print(df[['Dataset', 'Architecture', 'Gini', 'Entropy']].to_markdown(index=False))
        print(df[['Dataset', 'Architecture', 'Gini', 'Entropy']])
        
        # Save to CSV
        df.to_csv("table3_metrics.csv", index=False)
        print("\nMetrics saved to table3_metrics.csv")
    else:
        print("No results generated. Check data paths.")

if __name__ == "__main__":
    run_table3_analysis()