In [14]:
from tqdm import tqdm
import gc
import torch
from torch.nn.functional import softmax
from sklearn.metrics import precision_score, recall_score, confusion_matrix
from sklearn.metrics import auc
import pandas as pd

In [15]:
# Funcion para hallar la especificidad
def specificity_score(y_true, y_pred):
    matrix = confusion_matrix(y_true, y_pred)
    if matrix.size == 1:
        tn = matrix[0][0]
        fp = 0
    else:
        tn, fp, fn, tp = matrix.ravel()
    specificity = tn / (tn+fp)
    return specificity

In [None]:
def build_metrics_(model, test_loader):
    # Inicializamos un diccionario para almacenar las métricas
    metric_results = {}

    # Iteramos sobre el conjunto de test
    for i, (images, masks) in tqdm(enumerate(test_loader), total=len(test_loader)):

        # Permuta las dimensiones de las imágenes a (N, C, H, W)
        images = images.permute(0, 3, 1, 2)
        images, masks = images.to(device, dtype=torch.float), masks.to(device,  dtype=torch.long)

        outputs = model(images)
        outputs = softmax(outputs, dim=1)

        # Convertimos las predicciones y las etiquetas a arrays de numpy
        all_preds = outputs.detach().cpu().numpy()
        all_labels = masks.cpu().numpy()

        # Iteramos sobre los umbrales y las clases
        for threshold in [round(x * 0.1, 1) for x in range(0, 11)]:
            for class_index in range(5):
                # Binarizamos las salidas
                preds = (all_preds[:, class_index, :, :] > threshold).reshape(-1)
                true = (all_labels == class_index).reshape(-1)

                # Calculamos las métricas
                precision = precision_score(true, preds)
                recall = recall_score(true, preds)  # Sensibilidad
                specificity = specificity_score(true, preds)

                # Almacenamos las métricas en el diccionario
                if class_index not in metric_results:
                    metric_results[class_index] = {'thresholds': [], 'specificity': [], 'recall': [], 'precision': []}
                
                # Añadimos los valores de las métricas
                metric_results[class_index]['thresholds'].append(round(threshold, 2))
                metric_results[class_index]['specificity'].append(round(specificity, 5))
                metric_results[class_index]['recall'].append(round(recall, 5))
                metric_results[class_index]['precision'].append(round(precision, 5))

        del images, masks, outputs, all_preds, all_labels
        gc.collect()

    return metric_results

In [1]:
def mean_metrics(metric_results):
    averaged_metrics = {}

    for class_index in metric_results.keys():
        averaged_metrics[class_index] = {'thresholds': [], 'specificity': [], 'recall': [], 'precision': []}

        # Obtenemos las listas de métricas para esta clase
        thresholds = metric_results[class_index]['thresholds']
        specificity = metric_results[class_index]['specificity']
        recall = metric_results[class_index]['recall']
        precision = metric_results[class_index]['precision']

        # Calculamos la media de las métricas para cada umbral
        for threshold in set(thresholds):
            indices = [i for i, x in enumerate(thresholds) if x == threshold]

            averaged_metrics[class_index]['thresholds'].append(threshold)
            averaged_metrics[class_index]['specificity'].append(np.mean([specificity[i] for i in indices]))
            averaged_metrics[class_index]['recall'].append(np.mean([recall[i] for i in indices]))
            averaged_metrics[class_index]['precision'].append(np.mean([precision[i] for i in indices]))
    
    return averaged_metrics

SyntaxError: incomplete input (2932779966.py, line 1)

In [1]:
def print_metrics(averaged_metrics, classes):
    
    for class_index in averaged_metrics.keys():
        # Obtenemos las listas de métricas para esta clase
        thresholds = averaged_metrics[class_index]['thresholds']
        specificity = averaged_metrics[class_index]['specificity']
        recall = averaged_metrics[class_index]['recall']
        precision = averaged_metrics[class_index]['precision']

        # Creamos una lista de tuplas, donde cada tupla contiene el umbral y las métricas correspondientes
        metrics = list(zip(thresholds, specificity, recall, precision))

        # Ordenamos la lista de tuplas por el umbral
        metrics.sort()

        # Desempaquetamos la lista de tuplas ordenada de nuevo en las listas de métricas
        thresholds, specificity, recall, precision = zip(*metrics)

        # Actualizamos las listas de métricas en el diccionario
        averaged_metrics[class_index]['thresholds'] = list(thresholds)
        averaged_metrics[class_index]['specificity'] = list(specificity)
        averaged_metrics[class_index]['recall'] = list(recall)
        averaged_metrics[class_index]['precision'] = list(precision)
    
    
    
    # Crear DataFrames para cada clase
    dataframes = {}
    for class_index, metrics in averaged_metrics.items():
        data = {
            'Threshold': metrics['thresholds'],
            'Sensibilidad': metrics['recall'],
            'Especificidad': metrics['specificity'],
            'Precision': metrics['precision']
        }

        df = pd.DataFrame(data)
        df = df.round(5) 
        dataframes[classes[class_index]] = df

    # Acceder a los DataFrames individuales
    for class_name, df in dataframes.items():
        print(f"\Resultados para la clase {class_name}:\n")
        print(df)

In [19]:
def print_roc_curve(metric_results, classes):
    
    # Accede a los resultados almacenados
    for class_index, metrics in metric_results.items():
        thresholds = metrics['thresholds']
        specificity_values = metrics['specificity']
        recall_values = metrics['recall']
        precision_values = metrics['precision']

        # Calcula la tasa de falsos positivos (1 - especificidad)
        fpr = [1 - spec for spec in specificity_values]

        # Calcula el área bajo la curva ROC
        roc_auc = auc(fpr, recall_values)

        plt.figure()
        plt.plot(fpr, recall_values, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curve - Class ' + str(class_index) + " ~ " + classes[class_index])
        plt.legend(loc='lower right')
        plt.show()