In [None]:
#DAMLET C2 OOD EVALUATION CONFIG WITH OPTUNA BEST PARAMETERS

# ==============================================================================
# 1. IMPORTS & SETUP
# ==============================================================================
import os
import random
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from PIL import Image, UnidentifiedImageError
from multiprocessing import freeze_support
from typing import Union, Dict, List, Tuple
from scipy.spatial.distance import cdist
import shutil
import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend for saving figures
import matplotlib.pyplot as plt
import seaborn as sns

from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset, default_collate
from transformers import ViTForImageClassification
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix

try:
    from torch.cuda.amp import autocast
except ImportError:
    from torch.cpu.amp import autocast

# ==============================================================================
# 2. CONFIGURATION - USING OPTUNA'S BEST PARAMETERS (TRIAL 21)
# ==============================================================================
class Config:
    """A centralized configuration class for the C_2 as OOD experiment using best parameters."""
    # --- Path Configuration ---
    # Experts: C0, C1, C3, C4
    MODEL_CHECKPOINTS: List[str] = [
        r"d:/new_patches/new_patches/C_0/model/checkpoint_epoch9_valloss0.0343.pth",
        r"d:/new_patches/new_patches/C_1/model/checkpoint_epoch10_valloss0.0404.pth", 
        r"d:/new_patches/new_patches/C_3/model/checkpoint_epoch9_valloss0.0330_20250429_174446.pth",
        r"D:/new_patches/new_patches/C_4/model/checkpoint_epoch7_valloss0.0344_20250429_133142.pth" 
    ]
    
    # --- Experiment Configuration ---
    TEST_SET_PATH: str = r"D:\new_patches\new_patches\C_2"
    
    # --- OPTUNA BEST PARAMETERS (TRIAL 21 - Accuracy: 93.58%) ---
    NUM_REPRESENTATIVE_SAMPLES: int = 29
    NUM_CLOSEST_SAMPLES_PER_REP: int = 2300
    
    XGB_N_ESTIMATORS: int = 75
    XGB_MAX_DEPTH: int = 2
    XGB_LEARNING_RATE: float = 0.28819589755209746

    MLP_HIDDEN_LAYERS: List[int] = [20, 4] # MLP_LAYER_1_SIZE: 20, MLP_LAYER_2_SIZE: 4
    MLP_EPOCHS: int = 13
    MLP_LR: float = 0.0012246613206706332

    ATTENTION_EPOCHS: int = 23
    ATTENTION_LR: float = 0.003003575085410412
    # --- END OPTUNA PARAMETERS ---
    
    # --- Data Source Configuration ---
    FEATURE_SPACE_EXCEL_PATH: str = r'D:\.kaggle\outputs_newRR5_vit_tsne_full09212025\all_data_with_tsne.xlsx' 
    SOURCE_DATASET_NAMES: List[str] = ["C_0", "C_1", "C_3", "C_4"]
    
    # --- Visualization Configuration ---
    # **ENABLED: Setting to True for final report generation.**
    CREATE_VISUALIZATIONS: bool = True
    VIZ_OTHER_SAMPLES_COUNT: int = 5000 
    
    # --- Output & Temporary Directories ---
    # **NEW OUTPUT FOLDER for FINAL REPORT**
    OUTPUT_DIR_BASE: str = r"D:/.kaggle/C2_OOD_Evaluation_FINAL_REPORT" 
    
    # --- Model & Evaluation Parameters ---
    BATCH_SIZE: int = 128 
    NUM_WORKERS: int = 0
    NUM_LABELS: int = 2
    DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    RANDOM_STATE: int = 42

    # --- Expert Model Names for Logic ---
    EXPERT_NAMES: List[str] = ["C_0", "C_1", "C_3", "C_4"]
    TARGET_NAME: str = "C_2"

# ==============================================================================
# 3. UTILITY FUNCTIONS & META-MODEL CLASSES
# ==============================================================================
def set_seed(seed: int):
    """Sets the random seed for reproducibility."""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class AttentionGatingNetwork(nn.Module):
    def __init__(self, num_models: int, num_classes: int):
        super().__init__()
        self.layer = nn.Linear(num_models * num_classes, num_models)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.softmax(self.layer(x))

