# Imports y config

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch
import torch.nn.functional as F
from src.metrics import SegmentationMetrics as sm2d
from src.metrics.segmentation3d import SegmentationMetrics3D as sm3d

In [3]:
os.chdir('C:\\Users\\Usuario\\TFG\\digipanca\\')

# __Test metrics__

In [4]:
# 2D
y_pred_2d = torch.rand(1, 3, 256, 256)  # 2D logits
y_true_2d = torch.randint(0, 3, (1, 256, 256))  # 2D labels

# 3D
y_pred_3d = torch.rand(1, 3, 64, 256, 256)  # 3D logits
y_true_3d = torch.randint(0, 3, (1, 64, 256, 256))  # 3D labels

# Compute metrics
metrics_2d_a = sm2d.all_metrics(y_pred_2d, y_true_2d)
metrics_2d_b = sm3d.all_metrics(y_pred_2d, y_true_2d)
metrics_3d = sm3d.all_metrics(y_pred_3d, y_true_3d)

In [5]:
print(f"Metrics 2D (a): {metrics_2d_a}")
print(f"Metrics 2D (b): {metrics_2d_b}")
print(f"Metrics 3D: {metrics_3d}")

Metrics 2D (a): {'dice_class_0': 0.33286917209625244, 'dice_class_1': 0.3338851034641266, 'dice_class_2': 0.3329617381095886, 'dice_mean': 0.3332386612892151, 'iou_class_0': 0.19966590404510498, 'iou_class_1': 0.20039740204811096, 'iou_class_2': 0.19973251223564148, 'iou_mean': 0.19993193447589874, 'precision_class_0': 0.332343190908432, 'precision_class_1': 0.33461466431617737, 'precision_class_2': 0.33276012539863586, 'precision_mean': 0.3332393169403076, 'recall_class_0': 0.333396852016449, 'recall_class_1': 0.3331587314605713, 'recall_class_2': 0.33316361904144287, 'recall_mean': 0.33323973417282104, 'dice': 0.3332386612892151, 'iou': 0.19993193447589874, 'precision': 0.3332393169403076, 'recall': 0.33323973417282104}
Metrics 2D (b): {'dice_class_0': 0.33286917209625244, 'dice_class_1': 0.3338851034641266, 'dice_class_2': 0.3329617381095886, 'dice_mean': 0.3332386612892151, 'iou_class_0': 0.19966590404510498, 'iou_class_1': 0.20039740204811096, 'iou_class_2': 0.19973251223564148, '

In [6]:
# Comprobar si los diccionarios son idénticos
if metrics_2d_a == metrics_2d_b:
    print("Las métricas 2D (a) y 2D (b) son exactamente iguales.")
else:
    print("Las métricas 2D (a) y 2D (b) tienen diferencias.")

# Revisar valores que sean distintos
for key in metrics_2d_a:
    if key in metrics_2d_b:
        value_a = metrics_2d_a[key]
        value_b = metrics_2d_b[key]
        if value_a != value_b:
            print(f"Diferencia en {key}: a={value_a}, b={value_b}")

print("Verificación completada.")

Las métricas 2D (a) y 2D (b) son exactamente iguales.
Verificación completada.


# new func

In [9]:
import torch

class sm3dd:
    """
    Robust class for computing segmentation metrics for multiclass segmentation.
    """
    
    @staticmethod
    def _prepare_tensors(y_pred, y_true):
        """
        Prepare prediction and ground truth tensors for metric calculation.
        
        Parameters
        ----------
        y_pred : torch.Tensor
            Predicted segmentation mask (logits or one-hot)
        y_true : torch.Tensor
            Ground truth segmentation mask
        
        Returns
        -------
        tuple
            (predictions, ground_truth, num_classes)
        """
        # If prediction is logits, convert to predictions
        if y_pred.dim() in [4, 5]:  # [B, C, H, W] or [B, C, D, H, W]
            y_pred_classes = torch.argmax(y_pred, dim=1)
        else:
            y_pred_classes = y_pred
        
        # Determine number of classes
        num_classes = torch.max(torch.cat([y_pred_classes, y_true])).item() + 1
        
        return y_pred_classes, y_true, num_classes
    
    @staticmethod
    def dice_coefficient(y_pred, y_true, smooth=1e-7):
        """
        Compute Dice coefficient for multiclass segmentation.
        
        Returns
        -------
        tuple
            (mean dice, per-class dice dictionary)
        """
        y_pred, y_true, num_classes = sm3dd._prepare_tensors(y_pred, y_true)
        
        dice_scores = []
        class_dice = {}
        
        for cls in range(num_classes):
            # Create binary masks for the current class
            pred_mask = (y_pred == cls).float()
            true_mask = (y_true == cls).float()
            
            # Compute intersection and union
            intersection = torch.sum(pred_mask * true_mask)
            union = torch.sum(pred_mask) + torch.sum(true_mask)
            
            # Compute Dice
            dice = (2.0 * intersection + smooth) / (union + smooth)
            
            dice_scores.append(dice)
            class_dice[f"dice_class_{cls}"] = dice.item()
        
        mean_dice = torch.mean(torch.stack(dice_scores))
        class_dice["dice_mean"] = mean_dice.item()
        
        return mean_dice, class_dice
    
    @staticmethod
    def iou_score(y_pred, y_true, smooth=1e-7):
        """
        Compute Intersection over Union (IoU) for multiclass segmentation.
        
        Returns
        -------
        tuple
            (mean iou, per-class iou dictionary)
        """
        y_pred, y_true, num_classes = sm3dd._prepare_tensors(y_pred, y_true)
        
        iou_scores = []
        class_iou = {}
        
        for cls in range(num_classes):
            # Create binary masks for the current class
            pred_mask = (y_pred == cls).float()
            true_mask = (y_true == cls).float()
            
            # Compute intersection and union
            intersection = torch.sum(pred_mask * true_mask)
            union = torch.sum(pred_mask) + torch.sum(true_mask) - intersection
            
            # Compute IoU
            iou = (intersection + smooth) / (union + smooth)
            
            iou_scores.append(iou)
            class_iou[f"iou_class_{cls}"] = iou.item()
        
        mean_iou = torch.mean(torch.stack(iou_scores))
        class_iou["iou_mean"] = mean_iou.item()
        
        return mean_iou, class_iou
    
    @staticmethod
    def precision_recall(y_pred, y_true, smooth=1e-7):
        """
        Compute Precision and Recall for multiclass segmentation.
        
        Returns
        -------
        tuple
            (mean precision, mean recall, per-class precision, per-class recall)
        """
        y_pred, y_true, num_classes = sm3dd._prepare_tensors(y_pred, y_true)
        
        precision_scores = []
        recall_scores = []
        class_precision = {}
        class_recall = {}
        
        for cls in range(num_classes):
            # Create binary masks for the current class
            pred_mask = (y_pred == cls).float()
            true_mask = (y_true == cls).float()
            
            # True Positives, False Positives, False Negatives
            true_positives = torch.sum(pred_mask * true_mask)
            false_positives = torch.sum(pred_mask * (1 - true_mask))
            false_negatives = torch.sum((1 - pred_mask) * true_mask)
            
            # Precision and Recall
            precision = (true_positives + smooth) / (true_positives + false_positives + smooth)
            recall = (true_positives + smooth) / (true_positives + false_negatives + smooth)
            
            precision_scores.append(precision)
            recall_scores.append(recall)
            
            class_precision[f"precision_class_{cls}"] = precision.item()
            class_recall[f"recall_class_{cls}"] = recall.item()
        
        mean_precision = torch.mean(torch.stack(precision_scores))
        mean_recall = torch.mean(torch.stack(recall_scores))
        
        class_precision["precision_mean"] = mean_precision.item()
        class_recall["recall_mean"] = mean_recall.item()
        
        return mean_precision, mean_recall, class_precision, class_recall
    
    @staticmethod
    def all_metrics(y_pred, y_true):
        """
        Compute all metrics for multiclass segmentation.
        
        Returns
        -------
        dict
            Dictionary of all computed metrics
        """
        metrics = {}
        
        # Calculate metrics
        mean_dice, class_dice = sm3dd.dice_coefficient(y_pred, y_true)
        mean_iou, class_iou = sm3dd.iou_score(y_pred, y_true)
        mean_precision, mean_recall, class_precision, class_recall = sm3dd.precision_recall(y_pred, y_true)
        
        # Merge all metrics
        metrics.update(class_dice)
        metrics.update(class_iou)
        metrics.update(class_precision)
        metrics.update(class_recall)
        
        # Add overall metrics
        metrics['dice'] = mean_dice.item()
        metrics['iou'] = mean_iou.item()
        metrics['precision'] = mean_precision.item()
        metrics['recall'] = mean_recall.item()
        
        return metrics

# test

In [10]:
# 2D
y_pred_2d = torch.rand(1, 3, 256, 256)  # 2D logits
y_true_2d = torch.randint(0, 3, (1, 256, 256))  # 2D labels

# 3D
y_pred_3d = torch.rand(1, 3, 64, 256, 256)  # 3D logits
y_true_3d = torch.randint(0, 3, (1, 64, 256, 256))  # 3D labels

# Compute metrics
metrics_2d_a = sm2d.all_metrics(y_pred_2d, y_true_2d)
metrics_2d_b = sm3d.all_metrics(y_pred_2d, y_true_2d)
metrics_3d_a = sm3d.all_metrics(y_pred_3d, y_true_3d)
metrics_3d_b = sm3dd.all_metrics(y_pred_3d, y_true_3d)

In [11]:
# Comprobar si los diccionarios son idénticos
if metrics_3d_a == metrics_3d_b:
    print("Las métricas 3D (a) y 3D (b) son exactamente iguales.")
else:
    print("Las métricas 3D (a) y 3D (b) tienen diferencias.")

# Revisar valores que sean distintos
for key in metrics_3d_a:
    if key in metrics_3d_b:
        value_a = metrics_3d_a[key]
        value_b = metrics_3d_b[key]
        if value_a != value_b:
            print(f"Diferencia en {key}: a={value_a}, b={value_b}")

print("Verificación completada.")

Las métricas 3D (a) y 3D (b) tienen diferencias.
Diferencia en dice_class_0: a=1.9345697164535522, b=0.3338319659233093
Diferencia en dice_class_1: a=1.9438923597335815, b=0.3332183361053467
Diferencia en dice_class_2: a=1.9387074708938599, b=0.33346042037010193
Diferencia en dice_mean: a=1.9390565156936646, b=0.3335035741329193
Diferencia en iou_class_0: a=29.56686019897461, b=0.20035912096500397
Diferencia en iou_class_1: a=34.64580535888672, b=0.19991721212863922
Diferencia en iou_class_2: a=31.630409240722656, b=0.20009151101112366
Diferencia en iou_mean: a=31.947690963745117, b=0.20012260973453522
Diferencia en precision_class_0: a=0.9974095225334167, b=0.3335973620414734
Diferencia en precision_class_1: a=1.0023902654647827, b=0.3330245614051819
Diferencia en precision_class_2: a=0.9996621012687683, b=0.33388960361480713
Diferencia en precision_mean: a=0.9998206496238708, b=0.3335038423538208
Diferencia en recall_class_0: a=32.026187896728516, b=0.33406689763069153
Diferencia en 

In [12]:
print(metrics_3d_b)

{'dice_class_0': 0.3338319659233093, 'dice_class_1': 0.3332183361053467, 'dice_class_2': 0.33346042037010193, 'dice_mean': 0.3335035741329193, 'iou_class_0': 0.20035912096500397, 'iou_class_1': 0.19991721212863922, 'iou_class_2': 0.20009151101112366, 'iou_mean': 0.20012260973453522, 'precision_class_0': 0.3335973620414734, 'precision_class_1': 0.3330245614051819, 'precision_class_2': 0.33388960361480713, 'precision_mean': 0.3335038423538208, 'recall_class_0': 0.33406689763069153, 'recall_class_1': 0.3334123492240906, 'recall_class_2': 0.3330323398113251, 'recall_mean': 0.3335038721561432, 'dice': 0.3335035741329193, 'iou': 0.20012260973453522, 'precision': 0.3335038423538208, 'recall': 0.3335038721561432}


In [35]:
import torch

class SegmentationMetrics:
    """
    Class for computing segmentation metrics for 2D and 3D segmentation tasks.
    """
    
    @staticmethod
    def dice_coefficient(y_pred, y_true, smooth=1e-6):
        """
        Compute Dice coefficient for 2D and 3D segmentation.
        """
        dim = y_pred.dim()
        if dim == 5:  # 3D case: (B, C, D, H, W)
            reduce_axes = (2, 3, 4)
        elif dim == 4:  # 2D case: (B, C, H, W)
            reduce_axes = (2, 3)
        else:
            raise ValueError("Unsupported tensor dimensions: {}".format(dim))
        
        y_pred = torch.softmax(y_pred.float(), dim=1)
        y_true_one_hot = torch.nn.functional.one_hot(y_true.long(), num_classes=y_pred.shape[1])
        y_true_one_hot = y_true_one_hot.permute(0, -1, *range(1, y_true_one_hot.dim() - 1))
        # y_true_one_hot = y_true_one_hot.permute(0, -1, *range(1, dim-1))
        
        intersection = torch.sum(y_pred * y_true_one_hot, dim=reduce_axes)
        union = torch.sum(y_pred, dim=reduce_axes) + torch.sum(y_true_one_hot, dim=reduce_axes)
        dice = (2.0 * intersection + smooth) / (union + smooth)
        
        mean_dice = torch.mean(dice, dim=0)
        return mean_dice, {f"dice_class_{i}": d.item() for i, d in enumerate(mean_dice)}

    @staticmethod
    def iou_score(y_pred, y_true, smooth=1e-6):
        """
        Compute IoU (Jaccard Index) for 2D and 3D segmentation.
        """
        dim = y_pred.dim()
        if dim == 5:
            reduce_axes = (2, 3, 4)
        elif dim == 4:
            reduce_axes = (2, 3)
        else:
            raise ValueError("Unsupported tensor dimensions: {}".format(dim))
        
        y_pred = torch.softmax(y_pred, dim=1)
        y_true_one_hot = torch.nn.functional.one_hot(y_true.long(), num_classes=y_pred.shape[1])
        y_true_one_hot = y_true_one_hot.permute(0, -1, *range(1, dim-1))
        
        intersection = torch.sum(y_pred * y_true_one_hot, dim=reduce_axes)
        union = torch.sum(y_pred, dim=reduce_axes) + torch.sum(y_true_one_hot, dim=reduce_axes) - intersection
        iou = (intersection + smooth) / (union + smooth)
        
        mean_iou = torch.mean(iou, dim=0)
        return mean_iou, {f"iou_class_{i}": d.item() for i, d in enumerate(mean_iou)}
    
    @staticmethod
    def all_metrics(y_pred, y_true):
        """
        Compute all segmentation metrics.
        """
        metrics = {}
        mean_dice, class_dice = SegmentationMetrics.dice_coefficient(y_pred, y_true)
        mean_iou, class_iou = SegmentationMetrics.iou_score(y_pred, y_true)
        
        metrics.update(class_dice)
        metrics.update(class_iou)
        metrics['dice'] = mean_dice.mean().item()
        metrics['iou'] = mean_iou.mean().item()
        
        return metrics


In [14]:
def generate_synthetic_data(shape, num_classes):
    """
    Genera datos sintéticos de segmentación con valores aleatorios.

    Parameters:
    - shape: Tuple con el tamaño del tensor (ej. (B, H, W) o (B, D, H, W))
    - num_classes: Número de clases en la segmentación

    Returns:
    - y_pred: Tensor con predicciones (logits)
    - y_true: Tensor con ground truth (clases)
    """
    y_true = torch.randint(0, num_classes, shape)  # Ground truth con valores de clase
    y_pred_logits = torch.rand((shape[0], num_classes, *shape[1:]))  # Simula logits

    return y_pred_logits, y_true

# Prueba con datos 2D
num_classes = 3
batch_size = 2
height, width = 128, 128
y_pred_2d, y_true_2d = generate_synthetic_data((batch_size, height, width), num_classes)

# Prueba con datos 3D (volumen)
depth = 32
y_pred_3d, y_true_3d = generate_synthetic_data((batch_size, depth, height, width), num_classes)

# Calcular métricas en 2D
print("\n🔹 Resultados para imágenes 2D:")
metrics_2d_c = SegmentationMetrics.all_metrics(y_pred_2d, y_true_2d)
metrics_2d_a = sm2d.all_metrics(y_pred_2d, y_true_2d)
# for k, v in metrics_2d_c.items():
#     print(f"{k}: {v:.4f}")

# Calcular métricas en 3D
print("\n🔹 Resultados para volúmenes 3D:")
metrics_3d = SegmentationMetrics.all_metrics(y_pred_3d, y_true_3d)
for k, v in metrics_3d.items():
    print(f"{k}: {v:.4f}")


🔹 Resultados para imágenes 2D:

🔹 Resultados para volúmenes 3D:
dice_class_0: 0.3337
dice_class_1: 0.3332
dice_class_2: 0.3334
iou_class_0: 0.2002
iou_class_1: 0.1999
iou_class_2: 0.2000
dice: 0.3334
iou: 0.2001


In [15]:
# Comprobar si los diccionarios son idénticos
if metrics_2d_a == metrics_2d_c:
    print("Las métricas 2D (a) y 2D (c) son exactamente iguales.")
else:
    print("Las métricas 2D (a) y 2D (c) tienen diferencias.")

# Revisar valores que sean distintos
for key in metrics_2d_a:
    if key in metrics_2d_c:
        value_a = metrics_2d_a[key]
        value_b = metrics_2d_c[key]
        if value_a != value_b:
            print(f"Diferencia en {key}: a={value_a}, b={value_b}")

print("Verificación completada.")

Las métricas 2D (a) y 2D (c) tienen diferencias.
Diferencia en dice_class_0: a=0.33450204133987427, b=0.334501713514328
Diferencia en dice_class_1: a=0.33226069808006287, b=0.3322582542896271
Diferencia en dice_class_2: a=0.332177996635437, b=0.33217430114746094
Diferencia en iou_class_0: a=0.20084206759929657, b=0.20084208250045776
Diferencia en iou_class_1: a=0.1992282122373581, b=0.19922709465026855
Diferencia en iou_class_2: a=0.19916871190071106, b=0.19916707277297974
Diferencia en dice: a=0.3329802453517914, b=0.33297809958457947
Diferencia en iou: a=0.19974632561206818, b=0.19974541664123535
Verificación completada.


In [16]:
# Comprobar si los diccionarios son idénticos
if metrics_3d == metrics_3d_b:
    print("Las métricas 3D y 3D (b) son exactamente iguales.")
else:
    print("Las métricas 3D y 3D (b) tienen diferencias.")

# Revisar valores que sean distintos
for key in metrics_3d:
    if key in metrics_3d_b:
        value_a = metrics_3d[key]
        value_b = metrics_3d_b[key]
        if value_a != value_b:
            print(f"Diferencia en {key}: a={value_a}, b={value_b}")

print("Verificación completada.")

Las métricas 3D y 3D (b) tienen diferencias.
Diferencia en dice_class_0: a=0.3336580693721771, b=0.3338319659233093
Diferencia en dice_class_1: a=0.3331666588783264, b=0.3332183361053467
Diferencia en dice_class_2: a=0.33338677883148193, b=0.33346042037010193
Diferencia en iou_class_0: a=0.20023386180400848, b=0.20035912096500397
Diferencia en iou_class_1: a=0.19988000392913818, b=0.19991721212863922
Diferencia en iou_class_2: a=0.20003849267959595, b=0.20009151101112366
Diferencia en dice: a=0.3334038257598877, b=0.3335035741329193
Diferencia en iou: a=0.20005078613758087, b=0.20012260973453522
Verificación completada.


# __vs MONAI__

In [29]:
import torch
import monai.metrics as monai_metrics

# Simulación de segmentaciones 3D (cambia esto con tus datos reales)
y_true = torch.randint(0, 3, (1, 64, 64, 64))  # (batch, depth, height, width)
y_pred = torch.randint(0, 3, (1, 64, 64, 64))  # (batch, depth, height, width)

# Convertimos a one-hot encoding (MONAI requiere esto para métricas multicategoría)
num_classes = 3
y_true_onehot = torch.nn.functional.one_hot(y_true, num_classes=num_classes).permute(0, 4, 1, 2, 3)
y_pred_onehot = torch.nn.functional.one_hot(y_pred, num_classes=num_classes).permute(0, 4, 1, 2, 3)

# Convertir a tipo float
y_true_onehot = y_true_onehot.float()
y_pred_onehot = y_pred_onehot.float()

# 1️⃣ Cálculo de Dice Score con MONAI
dice_metric = monai_metrics.DiceMetric(include_background=True, reduction="mean")
dice_result = dice_metric(y_pred_onehot, y_true_onehot)

# 2️⃣ Cálculo de IoU con MONAI
iou_metric = monai_metrics.MeanIoU(include_background=True, reduction="mean")
iou_result = iou_metric(y_pred_onehot, y_true_onehot)

# 3️⃣ Precisión y Recall con MONAI
precision_metric = monai_metrics.ConfusionMatrixMetric(metric_name="precision", reduction="mean")
recall_metric = monai_metrics.ConfusionMatrixMetric(metric_name="recall", reduction="mean")

precision_result = precision_metric(y_pred_onehot, y_true_onehot)
recall_result = recall_metric(y_pred_onehot, y_true_onehot)

# 4️⃣ Mostramos los resultados
print("MONAI - Dice per class:", dice_result.mean().item())
print("MONAI - IoU per class:", iou_result.mean().item())
print("MONAI - Precision per class:", precision_result)
print("MONAI - Recall per class:", recall_result)

MONAI - Dice per class: 0.3337126672267914
MONAI - IoU per class: 0.20027339458465576
MONAI - Precision per class: tensor([[[ 29321.,  58304., 116020.,  58499.],
         [ 29030.,  58262., 116687.,  58165.],
         [ 29130.,  58097., 116918.,  57999.]]])
MONAI - Recall per class: tensor([[[ 29321.,  58304., 116020.,  58499.],
         [ 29030.,  58262., 116687.,  58165.],
         [ 29130.,  58097., 116918.,  57999.]]])


In [36]:
import torch

# Suponiendo que ya tienes y_pred y y_true cargados en memoria
# Asegúrate de que tengan el mismo formato que usaste en MONAI
# y_pred y y_true deben ser tensores de PyTorch

# Evaluar con SegmentationMetrics
metrics1 = SegmentationMetrics.all_metrics(y_pred, y_true)
print("Resultados SegmentationMetrics:", metrics1)

# Evaluar con sm3dd
metrics2 = sm3dd.all_metrics(y_pred, y_true)
print("Resultados sm3dd:", metrics2)

# Evaluar con SegmentationMetrics3D
metrics3 = SegmentationMetrics3D.all_metrics(y_pred, y_true)
print("Resultados SegmentationMetrics3D:", metrics3)

# Resultados de MONAI (asegúrate de tenerlos en un diccionario similar)
monai_results = {
    "dice": dice_metric.mean().item(),  # Ejemplo, reemplaza con los valores reales
    "iou": iou_metric.mean().item(), 
    # Otras métricas si las tienes
}

# Comparación
print("\nComparación con MONAI:")
for metric in ["dice", "iou"]:
    print(f"{metric} - MONAI: {monai_results[metric]:.4f}, SegmentationMetrics: {metrics1[metric]:.4f}, sm3dd: {metrics2[metric]:.4f}, SegmentationMetrics3D: {metrics3[metric]:.4f}")

RuntimeError: a Tensor with 64 elements cannot be converted to Scalar