In [None]:
import numpy as np
import os
import glob
import cv2
import matplotlib.pyplot as plt
from skimage.measure import label, regionprops
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc, roc_auc_score, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import seaborn as sns
import pandas as pd
import json

# Importaciones necesarias para el filtro Gaussiano
import torch
import torch.nn.functional as F
import torchvision.transforms as T

# --- FUNCIONES ADICIONALES REQUERIDAS ---
def percentile_normalize(tensor, percentile_cap=99.0):
    if isinstance(tensor, np.ndarray):
        tensor = torch.from_numpy(tensor).float() # Asegurar que es tensor de PyTorch

    if tensor.numel() == 0:
        return torch.zeros_like(tensor).numpy() # Devolver numpy si la entrada lo fue
    
    flat_tensor = tensor.flatten()
    p_low = torch.quantile(flat_tensor, (100 - percentile_cap) / 100.0)
    p_high = torch.quantile(flat_tensor, percentile_cap / 100.0)

    range_val = p_high - p_low
    if range_val < 1e-8:
        tensor_norm = torch.zeros_like(tensor)
    else:
        tensor_norm = (tensor - p_low) / range_val
        tensor_norm = torch.clamp(tensor_norm, 0, 1)
    
    return tensor_norm.numpy() if isinstance(tensor, torch.Tensor) else tensor_norm # Asegurar que el tipo de retorno coincida con la entrada


# --- NUEVA FUNCIÓN: APLICAR SUAVIZADO GAUSSIANO A UN SOLO MAPA ---
@torch.no_grad()
def apply_gaussian_smoothing_to_single_map(score_map_tensor, sigma=10.0):
    """
    Aplica un filtro Gaussiano a un único mapa de puntuación.
    
    Args:
        score_map_tensor (torch.Tensor): Un tensor PyTorch 2D (H, W) del mapa de puntuación.
        sigma (float): La desviación estándar para el filtro Gaussiano.
        
    Returns:
        torch.Tensor: El mapa de puntuación suavizado como un tensor PyTorch 2D (H, W).
    """
    # Añadir dimensiones de batch y canal para GaussianBlur (de H,W a 1,1,H,W)
    score_map_for_blur = score_map_tensor.unsqueeze(0).unsqueeze(0)

    # Calcular el tamaño del kernel basado en sigma
    # Una heurística común es 6 * sigma + 1 para un kernel que cubra ~3 desviaciones estándar a cada lado.
    gaussian_blur_kernel_size = int(sigma * 6 + 1)
    # Asegurar que el tamaño del kernel sea impar
    if gaussian_blur_kernel_size % 2 == 0:
        gaussian_blur_kernel_size += 1

    gaussian_blur = T.GaussianBlur(kernel_size=(gaussian_blur_kernel_size, gaussian_blur_kernel_size), sigma=(sigma, sigma))

    # Mover el tensor a la GPU si está disponible para el cálculo, luego volver a la CPU
    smoothed_map_tensor = gaussian_blur(score_map_for_blur.to('cuda' if torch.cuda.is_available() else 'cpu'))
    
    # Eliminar las dimensiones de batch y canal extra y asegurar que está en CPU
    return smoothed_map_tensor.squeeze().cpu()


# --- NUEVA FUNCIÓN: EVALUACIÓN A NIVEL DE PÍXEL (AUROC-PÍXEL) ---
# Esta función ahora solo maneja un mapa individual para su recolección de datos
# La curva ROC global a nivel de píxel se calculará por separado.
def get_pixel_level_data(predicted_score_map, ground_truth_mask_path):
    """
    Carga la máscara de ground truth y prepara los datos aplanados (y_true, y_scores)
    para el cálculo global de AUROC a nivel de píxel.

    Args:
        predicted_score_map (np.ndarray): Mapa de puntuación de anomalías predicho por el modelo
                                         (ya suavizado y normalizado, tipo float).
        ground_truth_mask_path (str): Ruta a la máscara binaria de ground truth (imagen .png, .jpg).

    Returns:
        tuple: (y_true_pixels, y_scores_pixels) o (None, None) si hay un error.
    """
    #print(f"Shape of predicted_score_map: {predicted_score_map.shape}")
    #print(f"Shape of ground_truth_mask_path: {ground_truth_mask_path}")
    try:
        ground_truth_mask = cv2.imread(ground_truth_mask_path, cv2.IMREAD_GRAYSCALE)
        if ground_truth_mask is None:
            # print(f"Advertencia: No se pudo cargar la máscara de ground truth: {ground_truth_mask_path}. Saltando.")
            return None, None

        ground_truth_mask = (ground_truth_mask > 0).astype(np.uint8)

        if predicted_score_map.shape != ground_truth_mask.shape:
            # print(f"Advertencia: Dimensiones no coinciden para {os.path.basename(ground_truth_mask_path)}. "
            #       f"Mapa: {predicted_score_map.shape}, GT: {ground_truth_mask.shape}. Saltando.")
            return None, None

        y_true_pixels = ground_truth_mask.flatten()
        y_scores_pixels = predicted_score_map.flatten()

        return y_true_pixels, y_scores_pixels

    except Exception as e:
        # print(f"Ocurrió un error inesperado durante la obtención de datos a nivel de píxel para {ground_truth_mask_path}: {e}")
        return None, None

# --- CONFIGURACIÓN DE RUTAS ---
BASE_MAHALANOBIS_MAPS_DIR = '/home/imercatoma/FeatUp/graficas_evaluacion_transistor'
BASE_IMAGE_DIR = '/home/imercatoma/FeatUp/datasets/mvtec_anomaly_detection/transistor/test'
BASE_GT_MASK_DIR = '/home/imercatoma/FeatUp/datasets/mvtec_anomaly_detection/transistor/ground_truth' # Nueva ruta
BASE_PLOT_SAVE_ROOT_DIR = '/home/imercatoma/FeatUp/graficas_evaluacion_transistor/evaluacion_roc'

ROC_DATA_SAVE_DIR = '/home/imercatoma/FeatUp/roc_data_for_combined_plots'

os.makedirs(BASE_PLOT_SAVE_ROOT_DIR, exist_ok=True)
os.makedirs(ROC_DATA_SAVE_DIR, exist_ok=True)

