In [None]:
# -*- coding: utf-8 -*-
"""
Created on Thu May 22 20:51:03 2025

@author: Jorge
"""

from glob import glob
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tifffile import imread
from cc3d import connected_components
from stardist.matching import matching
from tqdm import tqdm
from skimage.segmentation import clear_border
def normalize(img):
    return (img-img.min())/(img.max()-img.min())

def leer_imagenes(path):
    """Carga imágenes TIFF desde un directorio y las procesa como componentes conectados"""
    names = sorted(glob(os.path.join(path, "*.[tT][iI][fF]*")))
    if not names:
        raise ValueError(f"No se encontraron imágenes TIFF en {path}")
    print(f"Imágenes leídas desde {path}: {len(names)}")
    images = [clear_border(connected_components(imread(name))) for name in names]
    names = [os.path.basename(name).split(".")[0] for name in names]
    return images, names

def evaluate_models(gt_images, preds_dict, output_dir=".", thresholds=np.linspace(0.5, 1.0, 11), fmt = ['o-', 's-', '^-', 'o-', 's-', '^-']):
    """
    Evalúa múltiples modelos de segmentación contra un ground truth y genera gráficos y resultados.
    
    Args:
        gt_images (list): Lista de imágenes ground truth
        preds_dict (dict): Diccionario {nombre_modelo: lista_de_imágenes_predichas}
        output_dir (str): Directorio para guardar resultados
        thresholds (array): Umbrales de IoU a evaluar
    
    Returns:
        tuple: (DataFrame con resultados, DataFrame con promedios, DataFrame con std)
    """
    # Crear directorio de salida si no existe
    os.makedirs(output_dir, exist_ok=True)
    
    # 1. Evaluación con múltiples thresholds (para curvas AP)
    ap_results = {}
    std_results = {}
    
    for model_name, pred_images in preds_dict.items():
        print(f"\nEvaluando {model_name}...")
        matches = [
            [matching(gt_images[j], pred_images[j], thresh=i, report_matches=True) 
            for i in tqdm(thresholds, desc=f"Thresholds {model_name}")]
            for j in range(len(gt_images))
        ]
        ap_values = [[m.accuracy for m in match] for match in matches]
        ap_results[model_name] = np.mean(ap_values, axis=0)
        std_results[model_name] = np.std(ap_values, axis=0)
    
    # Graficar curvas AP
    plt.figure(figsize=(10, 6))
    for idx, model_name in enumerate(preds_dict.keys()):
        plt.errorbar(
            thresholds, 
            ap_results[model_name], 
            yerr=std_results[model_name], 
            fmt=fmt[idx], 
            label=model_name, 
            capsize=5
        )
    plt.legend()
    plt.ylabel("AP")
    plt.xlabel(r"IoU threshold ($\tau$)")
    plt.title("Detection Scores")
    plt.savefig(os.path.join(output_dir, "detection_scores.png"), dpi=300)
    plt.close()
    
    # 2. Evaluación con threshold fijo (0.5) para métricas detalladas
    thresh = 0.5
    detailed_results = []
    
    for model_name, pred_images in preds_dict.items():
        print(f"\nEvaluando detalladamente {model_name} con IoU={thresh}...")
        matches = [matching(gt_images[j], pred_images[j], thresh=thresh) 
                  for j in tqdm(range(len(gt_images)))]
        
        # Convertir a DataFrame
        keys = ['criterion', 'thresh', 'fp', 'tp', 'fn', 'precision', 'recall', 
                'accuracy', 'f1', 'n_true', 'n_pred', 'mean_true_score', 
                'mean_matched_score', 'panoptic_quality']
        df = pd.DataFrame([m._asdict() for m in matches], columns=keys)
        df.insert(0, 'Image', names)
        df.insert(0, 'Model', model_name)
        detailed_results.append(df)
    
    # Combinar todos los resultados
    df_results = pd.concat(detailed_results, ignore_index=True)
    
    # Formatear columnas
    float_cols = ['precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 
                 'mean_matched_score', 'panoptic_quality']
    int_cols = ['fp', 'tp', 'fn', 'n_true', 'n_pred']
    
    df_results[float_cols] = df_results[float_cols].astype(float).round(4)
    df_results[int_cols] = df_results[int_cols].astype(int)
    
    # Renombrar columnas para mejor presentación
    new_names = {
        "fp": "FP", "tp": "TP", "fn": "FN",
        "precision": "Precision", "recall": "Recall",
        "accuracy": "Average Precision", "f1": "F1-Score",
        "n_true": "N True", "n_pred": "N Pred"
    }
    df_results = df_results.rename(columns=new_names)
    
    # Calcular porcentajes
    df_results['FP (%)'] = (df_results['FP'] / df_results['N Pred']).round(4)
    df_results['TP (%)'] = (df_results['TP'] / df_results['N Pred']).round(4)
    df_results['FN (%)'] = (df_results['FN'] / df_results['N Pred']).round(4)
    
    # Calcular promedios y std por modelo
    metrics = ['FP (%)', 'TP (%)', 'FN (%)', 'Precision', 'Recall', 
               'Average Precision', 'F1-Score']
    df_mean = df_results.groupby('Model', sort=False)[metrics].mean()
    df_std = df_results.groupby('Model', sort=False)[metrics].std()
    
    # Guardar resultados
    df_results.to_csv(os.path.join(output_dir, "detailed_results.csv"), sep=";", index=False)
    df_mean.to_csv(os.path.join(output_dir, "mean_results.csv"), sep=";")
    df_std.to_csv(os.path.join(output_dir, "std_results.csv"), sep=";")
    
    # Generar gráficos comparativos
    plot_metrics_comparison(df_mean, df_std, output_dir)
    plot_violin_metrics(df_results, metrics, output_dir)  # Nueva función para gráfico de violines
    
    return df_results, df_mean, df_std

