In [None]:
"""
Figure 3 Generation: Direct Comparison of Attention Mechanisms (CNN vs ViT)

Description:
    This script generates a side-by-side comparison illustrating the "Stability Gap".
    It visualizes the attention maps of ConvNeXt (Localized) and ViT (Diffuse)
    on the same input image.
    
    Output: A single row figure with 3 panels: (A) Input, (B) ConvNeXt, (C) ViT.
"""

import os
import glob
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import warnings

# Suppress warnings
warnings.filterwarnings("ignore")

# --- Dependency Check ---
try:
    from pytorch_grad_cam import GradCAM, ScoreCAM 
    from pytorch_grad_cam.utils.image import show_cam_on_image
    from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
except ImportError:
    raise ImportError("Missing grad-cam. Please run: pip install -r requirements.txt")

# ===================================================================
# 1. Configuration
# ===================================================================

class Config:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Path Setup (Auto-detect)
    if os.path.exists('/kaggle/input'):
        DATA_ROOT = '/kaggle/input'
        WEIGHTS_DIR = './' 
        # Using the same representative case
        TEST_IMAGE_PATH = os.path.join(DATA_ROOT, 'chest-xray-pneumonia/chest_xray/test/PNEUMONIA/person80_bacteria_389.jpeg')
    else:
        DATA_ROOT = './data'
        WEIGHTS_DIR = './weights'
        TEST_IMAGE_PATH = './data/test/PNEUMONIA/case1.jpeg'

config = Config()

# ===================================================================
# 2. Helper Functions
# ===================================================================

def get_model(arch, weights_dir, device):
    """Loads a specific model architecture and its best weights."""
    if arch == "ConvNeXt":
        model = models.convnext_tiny(weights=None)
        model.classifier[2] = nn.Linear(model.classifier[2].in_features, 2)
        layer = [model.features[-1][-1]]
    elif arch == "ViT":
        model = models.vit_b_16(weights=None)
        model.heads.head = nn.Linear(model.heads.head.in_features, 2)
        layer = [model.encoder.layers[-1].ln_1]
    
    # Search for weights
    search = os.path.join(weights_dir, f"*{arch}*best.pth")
    files = glob.glob(search) + glob.glob(f"{weights_dir}/**/*{arch}*best.pth", recursive=True)
    
    if files:
        state = torch.load(files[0], map_location=device)
        state = {k.replace('module.', ''): v for k, v in state.items()}
        model.load_state_dict(state, strict=False)
        print(f"Loaded {arch} weights from {os.path.basename(files[0])}")
    else:
        print(f"⚠️ Weights for {arch} not found. Using random initialization.")
    
    model.to(device).eval()
    return model, layer

def reshape_transform_vit(tensor):
    """Reshapes ViT embeddings for CAM."""
    result = tensor[:, 1:, :].reshape(tensor.size(0), 14, 14, tensor.size(2))
    result = result.transpose(2, 3).transpose(1, 2)
    return result

# ===================================================================
# 3. Main Generation Loop
# ===================================================================

def generate_figure_3():
    if not os.path.exists(config.TEST_IMAGE_PATH):
        print("❌ Image not found. Please check configuration.")
        return

    # Load & Preprocess Image
    raw_img = Image.open(config.TEST_IMAGE_PATH).convert('RGB')
    
    # Transform for Model
    transform = transforms.Compose([
        transforms.Resize((224, 224)), 
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    input_tensor = transform(raw_img).unsqueeze(0).to(config.DEVICE)
    
    # Transform for Visualization
    viz_img = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])(raw_img).numpy().transpose(1, 2, 0)
    
    targets = [ClassifierOutputTarget(1)] # Target: Pneumonia

    # 1. Generate ConvNeXt Map (Grad-CAM)
    print("Generating ConvNeXt Map...")
    model_c, layer_c = get_model("ConvNeXt", config.WEIGHTS_DIR, config.DEVICE)
    with GradCAM(model=model_c, target_layers=layer_c) as cam:
        map_c = cam(input_tensor=input_tensor, targets=targets)[0, :]
    vis_c = show_cam_on_image(viz_img, map_c, use_rgb=True)

    # 2. Generate ViT Map (Score-CAM)
    print("Generating ViT Map (Score-CAM)...")
    model_v, layer_v = get_model("ViT", config.WEIGHTS_DIR, config.DEVICE)
    with ScoreCAM(model=model_v, target_layers=layer_v, reshape_transform=reshape_transform_vit) as cam:
        # Batch size for memory safety
        cam.batch_size = 16
        map_v = cam(input_tensor=input_tensor, targets=targets)[0, :]
    vis_v = show_cam_on_image(viz_img, map_v, use_rgb=True)

    # 3. Plot Comparison
    plt.rcParams['font.family'] = 'sans-serif'
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    
    ax[0].imshow(raw_img)
    ax[0].set_title("(A) Input CXR", fontsize=14, y=-0.15)
    ax[0].axis('off')
    
    ax[1].imshow(vis_c)
    ax[1].set_title("(B) ConvNeXt (Focal Attention)", fontsize=14, y=-0.15)
    ax[1].axis('off')
    
    ax[2].imshow(vis_v)
    ax[2].set_title("(C) ViT (Diffuse Attention)", fontsize=14, y=-0.15)
    ax[2].axis('off')
    
    plt.tight_layout()
    save_path = "Figure_3_Comparison.tif"
    plt.savefig(save_path, dpi=300, bbox_inches='tight', pil_kwargs={"compression": "tiff_lzw"})
    print(f"\n✅ Figure 3 saved to {save_path}")
    plt.show()

if __name__ == "__main__":
    generate_figure_3()