# --- FUNCIONES EXISTENTES ---
def load_mahalanobis_maps(base_dir):
    all_mahalanobis_maps = {}
    classes = []

    print("--- 1. Detectando clases y cargando mapas de Mahalanobis ---")

    for item in os.listdir(base_dir):
        class_path = os.path.join(base_dir, item)
        if os.path.isdir(class_path):
            classes.append(item)

    classes.sort()
    valid_classes = [cls for cls in classes if cls not in ['evaluacion_roc', 'roc_data_for_combined_plots']]
    print(f"    Clases detectadas: {valid_classes}")

    map_filepaths = {}

    for cls in valid_classes:
        class_specific_maps_dir = os.path.join(base_dir, cls, 'mahalanobis_score_maps')

        map_files = glob.glob(os.path.join(class_specific_maps_dir, 'maha_*.npy'), recursive=False)

        if not map_files:
            print(f"Advertencia: No se encontraron archivos .npy para la clase '{cls}' en {class_specific_maps_dir}")
            all_mahalanobis_maps[cls] = []
            map_filepaths[cls] = []
            continue

        class_maps = []
        class_file_names = []
        for f_path in map_files:
            try:
                map_data = np.load(f_path)
                class_maps.append(map_data)

                base_name = os.path.basename(f_path)
                image_id = base_name.replace('maha_', '').split('.')[0]
                
                if image_id:
                    class_file_names.append(image_id)
                else:
                    class_maps.pop()
            except Exception as e:
                print(f"Error al cargar {f_path}: {e}")
        all_mahalanobis_maps[cls] = class_maps
        map_filepaths[cls] = class_file_names
        print(f"    Total de mapas cargados para '{cls}': {len(class_maps)}")
    print("--- Mapas cargados exitosamente ---\n")
    return all_mahalanobis_maps, valid_classes, map_filepaths


def find_global_min_max_and_top_percentile_avg(mahalanobis_maps_dict, percentile_for_avg=1.0):
    all_pixel_values = []
    print("--- 2. Calculando mínimos, máximos globales y promedio del top 1% ---")
    for cls, maps_list in mahalanobis_maps_dict.items():
        if not maps_list:
            continue
        for map_array in maps_list:
            if map_array.size > 0:
                all_pixel_values.extend(map_array.flatten())

    if not all_pixel_values:
        print("Error: No se encontraron mapas o píxeles para calcular min/max/percentil globales.")
        return None, None, None

    all_pixel_values = np.array(all_pixel_values)
    min_final = np.min(all_pixel_values)
    max_final = np.max(all_pixel_values)
    
    percentile_value = np.percentile(all_pixel_values, 100 - percentile_for_avg)
    
    top_percentile_values = all_pixel_values[all_pixel_values >= percentile_value]

    if top_percentile_values.size == 0:
        avg_top_percentile = max_final
        print(f"    Advertencia: No se encontraron valores por encima del percentil {100 - percentile_for_avg} para calcular el promedio. Usando max_final como promedio del top {percentile_for_avg}%.")
    else:
        avg_top_percentile = np.mean(top_percentile_values)

    print(f"    Mínimo global (min_final): {min_final}")
    print(f"    Máximo global (max_final): {max_final}")
    print(f"    Umbral del percentil {100 - percentile_for_avg} (para el top {percentile_for_avg}%): {percentile_value:.4f}")
    print(f"    Promedio de valores en el top {percentile_for_avg}%: {avg_top_percentile:.4f}")
    print("--- Cálculo de min/max globales y promedio del top 1% finalizado ---\n")
    return min_final, max_final, avg_top_percentile

def normalize_maps(mahalanobis_maps_dict, min_val, max_val_for_norm):
    normalized_mahalanobis_maps = {}
    print("--- 3. Normalizando y aplicando filtro Gaussiano (Sigma 10.0) a mapas de Mahalanobis ---")

    if max_val_for_norm <= min_val:
        print("Advertencia: max_val_for_norm es menor o igual a min_val. La normalización resultará en 0 o 1.")
        for cls, maps_list in mahalanobis_maps_dict.items():
            normalized_class_maps = []
            for map_array in maps_list:
                normalized_map_np = np.full_like(map_array, 0.0, dtype=np.float32)
                if map_array.size > 0 and (map_array.max() >= max_val_for_norm and max_val_for_norm != min_val):
                    normalized_map_np = np.full_like(map_array, 1.0, dtype=np.float32)
                elif map_array.size > 0 and max_val_for_norm == min_val and map_array.max() > min_val:
                    normalized_map_np = np.full_like(map_array, 1.0, dtype=np.float32)
                
                # Convertir a tensor, aplicar suavizado, convertir de nuevo a numpy
                smoothed_map_tensor = apply_gaussian_smoothing_to_single_map(torch.from_numpy(normalized_map_np).float(), sigma=10.0)
                normalized_class_maps.append(smoothed_map_tensor.numpy())
            normalized_mahalanobis_maps[cls] = normalized_class_maps
        print("--- Normalización y suavizado finalizados (caso especial) ---\n")
        return normalized_mahalanobis_maps

    for cls, maps_list in mahalanobis_maps_dict.items():
        normalized_class_maps = []
        for i, map_array in enumerate(maps_list):
            normalized_map_np = (map_array - min_val) / (max_val_for_norm - min_val)
            normalized_map_np = np.clip(normalized_map_np, 0, 1) # Normalización estándar 0-1

            # Convertir numpy array a tensor de PyTorch
            map_tensor = torch.from_numpy(normalized_map_np).float()

            # Aplicar suavizado Gaussiano con sigma=10.0
            smoothed_map_tensor = apply_gaussian_smoothing_to_single_map(map_tensor, sigma=10.0)

            # Convertir el tensor suavizado de nuevo a numpy array
            normalized_class_maps.append(smoothed_map_tensor.numpy())
        normalized_mahalanobis_maps[cls] = normalized_class_maps
    print("--- Normalización y suavizado de mapas finalizados ---\n")
    return normalized_mahalanobis_maps


