# _Imports & config_

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir('C:\\Users\\Usuario\\TFG\\digipanca\\')

# __Class__

In [5]:
import torch

class SegmentationMetrics2D3D:
    """
    Class for computing segmentation metrics for 2D and 3D segmentation tasks.
    """

    @staticmethod
    def dice_coefficient(y_pred, y_true, smooth=1e-6):
        """
        Compute Dice coefficient for 2D and 3D segmentation.

        Parameters
        ----------
        y_pred : torch.Tensor
            Predicted segmentation mask (class indices or one-hot)
        y_true : torch.Tensor
            Ground truth segmentation mask (class indices or one-hot)
        smooth : float, optional
            Smoothing factor to avoid division by zero

        Returns
        -------
        torch.Tensor
            Dice coefficient (overall and per-class)
        """
        # Ensure inputs are 4D (2D: [B, C, H, W], 3D: [B, C, D, H, W])
        if y_pred.dim() == 3 or y_pred.dim() == 4:
            y_pred = y_pred.unsqueeze(1)
            y_true = y_true.unsqueeze(1)

        n_classes = torch.max(y_true).item() + 1
        y_pred_one_hot = torch.zeros(
            y_pred.size(0), n_classes, *y_pred.shape[2:], device=y_pred.device
        )
        y_pred_one_hot.scatter_(1, y_pred.long(), 1)

        y_true_one_hot = torch.zeros(
            y_true.size(0), n_classes, *y_true.shape[2:], device=y_true.device
        )
        y_true_one_hot.scatter_(1, y_true.long(), 1)

        dice_scores = []
        class_dice = {}

        for i in range(n_classes):
            pred_class = y_pred_one_hot[:, i]
            true_class = y_true_one_hot[:, i]

            intersection = torch.sum(pred_class * true_class)
            union = torch.sum(pred_class) + torch.sum(true_class)
            dice = (2.0 * intersection + smooth) / (union + smooth)
            dice_scores.append(dice)
            class_dice[f"dice_class_{i}"] = dice.item()

        mean_dice = torch.mean(torch.stack(dice_scores))
        class_dice["dice_mean"] = mean_dice.item()

        return mean_dice, class_dice

    @staticmethod
    def iou_score(y_pred, y_true, smooth=1e-6):
        """
        Compute IoU (Jaccard Index) for 2D and 3D segmentation.

        Parameters
        ----------
        y_pred : torch.Tensor
            Predicted segmentation mask (class indices or one-hot)
        y_true : torch.Tensor
            Ground truth segmentation mask (class indices or one-hot)
        smooth : float, optional
            Smoothing factor to avoid division by zero

        Returns
        -------
        tuple
            (mean_iou, per_class_iou_dict)
        """
        if y_pred.dim() == 3 or y_pred.dim() == 4:
            y_pred = y_pred.unsqueeze(1)
            y_true = y_true.unsqueeze(1)

        n_classes = torch.max(y_true).item() + 1
        y_pred_one_hot = torch.zeros(
            y_pred.size(0), n_classes, *y_pred.shape[2:], device=y_pred.device
        )
        y_pred_one_hot.scatter_(1, y_pred.long(), 1)

        y_true_one_hot = torch.zeros(
            y_true.size(0), n_classes, *y_true.shape[2:], device=y_true.device
        )
        y_true_one_hot.scatter_(1, y_true.long(), 1)

        iou_scores = []
        class_iou = {}

        for i in range(n_classes):
            pred_class = y_pred_one_hot[:, i]
            true_class = y_true_one_hot[:, i]

            intersection = torch.sum(pred_class * true_class)
            union = torch.sum(pred_class) + torch.sum(true_class) - intersection
            iou = (intersection + smooth) / (union + smooth)
            iou_scores.append(iou)
            class_iou[f"iou_class_{i}"] = iou.item()

        mean_iou = torch.mean(torch.stack(iou_scores))
        class_iou["iou_mean"] = mean_iou.item()

        return mean_iou, class_iou

    @staticmethod
    def all_metrics(y_pred, y_true):
        """
        Compute all metrics for 2D and 3D segmentation.

        Parameters
        ----------
        y_pred : torch.Tensor
            Predicted segmentation mask (class indices)
        y_true : torch.Tensor
            Ground truth segmentation mask (class indices)

        Returns
        -------
        dict
            Dictionary of all metrics
        """
        metrics = {}

        # Calculate Dice coefficient
        mean_dice, class_dice = SegmentationMetrics2D3D.dice_coefficient(y_pred, y_true)
        metrics.update(class_dice)

        # Calculate IoU score
        mean_iou, class_iou = SegmentationMetrics2D3D.iou_score(y_pred, y_true)
        metrics.update(class_iou)

        # Add overall metrics
        metrics['dice'] = mean_dice.item()
        metrics['iou'] = mean_iou.item()

        return metrics

# __Test__

In [3]:
import torch

def to_one_hot(masks, num_classes):
    """
    Convierte una máscara (2D o 3D) a un formato one-hot.

    Args:
        masks (torch.Tensor): Máscara de etiquetas de tamaño [batch_size, ..., height, width] o [batch_size, ..., depth, height, width].
        num_classes (int): Número de clases a convertir.

    Returns:
        torch.Tensor: Máscara en formato one-hot de tamaño [batch_size, num_classes, ..., height, width] o [batch_size, num_classes, ..., depth, height, width].
    """
    # Verifica si la máscara es 2D o 3D
    if masks.dim() == 3:  # Caso 2D (batch_size, height, width)
        batch_size, height, width = masks.shape
        # Reshape para agregar la dimensión de clases
        masks_one_hot = torch.zeros(batch_size, num_classes, height, width, device=masks.device)
        # Convierte los índices en one-hot
        masks_one_hot.scatter_(1, masks.unsqueeze(1), 1)
    
    elif masks.dim() == 4:  # Caso 3D (batch_size, depth, height, width)
        batch_size, depth, height, width = masks.shape
        # Reshape para agregar la dimensión de clases
        masks_one_hot = torch.zeros(batch_size, num_classes, depth, height, width, device=masks.device)
        # Convierte los índices en one-hot
        masks_one_hot.scatter_(1, masks.unsqueeze(1), 1)

    return masks_one_hot


In [4]:
num_classes = 5  # Número de clases
masks_2d = torch.randint(0, num_classes, (4, 256, 256))  # Simulamos una máscara 2D (4 imágenes, 256x256)
print(masks_2d.shape)
masks_one_hot_2d = to_one_hot(masks_2d, num_classes)
print(masks_one_hot_2d.shape)  # Salida esperada: [4, 5, 256, 256]

torch.Size([4, 256, 256])
torch.Size([4, 5, 256, 256])


In [5]:
num_classes = 5  # Número de clases
masks_3d = torch.randint(0, num_classes, (1, 128, 128, 64))  # Simulamos una máscara 3D (1 imagen, 128x128x64)
print(masks_3d.shape)
masks_one_hot_3d = to_one_hot(masks_3d, num_classes)
print(masks_one_hot_3d.shape)  # Salida esperada: [1, 5, 128, 128, 64]

torch.Size([1, 128, 128, 64])
torch.Size([1, 5, 128, 128, 64])


In [7]:
from src.metrics.segmentation_monai import SegmentationMonaiMetrics as smm
from src.metrics.segmentation_bak import SegmentationMetrics as sm

In [8]:
batch_size = 4
num_classes = 5
height, width = 256, 256

# Salida del modelo (logits)
y_pred_logits = torch.randn(batch_size, num_classes, height, width)
print(y_pred_logits.shape)
print(y_pred_logits)
# m2d = smm.convert_to_one_hot(y_pred_logits, masks_2d)
# m2d = sm.dice_coefficient(masks_2d, y_pred_logits)