def plot_metrics_comparison(df_mean, df_std, output_dir):
    """Genera gráfico de comparación de métricas entre modelos"""
    metrics = df_mean.columns
    x = np.arange(len(metrics))
    width = 0.8 / len(df_mean)  # Ajustar ancho según número de modelos
    
    plt.figure(figsize=(14, 6))
    
    for i, model in enumerate(df_mean.index):
        plt.bar(
            x + i * width,
            df_mean.loc[model],
            width,
            yerr=df_std.loc[model],
            label=model,
            capsize=5,
            alpha=0.8
        )
    
    plt.xticks(x + width*(len(df_mean)/2-0.5), metrics, rotation=45)
    plt.ylabel("Value")
    plt.title("Model Metrics Comparison (Mean ± STD)")
    plt.legend(title="Model")
    plt.grid(axis="y", alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "metrics_comparison.png"), dpi=300)
    plt.close()

def plot_violin_metrics(df_results, metrics, output_dir):
    """
    Genera gráfico de violines para visualizar distribución de métricas
    
    Args:
        df_results (DataFrame): DataFrame con todos los resultados
        metrics (list): Lista de métricas a visualizar
        output_dir (str): Directorio para guardar el gráfico
    """
    # Reestructurar los datos en formato "long" para Seaborn
    long_df = df_results.melt(
        id_vars=['Model'], 
        value_vars=metrics, 
        var_name='Métrica', 
        value_name='Valor'
    )
    
    # Crear el gráfico de violín
    plt.figure(figsize=(14, 8))
    sns.set_style("whitegrid")
    
    # Gráfico de violín
    sns.violinplot(
        data=long_df, 
        x='Métrica', 
        y='Valor', 
        hue='Model', 
        inner=None,
        palette="muted",
        split=False
    )
    
    # Puntos individuales
    sns.stripplot(
        data=long_df, 
        x='Métrica', 
        y='Valor', 
        hue='Model', 
        dodge=True, 
        jitter=True, 
        color='black', 
        size=4, 
        alpha=0.5, 
        linewidth=1, 
        legend=False
    )
    
    # Ajustes estéticos
    plt.title("Distribución de Métricas por Modelo")
    plt.xticks(rotation=45)
    plt.ylabel("Valor")
    plt.xlabel("Métricas")
    plt.grid(axis='y', alpha=0.3)
    
    # Mover la leyenda fuera del gráfico
    plt.legend(
        title="Modelos", 
        bbox_to_anchor=(1.05, 1), 
        loc='upper left', 
        borderaxespad=0.
    )
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "violin_metrics.png"), dpi=300, bbox_inches='tight')
    plt.close()

# Ejemplo de uso:
if __name__ == "__main__":
    # Cambiar al directorio del script
    os.chdir(os.path.dirname(os.path.abspath(__file__)))
    
    # Cargar imágenes
    gt, names = leer_imagenes("./gt")
    
    # Definir modelos a comparar (nombre: path)
    models_to_compare = {
        "CellPose": "./prediction_cellpose",
      
    }
    
    # Cargar predicciones
    preds_dict = {name: leer_imagenes(path)[0] for name, path in models_to_compare.items()}
    
    # Evaluar modelos
    results, mean_results, std_results = evaluate_models(
        gt, 
        preds_dict, 
        output_dir="\\WS1\WS1_Remote_Disk\Current Segovia Lab\Emilio Gutiérrez\Emilio Gutiérrez SOS\test Cellpose\test set\results plots",
        thresholds=np.linspace(0.5, 1.0, 11),
        fmt = ['o-', 's-', '^-', 'o-', 's-', '^-']
    )
    
    print("\nResultados promedio:")
    print(mean_results)