In [1]:
import torch
from src.metrics.segmentation import SegmentationMetrics as SM
import monai.metrics as mm

In [2]:
batch_size = 4
num_classes = 5
height, width = 256, 256

masks_2d = torch.randint(0, num_classes, (batch_size, height, width))
print(masks_2d.shape)

# (logits)
y_pred_logits = torch.randn(batch_size, num_classes, height, width).float()
print(y_pred_logits.shape)

torch.Size([4, 256, 256])
torch.Size([4, 5, 256, 256])


In [3]:
dice = SM.dice_coefficient(y_pred_logits, masks_2d)
print(dice)

(tensor(0.2007), {'dice_class_0': 0.20134757459163666, 'dice_class_1': 0.19990158081054688, 'dice_class_2': 0.20083345472812653, 'dice_class_3': 0.20156258344650269, 'dice_class_4': 0.19991423189640045, 'dice_mean': 0.20071187615394592})


In [4]:
dice_monai, _ = mm.DiceHelper(include_background=True, sigmoid=True, softmax=True, reduction='none')(y_pred_logits, masks_2d)
print(dice_monai)

tensor([[3.8419, 3.7979, 3.9360, 3.7039, 3.8813],
        [4.1168, 3.6410, 3.7574, 3.6230, 3.9416],
        [3.8023, 3.7547, 3.9631, 4.0147, 3.5757],
        [3.9806, 3.7606, 3.8083, 3.7821, 4.1642]])


In [6]:
def to_one_hot(masks, num_classes):
    """
    Convierte una máscara (2D o 3D) a un formato one-hot.

    Args:
        masks (torch.Tensor): Máscara de etiquetas de tamaño [batch_size, ..., height, width] o [batch_size, ..., depth, height, width].
        num_classes (int): Número de clases a convertir.

    Returns:
        torch.Tensor: Máscara en formato one-hot de tamaño [batch_size, num_classes, ..., height, width] o [batch_size, num_classes, ..., depth, height, width].
    """
    # Verifica si la máscara es 2D o 3D
    if masks.dim() == 3:  # Caso 2D (batch_size, height, width)
        batch_size, height, width = masks.shape
        # Reshape para agregar la dimensión de clases
        masks_one_hot = torch.zeros(batch_size, num_classes, height, width, device=masks.device)
        # Convierte los índices en one-hot
        masks_one_hot.scatter_(1, masks.unsqueeze(1), 1)
    
    elif masks.dim() == 4:  # Caso 3D (batch_size, depth, height, width)
        batch_size, depth, height, width = masks.shape
        # Reshape para agregar la dimensión de clases
        masks_one_hot = torch.zeros(batch_size, num_classes, depth, height, width, device=masks.device)
        # Convierte los índices en one-hot
        masks_one_hot.scatter_(1, masks.unsqueeze(1), 1)

    return masks_one_hot

In [7]:
y_pred = torch.argmax(y_pred_logits, dim=1, keepdim=True)
y_pred_one_hot = torch.zeros(batch_size, num_classes, height, width)
y_pred_one_hot.scatter_(1, y_pred, 1)
y_true_one_hot = to_one_hot(masks_2d, num_classes)
dice_monai = mm.DiceMetric(reduction='none')(y_pred_one_hot, y_true_one_hot)
print(dice_monai.mean(dim=0))
print(dice_monai.mean(dim=0).mean())

tensor([0.2013, 0.1999, 0.2008, 0.2016, 0.1999])
tensor(0.2007)


In [8]:
import monai.transforms as MT
def monai_to_one_hot(y, N):
    ad = MT.AsDiscrete(argmax=True, to_onehot=N)
    y_oh = ad(y)
    return y_oh

In [9]:
y_pred_oh_monai = monai_to_one_hot(y_pred_logits, num_classes)

In [10]:
print((y_pred_one_hot.all()==y_pred_oh_monai.all()).item())

True


In [13]:
import torch.nn.functional as F
def test_conversion():
    batch_size = 4
    num_classes = 5
    height, width = 256, 256
    
    y_true = torch.randint(0, num_classes, (batch_size, height, width))
    print("y_true:", y_true.shape)
    
    # (logits)
    y_pred = torch.randn(batch_size, num_classes, height, width).float()
    print("y_pred:", y_pred.shape)

    y_pred_softmax = F.softmax(y_pred, dim=1)
    y_pred_class_soft = torch.argmax(y_pred_softmax, dim=1, keepdim=True) # Argmax to softmax y_pred (probs)
    y_pred_class = torch.argmax(y_pred, dim=1, keepdim=True) # Argmax to logits

    print(f"softmax: \n\t{y_pred_class_soft.shape}\n\t{y_pred_class_soft.unique()}")
    print(f"logits: \n\t{y_pred_class.shape}\n\t{y_pred_class.unique()}")

    print('-'*35)
    print('EQUALITY TEST')
    print(y_pred_class==y_pred_class_soft)
    print('Equals:', (y_pred_class.all()==y_pred_class_soft.all()).item())

In [14]:
test_conversion()

y_true: torch.Size([4, 256, 256])
y_pred: torch.Size([4, 5, 256, 256])
softmax: 
	torch.Size([4, 1, 256, 256])
	tensor([0, 1, 2, 3, 4])
logits: 
	torch.Size([4, 1, 256, 256])
	tensor([0, 1, 2, 3, 4])
-----------------------------------
EQUALITY TEST
tensor([[[[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]]],


        [[[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]]],


        [[[True, True, True,  ..., True, True, True],
          

In [19]:
import time
def time_test():
    B, D, H, W, C = 4, 64, 256, 256, 5  # 5 clases
    # Datos en GPU si está disponible
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    y_true = torch.randint(0, C, (B, D, H, W), device=device)
    
    # Benchmark F.one_hot
    start = time.time()
    y_one_hot_f = torch.nn.functional.one_hot(y_true, C).permute(0, 4, 1, 2, 3)
    print(f"F.one_hot time: {time.time() - start:.6f} sec")
    
    # Benchmark scatter_
    start = time.time()
    y_one_hot_s = torch.zeros((B, C, D, H, W), device=device)
    y_one_hot_s.scatter_(1, y_true.unsqueeze(1), 1)
    print(f"scatter_ time: {time.time() - start:.6f} sec")

In [26]:
time_test()

F.one_hot time: 0.101663 sec
scatter_ time: 0.059778 sec