torch.Size([4, 5, 256, 256])
tensor([[[[-2.8458e-01, -4.9904e-01, -7.3456e-01,  ..., -5.1145e-01,
            4.6625e-01,  5.2171e-02],
          [ 3.7925e-01, -1.3227e+00, -4.8338e-01,  ...,  6.0373e-01,
            2.6324e-01, -6.9275e-01],
          [-4.8050e-01, -1.7489e+00,  2.7859e-01,  ...,  9.5167e-01,
            4.6431e-02,  1.4337e+00],
          ...,
          [ 1.8550e+00,  7.5334e-02, -2.1264e-01,  ..., -1.2077e+00,
           -6.3903e-01,  2.9483e-01],
          [-3.0314e+00,  1.0875e+00, -1.6291e+00,  ...,  3.8931e-01,
           -1.9849e+00, -4.8893e-01],
          [-3.4027e-01, -1.1086e+00,  4.0716e-02,  ..., -4.7110e-01,
           -6.6544e-01,  9.5736e-01]],

         [[-5.2425e-01, -2.9773e-01, -7.7641e-01,  ...,  1.7241e+00,
            5.6370e-01,  9.8772e-01],
          [ 1.6733e+00,  4.8521e-02, -9.2547e-01,  ..., -1.4252e+00,
           -5.1551e-01, -1.6048e-01],
          [ 9.3728e-01, -1.0400e+00, -4.3233e-01,  ...,  2.8268e+00,
           -1.6933e+00, -7.99

In [9]:
print(y_pred_logits.size(3))

256


In [10]:
y_pred_one_hot, y_true_one_hot = smm.convert_to_one_hot(y_pred_logits, masks_2d)

2D case
y_pred shape: torch.Size([4, 5, 256, 256]), y_true shape: torch.Size([4, 256, 256])
y_pred values: tensor([-4.6263, -4.6085, -4.5831,  ...,  4.5355,  4.7422,  4.7642])
y_true_one_hot shape: torch.Size([4, 5, 256, 256]), y_pred_one_hot shape: torch.Size([4, 5, 256, 256])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])


In [11]:
metrics_a = sm.dice_coefficient(y_pred_one_hot, y_true_one_hot)
metrics_b = sm.dice_coefficient(y_pred_logits, masks_2d)

else
else if


In [12]:
print(metrics_a)

(tensor(0.2011), {'dice_class_0': 0.20053575932979584, 'dice_class_1': 0.20105890929698944, 'dice_class_2': 0.20377790927886963, 'dice_class_3': 0.20013177394866943, 'dice_class_4': 0.19978252053260803, 'dice_mean': 0.20105738937854767})


In [13]:
print(metrics_b)

(tensor(0.2004), {'dice_class_0': 0.19954991340637207, 'dice_class_1': 0.20108310878276825, 'dice_class_2': 0.20121122896671295, 'dice_class_3': 0.19977664947509766, 'dice_class_4': 0.20027513802051544, 'dice_mean': 0.20037920773029327})


In [14]:
from monai.metrics import DiceMetric
# ======================
# Cálculo con MONAI
# ======================
n_classes = num_classes
dice_metric = DiceMetric(include_background=True, reduction="none")  # "none" para obtener todas las clases
monai_dice = dice_metric(y_pred_one_hot, y_true_one_hot)

# ======================
# Cálculo con tu implementación
# ======================
_, my_dice = sm.dice_coefficient(y_pred_one_hot, y_true_one_hot)  # Devuelve valores por clase en un diccionario

# ======================
# Comparación
# ======================
print("==> Comparación de Dice Score por clase:")
for i in range(n_classes):
    print(f"Clase {i}: MONAI = {monai_dice[:, i].mean().item():.4f}, Mi implementación = {my_dice[f'dice_class_{i}']:.4f}")

print(f"\n==> Dice medio:")
print(f"MONAI = {monai_dice.mean().item():.4f}, Mi implementación = {my_dice['dice_mean']:.4f}")

else
==> Comparación de Dice Score por clase:
Clase 0: MONAI = 0.2005, Mi implementación = 0.2005
Clase 1: MONAI = 0.2011, Mi implementación = 0.2011
Clase 2: MONAI = 0.2038, Mi implementación = 0.2038
Clase 3: MONAI = 0.2001, Mi implementación = 0.2001
Clase 4: MONAI = 0.1998, Mi implementación = 0.1998

==> Dice medio:
MONAI = 0.2011, Mi implementación = 0.2011


In [15]:
metrics_c = smm.compute_dice(y_pred_logits, masks_2d)

2D case
y_pred shape: torch.Size([4, 5, 256, 256]), y_true shape: torch.Size([4, 256, 256])
y_pred values: tensor([-4.6263, -4.6085, -4.5831,  ...,  4.5355,  4.7422,  4.7642])
y_true_one_hot shape: torch.Size([4, 5, 256, 256]), y_pred_one_hot shape: torch.Size([4, 5, 256, 256])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])


In [16]:
print(metrics_c)

(tensor(0.2011), {'dice_class_0': 0.2005341649055481, 'dice_class_1': 0.20105457305908203, 'dice_class_2': 0.20378108322620392, 'dice_class_3': 0.20013247430324554, 'dice_class_4': 0.19977997243404388, 'dice_mean': 0.20105645060539246})


# __Test 3D__

In [17]:
num_classes = 5  # Número de clases
masks_3d = torch.randint(0, num_classes, (1, 64, 128, 128))  # Simulamos una máscara 3D (1 imagen, 128x128x64)
print(masks_3d.shape)
masks_one_hot_3d = to_one_hot(masks_3d, num_classes)
print(masks_one_hot_3d.shape)  # Salida esperada: [1, 5, 64, 128, 128]

torch.Size([1, 64, 128, 128])
torch.Size([1, 5, 64, 128, 128])


In [18]:
batch_size = 1
num_classes = 5
depth, height, width = 64, 128, 128

# Salida del modelo (logits)
y_pred_logits_3d = torch.randn(batch_size, num_classes, depth, height, width)
print(y_pred_logits_3d.shape)

torch.Size([1, 5, 64, 128, 128])


In [19]:
y_pred_one_hot_3d, y_true_one_hot_3d = smm.convert_to_one_hot(y_pred_logits_3d, masks_3d)

3D case
y_pred shape: torch.Size([1, 5, 64, 128, 128]), y_true shape: torch.Size([1, 64, 128, 128])
y_pred values: tensor([-5.2455, -4.8909, -4.8806,  ...,  4.8246,  4.8773,  4.9632])
y_true_one_hot shape: torch.Size([1, 5, 64, 128, 128]), y_pred_one_hot shape: torch.Size([1, 5, 64, 128, 128])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])


In [20]:
metrics_3d = smm.compute_dice(y_pred_logits_3d, masks_3d)
print(metrics_3d)

3D case
y_pred shape: torch.Size([1, 5, 64, 128, 128]), y_true shape: torch.Size([1, 64, 128, 128])
y_pred values: tensor([-5.2455, -4.8909, -4.8806,  ...,  4.8246,  4.8773,  4.9632])
y_true_one_hot shape: torch.Size([1, 5, 64, 128, 128]), y_pred_one_hot shape: torch.Size([1, 5, 64, 128, 128])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])
(tensor(0.1999), {'dice_class_0': 0.20025311410427094, 'dice_class_1': 0.20036616921424866, 'dice_class_2': 0.19872121512889862, 'dice_class_3': 0.19920654594898224, 'dice_class_4': 0.20111529529094696, 'dice_mean': 0.19993247091770172})