def apply_threshold_and_filter(score_map, threshold, min_area_pixels=500):
    binary_mask = (score_map > threshold).astype(np.uint8) * 255
    if np.sum(binary_mask) == 0:
        return np.zeros_like(binary_mask)
    labeled_mask = label(binary_mask)
    filtered_mask = np.zeros_like(binary_mask)
    for region in regionprops(labeled_mask):
        if region.area >= min_area_pixels:
            coords = region.coords
            filtered_mask[coords[:, 0], coords[:, 1]] = 255
    return filtered_mask

def classify_image_anomaly(predicted_mask):
    return np.sum(predicted_mask) > 0

def plot_roc_curve(fpr, tpr, roc_auc, optimal_thresholds_for_plotting, save_path, thresholds_roc_values, curve_type="Image-level", category_name=""):
    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'Curva ROC (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Tasa de Falsos Positivos (FPR)')
    plt.ylabel('Tasa de Verdaderos Positivos (TPR)')
    title = f'Curva ROC de Detección de Anomalías ({curve_type})'
    if category_name:
        title += f' ({category_name})'
    plt.title(title)
    plt.legend(loc="lower right")
    if optimal_thresholds_for_plotting is not None and len(optimal_thresholds_for_plotting) > 0:
        for opt_thresh_plot in optimal_thresholds_for_plotting:
            # Encuentra el índice del umbral más cercano para la anotación
            idx = np.argmin(np.abs(thresholds_roc_values - opt_thresh_plot))
            plt.plot(fpr[idx], tpr[idx], 'o', color='red', markersize=8)
            plt.annotate(f'{opt_thresh_plot:.2f}', (fpr[idx], tpr[idx]), textcoords="offset points", xytext=(5,-10), ha='center', color='red')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f"✅ Curva ROC ({curve_type}) guardada en: {save_path}")

def visualize_overlay(image_path, score_map, threshold, min_area_pixels, save_path):
    try:
        original_image = cv2.imread(image_path)
        if original_image is None:
            print(f"Error: No se pudo cargar la imagen original desde {image_path}")
            return
        original_image_rgb = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
        filtered_mask = apply_threshold_and_filter(score_map, threshold, min_area_pixels)
        overlay_color = np.array([255, 0, 0], dtype=np.uint8)
        overlay = np.zeros_like(original_image_rgb, dtype=np.uint8)
        overlay[filtered_mask > 0] = overlay_color
        alpha = 0.4
        overlaid_image = cv2.addWeighted(original_image_rgb, 1 - alpha, overlay, alpha, 0)
        plt.figure(figsize=(10, 10))
        plt.imshow(overlaid_image)
        plt.title(f'Anomalía Detectada (Umbral: {threshold:.4f})\n{os.path.basename(image_path)}')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()
    except Exception as e:
        print(f"Error al visualizar la superposición para {image_path}: {e}")

def plot_confusion_matrix(y_true, y_pred, save_path, threshold, display_labels_true, display_labels_pred, title_suffix=""):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 9))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
                xticklabels=display_labels_pred,
                yticklabels=display_labels_true)
    plt.xlabel('Predicción')
    plt.ylabel('Etiqueta Verdadera')
    plt.title(f'Matriz de Confusión {title_suffix} (Umbral: {threshold:.4f})')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f"✅ Matriz de Confusión {title_suffix} guardada en: {save_path}")

def calculate_and_print_metrics(y_true, y_pred, threshold, min_connected_component_area, auc_value):
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)

    cm = confusion_matrix(y_true, y_pred)
    if cm.shape == (2,2):
        tn, fp, fn, tp = cm.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    else:
        specificity = float('nan')

    f1 = f1_score(y_true, y_pred, zero_division=0)

    print(f"\n--- Métricas de Rendimiento a Nivel de Imagen (Umbral: {threshold:.4f}, MCC Area: {min_connected_component_area}) ---")
    print(f"    Accuracy:        {accuracy:.4f}")
    print(f"    Precision:       {precision:.4f}")
    print(f"    Recall (Sensibilidad): {recall:.4f}")
    print(f"    Especificidad: {specificity:.4f}")
    print(f"    F1-Score:        {f1:.4f}")
    print("--------------------------------------------------------------------")
    return {
        "Umbral": f"{threshold:.4f}",
        "Min_Connected_Component_Area": min_connected_component_area,
        "AUC": f"{auc_value:.4f}",
        "Accuracy": f"{accuracy:.4f}",
        "Precision": f"{precision:.4f}",
        "Recall (Sensibilidad)": f"{recall:.4f}",
        "Especificidad": f"{specificity:.4f}",
        "F1-Score": f"{f1:.4f}"
    }
    
def get_top_n_values_from_maps(mahalanobis_maps_dict, map_file_ids_dict, n=10):
    print(f"\n--- Top {n} valores más altos de Mahalanobis para cada mapa ---")
    for cls_name, maps_list in mahalanobis_maps_dict.items():
        file_ids = map_file_ids_dict.get(cls_name, [])
        if not maps_list:
            print(f"    No hay mapas para la clase '{cls_name}'.")
            continue
        print(f" Clase: '{cls_name}'")
        for i, score_map in enumerate(maps_list):
            if score_map.size == 0:
                print(f"      Mapa {file_ids[i] if i < len(file_ids) else f'Index {i}'}: Vacío.")
                continue
            
            flat_scores = score_map.flatten()
            top_n_values = np.sort(flat_scores)[::-1][:n]
            print(f"      Mapa {file_ids[i] if i < len(file_ids) else f'Index {i}'} (Top {n}): {[f'{val:.4f}' for val in top_n_values]}")
    print("--- Fin de la visualización de los top valores ---")

def print_raw_top_10_mahalanobis_scores_from_loaded(mahalanobis_maps_dict, map_file_ids_dict, n=10):
    print(f"\n--- Top {n} valores de Mahalanobis (RAW - Sin normalizar) ---")
    
    for cls_name, maps_list in mahalanobis_maps_dict.items():
        file_ids = map_file_ids_dict.get(cls_name, [])
        if not maps_list:
            print(f"    Clase: '{cls_name}' - No hay mapas cargados.")
            continue
        
        print(f"Clase: '{cls_name}'")
        for i, score_map in enumerate(maps_list):
            image_id = file_ids[i] if i < len(file_ids) else f'Index {i}'
            
            if score_map.size == 0:
                print(f"    Imagen: {image_id} - Mapa vacío.")
                continue

            top_n_values = np.sort(score_map.flatten())[-n:]
            
            print(f"    Imagen: {image_id}")
            print(f"    Top {n} valores: {[f'{val:.3f}' for val in top_n_values]}")
    print("--- Fin de la visualización de los top valores raw ---")