class MLPMetaModel(nn.Module):
    def __init__(self, input_size: int, hidden_layers: List[int], output_size: int):
        super().__init__()
        layers = []
        current_size = input_size
        for hidden_size in hidden_layers:
            layers.append(nn.Linear(current_size, hidden_size))
            layers.append(nn.ReLU())
            current_size = hidden_size
        layers.append(nn.Linear(current_size, output_size))
        self.model = nn.Sequential(*layers)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

class ImageListDataset(Dataset):
    """A custom dataset to load a list of images and their labels from a DataFrame."""
    def __init__(self, dataframe: pd.DataFrame, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image = robust_pil_loader(row['Path'])
        label_str = str(row['Inner_Label'])
        label_int = 1 if 'L_1' in label_str else 0
        label = torch.tensor(label_int, dtype=torch.long)
        
        if image and self.transform:
             image = self.transform(image)
        return image, label

def print_section_header(title: str):
    print("\n" + "=" * 80); print(f"| {title.upper():^76} |"); print("=" * 80)

def robust_pil_loader(path: str) -> Union[Image.Image, None]:
    try:
        with open(path, "rb") as f: return Image.open(f).convert("RGB")
    except (UnidentifiedImageError, OSError, FileNotFoundError): return None

def get_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def collate_fn_skip_corrupted(batch: list) -> Tuple:
    """
    Custom collate function that filters out samples with corrupted images (None).
    """
    valid_batch = [item for item in batch if item[0] is not None]
    if not valid_batch:
        return (None, None)
    return default_collate(valid_batch)

# **COORDINATE DETECTION (Crucial for stability)**
def get_coordinate_columns(df: pd.DataFrame) -> Tuple[List[str], str, bool]:
    """Identifies the correct coordinate columns (tSNE or UMAP) regardless of case."""
    
    available_cols = [col.upper() for col in df.columns]
    
    if 'TSNE_X' in available_cols and 'TSNE_Y' in available_cols:
        x_col = df.columns[available_cols.index('TSNE_X')]
        y_col = df.columns[available_cols.index('TSNE_Y')]
        return [x_col, y_col], "t-SNE", True
    elif 'UMAP_X' in available_cols and 'UMAP_Y' in available_cols:
        x_col = df.columns[available_cols.index('UMAP_X')]
        y_col = df.columns[available_cols.index('UMAP_Y')]
        return [x_col, y_col], "UMAP", True
    else:
        # Returns default columns and False flag if necessary columns are missing.
        print("❌ FATAL: Could not find UMAP or t-SNE columns in the Excel file.")
        return ['tSNE_X', 'tSNE_Y'], "t-SNE (Missing)", False


# ==============================================================================
# 4. CORE COMPUTATIONAL, VISUALIZATION & REPORTING FUNCTIONS
# ==============================================================================
def find_representative_samples(test_set_df: pd.DataFrame, config: Config) -> Tuple[List[str], pd.DataFrame]:
    """Uses K-Means clustering on the provided test set to find archetypal samples."""
    
    num_representatives = config.NUM_REPRESENTATIVE_SAMPLES
    coord_cols, coord_name, found = get_coordinate_columns(test_set_df)

    if not found:
        return [], pd.DataFrame()
        
    print(f"  > Clustering {len(test_set_df)} total samples into {num_representatives} groups based on {coord_name} coordinates alone.")
    
    representative_paths = []
    
    if not test_set_df.empty and len(test_set_df) >= num_representatives:
        coords = test_set_df[coord_cols].values
        kmeans = KMeans(n_clusters=num_representatives, random_state=config.RANDOM_STATE, n_init=10).fit(coords)
        
        for centroid in kmeans.cluster_centers_:
            distances = cdist(centroid.reshape(1, -1), coords, 'euclidean').flatten()
            closest_idx = np.argmin(distances)
            representative_paths.append(test_set_df.iloc[closest_idx]['Path'])
    
    representative_df = test_set_df[test_set_df['Path'].isin(representative_paths)]
    return representative_paths, representative_df

def find_closest_samples(target_sample_path: str, feature_space_df: pd.DataFrame, config: Config) -> pd.DataFrame:
    """Identifies the N closest source samples to a single target sample."""
    
    num_closest = config.NUM_CLOSEST_SAMPLES_PER_REP
    coord_cols, _, found = get_coordinate_columns(feature_space_df)
    if not found:
        return None
        
    target_filename = Path(target_sample_path).name
    target_row = feature_space_df[feature_space_df['Filename'] == target_filename]
    if target_row.empty: return None
        
    target_vector = target_row[coord_cols].values
    source_df = feature_space_df[feature_space_df['Dataset'].isin(config.SOURCE_DATASET_NAMES)].copy()
    if source_df.empty: return None

    source_vectors = source_df[coord_cols].values
    distances = cdist(target_vector, source_vectors, 'euclidean').flatten()
    
    source_df['Distance_to_Target'] = distances
    closest_df = source_df.nsmallest(num_closest, 'Distance_to_Target')
    return closest_df

def load_models(config: Config) -> List[torch.nn.Module]:
    print_section_header("Loading Base Models")
    models = []
    for idx, path in enumerate(config.MODEL_CHECKPOINTS):
        try:
            model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=config.NUM_LABELS, ignore_mismatched_sizes=True)
            checkpoint = torch.load(path, map_location=config.DEVICE, weights_only=True)
            model.load_state_dict(checkpoint.get('model_state_dict', checkpoint))
            model.to(config.DEVICE).eval()
            models.append(model)
            print(f"   > Model {config.EXPERT_NAMES[idx]} loaded successfully.")
        except Exception as e:
            print(f"❌ FATAL: Could not load model from {path}: {e}. Exiting.")
            exit()
    return models