In [21]:
# ======================
# Cálculo con MONAI
# ======================
n_classes = num_classes
dice_metric_3d = DiceMetric(include_background=True, reduction="none")  # "none" para obtener todas las clases
monai_dice_3d = dice_metric(y_pred_one_hot_3d, y_true_one_hot_3d)

# ======================
# Cálculo con tu implementación
# ======================
_, my_dice_3d = smm.compute_dice(y_pred_logits_3d, masks_3d)  # Devuelve valores por clase en un diccionario

# ======================
# Comparación
# ======================
print("==> Comparación de Dice Score por clase:")
for i in range(n_classes):
    print(f"Clase {i}: MONAI = {monai_dice_3d[:, i].mean().item():.4f}, Mi implementación = {my_dice_3d[f'dice_class_{i}']:.4f}")

print(f"\n==> Dice medio:")
print(f"MONAI = {monai_dice_3d.mean().item():.4f}, Mi implementación = {my_dice_3d['dice_mean']:.4f}")

3D case
y_pred shape: torch.Size([1, 5, 64, 128, 128]), y_true shape: torch.Size([1, 64, 128, 128])
y_pred values: tensor([-5.2455, -4.8909, -4.8806,  ...,  4.8246,  4.8773,  4.9632])
y_true_one_hot shape: torch.Size([1, 5, 64, 128, 128]), y_pred_one_hot shape: torch.Size([1, 5, 64, 128, 128])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])
==> Comparación de Dice Score por clase:
Clase 0: MONAI = 0.2003, Mi implementación = 0.2003
Clase 1: MONAI = 0.2004, Mi implementación = 0.2004
Clase 2: MONAI = 0.1987, Mi implementación = 0.1987
Clase 3: MONAI = 0.1992, Mi implementación = 0.1992
Clase 4: MONAI = 0.2011, Mi implementación = 0.2011

==> Dice medio:
MONAI = 0.1999, Mi implementación = 0.1999


In [22]:
def is_one_hot(tensor):
    """
    Check if the tensor is one-hot encoded.
    """
    return (tensor.sum(dim=1) == 1).all() and torch.all((tensor == 0) | (tensor == 1))

In [23]:
print(is_one_hot(masks_2d), False)
print(is_one_hot(masks_one_hot_2d), True)
print(is_one_hot(y_pred_logits), False)
print(is_one_hot(y_pred_one_hot), True)
print(is_one_hot(y_true_one_hot), True)
print(is_one_hot(masks_3d), False)
print(is_one_hot(masks_one_hot_3d), True)
print(is_one_hot(y_pred_one_hot_3d), True)
print(is_one_hot(y_true_one_hot_3d), True)

tensor(False) False
tensor(True) True
tensor(False) False
tensor(True) True
tensor(True) True
tensor(False) False
tensor(True) True
tensor(True) True
tensor(True) True


In [24]:
metrics_3d_a = smm.compute_dice(y_pred_logits_3d, masks_3d)
print(metrics_3d_a)
metrics_3d_b = smm.compute_dice(y_pred_one_hot_3d, masks_one_hot_3d)
print(metrics_3d_b)

3D case
y_pred shape: torch.Size([1, 5, 64, 128, 128]), y_true shape: torch.Size([1, 64, 128, 128])
y_pred values: tensor([-5.2455, -4.8909, -4.8806,  ...,  4.8246,  4.8773,  4.9632])
y_true_one_hot shape: torch.Size([1, 5, 64, 128, 128]), y_pred_one_hot shape: torch.Size([1, 5, 64, 128, 128])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])
(tensor(0.1999), {'dice_class_0': 0.20025311410427094, 'dice_class_1': 0.20036616921424866, 'dice_class_2': 0.19872121512889862, 'dice_class_3': 0.19920654594898224, 'dice_class_4': 0.20111529529094696, 'dice_mean': 0.19993247091770172})
Both y_pred and y_true are already one-hot encoded.
(tensor(0.1999), {'dice_class_0': 0.20025311410427094, 'dice_class_1': 0.20036616921424866, 'dice_class_2': 0.19872121512889862, 'dice_class_3': 0.19920654594898224, 'dice_class_4': 0.20111529529094696, 'dice_mean': 0.19993247091770172})


## __Perfect and worst case__

In [25]:
num_classes = 5  # Número de clases
y_true_3d_perfect_i = torch.randint(0, num_classes, (1, 64, 128, 128))  # Simulamos una máscara 3D (1 imagen, 128x128x64)
print(y_true_3d_perfect_i.shape)
y_true_3d_perfect = to_one_hot(y_true_3d_perfect_i, num_classes)
print(y_true_3d_perfect.shape)  # Salida esperada: [1, 5, 64, 128, 128]
y_pred_3d_perfect = y_true_3d_perfect.clone()
print(y_pred_3d_perfect.shape)

torch.Size([1, 64, 128, 128])
torch.Size([1, 5, 64, 128, 128])
torch.Size([1, 5, 64, 128, 128])


In [26]:
perfect_3d_dice = smm.compute_dice(y_pred_3d_perfect, y_true_3d_perfect)
print(perfect_3d_dice)

Both y_pred and y_true are already one-hot encoded.
(tensor(1.), {'dice_class_0': 1.0, 'dice_class_1': 1.0, 'dice_class_2': 1.0, 'dice_class_3': 1.0, 'dice_class_4': 1.0, 'dice_mean': 1.0})


In [27]:
y_pred_3d_worst = (y_true_3d_perfect_i + 1) % num_classes
print(y_pred_3d_worst.shape)
y_pred_3d_worst = to_one_hot(y_pred_3d_worst, num_classes)
worst_3d_dice = smm.compute_dice(y_pred_3d_worst, y_true_3d_perfect)
print(worst_3d_dice)

torch.Size([1, 64, 128, 128])
Both y_pred and y_true are already one-hot encoded.
(tensor(0.), {'dice_class_0': 0.0, 'dice_class_1': 0.0, 'dice_class_2': 0.0, 'dice_class_3': 0.0, 'dice_class_4': 0.0, 'dice_mean': 0.0})


## __Consistency test: 3D (depth=1) == 2D__

In [28]:
def consistency_test(metric='dice'):
    b, num_classes, d, h, w = 1, 5, 1, 256, 256
    y_true_2d = torch.randint(0, num_classes, (b, h, w))
    y_pred_2d = torch.randn(b, num_classes, h, w)
    y_true_3d = y_true_2d.unsqueeze(-1)  # [B, H, W, 1]
    y_pred_3d = y_pred_2d.unsqueeze(-1)  # [B, C, H, W, 1]
    print(f"y_true_2d: {y_true_2d.shape} | y_true_3d: {y_true_3d.shape}")
    print(f"y_pred_2d: {y_pred_2d.shape} | y_pred_3d: {y_pred_3d.shape}")
    y_true_3d = torch.permute(y_true_3d, (0, 3, 1, 2))
    y_pred_3d = torch.permute(y_pred_3d, (0, 1, 4, 2, 3))
    print(f"y_true_2d: {y_true_2d.shape} | y_true_3d: {y_true_3d.shape}")
    print(f"y_pred_2d: {y_pred_2d.shape} | y_pred_3d: {y_pred_3d.shape}")

    if metric == 'dice':
        metric_2d = smm.compute_dice(y_pred_2d, y_true_2d)
        metric_3d = smm.compute_dice(y_pred_3d, y_true_3d)
    elif metric == 'iou':
        metric_2d = smm.compute_iou(y_pred_2d, y_true_2d)
        metric_3d = smm.compute_iou(y_pred_3d, y_true_3d)

    print("Equals:", metric_2d==metric_3d)
    print(metric_2d)
    print(metric_3d)

