In [None]:
"""
Figure 2 Generation: Comprehensive XAI Analysis (Grad-CAM, Score-CAM, LIME)

Description:
    This script generates the multi-panel visualization (Figure 2 in the manuscript).
    It compares EfficientNet, ConvNeXt, and ViT on a representative test case, displaying:
    1. Model Prediction & Confidence
    2. Attention Heatmaps (Grad-CAM++ for CNNs, Score-CAM for ViT)
    3. Quantitative Metrics (Gini Coefficient, Entropy)
    4. LIME Explanations (Super-pixel perturbations)

    Note: 
    - This process is computationally intensive due to LIME and Score-CAM.
    - Ensure 'grad-cam' and 'lime' are installed via requirements.txt.
"""

import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# --- Dependency Check ---
try:
    from lime import lime_image
    from pytorch_grad_cam import GradCAMPlusPlus, ScoreCAM
    from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
    from pytorch_grad_cam.utils.image import show_cam_on_image
except ImportError:
    raise ImportError("Missing dependencies. Please run: pip install -r requirements.txt")

# ===================================================================
# 1. Configuration & Paths
# ===================================================================

class Config:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Path Setup (Auto-detect Environment)
    if os.path.exists('/kaggle/input'):
        DATA_ROOT = '/kaggle/input'
        WEIGHTS_DIR = './' 
        # Specific case used in the paper (Adjust path if necessary)
        TEST_IMAGE_PATH = os.path.join(DATA_ROOT, 'chest-xray-pneumonia/chest_xray/test/PNEUMONIA/person80_bacteria_389.jpeg')
    else:
        DATA_ROOT = './data'
        WEIGHTS_DIR = './weights'
        # Local placeholder path - Reviewer needs to update this
        TEST_IMAGE_PATH = './data/test/PNEUMONIA/case1.jpeg'

config = Config()
print(f"✅ Environment Ready: {config.DEVICE}")

# ===================================================================
# 2. Quantitative Metrics (Gini & Entropy)
# ===================================================================

def calculate_gini(heatmap):
    """Calculates Gini Coefficient: Measure of attention sparsity (Higher = Focused)."""
    if np.sum(heatmap) < 1e-9: return 0.0
    flat = np.sort(heatmap.flatten())
    n = len(flat)
    index = np.arange(1, n + 1)
    gini = (2 * np.sum(index * flat)) / (n * np.sum(flat)) - (n + 1) / n
    return gini

def calculate_entropy(heatmap):
    """Calculates Entropy: Measure of attention disorder (Lower = Focused)."""
    if np.sum(heatmap) < 1e-9: return np.log(heatmap.size)
    flat = heatmap.flatten() / (heatmap.sum() + 1e-9)
    flat = flat[flat > 0]
    return -np.sum(flat * np.log(flat))

# ===================================================================
# 3. Model Loading & XAI Utilities
# ===================================================================

def get_model_and_layers(architecture_name, weights_dir, device):
    """Initializes model and identifies target layers for CAM."""
    is_vit = False
    target_layers = []
    
    # Search for weights file (Auto-find best model)
    weight_path = None
    search_pattern = os.path.join(weights_dir, f"*{architecture_name}*best.pth")
    # Also search subdirectories
    found_files = glob.glob(search_pattern) + glob.glob(f"{weights_dir}/**/*{architecture_name}*best.pth", recursive=True)
    
    if found_files:
        weight_path = found_files[0]
        print(f"Loading {architecture_name} from: {os.path.basename(weight_path)}")
    else:
        print(f"⚠️ Weights for {architecture_name} not found. Using random init (Sanity Check).")

    # Initialize Architecture
    if architecture_name == "EfficientNet_B0":
        model = models.efficientnet_b0(weights=None)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, 2)
        target_layers = [model.features[-1]]
        
    elif architecture_name == "ConvNeXt_Tiny":
        model = models.convnext_tiny(weights=None)
        model.classifier[2] = nn.Linear(model.classifier[2].in_features, 2)
        target_layers = [model.features[-1][-1]]
        
    elif architecture_name == "ViT_Base_16":
        model = models.vit_b_16(weights=None)
        model.heads.head = nn.Linear(model.heads.head.in_features, 2)
        is_vit = True
        # Target: Last LayerNorm of the Encoder (Crucial for ViT CAM)
        target_layers = [model.encoder.layers[-1].ln_1]
    else:
        raise ValueError("Unknown architecture")

    if weight_path:
        try:
            state = torch.load(weight_path, map_location=device)
            state = {k.replace('module.', ''): v for k, v in state.items()}
            model.load_state_dict(state, strict=False)
        except Exception as e:
            print(f"Error loading weights: {e}")
    
    model.to(device).eval()
    return model, target_layers, is_vit