def get_predictions(models: List[torch.nn.Module], loader: DataLoader, config: Config, desc: str, leave_progress=True) -> Tuple[np.ndarray, np.ndarray]:
    """Gets predictions for a given dataloader."""
    num_models = len(models)
    all_probs, true_labels = [[] for _ in range(num_models)], []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc=desc, leave=leave_progress):
            if images is None: continue
            images = images.to(config.DEVICE)
            true_labels.extend(labels.cpu().numpy())
            # Fixed deprecated call from older code versions
            with torch.amp.autocast('cuda', enabled=(config.DEVICE.type == 'cuda')): 
                for i, model in enumerate(models):
                    logits = model(images).logits
                    all_probs[i].extend(torch.softmax(logits, dim=1).cpu().numpy())
    return np.array(all_probs), np.array(true_labels)

def calculate_proximity_weights(dataset_counts: pd.Series, config: Config) -> np.ndarray:
    """Calculates model weights based on their dataset's proximity to the target."""
    total_samples = dataset_counts.sum()
    weights = [dataset_counts.get(name, 0) / total_samples for name in config.EXPERT_NAMES]
    return np.array(weights)

def train_all_stacking_models(meta_features_np: np.ndarray, meta_train_labels: np.ndarray, config: Config, models: List[torch.nn.Module]) -> Dict[str, any]:
    
    trained_meta_models = {}
    lr_model = LogisticRegression(random_state=config.RANDOM_STATE, max_iter=1000).fit(meta_features_np, meta_train_labels)
    trained_meta_models["Stacking_LR"] = lr_model

    xgb_model = XGBClassifier(
        n_estimators=config.XGB_N_ESTIMATORS, max_depth=config.XGB_MAX_DEPTH,
        learning_rate=config.XGB_LEARNING_RATE, random_state=config.RANDOM_STATE,
        eval_metric='logloss' 
    ).fit(meta_features_np, meta_train_labels)
    trained_meta_models["Stacking_XGB"] = xgb_model

    mlp_model = MLPMetaModel(
        input_size=len(models) * config.NUM_LABELS, hidden_layers=config.MLP_HIDDEN_LAYERS,
        output_size=config.NUM_LABELS
    ).to(config.DEVICE)
    
    mlp_dataset = TensorDataset(torch.from_numpy(meta_features_np).float(), torch.from_numpy(meta_train_labels).long())
    mlp_loader = DataLoader(mlp_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
    optimizer = torch.optim.Adam(mlp_model.parameters(), lr=config.MLP_LR)
    criterion = nn.CrossEntropyLoss()
    
    mlp_model.train()
    for epoch in tqdm(range(config.MLP_EPOCHS), desc="   => Training MLP", leave=False):
        for x_batch, y_batch in mlp_loader:
            x_batch, y_batch = x_batch.to(config.DEVICE), y_batch.to(config.DEVICE)
            optimizer.zero_grad()
            outputs = mlp_model(x_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
    
    mlp_model.eval()
    trained_meta_models["Stacking_MLP"] = mlp_model
    return trained_meta_models

def train_attention_network(train_probs: np.ndarray, train_labels: np.ndarray, config: Config, base_models: List[torch.nn.Module]) -> AttentionGatingNetwork:
    
    train_probs_t = torch.from_numpy(train_probs.swapaxes(0, 1)).to(config.DEVICE)
    meta_features = train_probs_t.reshape(train_probs_t.shape[0], -1)
    train_labels_t = torch.from_numpy(train_labels).long().to(config.DEVICE)
    
    attention_train_dataset = TensorDataset(meta_features.float(), train_probs_t.float(), train_labels_t)
    attention_loader = DataLoader(attention_train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
    
    attention_net = AttentionGatingNetwork(len(base_models), config.NUM_LABELS).to(config.DEVICE)
    optimizer = torch.optim.Adam(attention_net.parameters(), lr=config.ATTENTION_LR)
    criterion = nn.CrossEntropyLoss()
    
    attention_net.train()
    for epoch in tqdm(range(config.ATTENTION_EPOCHS), desc="   => Training Attention Net", leave=False):
        for x_batch, probs_batch, y_batch in attention_loader:
            optimizer.zero_grad()
            attention_weights = attention_net(x_batch)
            weighted_probs = probs_batch * attention_weights.unsqueeze(-1)
            final_probs = torch.sum(weighted_probs, dim=1)
            loss = criterion(final_probs, y_batch)
            loss.backward()
            optimizer.step()
        
    attention_net.eval()
    return attention_net

def apply_ensembles(all_probs: np.ndarray, proximity_weights: np.ndarray, attention_net: AttentionGatingNetwork, config: Config) -> Dict[str, np.ndarray]:
    """Applies all specialized and standard ensemble methods."""
    all_preds = np.argmax(all_probs, axis=2)
    ensembles = {"Majority Vote": np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=all_preds)}

    if proximity_weights is not None and proximity_weights.any():
        ensembles["Proximity_Weighted (Soft)"] = np.argmax(np.einsum('j,ijk->ik', proximity_weights, all_probs.transpose(1, 0, 2)), axis=1)

    if attention_net:
        test_probs_t = torch.from_numpy(all_probs.swapaxes(0, 1)).to(config.DEVICE)
        meta_features_test = test_probs_t.reshape(test_probs_t.shape[0], -1)
        with torch.no_grad():
            attention_weights = attention_net(meta_features_test.float())
        weighted_probs = test_probs_t.cpu().numpy() * attention_weights.cpu().numpy()[:, :, np.newaxis]
        final_probs = np.sum(weighted_probs, axis=1)
        ensembles["Attention_Ensemble (Specialized)"] = np.argmax(final_probs, axis=1)

    return ensembles

def create_detailed_visualizations(feature_space_df: pd.DataFrame, specific_test_df: pd.DataFrame, representative_paths: List[str], representative_df: pd.DataFrame, config: Config):
    """Generates a series of plots to visualize the feature space."""
    print_section_header("Generating Visualizations")
    viz_dir = Path(config.OUTPUT_DIR_BASE) / "visualizations"
    viz_dir.mkdir(exist_ok=True, parents=True)
    sns.set_style("whitegrid")
    
    coord_cols, coord_name, _ = get_coordinate_columns(feature_space_df)
    coord_x, coord_y = coord_cols[0], coord_cols[1]

    # --- Plot 1: Overall Feature Space of All Datasets ---
    plt.figure(figsize=(18, 14))
    sns.scatterplot(x=coord_x, y=coord_y, hue='Dataset', data=feature_space_df, palette='tab10', alpha=0.6, s=15, legend='full')
    plt.title(f'Global {coord_name} Feature Space of All Datasets', fontsize=20, weight='bold')
    plt.xlabel(f'{coord_name} Dimension 1', fontsize=12); plt.ylabel(f'{coord_name} Dimension 2', fontsize=12)
    plt.legend(title='Dataset', loc='best', markerscale=2)
    plt.savefig(viz_dir / "01_global_feature_space.png", dpi=300, bbox_inches='tight')
    plt.close()
    print("  > Saved plot 1: Global Feature Space")

    # --- Plot 3: OOD Test Set with Class Labels and Representative Anchors ---
    plt.figure(figsize=(16, 12))
    sns.scatterplot(x=coord_x, y=coord_y, hue='Inner_Label', data=specific_test_df, palette={'L_0': 'royalblue', 'L_1': 'crimson'}, alpha=0.5, s=20)
    
    if not representative_df.empty:
        plt.scatter(representative_df[coord_x], representative_df[coord_y], marker='*', s=400, c='gold', edgecolor='black', linewidth=1, label='Representative Anchors')
    
        path_to_r_number = {path: f'R{i+1}' for i, path in enumerate(representative_paths)}
        for i, row in representative_df.iterrows():
            try:
                anchor_index = representative_paths.index(row['Path']) + 1
                plt.text(row[coord_x] + 0.5, row[coord_y] + 0.5, f'R{anchor_index}', fontsize=12, weight='bold', color='black')
            except (ValueError, IndexError):
                continue

    plt.title(f'Distribution of {config.TARGET_NAME} with {len(representative_paths)} Anchors', fontsize=18, weight='bold')
    plt.xlabel(f'{coord_name} Dimension 1'); plt.ylabel(f'{coord_name} Dimension 2')
    plt.legend(title='Class')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.savefig(viz_dir / "03_test_set_with_anchors_kmeans.png", dpi=300, bbox_inches='tight')
    plt.close()
    print("  > Saved plot 3: OOD Test Set with Anchors (KMeans)")

    # --- Plots 4-N: Individual Anchor Neighborhoods ---
    source_df = feature_space_df[feature_space_df['Dataset'].isin(config.SOURCE_DATASET_NAMES)].copy()
    path_to_r_number = {path: f'R{i+1}' for i, path in enumerate(representative_paths)}
    
    for path in tqdm(representative_paths, desc="  > Generating Anchor Neighborhood Plots"):
        closest_df = find_closest_samples(path, feature_space_df, config)
        if closest_df is None: continue
        
        anchor_row = feature_space_df[feature_space_df['Path'] == path]
        anchor_name = path_to_r_number.get(path, 'Unknown')
        
        plt.figure(figsize=(16, 12))
        other_source_df = source_df.drop(closest_df.index).sample(n=min(config.VIZ_OTHER_SAMPLES_COUNT, len(source_df) - len(closest_df)), random_state=config.RANDOM_STATE)
        sns.scatterplot(x=coord_x, y=coord_y, data=other_source_df, color='gainsboro', alpha=0.5, label='Other Source Samples', s=15)
        
        sns.scatterplot(x=coord_x, y=coord_y, hue='Dataset', data=closest_df, palette='viridis', alpha=0.9, s=50, legend='full')

        plt.scatter(anchor_row[coord_x], anchor_row[coord_y], marker='*', s=700, c='red', edgecolor='black', linewidth=1.5, label=f'Anchor {anchor_name}')
        
        plt.title(f'KNN Neighborhood for Anchor {anchor_name}: {Path(path).name} (N={config.NUM_CLOSEST_SAMPLES_PER_REP})', fontsize=16, weight='bold')
        plt.xlabel(f'{coord_name} Dimension 1'); plt.ylabel(f'{coord_name} Dimension 2')
        plt.legend(title='Dataset')
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.savefig(viz_dir / f"04_{anchor_name}_anchor_{Path(path).stem}.png", dpi=300, bbox_inches='tight')
        plt.close()


def compute_full_metrics(true_labels: np.ndarray, preds: np.ndarray) -> Dict[str, any]:
    """Computes a detailed dictionary of classification metrics."""
    cm = confusion_matrix(true_labels, preds, labels=[0, 1])
    tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0,0,0,0)
    
    precision_0 = tn / (tn + fn) if (tn + fn) > 0 else 0.0
    precision_1 = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall_0 = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    recall_1 = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1_0 = 2 * (precision_0 * recall_0) / (precision_0 + recall_0) if (precision_0 + recall_0) > 0 else 0.0
    f1_1 = 2 * (precision_1 * recall_1) / (precision_1 + recall_1) if (precision_1 + recall_1) > 0 else 0.0
    
    class0_mask, class1_mask = true_labels == 0, true_labels == 1
    
    return {
        "acc": 100 * np.sum(preds == true_labels) / len(true_labels) if len(true_labels) > 0 else 0.0,
        "class0_acc": 100 * np.sum(preds[class0_mask] == true_labels[class0_mask]) / class0_mask.sum() if class0_mask.sum() > 0 else 0.0,
        "class1_acc": 100 * np.sum(preds[class1_mask] == true_labels[class1_mask]) / class1_mask.sum() if class1_mask.sum() > 0 else 0.0,
        "precision": [precision_0, precision_1], "recall": [recall_0, recall_1], 
        "f1": [f1_0, f1_1], "correct": int(np.sum(preds == true_labels)), "total": len(true_labels)
    }

def generate_final_report(all_predictions: Dict[str, np.ndarray], true_labels: np.ndarray, class_names: List[str], config: Config):
    """Generates a detailed summary report and saves it as an image."""
    print_section_header("Master Summary Report")
    
    summary_data = []
    method_order = sorted(all_predictions.keys())

    for name in method_order:
        metrics = compute_full_metrics(true_labels, all_predictions[name])
        summary_data.append({
            "Method": name, 
            "Acc (%)": metrics['acc'], # Store as number for sorting
            f"{class_names[0]} Acc (%)": f"{metrics['class0_acc']:.2f}", 
            f"{class_names[1]} Acc (%)": f"{metrics['class1_acc']:.2f}",
            "Prec C0": f"{metrics['precision'][0]:.2f}", "Prec C1": f"{metrics['precision'][1]:.2f}",
            "Rec C0": f"{metrics['recall'][0]:.2f}", "Rec C1": f"{metrics['recall'][1]:.2f}",
            "F1 C0": f"{metrics['f1'][0]:.2f}", "F1 C1": f"{metrics['f1'][1]:.2f}",
            "Correct/Total": f"{metrics['correct']}/{metrics['total']}"
        })

    summary_df = pd.DataFrame(summary_data)
    # Sort the DataFrame by Accuracy in descending order
    summary_df = summary_df.sort_values(by='Acc (%)', ascending=False).reset_index(drop=True)
    # Format the accuracy column back to a string with 2 decimal places for display
    summary_df['Acc (%)'] = summary_df['Acc (%)'].map('{:.2f}'.format)

    print(summary_df.to_string(index=False))

    fig, ax = plt.subplots(figsize=(20, (len(summary_df) * 0.4 + 1.5)))
    ax.axis('off')
    ax.axis('tight')
    table = ax.table(cellText=summary_df.values, colLabels=summary_df.columns, loc='center', cellLoc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(8)
    table.scale(1, 1.8)
    plt.title(f"Evaluation Summary for {config.TARGET_NAME} (Optimized - Acc: {summary_df.iloc[0]['Acc (%)']}%)", fontsize=16, y=1.05)
    summary_image_path = Path(config.OUTPUT_DIR_BASE) / "final_evaluation_summary.png"
    plt.savefig(summary_image_path, bbox_inches='tight', dpi=300)
    plt.close()
    print(f"\n✅ Detailed summary table saved as an image to: {summary_image_path}")


# ==============================================================================
# 5. MAIN EXECUTION SCRIPT
# ==============================================================================
def main():
    """Main function to orchestrate the entire evaluation pipeline."""
    config = Config()
    Path(config.OUTPUT_DIR_BASE).mkdir(exist_ok=True, parents=True)
    set_seed(config.RANDOM_STATE) 
    
    print_section_header("Pipeline Start: Final Report Generation (Optimized Parameters)")
    
    # --- 1. Pre-load Static Data (Models and Feature Space) ---
    base_models = load_models(config)
    try:
        feature_space_df = pd.read_excel(config.FEATURE_SPACE_EXCEL_PATH)
        feature_space_df['Path'] = feature_space_df['Path'].astype(str)
        print(f"✅ Pre-loaded coordinate data from {config.FEATURE_SPACE_EXCEL_PATH}")
    except Exception as e:
        print(f"❌ FATAL ERROR loading coordinate Excel file: {e}"); exit()
        
    # --- 2. Define the Specific Test Set ---
    test_set_filenames = {p.name for p in Path(config.TEST_SET_PATH).rglob('*') if p.suffix in ['.png', '.jpg', '.jpeg', '.tif']}
    specific_test_df = feature_space_df[feature_space_df['Filename'].isin(test_set_filenames)].copy()

    if specific_test_df.empty:
        print(f"❌ FATAL: No matching filenames found between folder '{config.TEST_SET_PATH}' and the Excel file.")
        exit()
    
    # --- 3. Run Predictions on Test Set ONCE (Heavy Step) ---
    print_section_header(f"Running Predictions for Test Set ({len(specific_test_df)} images)...")
    full_test_dataset = ImageListDataset(dataframe=specific_test_df, transform=get_transform())
    full_test_loader = DataLoader(full_test_dataset, batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS, collate_fn=collate_fn_skip_corrupted)
    
    all_test_probs, true_labels = get_predictions(base_models, full_test_loader, config, "   => Final Test Set Preds", leave_progress=True)
    print("✅ Prediction complete.")
    
    # --- 4. Build Aggregate Neighbor Dataset for Training Meta-Models ---
    print_section_header("Building Aggregate Neighbor Dataset")
    representative_paths, representative_df = find_representative_samples(specific_test_df, config)

    all_neighbor_dfs = []
    for path in representative_paths:
        closest_df = find_closest_samples(path, feature_space_df, config)
        if closest_df is not None:
            all_neighbor_dfs.append(closest_df)
    
    if not all_neighbor_dfs:
        print("❌ FATAL: Could not find any neighbors for the representative samples. Halting."); exit()

    aggregate_neighbors_df = pd.concat(all_neighbor_dfs).drop_duplicates(subset=['Filename']).reset_index(drop=True)
    dataset_counts = aggregate_neighbors_df['Dataset'].value_counts()
    print(f"  > Created aggregate training set with {len(aggregate_neighbors_df)} unique neighbors.")
    
    # --- 5. Train specialized meta-models ---
    print_section_header("Training Meta-Models with Optimized Parameters")
    transform = get_transform()
    neighbor_dataset = ImageListDataset(dataframe=aggregate_neighbors_df, transform=transform)
    neighbor_loader = DataLoader(neighbor_dataset, batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS, collate_fn=collate_fn_skip_corrupted)
    
    meta_train_probs, meta_train_labels = get_predictions(base_models, neighbor_loader, config, "   => Neighbor Preds", leave_progress=True)
    
    meta_features_np = meta_train_probs.swapaxes(0, 1).reshape(meta_train_probs.shape[1], -1)
    
    trained_meta_models = train_all_stacking_models(meta_features_np, meta_train_labels, config, base_models)
    attention_network = train_attention_network(meta_train_probs, meta_train_labels, config, base_models)
    proximity_weights = calculate_proximity_weights(dataset_counts, config)

    # --- 6. Apply all ensemble methods on the Test Set ---
    print_section_header("Applying Ensemble Models")
    all_predictions = {f"Model_{name}": np.argmax(all_test_probs[i], axis=1) for i, name in enumerate(config.EXPERT_NAMES)}
    all_predictions.update(apply_ensembles(all_test_probs, proximity_weights, attention_network, config))

    # --- Apply the specialized stacking models ---
    if trained_meta_models:
        meta_features_test = all_test_probs.swapaxes(0, 1).reshape(all_test_probs.shape[1], -1)
        for name, meta_model in trained_meta_models.items():
            if isinstance(meta_model, (LogisticRegression, XGBClassifier)):
                all_predictions[name] = meta_model.predict(meta_features_test)
            elif isinstance(meta_model, nn.Module):
                with torch.no_grad():
                    meta_features_t = torch.from_numpy(meta_features_test).float().to(config.DEVICE)
                    outputs = meta_model(meta_features_t)
                    all_predictions[name] = torch.argmax(outputs, dim=1).cpu().numpy()


    # --- 7. Final Reporting & Saving Results ---
    generate_final_report(all_predictions, true_labels, ['L_0', 'L_1'], config)

    if config.CREATE_VISUALIZATIONS:
        print_section_header("Generating Feature Space Visualizations")
        create_detailed_visualizations(feature_space_df, specific_test_df, representative_paths, representative_df, config)

    # Save per-image results to Excel
    processed_paths_df = specific_test_df.iloc[:len(true_labels)]
    summary_df = pd.DataFrame({'ImagePath': processed_paths_df['Path'].tolist(), 'TrueLabel': true_labels})
    for method, preds in all_predictions.items():
        summary_df[f'{method}_Pred'] = preds
    
    summary_path = Path(config.OUTPUT_DIR_BASE) / f"per_image_results_FINAL_REPORT.xlsx"
    summary_df.to_excel(summary_path, index=False)
    print(f"\n✅ Per-image prediction results saved to: {summary_path}")

    print_section_header("PIPELINE FINISHED - REPORT GENERATED")

if __name__ == '__main__':
    freeze_support()
    main()