In [1]:
%load_ext autoreload
%autoreload 2

In [38]:
import os
import numpy as np
import torch
from src.metrics.segmentation import SegmentationMetrics as MetricsA
from src.metrics.segmentation_bak import SegmentationMetrics as MetricsB
from monai.metrics import DiceMetric

In [3]:
os.getcwd()

'C:\\Users\\Usuario\\TFG\\digipanca\\notebooks\\preprocessing'

In [4]:
os.chdir('C:\\Users\\Usuario\\TFG\\digipanca\\')
os.getcwd()

'C:\\Users\\Usuario\\TFG\\digipanca'

In [21]:
def multiclass_dice_coefficient(pred, target, num_classes, threshold=0.5, smooth=1e-5):
    """
    Compute the Dice coefficient for multi-class segmentation, excluding the background class if specified.

    Args:
        pred (torch.Tensor): The predicted tensor (N, C, H, W), where C is the number of classes.
        target (torch.Tensor): The ground truth tensor (N, H, W) with class labels.
        num_classes (int): The number of classes.
        threshold (float): The threshold for converting predictions to binary (for each class).
        smooth (float): Smoothing factor to avoid division by zero.
        ignore_background (bool): If True, ignore the background class (usually class 0) in the calculation.

    Returns:
        dice_scores (dict): A dictionary containing Dice scores for each class.
        avg_dice (float): The average Dice score across all classes (excluding background if specified).
    """
    dice_scores = {}
    avg_dice = 0.0
    valid_class_count = 0
    pred = torch.softmax(pred, dim=1)

    for class_idx in range(1, num_classes):

        pred_class = pred[:, class_idx, :, :]
        target_class = (target == class_idx).float()

        # Flatten the tensors for calculation
        pred_flat = pred_class.reshape(-1)
        target_flat = target_class.reshape(-1)

        # Calculate the intersection and union
        intersection = (pred_flat * target_flat).sum()
        dice = (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)

        # Store the Dice score for the current class
        dice_scores[f"Class_{class_idx}"] = dice.item()

        # Accumulate for the average Dice score
        avg_dice += dice.item()
        valid_class_count += 1

     # Calculate the average Dice score across all valid classes
    if valid_class_count > 0:
        avg_dice /= valid_class_count

    return dice_scores, avg_dice

In [33]:
# Parámetros
batch_size = 4
num_classes = 5
height = 256
width = 256

# Generar y_true (Ground Truth) como índices de clase
y_true = torch.randint(0, num_classes, (batch_size, height, width))

# Generar y_pred (Predicciones) como logits
y_pred = torch.rand((batch_size, num_classes, height, width))

# Convertir y_true a one-hot encoding (opcional, si quieres probar con one-hot encoding)
y_true_one_hot = torch.nn.functional.one_hot(y_true, num_classes).permute(0, 3, 1, 2).float()

In [26]:
y_true = torch.randint(0, 2, (2, 5, 256, 256)).float()
y_pred = torch.rand((2, 5, 256, 256))

In [34]:
dice_a, _ = MetricsA.dice_coefficient(y_pred, y_true)
dice_b = MetricsB.dice_coefficient(y_pred, y_true_one_hot)
_, dice_c = multiclass_dice_coefficient(y_pred, y_true, 5) 

In [35]:
print(dice_a)
print(dice_b)
print(dice_c)

tensor(0.2000)
tensor(0.2857)
0.1999007798731327


In [55]:
def test_metrics(times=10):
    # Parámetros
    batch_size = 4
    num_classes = 5
    height = 512
    width = 512

    mm = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

    for i in range(1, times+1):
        # Generar y_true (Ground Truth) como índices de clase
        # y_true = torch.randint(0, num_classes, (batch_size, height, width))
        # Generar y_pred (Predicciones) como logits
        # y_pred = torch.rand((batch_size, num_classes, height, width))

        # y_true = torch.ones((0, num_classes, (batch_size, height, width)), dtype=torch.int16)
        y_pred = torch.ones((batch_size, num_classes, height, width), dtype=torch.int16)
        # Convertir y_true a one-hot encoding (opcional, si quieres probar con one-hot encoding)
        y_true_one_hot = torch.nn.functional.one_hot(y_true, num_classes).permute(0, 3, 1, 2).float()
    
        dice_a, _ = MetricsA.dice_coefficient(y_pred, y_true)
        dice_b = MetricsB.dice_coefficient(y_pred, y_true_one_hot)
        _, dice_c = multiclass_dice_coefficient(y_pred, y_true, 5)
        dice_m = mm(y_pred, y_true)
    
        print("---------------")
        print(f"test {i}")
        print(f"Dice A: {dice_a.float()}")
        print(f"Dice B: {dice_b.float()}")
        print(f"Dice C: {dice_c}")
        print(f"Dice M: {dice_m}")
        print("---------------")
        mm.reset()

In [56]:
test_metrics()

RuntimeError: "softmax_kernel_impl" not implemented for 'Short'