def reshape_transform_vit(tensor):
    """Reshapes ViT patch embeddings to 2D grid for CAM compatibility."""
    result = tensor[:, 1:, :].reshape(tensor.size(0), 14, 14, tensor.size(2))
    result = result.permute(0, 3, 1, 2)
    return result

def batch_predict_for_lime(model, numpy_images, device):
    """Prediction helper function for LIME."""
    model.eval()
    # LIME passes numpy images, need to transform to Tensor
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    # Stack images into a batch
    batch_tensor = torch.stack([preprocess(img) for img in numpy_images], dim=0).to(device)
    with torch.no_grad():
        logits = model(batch_tensor)
    probs = F.softmax(logits, dim=1)
    return probs.cpu().numpy()

# ===================================================================
# 4. Main Generation Loop
# ===================================================================

def generate_figure_2():
    if not os.path.exists(config.TEST_IMAGE_PATH):
        print(f"❌ Test image not found at {config.TEST_IMAGE_PATH}. Please check the path.")
        return

    # Prepare Transforms
    transform_viz = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
    transform_model = transforms.Compose([
        transforms.Resize((224, 224)), transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    raw_pil = Image.open(config.TEST_IMAGE_PATH).convert('RGB')
    input_tensor = transform_model(raw_pil).unsqueeze(0).to(config.DEVICE)
    viz_img = transform_viz(raw_pil).numpy().transpose(1, 2, 0)

    model_list = ["EfficientNet_B0", "ConvNeXt_Tiny", "ViT_Base_16"]
    
    # Initialize Plot
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    plt.subplots_adjust(wspace=0.1, hspace=0.2)
    
    print(f"\nProcessing Image: {os.path.basename(config.TEST_IMAGE_PATH)}")

    for idx, arch in enumerate(model_list):
        print(f"\n---> Analyzing {arch}...")
        model, layers, is_vit = get_model_and_layers(arch, config.WEIGHTS_DIR, config.DEVICE)
        
        # 1. Prediction
        with torch.no_grad():
            logits = model(input_tensor)
            probs = F.softmax(logits, dim=1)[0]
            pred_class = torch.argmax(probs).item() # 1 = Pneumonia
            conf = probs[pred_class].item()
        
        # 2. CAM (Attention)
        # Use ScoreCAM for ViT (more accurate) and GradCAM++ for CNNs
        cam_algo = ScoreCAM if is_vit else GradCAMPlusPlus
        targets = [ClassifierOutputTarget(pred_class)]
        
        print(f"     Generating {cam_algo.__name__}...")
        with cam_algo(model=model, target_layers=layers, 
                      reshape_transform=reshape_transform_vit if is_vit else None) as cam:
            # Batch size limited to avoid OOM on ViT
            if is_vit: cam.batch_size = 16
            grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
            
        cam_vis = show_cam_on_image(viz_img, grayscale_cam, use_rgb=True)
        
        # Calculate Metrics
        gini = calculate_attention_concentration(grayscale_cam)
        entropy = calculate_heatmap_entropy(grayscale_cam)

        # 3. LIME (Explanation)
        print("     Running LIME (This takes time)...")
        explainer = lime_image.LimeImageExplainer()
        explanation = explainer.explain_instance(
            (viz_img * 255).astype(np.uint8), 
            lambda x: batch_predict_for_lime(model, x, config.DEVICE),
            top_labels=2, 
            hide_color=0, 
            num_samples=500, # Adjust samples for speed vs quality
            random_seed=42
        )
        lime_vis, _ = explanation.get_image_and_mask(pred_class, positive_only=True, num_features=5, hide_rest=True)

        # --- Plotting ---
        # Col 1: Original + Label
        axes[idx, 0].imshow(raw_pil)
        axes[idx, 0].set_title(f"{arch}\nPred: {'Pneumonia' if pred_class==1 else 'Normal'} ({conf:.1%})", fontsize=12, fontweight='bold', loc='left')
        axes[idx, 0].axis('off')

        # Col 2: CAM + Metrics
        axes[idx, 1].imshow(cam_vis)
        axes[idx, 1].set_title(f"{'Score-CAM' if is_vit else 'Grad-CAM++'}\nGini: {gini:.3f} | Ent: {entropy:.2f}", fontsize=12)
        axes[idx, 1].axis('off')

        # Col 3: LIME
        axes[idx, 2].imshow(lime_vis)
        axes[idx, 2].set_title("LIME Super-pixels", fontsize=12)
        axes[idx, 2].axis('off')

    # Save
    save_path = "Figure_2_XAI_Comparison.tif"
    plt.savefig(save_path, dpi=300, bbox_inches='tight', pil_kwargs={"compression": "tiff_lzw"})
    print(f"\n✅ Figure 2 saved to {save_path}")
    plt.show()

if __name__ == "__main__":
    generate_figure_2()