def plot_multi_class_binary_prediction_confusion_matrix(y_true_class_names, y_pred_binary_labels, save_path, threshold, title_suffix=""):
    """
    Genera una matriz de confusión con las clases verdaderas originales en el eje Y
    y las predicciones binarias ('Normal'/'Anómalo') en el eje X.
    
    Args:
        y_true_class_names (list): Lista de nombres de clases verdaderas (strings, ej. 'good', 'crack').
        y_pred_binary_labels (list): Lista de etiquetas predichas binarias (strings, 'Normal' o 'Anómalo').
        save_path (str): Ruta completa donde se guardará el gráfico.
        threshold (float): Umbral utilizado para la evaluación.
        title_suffix (str): Sufijo para el título del gráfico.
    """
    if len(y_true_class_names) != len(y_pred_binary_labels):
        print("Error: Las listas de etiquetas verdaderas y predichas tienen longitudes diferentes.")
        return

    # 1. Definir el orden de las clases verdaderas para el eje Y
    true_classes_unique = sorted(list(set(y_true_class_names)))
    if 'good' in true_classes_unique:
        true_classes_unique.remove('good')
        true_classes_unique.insert(0, 'good') # Asegurar que 'good' esté al principio

    num_true_classes = len(true_classes_unique)
    
    # 2. Definir las etiquetas de las predicciones binarias para el eje X
    predicted_binary_labels_display = ['Normal', 'Anómalo']
    num_pred_classes = len(predicted_binary_labels_display)

    # 3. Crear mapeos de etiquetas a índices numéricos
    true_label_to_idx = {label: i for i, label in enumerate(true_classes_unique)}
    pred_binary_label_to_idx = {'Normal': 0, 'Anómalo': 1} # 'Normal' -> 0, 'Anómalo' -> 1

    # 4. Inicializar una matriz de ceros con las dimensiones exactas deseadas (num_true_classes x 2)
    custom_cm = np.zeros((num_true_classes, num_pred_classes), dtype=int)

    # 5. Rellenar la matriz de confusión personalizada
    for i in range(len(y_true_class_names)):
        true_class = y_true_class_names[i]
        predicted_binary_label = y_pred_binary_labels[i]

        true_idx = true_label_to_idx.get(true_class)
        pred_idx = pred_binary_label_to_idx.get(predicted_binary_label)

        # Solo si ambas etiquetas son válidas, incrementamos el contador
        if true_idx is not None and pred_idx is not None:
            custom_cm[true_idx, pred_idx] += 1
        else:
            print(f"Advertencia: Etiqueta inesperada encontrada. Verdadera: '{true_class}', Predicha: '{predicted_binary_label}'")

    # 6. Generar el heatmap con la matriz personalizada
    plt.figure(figsize=(5, 5))
    sns.heatmap(custom_cm, annot=True, fmt='d', cmap='Blues', cbar=False,
                xticklabels=predicted_binary_labels_display,
                yticklabels=true_classes_unique)

    plt.xlabel('Predicción')
    plt.ylabel('Clase Verdadera')
    plt.title(f'Matriz de Confusión (Umbral: {threshold:.4f})', loc='center')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f"✅ Matriz de Confusión {title_suffix} guardada en: {save_path}")