consistency_test()

y_true_2d: torch.Size([1, 256, 256]) | y_true_3d: torch.Size([1, 256, 256, 1])
y_pred_2d: torch.Size([1, 5, 256, 256]) | y_pred_3d: torch.Size([1, 5, 256, 256, 1])
y_true_2d: torch.Size([1, 256, 256]) | y_true_3d: torch.Size([1, 1, 256, 256])
y_pred_2d: torch.Size([1, 5, 256, 256]) | y_pred_3d: torch.Size([1, 5, 1, 256, 256])
2D case
y_pred shape: torch.Size([1, 5, 256, 256]), y_true shape: torch.Size([1, 256, 256])
y_pred values: tensor([-4.2794, -4.2765, -4.1770,  ...,  4.0452,  4.3115,  4.3712])
y_true_one_hot shape: torch.Size([1, 5, 256, 256]), y_pred_one_hot shape: torch.Size([1, 5, 256, 256])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])
3D case
y_pred shape: torch.Size([1, 5, 1, 256, 256]), y_true shape: torch.Size([1, 1, 256, 256])
y_pred values: tensor([-4.2794, -4.2765, -4.1770,  ...,  4.0452,  4.3115,  4.3712])
y_true_one_hot shape: torch.Size([1, 5, 1, 256, 256]), y_pred_one_hot shape: torch.Size([1, 

# __Test IoU__

In [29]:
iou_2d = smm.compute_iou(y_pred_logits, masks_2d)
iou_3d = smm.compute_iou(y_pred_logits_3d, masks_3d)
print(iou_2d)
print(iou_3d)

2D case
y_pred shape: torch.Size([4, 5, 256, 256]), y_true shape: torch.Size([4, 256, 256])
y_pred values: tensor([-4.6263, -4.6085, -4.5831,  ...,  4.5355,  4.7422,  4.7642])
y_true_one_hot shape: torch.Size([4, 5, 256, 256]), y_pred_one_hot shape: torch.Size([4, 5, 256, 256])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])
3D case
y_pred shape: torch.Size([1, 5, 64, 128, 128]), y_true shape: torch.Size([1, 64, 128, 128])
y_pred values: tensor([-5.2455, -4.8909, -4.8806,  ...,  4.8246,  4.8773,  4.9632])
y_true_one_hot shape: torch.Size([1, 5, 64, 128, 128]), y_pred_one_hot shape: torch.Size([1, 5, 64, 128, 128])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])
(tensor(0.1118), {'iou_class_0': 0.11144384741783142, 'iou_class_1': 0.11176510900259018, 'iou_class_2': 0.11345276981592178, 'iou_class_3': 0.11119396239519119, 'iou_class_4': 0.110975846

In [30]:
iou_2d_b = sm.iou_score(y_pred_one_hot, masks_one_hot_2d)
print(iou_2d_b)

(tensor(0.1118), {'iou_class_0': 0.11144192516803741, 'iou_class_1': 0.11176513880491257, 'iou_class_2': 0.11344805359840393, 'iou_class_3': 0.11119246482849121, 'iou_class_4': 0.11097687482833862, 'iou_mean': 0.11176488548517227})


In [31]:
def manual_iou(y_pred, y_true):
    y_pred_one_hot, y_true_one_hot = smm.convert_to_one_hot(y_pred, y_true)
    # Calculate IoU for each class
    iou_scores = []
    class_iou = {}
    
    for i in range(n_classes):
        pred_class = y_pred_one_hot[:, i, ...]
        true_class = y_true_one_hot[:, i, ...]
        
        intersection = torch.sum(pred_class * true_class)
        union = torch.sum(pred_class) + torch.sum(true_class) - intersection
        iou = (intersection + 1e-9) / (union + 1e-9)
        iou_scores.append(iou)
        class_iou[f"iou_class_{i}"] = iou.item()
    
    mean_iou = torch.mean(torch.stack(iou_scores))
    class_iou["iou_mean"] = mean_iou.item()
    
    return mean_iou, class_iou

In [32]:
miou_2d = manual_iou(y_pred_logits, masks_2d)
print(miou_2d)

2D case
y_pred shape: torch.Size([4, 5, 256, 256]), y_true shape: torch.Size([4, 256, 256])
y_pred values: tensor([-4.6263, -4.6085, -4.5831,  ...,  4.5355,  4.7422,  4.7642])
y_true_one_hot shape: torch.Size([4, 5, 256, 256]), y_pred_one_hot shape: torch.Size([4, 5, 256, 256])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])
(tensor(0.1118), {'iou_class_0': 0.11144192516803741, 'iou_class_1': 0.11176513880491257, 'iou_class_2': 0.11344805359840393, 'iou_class_3': 0.11119246482849121, 'iou_class_4': 0.11097687482833862, 'iou_mean': 0.11176488548517227})


## __Perfect and worst case__

In [33]:
perfect_3d_iou = smm.compute_iou(y_pred_3d_perfect, y_true_3d_perfect)
print(perfect_3d_iou)

Both y_pred and y_true are already one-hot encoded.
(tensor(1.), {'iou_class_0': 1.0, 'iou_class_1': 1.0, 'iou_class_2': 1.0, 'iou_class_3': 1.0, 'iou_class_4': 1.0, 'iou_mean': 1.0})


In [34]:
worst_3d_iou = smm.compute_iou(y_pred_3d_worst, y_true_3d_perfect)
print(worst_3d_iou)

Both y_pred and y_true are already one-hot encoded.
(tensor(0.), {'iou_class_0': 0.0, 'iou_class_1': 0.0, 'iou_class_2': 0.0, 'iou_class_3': 0.0, 'iou_class_4': 0.0, 'iou_mean': 0.0})


## __Consistency test__

In [35]:
consistency_test('iou')

y_true_2d: torch.Size([1, 256, 256]) | y_true_3d: torch.Size([1, 256, 256, 1])
y_pred_2d: torch.Size([1, 5, 256, 256]) | y_pred_3d: torch.Size([1, 5, 256, 256, 1])
y_true_2d: torch.Size([1, 256, 256]) | y_true_3d: torch.Size([1, 1, 256, 256])
y_pred_2d: torch.Size([1, 5, 256, 256]) | y_pred_3d: torch.Size([1, 5, 1, 256, 256])
2D case
y_pred shape: torch.Size([1, 5, 256, 256]), y_true shape: torch.Size([1, 256, 256])
y_pred values: tensor([-4.8800, -4.5549, -4.5389,  ...,  4.6270,  4.6670,  5.0255])
y_true_one_hot shape: torch.Size([1, 5, 256, 256]), y_pred_one_hot shape: torch.Size([1, 5, 256, 256])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])
3D case
y_pred shape: torch.Size([1, 5, 1, 256, 256]), y_true shape: torch.Size([1, 1, 256, 256])
y_pred values: tensor([-4.8800, -4.5549, -4.5389,  ...,  4.6270,  4.6670,  5.0255])
y_true_one_hot shape: torch.Size([1, 5, 1, 256, 256]), y_pred_one_hot shape: torch.Size([1, 

# __Precision and recall__

In [36]:
print(f"y_pred_logits: {y_pred_logits.shape} | masks_2d: {masks_2d.shape}")
print(f"y_pred_logits_3d: {y_pred_logits_3d.shape} | masks_3d: {masks_3d.shape}")

y_pred_logits: torch.Size([4, 5, 256, 256]) | masks_2d: torch.Size([4, 256, 256])
y_pred_logits_3d: torch.Size([1, 5, 64, 128, 128]) | masks_3d: torch.Size([1, 64, 128, 128])


In [37]:
pr_2d = smm.compute_precision_recall(y_pred_logits, masks_2d)
pr_3d = smm.compute_precision_recall(y_pred_logits_3d, masks_3d)

2D case
y_pred shape: torch.Size([4, 5, 256, 256]), y_true shape: torch.Size([4, 256, 256])
y_pred values: tensor([-4.6263, -4.6085, -4.5831,  ...,  4.5355,  4.7422,  4.7642])
y_true_one_hot shape: torch.Size([4, 5, 256, 256]), y_pred_one_hot shape: torch.Size([4, 5, 256, 256])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])
3D case
y_pred shape: torch.Size([1, 5, 64, 128, 128]), y_true shape: torch.Size([1, 64, 128, 128])
y_pred values: tensor([-5.2455, -4.8909, -4.8806,  ...,  4.8246,  4.8773,  4.9632])
y_true_one_hot shape: torch.Size([1, 5, 64, 128, 128]), y_pred_one_hot shape: torch.Size([1, 5, 64, 128, 128])
y_true values: tensor([0, 1, 2, 3, 4])
y_true_one_hot values: tensor([0., 1.])
y_pred_one_hot values: tensor([0., 1.])