# --- EJECUCIÓN DE LOS PASOS ---
if __name__ == "__main__":
    current_category = os.path.basename(os.path.normpath(BASE_MAHALANOBIS_MAPS_DIR))
    print(f"\n***** Procesando categoría: {current_category.upper()} *****\n")

    mahalanobis_maps, MAP_CLASSES, MAP_FILE_IDS = load_mahalanobis_maps(BASE_MAHALANOBIS_MAPS_DIR)

    print_raw_top_10_mahalanobis_scores_from_loaded(mahalanobis_maps, MAP_FILE_IDS, n=10)
    get_top_n_values_from_maps(mahalanobis_maps, MAP_FILE_IDS, n=10)

    CLASSES = [cls for cls in MAP_CLASSES if cls not in ['evaluacion_roc', 'roc_data_for_combined_plots', '']]
    CLASSES.sort()
    print(f"    Clases finales para procesamiento: {CLASSES}")

    class_to_id = {cls_name: i for i, cls_name in enumerate(CLASSES)}
    id_to_class = {i: cls_name for i, cls_name in enumerate(CLASSES)}
    print(f"\n    Mapeo de clases a IDs: {class_to_id}")

    min_final_val, max_final_val_original, avg_top_percentile_val = find_global_min_max_and_top_percentile_avg(mahalanobis_maps, percentile_for_avg=0.1)

    if min_final_val is None or avg_top_percentile_val is None:
        print("No se pudo proceder con la normalización y evaluación debido a un error en el cálculo de min/max/promedio del top 1%.")
        exit()

    # Aquí es donde se llama a normalize_maps, que ahora incluye el suavizado Gaussiano
    normalized_mahalanobis_maps = normalize_maps(mahalanobis_maps, min_final_val, avg_top_percentile_val)

    print(f"\nProceso completado para las clases: {CLASSES}")

    print("\n--- 4. Evaluando a nivel de imagen para la curva ROC y preparando datos para métricas ---")

    MIN_CONNECTED_COMPONENT_AREA = 0

    all_true_labels_binary_roc = []
    all_anomaly_scores_for_roc = []

    # Nuevas listas para la evaluación global a nivel de píxel
    all_true_pixels = []
    all_predicted_scores_pixels = []

    predicted_label_to_id_detailed = {}
    predicted_normal_id_detailed = 0
    predicted_label_to_id_detailed['Predicted Normal'] = predicted_normal_id_detailed
    
    predicted_class_id_counter = 1

    if 'good' in CLASSES:
        predicted_label_to_id_detailed[f'Predicted Anomaly (from good)'] = predicted_class_id_counter
        predicted_class_id_counter += 1

    anomaly_classes = [cls for cls in CLASSES if cls != 'good']
    anomaly_classes.sort()
    for cls_anomaly in anomaly_classes:
        predicted_label_to_id_detailed[f'Predicted {cls_anomaly.capitalize()} Anomaly'] = predicted_class_id_counter
        predicted_class_id_counter += 1

    print("    Recolectando puntuaciones de anomalía y etiquetas verdaderas (para ROC y CM)...")

    for cls_name in CLASSES:
        maps_list = normalized_mahalanobis_maps.get(cls_name, [])
        file_ids = MAP_FILE_IDS.get(cls_name, [])

        if not maps_list:
            continue

        gt_label_is_anomaly = (cls_name != 'good')

        for i, score_map in tqdm(enumerate(maps_list), desc=f"    Procesando mapas de {cls_name}"):
            image_max_anomaly_score = 0.0
            if score_map.size > 0:
                image_max_anomaly_score = np.max(score_map)

            all_true_labels_binary_roc.append(1 if gt_label_is_anomaly else 0)
            all_anomaly_scores_for_roc.append(image_max_anomaly_score)

            # --- Recolección de datos a nivel de píxel ---
            image_id = file_ids[i]
            gt_mask_sub_dir = cls_name 
            gt_mask_filename = f"{image_id}_mask.png" 
            ground_truth_mask_full_path = os.path.join(BASE_GT_MASK_DIR, gt_mask_sub_dir, gt_mask_filename)
            
            y_true_pixels_img, y_scores_pixels_img = get_pixel_level_data(score_map, ground_truth_mask_full_path)
            if y_true_pixels_img is not None and y_scores_pixels_img is not None:
                all_true_pixels.extend(y_true_pixels_img)
                all_predicted_scores_pixels.extend(y_scores_pixels_img)


    print(f"\n[DEBUG] all_true_labels_binary_roc (first 10): {all_true_labels_binary_roc[:10]}")
    print(f"[DEBUG] all_anomaly_scores_for_roc (first 10): {[f'{s:.4f}' for s in all_anomaly_scores_for_roc[:10]]}")
    print(f"[DEBUG] Unique values in all_true_labels_binary_roc: {np.unique(all_true_labels_binary_roc)}")
    print(f"[DEBUG] Unique values in all_anomaly_scores_for_roc: {np.unique(all_anomaly_scores_for_roc)}")

    category_name_for_roc = current_category

    # --- Cálculo y Ploteo de Curva ROC a Nivel de Imagen ---
    if len(np.unique(all_true_labels_binary_roc)) < 2:
        print("\nAdvertencia: Solo hay una clase en all_true_labels_binary_roc. No se puede calcular la curva ROC a nivel de imagen.")
        roc_auc_image_level = float('nan')
        fpr_image_level, tpr_image_level, thresholds_roc_raw_image_level = np.array([0,1]), np.array([0,1]), np.array([0,1])
    elif len(np.unique(all_anomaly_scores_for_roc)) < 2 or np.all(all_anomaly_scores_for_roc == all_anomaly_scores_for_roc[0]):
        print("\nAdvertencia: all_anomaly_scores_for_roc contiene solo un valor único o muy pocos que impiden ROC a nivel de imagen.")
        roc_auc_image_level = 0.5 
        fpr_image_level, tpr_image_level, thresholds_roc_raw_image_level = np.array([0,1]), np.array([0,1]), np.array([0,1])
    else:
        fpr_image_level, tpr_image_level, thresholds_roc_raw_image_level = roc_curve(all_true_labels_binary_roc, all_anomaly_scores_for_roc)
        roc_auc_image_level = auc(fpr_image_level, tpr_image_level)

    print(f"\n--- Cálculo de ROC y AUC (Nivel de Imagen) finalizado ---")
    print(f"Área Bajo la Curva (AUC - Nivel de Imagen): {roc_auc_image_level:.4f}")

    roc_data_filename_image = f"roc_data_image_level_{category_name_for_roc}.json"
    roc_data_filepath_image = os.path.join(ROC_DATA_SAVE_DIR, roc_data_filename_image)

    roc_data_to_save_image = {
        'category': category_name_for_roc,
        'fpr': fpr_image_level.tolist(),
        'tpr': tpr_image_level.tolist(),
        'roc_auc': roc_auc_image_level
    }

    with open(roc_data_filepath_image, 'w') as f:
        json.dump(roc_data_to_save_image, f)
    print(f"✅ Datos de la curva ROC (Nivel de Imagen) para '{category_name_for_roc}' guardados en: {roc_data_filepath_image}")

    optimal_thresholds_for_plotting_image = []
    optimal_thresholds_for_metrics_image = []

    if not np.isnan(roc_auc_image_level) and roc_auc_image_level > 0:
        youden_j = tpr_image_level - fpr_image_level
        best_idx = np.argmax(youden_j)

        if len(thresholds_roc_raw_image_level) > best_idx and thresholds_roc_raw_image_level[best_idx] not in optimal_thresholds_for_metrics_image:
            optimal_thresholds_for_metrics_image.append(thresholds_roc_raw_image_level[best_idx])
            optimal_thresholds_for_plotting_image.append(thresholds_roc_raw_image_level[best_idx])
        
        # Añadir algunos umbrales representativos
        unique_thresholds_sorted = sorted(list(np.unique(thresholds_roc_raw_image_level)))
        num_to_add = 5 - len(optimal_thresholds_for_metrics_image)
        if num_to_add > 0:
            step = max(1, len(unique_thresholds_sorted) // (num_to_add + 1))
            for i in range(step, len(unique_thresholds_sorted), step):
                if len(optimal_thresholds_for_metrics_image) >= 5:
                    break
                current_threshold = unique_thresholds_sorted[i]
                if 0.001 < current_threshold < 0.999 and current_threshold not in optimal_thresholds_for_metrics_image:
                    idx = np.where(thresholds_roc_raw_image_level == current_threshold)[0][0]
                    if (fpr_image_level[idx] > 0 or tpr_image_level[idx] > 0) and (fpr_image_level[idx] < 1 or tpr_image_level[idx] < 1):
                        optimal_thresholds_for_metrics_image.append(current_threshold)
                        optimal_thresholds_for_plotting_image.append(current_threshold)
        
        optimal_thresholds_for_metrics_image.sort()
        optimal_thresholds_for_plotting_image.sort()


    print(f"\n--- Umbrales 'Óptimos' detectados (Nivel de Imagen) ---")
    if not optimal_thresholds_for_metrics_image:
        print("    No se pudieron encontrar umbrales óptimos únicos en el rango (0,1) para el nivel de imagen.")
    for i, opt_thresh in enumerate(optimal_thresholds_for_metrics_image):
        idx = np.argmin(np.abs(thresholds_roc_raw_image_level - opt_thresh))
        print(f"    Umbral {i+1}: {opt_thresh:.4f} (TPR: {tpr_image_level[idx]:.4f}, FPR: {fpr_image_level[idx]:.4f})")

    roc_save_path_image = os.path.join(BASE_PLOT_SAVE_ROOT_DIR, 'roc_curve_image_level.png')
    plot_roc_curve(fpr_image_level, tpr_image_level, roc_auc_image_level, optimal_thresholds_for_plotting_image, roc_save_path_image, thresholds_roc_raw_image_level, "Image-level", current_category)

    selected_threshold_for_eval = None
    if optimal_thresholds_for_metrics_image:
        selected_threshold_for_eval = optimal_thresholds_for_metrics_image[0]
        print(f"\n    Umbral seleccionado para visualización y métricas: {selected_threshold_for_eval:.4f}")
    else:
        print("\nAdvertencia: No se encontraron umbrales óptimos. Usando un umbral por defecto de 0.5 para visualización y métricas.")
        selected_threshold_for_eval = 0.5

    if selected_threshold_for_eval is None:
        print("No se pudo determinar un umbral para la evaluación. No se realizarán las visualizaciones, matriz de confusión ni tabla de métricas.")
        exit()

    print(f"\n--- Generando predicciones finales con el umbral seleccionado ({selected_threshold_for_eval:.4f}) ---")
    all_predicted_labels_cm_binary = []
    
    # Estas son las listas que ya habías pedido y no se modifican.
    all_true_class_names_for_cm = []
    all_predicted_binary_labels_for_cm = [] 

    for cls_name in CLASSES:
        maps_list = normalized_mahalanobis_maps.get(cls_name, [])
        if not maps_list:
            continue

        gt_label_is_anomaly = (cls_name != 'good')
        
        for score_map in tqdm(maps_list, desc=f"    Aplicando umbral para {cls_name}"):
            binary_mask = apply_threshold_and_filter(score_map, selected_threshold_for_eval, MIN_CONNECTED_COMPONENT_AREA)
            is_predicted_anomaly = classify_image_anomaly(binary_mask)

            # Para la matriz de confusión binaria 2x2 (True Anomaly vs Predicted Anomaly)
            all_predicted_labels_cm_binary.append(1 if is_predicted_anomaly else 0)

            # Para la matriz de confusión 5x2 (True Class vs Predicted Binary)
            all_true_class_names_for_cm.append(cls_name) # Nombre de la clase verdadera
            all_predicted_binary_labels_for_cm.append('Anómalo' if is_predicted_anomaly else 'Normal') # Predicción binaria

    print("\n--- 5.1: Generando Matriz de Confusión Binaria (Normal vs. Anómala) ---")

    cm_binary_save_path = os.path.join(BASE_PLOT_SAVE_ROOT_DIR, f'confusion_matrix_binary_thresh_{selected_threshold_for_eval:.4f}.png')

    display_labels_true_binary = ['Normal (Good)', 'Anomalous (Any Type)']
    display_labels_pred_binary = ['Predicted Normal', 'Predicted Anomalous']

    if len(all_true_labels_binary_roc) != len(all_predicted_labels_cm_binary):
        print(f"Error: Longitud de etiquetas verdaderas ({len(all_true_labels_binary_roc)}) y predichas ({len(all_predicted_labels_cm_binary)}) para CM binaria no coinciden.")
    else:
        plot_confusion_matrix(all_true_labels_binary_roc, all_predicted_labels_cm_binary, cm_binary_save_path,
                              selected_threshold_for_eval, display_labels_true_binary, display_labels_pred_binary,
                              title_suffix=" - Binary")


    print("\n--- 5.2: Generando Matriz de Confusión Detallada (True Class vs. Predicted Binary) ---")

    cm_multi_binary_save_path = os.path.join(BASE_PLOT_SAVE_ROOT_DIR, f'confusion_matrix_multi_binary_thresh_{selected_threshold_for_eval:.4f}.png')
    
    # Llamada a la función corregida para la matriz 5x2
    plot_multi_class_binary_prediction_confusion_matrix(
        all_true_class_names_for_cm,
        all_predicted_binary_labels_for_cm,
        cm_multi_binary_save_path,
        selected_threshold_for_eval,
        title_suffix=" - Multi-Class vs. Binary Prediction"
    )



    print("\n--- 6. Calculando, mostrando y guardando Tabla de Métricas de Rendimiento (Nivel de Imagen) ---")
    metrics_data = calculate_and_print_metrics(all_true_labels_binary_roc, all_predicted_labels_cm_binary, selected_threshold_for_eval, MIN_CONNECTED_COMPONENT_AREA, roc_auc_image_level)

    # Convertir a JSON en lugar de Excel
    image_level_metrics_save_path = os.path.join(BASE_PLOT_SAVE_ROOT_DIR, 'image_level_metrics.json')

    # Si ya existe, cargar y añadir. Si no, crear.
    if os.path.exists(image_level_metrics_save_path):
        try:
            with open(image_level_metrics_save_path, 'r') as f:
                existing_data = json.load(f)
            if not isinstance(existing_data, list): # Asegurarse de que sea una lista para añadir
                existing_data = [existing_data] if existing_data else []
            existing_data.append(metrics_data)
            with open(image_level_metrics_save_path, 'w') as f:
                json.dump(existing_data, f, indent=4)
            print(f"✅ Métricas añadidas al archivo JSON existente: {image_level_metrics_save_path}")
        except json.JSONDecodeError as e:
            print(f"⚠️ Error al leer el archivo JSON existente. Creando uno nuevo. Error: {e}")
            with open(image_level_metrics_save_path, 'w') as f:
                json.dump([metrics_data], f, indent=4)
            print(f"✅ Métricas guardadas en un nuevo archivo JSON: {image_level_metrics_save_path}")
    else:
        with open(image_level_metrics_save_path, 'w') as f:
            json.dump([metrics_data], f, indent=4)
        print(f"✅ Métricas guardadas en un nuevo archivo JSON: {image_level_metrics_save_path}")

    # --- NUEVO BLOQUE: EVALUACIÓN GLOBAL A NIVEL DE PÍXEL (AUROC-Píxel y Curva ROC) ---
    print("\n--- 7. Evaluando a nivel de Píxel (AUROC-Píxel Global y Curva ROC) ---")

    if len(np.unique(all_true_pixels)) < 2:
        print("\nAdvertencia: No hay suficientes clases en los píxeles de ground truth para calcular la curva ROC a nivel de píxel.")
        pixel_auroc_global = float('nan')
        fpr_pixel_level, tpr_pixel_level, thresholds_roc_raw_pixel_level = np.array([0,1]), np.array([0,1]), np.array([0,1])
    elif len(np.unique(all_predicted_scores_pixels)) < 2 or np.all(all_predicted_scores_pixels == all_predicted_scores_pixels[0]):
        print("\nAdvertencia: Los scores predichos a nivel de píxel contienen solo un valor único o muy pocos que impiden ROC.")
        pixel_auroc_global = 0.5
        fpr_pixel_level, tpr_pixel_level, thresholds_roc_raw_pixel_level = np.array([0,1]), np.array([0,1]), np.array([0,1])
    else:
        fpr_pixel_level, tpr_pixel_level, thresholds_roc_raw_pixel_level = roc_curve(all_true_pixels, all_predicted_scores_pixels)
        pixel_auroc_global = auc(fpr_pixel_level, tpr_pixel_level)

    print(f"\n--- Cálculo de ROC y AUC (Nivel de Píxel Global) finalizado ---")
    print(f"Área Bajo la Curva (AUC - Nivel de Píxel Global): {pixel_auroc_global:.4f}")

    # Guarda los datos de la curva ROC a nivel de píxel
    roc_data_filename_pixel = f"roc_data_pixel_level_{category_name_for_roc}.json"
    roc_data_filepath_pixel = os.path.join(ROC_DATA_SAVE_DIR, roc_data_filename_pixel)

    roc_data_to_save_pixel = {
        'category': category_name_for_roc,
        'fpr': fpr_pixel_level.tolist(),
        'tpr': tpr_pixel_level.tolist(),
        'roc_auc': pixel_auroc_global
    }

    with open(roc_data_filepath_pixel, 'w') as f:
        json.dump(roc_data_to_save_pixel, f)
    print(f"✅ Datos de la curva ROC (Nivel de Píxel) para '{category_name_for_roc}' guardados en: {roc_data_filepath_pixel}")

    # Plotea la curva ROC a nivel de píxel
    roc_save_path_pixel = os.path.join(BASE_PLOT_SAVE_ROOT_DIR, 'roc_curve_pixel_level.png')
    plot_roc_curve(fpr_pixel_level, tpr_pixel_level, pixel_auroc_global, None, roc_save_path_pixel, thresholds_roc_raw_pixel_level, "Pixel-level", current_category)

    # También puedes guardar el AUROC-Píxel global en un JSON para métricas generales
    pixel_metrics_data = {
        "Overall_Pixel_AUROC": f"{pixel_auroc_global:.4f}" if not np.isnan(pixel_auroc_global) else "N/A"
    }
    pixel_metrics_save_path = os.path.join(BASE_PLOT_SAVE_ROOT_DIR, 'pixel_level_summary_metrics.json')
    with open(pixel_metrics_save_path, 'w') as f:
        json.dump(pixel_metrics_data, f, indent=4)
    print(f"✅ Métricas resumen a nivel de píxel guardadas en: {pixel_metrics_save_path}")

    # --- NUEVO BLOQUE: 8. Generando visualizaciones de máscaras de anomalía para TODOS los mapas de Mahalanobis cargados ---
    print("\n--- 8. Generando visualizaciones de máscaras de anomalía para TODOS los mapas de Mahalanobis cargados ---")
    print(f"     (Las imágenes se guardarán en: {os.path.join(BASE_PLOT_SAVE_ROOT_DIR, 'overlays_all_images')})")

    overlays_save_dir = os.path.join(BASE_PLOT_SAVE_ROOT_DIR, 'overlays_all_images')
    os.makedirs(overlays_save_dir, exist_ok=True)

    total_visualizations = 0
    for cls in CLASSES:
        maps_list = normalized_mahalanobis_maps.get(cls, [])
        file_ids = MAP_FILE_IDS.get(cls, [])

        if not maps_list:
            continue

        print(f"     Procesando visualizaciones para la clase: '{cls}' ({len(maps_list)} imágenes)")
        for i, score_map in enumerate(tqdm(maps_list, desc=f"     Generando overlays para {cls}")):
            image_id = file_ids[i]
            # Construye la ruta a la imagen original correctamente
            original_image_path = os.path.join(BASE_IMAGE_DIR, cls, image_id + '.png')

            if os.path.exists(original_image_path):
                save_viz_path = os.path.join(overlays_save_dir, f'overlay_{cls}_{image_id}_thresh_{selected_threshold_for_eval:.4f}.png')
                visualize_overlay(original_image_path, score_map, selected_threshold_for_eval, MIN_CONNECTED_COMPONENT_AREA, save_viz_path)
                total_visualizations += 1
            else:
                print(f"Advertencia: La imagen original no se encontró en {original_image_path}. No se generó visualización para esta.")

    print(f"\n¡Se generaron {total_visualizations} visualizaciones de máscaras de anomalía!")

    print("\n¡Proceso de evaluación completado!")


***** Procesando categoría: GRAFICAS_EVALUACION_TRANSISTOR *****

--- 1. Detectando clases y cargando mapas de Mahalanobis ---
    Clases detectadas: ['bent_lead', 'cut_lead', 'damaged_case', 'good', 'misplaced']
    Total de mapas cargados para 'bent_lead': 9
    Total de mapas cargados para 'cut_lead': 10
    Total de mapas cargados para 'damaged_case': 7
    Total de mapas cargados para 'good': 55
    Total de mapas cargados para 'misplaced': 6
--- Mapas cargados exitosamente ---


--- Top 10 valores de Mahalanobis (RAW - Sin normalizar) ---
Clase: 'bent_lead'
    Imagen: 004
    Top 10 valores: ['6794.454', '6797.185', '6799.210', '6805.154', '6806.046', '6806.520', '6812.883', '6815.855', '6819.720', '6826.556']
    Imagen: 001
    Top 10 valores: ['6245.067', '6247.272', '6292.881', '6294.206', '6295.530', '6296.854', '6298.178', '6299.502', '6300.827', '6302.151']
    Imagen: 002
    Top 10 valores: ['7079.948', '7084.437', '7087.191', '7088.212', '7089.232', '7090.253', '7091.

    Procesando mapas de bent_lead: 9it [00:01,  5.61it/s]
    Procesando mapas de cut_lead: 10it [00:01,  5.79it/s]
    Procesando mapas de damaged_case: 7it [00:01,  5.94it/s]
    Procesando mapas de good: 0it [00:00, ?it/s][ WARN:0@835.979] global loadsave.cpp:268 findDecoder imread_('/home/imercatoma/FeatUp/datasets/mvtec_anomaly_detection/transistor/ground_truth/good/004_mask.png'): can't open/read file: check file path/integrity
[ WARN:0@835.980] global loadsave.cpp:268 findDecoder imread_('/home/imercatoma/FeatUp/datasets/mvtec_anomaly_detection/transistor/ground_truth/good/029_mask.png'): can't open/read file: check file path/integrity
[ WARN:0@835.980] global loadsave.cpp:268 findDecoder imread_('/home/imercatoma/FeatUp/datasets/mvtec_anomaly_detection/transistor/ground_truth/good/039_mask.png'): can't open/read file: check file path/integrity
[ WARN:0@835.981] global loadsave.cpp:268 findDecoder imread_('/home/imercatoma/FeatUp/datasets/mvtec_anomaly_detection/transistor/groun


[DEBUG] all_true_labels_binary_roc (first 10): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[DEBUG] all_anomaly_scores_for_roc (first 10): ['0.7816', '0.6468', '0.7605', '0.8044', '0.7199', '0.8293', '0.6565', '0.7951', '0.8013', '0.6584']
[DEBUG] Unique values in all_true_labels_binary_roc: [0 1]
[DEBUG] Unique values in all_anomaly_scores_for_roc: [0.41146728 0.41486073 0.41601264 0.42867944 0.4312841  0.43708363
 0.44192314 0.45256364 0.48091993 0.48748183 0.4899455  0.49219564
 0.49474004 0.5061231  0.511478   0.51757145 0.5273648  0.52989537
 0.53651285 0.5374994  0.54983675 0.55554384 0.5577757  0.56198716
 0.5673141  0.5796086  0.5910822  0.59270495 0.6135061  0.6153472
 0.6173675  0.61946464 0.6212621  0.6301865  0.6303536  0.63288385
 0.6365425  0.63667613 0.6375533  0.63839525 0.63898695 0.64132756
 0.6457016  0.6468082  0.6545765  0.6554706  0.6565181  0.6572018
 0.65839666 0.6691309  0.6753833  0.6931536  0.69603187 0.6973092
 0.6995614  0.7198877  0.7206137  0.744116   0.7605251  0.771

    Aplicando umbral para bent_lead: 100%|██████████| 9/9 [00:00<00:00, 58.29it/s]
    Aplicando umbral para cut_lead: 100%|██████████| 10/10 [00:00<00:00, 61.78it/s]
    Aplicando umbral para damaged_case: 100%|██████████| 7/7 [00:00<00:00, 62.76it/s]
    Aplicando umbral para good: 100%|██████████| 55/55 [00:00<00:00, 106.82it/s]
    Aplicando umbral para misplaced: 100%|██████████| 6/6 [00:00<00:00, 27.88it/s]



--- 5.1: Generando Matriz de Confusión Binaria (Normal vs. Anómala) ---
✅ Matriz de Confusión  - Binary guardada en: /home/imercatoma/FeatUp/graficas_evaluacion_transistor/evaluacion_roc/confusion_matrix_binary_thresh_0.6468.png

--- 5.2: Generando Matriz de Confusión Detallada (True Class vs. Predicted Class) ---
✅ Matriz de Confusión  - Detailed by Class guardada en: /home/imercatoma/FeatUp/graficas_evaluacion_transistor/evaluacion_roc/confusion_matrix_detailed_thresh_0.6468.png

--- 6. Calculando, mostrando y guardando Tabla de Métricas de Rendimiento (Nivel de Imagen) ---

--- Métricas de Rendimiento a Nivel de Imagen (Umbral: 0.6468, MCC Area: 0) ---
    Accuracy:        0.8046
    Precision:       0.6744
    Recall (Sensibilidad): 0.9062
    Especificidad: 0.7455
    F1-Score:        0.7733
--------------------------------------------------------------------
✅ Métricas añadidas al archivo JSON existente: /home/imercatoma/FeatUp/graficas_evaluacion_transistor/evaluacion_roc/image

     Generando overlays para bent_lead: 100%|██████████| 9/9 [00:08<00:00,  1.06it/s]


     Procesando visualizaciones para la clase: 'cut_lead' (10 imágenes)


     Generando overlays para cut_lead: 100%|██████████| 10/10 [00:09<00:00,  1.08it/s]


     Procesando visualizaciones para la clase: 'damaged_case' (7 imágenes)


     Generando overlays para damaged_case: 100%|██████████| 7/7 [00:06<00:00,  1.12it/s]


     Procesando visualizaciones para la clase: 'good' (55 imágenes)


     Generando overlays para good: 100%|██████████| 55/55 [00:51<00:00,  1.08it/s]


     Procesando visualizaciones para la clase: 'misplaced' (6 imágenes)


     Generando overlays para misplaced: 100%|██████████| 6/6 [00:05<00:00,  1.05it/s]


¡Se generaron 87 visualizaciones de máscaras de anomalía!

¡Proceso de evaluación completado!