In [38]:
print(pr_2d)
print(pr_3d)

(tensor(0.2011), tensor(0.2011), {'precision_class_0': 0.19984491169452667, 'precision_class_1': 0.2016676962375641, 'precision_class_2': 0.20346364378929138, 'precision_class_3': 0.20045655965805054, 'precision_class_4': 0.19988667964935303, 'precision_mean': 0.20106391608715057}, {'recall_class_0': 0.2012466937303543, 'recall_class_1': 0.2004479616880417, 'recall_class_2': 0.20411993563175201, 'recall_class_3': 0.1998201310634613, 'recall_class_4': 0.19967710971832275, 'recall_mean': 0.2010623663663864})
(tensor(0.1999), tensor(0.1999), {'precision_class_0': 0.19995534420013428, 'precision_class_1': 0.2003929167985916, 'precision_class_2': 0.19859257340431213, 'precision_class_3': 0.19917000830173492, 'precision_class_4': 0.20155660808086395, 'precision_mean': 0.1999334841966629}, {'recall_class_0': 0.20055177807807922, 'recall_class_1': 0.2003394216299057, 'recall_class_2': 0.19885002076625824, 'recall_class_3': 0.19924309849739075, 'recall_class_4': 0.20067590475082397, 'recall_mea

In [39]:
pr_2d_b = sm.precision_recall(y_pred_one_hot, masks_one_hot_2d)
print(pr_2d_b)

(tensor(0.2011), tensor(0.2011), {'precision_class_0': 0.1998366117477417, 'precision_class_1': 0.2016657441854477, 'precision_class_2': 0.20344407856464386, 'precision_class_3': 0.20045144855976105, 'precision_class_4': 0.19988928735256195, 'precision_mean': 0.20105743408203125}, {'recall_class_0': 0.20123980939388275, 'recall_class_1': 0.2004557102918625, 'recall_class_2': 0.20411284267902374, 'recall_class_3': 0.1998131275177002, 'recall_class_4': 0.19967585802078247, 'recall_mean': 0.201059490442276})


In [40]:
from monai.metrics import get_confusion_matrix, compute_confusion_matrix_metric
cm = get_confusion_matrix(y_pred_one_hot, masks_one_hot_2d)
# --------------
# Precision
monai_cm_metrics = compute_confusion_matrix_metric('precision', cm)
print("PRECISION")
print(monai_cm_metrics.shape)
print(monai_cm_metrics)
monai_cm_metrics = monai_cm_metrics.mean(dim=0)
print(monai_cm_metrics)
print(monai_cm_metrics.mean())
# --------------
# Recall
monai_cm_metrics = compute_confusion_matrix_metric('recall', cm)
print("RECALL")
print(monai_cm_metrics.shape)
print(monai_cm_metrics)
monai_cm_metrics = monai_cm_metrics.mean(dim=0)
print(monai_cm_metrics)
print(monai_cm_metrics.mean())

PRECISION
torch.Size([4, 5])
tensor([[0.1976, 0.1987, 0.2041, 0.2018, 0.2015],
        [0.2028, 0.1984, 0.1985, 0.1997, 0.1988],
        [0.1955, 0.2047, 0.2079, 0.1983, 0.1988],
        [0.2035, 0.2048, 0.2033, 0.2021, 0.2004]])
tensor([0.1998, 0.2017, 0.2035, 0.2005, 0.1999])
tensor(0.2011)
RECALL
torch.Size([4, 5])
tensor([[0.1995, 0.1974, 0.2062, 0.2006, 0.1999],
        [0.2064, 0.1987, 0.1993, 0.1959, 0.1979],
        [0.1988, 0.2037, 0.2039, 0.2001, 0.1986],
        [0.2003, 0.2020, 0.2070, 0.2026, 0.2022]])
tensor([0.2012, 0.2004, 0.2041, 0.1998, 0.1997])
tensor(0.2011)


# __Manual vs MONAI__

## Time

In [41]:
import torch
import time
from monai.metrics import get_confusion_matrix, compute_confusion_matrix_metric

# Crear datos de ejemplo
batch_size = 4
n_classes = 5
height, width = 256, 256
y_pred = torch.randn(batch_size, n_classes, height, width)  # Predicción
y_true = torch.randint(0, n_classes, (batch_size, height, width))  # Ground truth

# Convertir a one-hot
y_pred_one_hot = torch.zeros(batch_size, n_classes, height, width, device=y_pred.device)
y_pred_classes = torch.argmax(y_pred, dim=1, keepdim=True)
y_pred_one_hot.scatter_(1, y_pred_classes, 1)

y_true_one_hot = torch.zeros(batch_size, n_classes, height, width, device=y_true.device)
y_true_one_hot.scatter_(1, y_true.unsqueeze(1).long(), 1)

# MONAI
start_time = time.time()
cm = get_confusion_matrix(y_pred_one_hot, y_true_one_hot)
precision = compute_confusion_matrix_metric('precision', cm)
precision = precision.mean(dim=0).mean()
end_time = time.time()
print(f"MONAI time: {end_time - start_time:.6f} seconds")

# Cálculo manual
start_time = time.time()
# Aquí puedes hacer el cálculo manual de precisión (simplificado para ejemplo)
tp = torch.sum(y_pred_one_hot * y_true_one_hot, dim=(0, 2, 3))
fp = torch.sum(y_pred_one_hot, dim=(0, 2, 3)) - tp
fn = torch.sum(y_true_one_hot, dim=(0, 2, 3)) - tp
precision_manual = tp / (tp + fp + 1e-12)  # Evitar división por cero
precision_manual_mean = precision_manual.mean()
end_time = time.time()
print(f"Manual time: {end_time - start_time:.6f} seconds")

# Asumiendo que precision_monai y precision_manual son los resultados de MONAI y manual, respectivamente
print(f"MONAI Precision: {precision}")
print(f"Manual Precision: {precision_manual.mean().item()}")
print(f"Precision difference: {torch.abs(precision - precision_manual.mean().item()).sum()}")

MONAI time: 0.005656 seconds
Manual time: 0.002005 seconds
MONAI Precision: 0.19924600422382355
Manual Precision: 0.1992441862821579
Precision difference: 1.817941665649414e-06


## __Dice__

In [42]:
def compare_metrics(metrics_a, metrics_b):
    if metrics_a == metrics_b:
        print("Both metrics are identical.")
    else:
        print("Metrics have differences.")

    print('-' * 30)

    mean_a, class_a = metrics_a
    mean_b, class_b = metrics_b
    print(f"Mean A: {mean_a.item()}")
    print(f"Mean B: {mean_b.item()}")

    print('-' * 30)

    for key in class_a:
        if key in class_b:
            a = class_a[key]
            b = class_b[key]
            same = 'EQ' if a == b else 'DIFF'
            print(f"{same} - {key}")
            if a != b:
                print(f"\ta: {a}\n\tb: {b}")

    print('-' * 30)

    print("Verification completed")

In [43]:
def manual_dice(y_pred, y_true):
    C = y_true.size(1)
    sum_dims = tuple(range(2, y_true.ndim))
    
    intersection = torch.sum(y_pred * y_true, dim=sum_dims)
    union = torch.sum(y_pred, dim=sum_dims) + torch.sum(y_true, dim=sum_dims)

    dice_scores = 2. * intersection / (union + 1e-12)
    dice_scores = dice_scores.mean(dim=0)

    dice_dict = {f"dice_class_{i}": dice_scores[i].item() for i in range(C)}
    dice_dict["dice_mean"] = dice_scores.mean().item()
    
    return dice_scores.mean(), dice_dict
    
    # print("Manual Dice Scores:", dice_scores_manual)
    # print("Manual Dice Mean:", dice_scores_manual.mean())

In [44]:
monai_dice = smm.compute_dice(y_pred_one_hot, y_true_one_hot)
md = manual_dice(y_pred_one_hot, y_true_one_hot)

Both y_pred and y_true are already one-hot encoded.


In [45]:
print(monai_dice)
print('--------------')
print(md)
print('--------------')
print(monai_dice==md)

(tensor(0.1992), {'dice_class_0': 0.19837766885757446, 'dice_class_1': 0.19877730309963226, 'dice_class_2': 0.2004932165145874, 'dice_class_3': 0.1992664635181427, 'dice_class_4': 0.19928932189941406, 'dice_mean': 0.1992408037185669})
--------------
(tensor(0.1992), {'dice_class_0': 0.19837766885757446, 'dice_class_1': 0.19877730309963226, 'dice_class_2': 0.2004932165145874, 'dice_class_3': 0.1992664635181427, 'dice_class_4': 0.19928932189941406, 'dice_mean': 0.1992408037185669})
--------------
True


In [46]:
monai_dice_3d = smm.compute_dice(y_pred_one_hot_3d, y_true_one_hot_3d)
md_3d = manual_dice(y_pred_one_hot_3d, y_true_one_hot_3d)

Both y_pred and y_true are already one-hot encoded.


In [47]:
print(monai_dice_3d)
print('--------------')
print(md_3d)
print('--------------')
print(monai_dice_3d==md_3d)

(tensor(0.1999), {'dice_class_0': 0.20025311410427094, 'dice_class_1': 0.20036616921424866, 'dice_class_2': 0.19872121512889862, 'dice_class_3': 0.19920654594898224, 'dice_class_4': 0.20111529529094696, 'dice_mean': 0.19993247091770172})
--------------
(tensor(0.1999), {'dice_class_0': 0.20025311410427094, 'dice_class_1': 0.20036616921424866, 'dice_class_2': 0.19872121512889862, 'dice_class_3': 0.19920654594898224, 'dice_class_4': 0.20111529529094696, 'dice_mean': 0.19993247091770172})
--------------
True


In [48]:
compare_metrics(monai_dice_3d, md_3d)

Both metrics are identical.
------------------------------
Mean A: 0.19993247091770172
Mean B: 0.19993247091770172
------------------------------
EQ - dice_class_0
EQ - dice_class_1
EQ - dice_class_2
EQ - dice_class_3
EQ - dice_class_4
EQ - dice_mean
------------------------------
Verification completed


## __IoU Score__

In [49]:
def manual_iou(y_pred, y_true):
    C = y_true.size(1)  # Número de clases
    sum_dims = tuple(range(2, y_true.ndim))  # Sumar en todas las dimensiones espaciales

    intersection = torch.sum(y_pred * y_true, dim=sum_dims)
    union = torch.sum(y_pred, dim=sum_dims) + torch.sum(y_true, dim=sum_dims) - intersection

    iou_scores = intersection / (union + 1e-12)  # Pequeña constante para evitar división por cero
    iou_scores = iou_scores.mean(dim=0)  # Promedio sobre el batch

    iou_dict = {f"iou_class_{i}": iou_scores[i].item() for i in range(C)}
    iou_dict["iou_mean"] = iou_scores.mean().item()
    
    return iou_scores.mean(), iou_dict

In [50]:
monai_iou = smm.compute_iou(y_pred_one_hot, y_true_one_hot)
miou = manual_iou(y_pred_one_hot, y_true_one_hot)

Both y_pred and y_true are already one-hot encoded.


In [51]:
compare_metrics(monai_iou, miou)

Both metrics are identical.
------------------------------
Mean A: 0.11064586788415909
Mean B: 0.11064586788415909
------------------------------
EQ - iou_class_0
EQ - iou_class_1
EQ - iou_class_2
EQ - iou_class_3
EQ - iou_class_4
EQ - iou_mean
------------------------------
Verification completed


In [52]:
monai_iou_3d = smm.compute_iou(y_pred_one_hot_3d, y_true_one_hot_3d)
miou_3d = manual_iou(y_pred_one_hot_3d, y_true_one_hot_3d)

Both y_pred and y_true are already one-hot encoded.


In [53]:
compare_metrics(monai_iou_3d, miou_3d)

Both metrics are identical.
------------------------------
Mean A: 0.1110696792602539
Mean B: 0.1110696792602539
------------------------------
EQ - iou_class_0
EQ - iou_class_1
EQ - iou_class_2
EQ - iou_class_3
EQ - iou_class_4
EQ - iou_mean
------------------------------
Verification completed


## __Precision & recall__

In [54]:
def manual_precision_recall(y_pred, y_true):
    C = y_true.size(1)  # Número de clases
    sum_dims = tuple(range(2, y_true.ndim))
    
    # True Positives (TP), False Positives (FP) y False Negatives (FN)
    TP = torch.sum(y_pred * y_true, dim=sum_dims)  # Elementos correctamente predichos
    FP = torch.sum(y_pred, dim=sum_dims) - TP  # Predicciones incorrectas
    FN = torch.sum(y_true, dim=sum_dims) - TP  # Elementos no detectados
    
    # Cálculo de Precision y Recall
    precision = TP / (TP + FP + 1e-12)
    recall = TP / (TP + FN + 1e-12)
    
    # Promediar sobre el batch
    precision = precision.mean(dim=0)
    recall = recall.mean(dim=0)
    
    precision_dict = {f"precision_class_{i}": precision[i].item() for i in range(C)}
    recall_dict = {f"recall_class_{i}": recall[i].item() for i in range(C)}
    
    precision_dict["precision_mean"] = precision.mean().item()
    recall_dict["recall_mean"] = recall.mean().item()
    
    return precision.mean(), recall.mean(), precision_dict, recall_dict


In [55]:
monai_pr = smm.compute_precision_recall(y_pred_one_hot, y_true_one_hot)
mpr = manual_precision_recall(y_pred_one_hot, y_true_one_hot)

Both y_pred and y_true are already one-hot encoded.


In [56]:
monai_p = monai_pr[::2]
monai_r = monai_pr[1::2]
mp = mpr[::2]
mr = mpr[1::2]
compare_metrics(monai_p, mp)
print('='*30)
compare_metrics(monai_r, mr)

Both metrics are identical.
------------------------------
Mean A: 0.19924600422382355
Mean B: 0.19924600422382355
------------------------------
EQ - precision_class_0
EQ - precision_class_1
EQ - precision_class_2
EQ - precision_class_3
EQ - precision_class_4
EQ - precision_mean
------------------------------
Verification completed
Both metrics are identical.
------------------------------
Mean A: 0.19924746453762054
Mean B: 0.19924746453762054
------------------------------
EQ - recall_class_0
EQ - recall_class_1
EQ - recall_class_2
EQ - recall_class_3
EQ - recall_class_4
EQ - recall_mean
------------------------------
Verification completed


In [57]:
import torch
import time
from monai.metrics import get_confusion_matrix, compute_confusion_matrix_metric

# Crear datos de ejemplo
def test():
    batch_size = 64
    n_classes = 5
    height, width = 256, 256
    y_pred = torch.randn(batch_size, n_classes, height, width)  # Predicción
    y_true = torch.randint(0, n_classes, (batch_size, height, width))  # Ground truth
    
    # Convertir a one-hot
    y_pred_one_hot = torch.zeros(batch_size, n_classes, height, width, device=y_pred.device)
    y_pred_classes = torch.argmax(y_pred, dim=1, keepdim=True)
    y_pred_one_hot.scatter_(1, y_pred_classes, 1)
    
    y_true_one_hot = torch.zeros(batch_size, n_classes, height, width, device=y_true.device)
    y_true_one_hot.scatter_(1, y_true.unsqueeze(1).long(), 1)
    
    # MONAI
    start_time = time.time()
    cm = get_confusion_matrix(y_pred_one_hot, y_true_one_hot)
    precision = compute_confusion_matrix_metric('precision', cm)
    precision = precision.mean(dim=0).mean()
    end_time = time.time()
    print(f"MONAI time: {end_time - start_time:.6f} seconds")
    
    # Cálculo manual
    start_time = time.time()
    # Aquí puedes hacer el cálculo manual de precisión (simplificado para ejemplo)
    precision_manual = manual_precision_recall(y_pred_one_hot, y_true_one_hot)[0]
    end_time = time.time()
    print(f"Manual time: {end_time - start_time:.6f} seconds")
    
    # Asumiendo que precision_monai y precision_manual son los resultados de MONAI y manual, respectivamente
    print(f"MONAI Precision: {precision}")
    print(f"Manual Precision: {precision_manual.mean().item()}")
    print(f"Precision difference: {torch.abs(precision - precision_manual.mean().item()).sum()}")

In [58]:
for _ in range(10):
    test()
    print('█'*40)

MONAI time: 0.147126 seconds
Manual time: 0.017724 seconds
MONAI Precision: 0.20030975341796875
Manual Precision: 0.20030975341796875
Precision difference: 0.0
████████████████████████████████████████
MONAI time: 0.142831 seconds
Manual time: 0.019365 seconds
MONAI Precision: 0.20011715590953827
Manual Precision: 0.20011715590953827
Precision difference: 0.0
████████████████████████████████████████
MONAI time: 0.214869 seconds
Manual time: 0.178271 seconds
MONAI Precision: 0.19937287271022797
Manual Precision: 0.19937287271022797
Precision difference: 0.0
████████████████████████████████████████
MONAI time: 0.127131 seconds
Manual time: 0.019512 seconds
MONAI Precision: 0.20042698085308075
Manual Precision: 0.20042698085308075
Precision difference: 0.0
████████████████████████████████████████
MONAI time: 0.121785 seconds
Manual time: 0.021129 seconds
MONAI Precision: 0.19999970495700836
Manual Precision: 0.19999970495700836
Precision difference: 0.0
████████████████████████████████████

# __Test new Class__

In [59]:
from src.metrics import SegmentationMetrics as SM

In [60]:
all_metrics = SM.all_metrics(y_pred_logits, masks_2d)

In [61]:
print(all_metrics)

{'dice_class_0': 0.2005341649055481, 'dice_class_1': 0.20105457305908203, 'dice_class_2': 0.20378108322620392, 'dice_class_3': 0.20013247430324554, 'dice_class_4': 0.19977997243404388, 'dice_mean': 0.20105645060539246, 'iou_class_0': 0.11144384741783142, 'iou_class_1': 0.11176510900259018, 'iou_class_2': 0.11345276981592178, 'iou_class_3': 0.11119396239519119, 'iou_class_4': 0.11097584664821625, 'iou_mean': 0.11176630109548569, 'precision_class_0': 0.19984491169452667, 'precision_class_1': 0.2016676962375641, 'precision_class_2': 0.20346364378929138, 'precision_class_3': 0.20045655965805054, 'precision_class_4': 0.19988667964935303, 'recall_class_0': 0.2012466937303543, 'recall_class_1': 0.2004479616880417, 'recall_class_2': 0.20411993563175201, 'recall_class_3': 0.1998201310634613, 'recall_class_4': 0.19967710971832275, 'dice': 0.20105645060539246, 'iou': 0.11176630109548569, 'precision': 0.20106391608715057, 'recall': 0.2010623663663864}


In [62]:
def consistency_test(metric='dice'):
    b, num_classes, d, h, w = 1, 5, 1, 256, 256
    y_true_2d = torch.randint(0, num_classes, (b, h, w))
    y_pred_2d = torch.randn(b, num_classes, h, w)
    y_true_3d = y_true_2d.unsqueeze(-1)  # [B, H, W, 1]
    y_pred_3d = y_pred_2d.unsqueeze(-1)  # [B, C, H, W, 1]
    print(f"y_true_2d: {y_true_2d.shape} | y_true_3d: {y_true_3d.shape}")
    print(f"y_pred_2d: {y_pred_2d.shape} | y_pred_3d: {y_pred_3d.shape}")
    y_true_3d = torch.permute(y_true_3d, (0, 3, 1, 2))
    y_pred_3d = torch.permute(y_pred_3d, (0, 1, 4, 2, 3))
    print(f"y_true_2d: {y_true_2d.shape} | y_true_3d: {y_true_3d.shape}")
    print(f"y_pred_2d: {y_pred_2d.shape} | y_pred_3d: {y_pred_3d.shape}")

    if metric == 'dice':
        metric_2d = SM.dice_coefficient(y_pred_2d, y_true_2d)
        metric_3d = SM.dice_coefficient(y_pred_3d, y_true_3d)
    elif metric == 'iou':
        metric_2d = SM.iou_score(y_pred_2d, y_true_2d)
        metric_3d = SM.iou_score(y_pred_3d, y_true_3d)

    print("Equals:", metric_2d==metric_3d)
    print(metric_2d)
    print(metric_3d)

consistency_test()

y_true_2d: torch.Size([1, 256, 256]) | y_true_3d: torch.Size([1, 256, 256, 1])
y_pred_2d: torch.Size([1, 5, 256, 256]) | y_pred_3d: torch.Size([1, 5, 256, 256, 1])
y_true_2d: torch.Size([1, 256, 256]) | y_true_3d: torch.Size([1, 1, 256, 256])
y_pred_2d: torch.Size([1, 5, 256, 256]) | y_pred_3d: torch.Size([1, 5, 1, 256, 256])
Equals: True
(tensor(0.1972), {'dice_class_0': 0.195200115442276, 'dice_class_1': 0.19414156675338745, 'dice_class_2': 0.1973593831062317, 'dice_class_3': 0.19856853783130646, 'dice_class_4': 0.200923353433609, 'dice_mean': 0.19723859429359436})
(tensor(0.1972), {'dice_class_0': 0.195200115442276, 'dice_class_1': 0.19414156675338745, 'dice_class_2': 0.1973593831062317, 'dice_class_3': 0.19856853783130646, 'dice_class_4': 0.200923353433609, 'dice_mean': 0.19723859429359436})


In [63]:
monai_dice = smm.compute_dice(y_pred_one_hot, y_true_one_hot)
md_a = sm.dice_coefficient(y_pred_one_hot, y_true_one_hot)
md_b = SM.dice_coefficient(y_pred_logits, masks_2d)
compare_metrics(monai_dice, md_a)
compare_metrics(monai_dice, md_b)

Both y_pred and y_true are already one-hot encoded.
else
Metrics have differences.
------------------------------
Mean A: 0.1992408037185669
Mean B: 0.1992429792881012
------------------------------
DIFF - dice_class_0
	a: 0.19837766885757446
	b: 0.1983867734670639
DIFF - dice_class_1
	a: 0.19877730309963226
	b: 0.19877557456493378
DIFF - dice_class_2
	a: 0.2004932165145874
	b: 0.20049138367176056
DIFF - dice_class_3
	a: 0.1992664635181427
	b: 0.1992613524198532
DIFF - dice_class_4
	a: 0.19928932189941406
	b: 0.19929976761341095
DIFF - dice_mean
	a: 0.1992408037185669
	b: 0.1992429792881012
------------------------------
Verification completed
Metrics have differences.
------------------------------
Mean A: 0.1992408037185669
Mean B: 0.20105645060539246
------------------------------
DIFF - dice_class_0
	a: 0.19837766885757446
	b: 0.2005341649055481
DIFF - dice_class_1
	a: 0.19877730309963226
	b: 0.20105457305908203
DIFF - dice_class_2
	a: 0.2004932165145874
	b: 0.20378108322620392
DIF

In [64]:
def test_2():
    batch_size = 4
    num_classes = 5
    height, width = 256, 256
    
    # Salida del modelo (logits)
    y_pred_logits = torch.randn(batch_size, num_classes, height, width)
    print(y_pred_logits.shape)
    
    num_classes = 5  # Número de clases
    masks_2d = torch.randint(0, num_classes, (4, 256, 256))  # Simulamos una máscara 2D (4 imágenes, 256x256)
    print(masks_2d.shape)

    # Convert to one hot
    y_pred_one_hot, y_true_one_hot = SM.convert_to_one_hot(y_pred_logits, masks_2d)

    monai_dice = smm.compute_dice(y_pred_one_hot, y_true_one_hot)
    md_a = sm.dice_coefficient(y_pred_one_hot, y_true_one_hot)
    md_b = SM.dice_coefficient(y_pred_logits, masks_2d)
    compare_metrics(monai_dice, md_a)
    compare_metrics(monai_dice, md_b)

In [65]:
test_2()

torch.Size([4, 5, 256, 256])
torch.Size([4, 256, 256])
Both y_pred and y_true are already one-hot encoded.
else
Metrics have differences.
------------------------------
Mean A: 0.20026692748069763
Mean B: 0.2002691924571991
------------------------------
DIFF - dice_class_0
	a: 0.1961025595664978
	b: 0.19610051810741425
DIFF - dice_class_1
	a: 0.20210085809230804
	b: 0.20210175216197968
DIFF - dice_class_2
	a: 0.1993359625339508
	b: 0.1993325650691986
DIFF - dice_class_3
	a: 0.20173121988773346
	b: 0.20173537731170654
DIFF - dice_class_4
	a: 0.20206411182880402
	b: 0.2020758092403412
DIFF - dice_mean
	a: 0.20026692748069763
	b: 0.2002691924571991
------------------------------
Verification completed
Both metrics are identical.
------------------------------
Mean A: 0.20026692748069763
Mean B: 0.20026692748069763
------------------------------
EQ - dice_class_0
EQ - dice_class_1
EQ - dice_class_2
EQ - dice_class_3
EQ - dice_class_4
EQ - dice_mean
------------------------------
Verificat

In [66]:
def test_3():
    batch_size = 4
    num_classes = 5
    height, width = 256, 256
    
    # Salida del modelo (logits)
    y_pred_logits = torch.randn(batch_size, num_classes, height, width)
    
    num_classes = 5  # Número de clases
    masks_2d = torch.randint(0, num_classes, (4, 256, 256))  # Simulamos una máscara 2D (4 imágenes, 256x256)

    # Convert to one hot
    _, y_true_one_hot = SM.convert_to_one_hot(y_pred_logits, masks_2d)

    y_true_one_hot_2 = to_one_hot(masks_2d, num_classes)

    print("Equals:", (y_true_one_hot.all()==y_true_one_hot_2.all()).item())
    print(f"class: {y_true_one_hot.shape}")
    print(f"to 1H: {y_true_one_hot_2.shape}")

In [67]:
test_3()

Equals: True
class: torch.Size([4, 5, 256, 256])
to 1H: torch.Size([4, 5, 256, 256])


In [68]:
def test_perfect_worst():
    num_classes = 5  # Número de clases
    y_true_3d_perfect_i = torch.randint(0, num_classes, (1, 64, 128, 128))  # Simulamos una máscara 3D (1 imagen, 128x128x64)
    # print(y_true_3d_perfect_i.shape)
    y_true_3d_perfect = to_one_hot(y_true_3d_perfect_i, num_classes)
    # print(y_true_3d_perfect.shape)  # Salida esperada: [1, 5, 64, 128, 128]
    y_pred_3d_perfect = y_true_3d_perfect.clone()
    # print(y_pred_3d_perfect.shape)
    perfect_3d_dice = SM.dice_coefficient(y_pred_3d_perfect, y_true_3d_perfect)
    print('Perfect case:')
    print(perfect_3d_dice)
    y_pred_3d_worst = (y_true_3d_perfect_i + 1) % num_classes
    # print(y_pred_3d_worst.shape)
    y_pred_3d_worst = to_one_hot(y_pred_3d_worst, num_classes)
    worst_3d_dice = SM.dice_coefficient(y_pred_3d_worst, y_true_3d_perfect)
    print('Worst case')
    print(worst_3d_dice)

    print('-'*30)
    perfect_3d_dice = sm.dice_coefficient(y_pred_3d_perfect, y_true_3d_perfect)
    print(perfect_3d_dice)
    worst_3d_dice = sm.dice_coefficient(y_pred_3d_worst, y_true_3d_perfect)
    print(worst_3d_dice)

test_perfect_worst()

Perfect case:
(tensor(1.), {'dice_class_0': 1.0, 'dice_class_1': 1.0, 'dice_class_2': 1.0, 'dice_class_3': 1.0, 'dice_class_4': 1.0, 'dice_mean': 1.0})
Worst case
(tensor(0.), {'dice_class_0': 0.0, 'dice_class_1': 0.0, 'dice_class_2': 0.0, 'dice_class_3': 0.0, 'dice_class_4': 0.0, 'dice_mean': 0.0})
------------------------------
else
(tensor(1.), {'dice_class_0': 1.0, 'dice_class_1': 1.0, 'dice_class_2': 1.0, 'dice_class_3': 1.0, 'dice_class_4': 1.0, 'dice_mean': 1.0})
else
(tensor(2.3842e-12), {'dice_class_0': 2.3862248982320367e-12, 'dice_class_1': 2.3887727733373776e-12, 'dice_class_2': 2.38605966582095e-12, 'dice_class_3': 2.378387200727139e-12, 'dice_class_4': 2.3815138229521526e-12, 'dice_mean': 2.384191802318192e-12})
