In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir('C:\\Users\\Usuario\\TFG\\digipanca\\')

In [3]:
import torch
import random
import numpy as np

# __Functions__

## Convert to one-hot

In [4]:
def convert_to_one_hot(y_pred, y_true):
    def is_one_hot(tensor):
        """
        Check if the tensor is one-hot encoded.
        """
        return (tensor.sum(dim=1) == 1).all() and \
               torch.all((tensor == 0) | (tensor == 1))
    
    # Check if the input is already one-hot encoded
    if is_one_hot(y_pred) and is_one_hot(y_true):
        return y_pred, y_true
    
    # Check if the input is 2D or 3D
    if y_pred.dim() == 4 and y_true.dim() == 3: # 2D case
        B, C, H, W = y_pred.shape
        n_classes = C

        # Convert y_pred to one-hot encoding
        y_pred_classes = torch.argmax(y_pred, dim=1, keepdim=True)
        y_pred_one_hot = torch.zeros(B, n_classes, H, W, device=y_pred.device)
        y_pred_one_hot.scatter_(1, y_pred_classes, 1)

        # Convert y_true to one-hot encoding
        y_true_one_hot = torch.zeros(B, n_classes, H, W, device=y_true.device)
        y_true_one_hot.scatter_(1, y_true.unsqueeze(1).long(), 1)

        return y_pred_one_hot, y_true_one_hot
    
    elif y_pred.dim() == 5 and y_true.dim() == 4:   # 3D case
        B, C, D, H, W = y_pred.shape
        n_classes = C

        # Convert y_pred to one-hot encoding
        y_pred_classes = torch.argmax(y_pred, dim=1, keepdim=True)
        y_pred_one_hot = torch.zeros(B, n_classes, D, H, W, device=y_pred.device)
        y_pred_one_hot.scatter_(1, y_pred_classes, 1)

        # Convert y_true to one-hot encoding
        y_true_one_hot = torch.zeros(B, n_classes, D, H, W, device=y_true.device)
        y_true_one_hot.scatter_(1, y_true.unsqueeze(1).long(), 1)

        return y_pred_one_hot, y_true_one_hot

    else:
        raise ValueError("Input tensors must be either 2D or 3D.")

## Dice with mean (`segmentation.py`)

In [5]:
def dice_with_mean(y_pred, y_true, smooth=1e-12):
    # Convert to one-hot if inputs are class indices or logits
    y_pred_one_hot, y_true_one_hot = convert_to_one_hot(y_pred, y_true)
    C = y_true_one_hot.size(1)
    sum_dims = tuple(range(2, y_true_one_hot.ndim))
    
    # Compute intersection and union
    intersection = torch.sum(y_pred_one_hot * y_true_one_hot, dim=sum_dims)
    union = torch.sum(y_pred_one_hot, dim=sum_dims) \
          + torch.sum(y_true_one_hot, dim=sum_dims)

    # Compute Dice score
    dice_scores = (2. * intersection + smooth) / (union + smooth)
    dice_scores = dice_scores.mean(dim=0)

    dice_dict = {f"dice_class_{i}": dice_scores[i].item() for i in range(C)}
    dice_dict["dice_mean"] = dice_scores.mean().item()
    
    return dice_scores.mean(), dice_dict

## Dice with loop (`segmentation_bak.py` with `segmentation.py`'s convert to one-hot method)

In [6]:
def dice_with_loop(y_pred, y_true, smooth=1e-12):
    # Convert to one-hot if inputs are class indices or logits
    y_pred_one_hot, y_true_one_hot = convert_to_one_hot(y_pred, y_true)
    n_classes = y_true_one_hot.size(1)

    # Calculate dice for each class
    dice_scores = []
    class_dice = {}
    
    for i in range(n_classes):
        pred_class = y_pred_one_hot[:, i, :, :]
        true_class = y_true_one_hot[:, i, :, :]
        
        intersection = torch.sum(pred_class * true_class)
        union = torch.sum(pred_class) + torch.sum(true_class)
        dice = (2.0 * intersection + smooth) / (union + smooth)
        dice_scores.append(dice)
        class_dice[f"dice_class_{i}"] = dice.item()
    
    mean_dice = torch.mean(torch.stack(dice_scores))
    class_dice["dice_mean"] = mean_dice.item()
    
    return mean_dice, class_dice

## New method: sum

In [7]:
def dice_with_sum(y_pred, y_true, smooth=1e-12):
    """
    Compute Dice coefficient using global sums over the batch to handle class imbalance.
    
    Parameters
    ----------
    y_pred : torch.Tensor
        Predicted segmentation mask (class indices, logits or one-hot)
    y_true : torch.Tensor
        Ground truth segmentation mask (class indices, logits or one-hot)
    smooth : float, optional
        Smoothing factor to avoid division by zero
    
    Returns
    -------
    tuple
        (mean_dice, per_class_dice_dict)
    """
    # Convert to one-hot
    y_pred_one_hot, y_true_one_hot = convert_to_one_hot(y_pred, y_true)
    C = y_true_one_hot.size(1)
    sum_dims = tuple(range(2, y_true_one_hot.ndim))  # dimensions other than (B, C)

    # Compute intersection and union
    intersection = torch.sum(y_pred_one_hot * y_true_one_hot, dim=sum_dims)  # (B, C)
    union = torch.sum(y_pred_one_hot, dim=sum_dims) + torch.sum(y_true_one_hot, dim=sum_dims)  # (B, C)

    # Sum over batch for global metrics
    intersection = intersection.sum(dim=0)  # (C,)
    union = union.sum(dim=0)  # (C,)

    dice_scores = 2. * intersection / (union + smooth)  # (C,)

    dice_dict = {f"dice_class_{i}": dice_scores[i].item() for i in range(C)}
    dice_dict["dice_mean"] = dice_scores.mean().item()

    return dice_scores.mean(), dice_dict

# __Test__

In [8]:
from src.models.custom_deeplabv3 import CustomDeepLabV3 as dl

In [9]:
B, num_classes, H, W = 4, 5, 256, 256
model = dl(num_classes=num_classes, dropout_rate=0.2).eval()

## Generate prediction

In [10]:
channels = 1
input_tensor = torch.randn(B, channels, H, W)
print(input_tensor.shape)
with torch.no_grad():
    output = model(input_tensor)['out']
print(output.shape)

torch.Size([4, 1, 256, 256])
torch.Size([4, 5, 256, 256])


## Generate mask

In [11]:
random_mask = torch.randint(0, num_classes, (B, H, W))
print(random_mask.shape)

torch.Size([4, 256, 256])


## Compute scores

In [12]:
y_pred = output.clone()
y_true = random_mask.clone()
print(y_pred.shape)
print(y_true.shape)

torch.Size([4, 5, 256, 256])
torch.Size([4, 256, 256])


In [13]:
dice_mean = dice_with_mean(y_pred, y_true)
dice_loop = dice_with_loop(y_pred, y_true)

In [14]:
print("EQUALS:", (dice_mean==dice_loop))
print('-'*40)
print(dice_mean)
print('-'*40)
print(dice_loop)

EQUALS: False
----------------------------------------
(tensor(0.0670), {'dice_class_0': 7.62340438964058e-17, 'dice_class_1': 0.33489900827407837, 'dice_class_2': 7.665002310029137e-17, 'dice_class_3': 7.657760840074602e-17, 'dice_class_4': 7.61560043686951e-17, 'dice_mean': 0.0669798031449318})
----------------------------------------
(tensor(0.0670), {'dice_class_0': 1.9058147014631927e-17, 'dice_class_1': 0.33490118384361267, 'dice_class_2': 1.9162227842387024e-17, 'dice_class_3': 1.9143885939484273e-17, 'dice_class_4': 1.9038190455173472e-17, 'dice_mean': 0.06698023527860641})


### VS sum version

In [15]:
dice_sum = dice_with_sum(y_pred, y_true)

In [16]:
print('_'*40)
print("mean == sum:", (dice_mean==dice_sum))
print('-'*40)
print(dice_mean)
print('-'*40)
print(dice_sum)
print('_'*40)
print("loop == sum:", (dice_loop==dice_sum))
print('-'*40)
print(dice_loop)
print('-'*40)
print(dice_sum)

________________________________________
mean == sum: False
----------------------------------------
(tensor(0.0670), {'dice_class_0': 7.62340438964058e-17, 'dice_class_1': 0.33489900827407837, 'dice_class_2': 7.665002310029137e-17, 'dice_class_3': 7.657760840074602e-17, 'dice_class_4': 7.61560043686951e-17, 'dice_mean': 0.0669798031449318})
----------------------------------------
(tensor(0.0670), {'dice_class_0': 0.0, 'dice_class_1': 0.33490118384361267, 'dice_class_2': 0.0, 'dice_class_3': 0.0, 'dice_class_4': 0.0, 'dice_mean': 0.06698023527860641})
________________________________________
loop == sum: False
----------------------------------------
(tensor(0.0670), {'dice_class_0': 1.9058147014631927e-17, 'dice_class_1': 0.33490118384361267, 'dice_class_2': 1.9162227842387024e-17, 'dice_class_3': 1.9143885939484273e-17, 'dice_class_4': 1.9038190455173472e-17, 'dice_mean': 0.06698023527860641})
----------------------------------------
(tensor(0.0670), {'dice_class_0': 0.0, 'dice_clas

### VS previous implementation, using softmax (`segmentation_bak.py`)

In [17]:
def dice_bak(y_pred, y_true, smooth=1e-12):
    # Convert to one-hot if inputs are class indices
    if y_pred.dim() == 3:
        # Convert predicted class indices to one-hot
        n_classes = torch.max(y_true).item() + 1
        y_pred_one_hot = torch.zeros(
            y_pred.size(0), n_classes, y_pred.size(1), y_pred.size(2), 
            device=y_pred.device
        )
        y_pred_one_hot.scatter_(1, y_pred.unsqueeze(1), 1)
        
        y_true_one_hot = torch.zeros(
            y_true.size(0), n_classes, y_true.size(1), y_true.size(2), 
            device=y_true.device
        )
        y_true_one_hot.scatter_(1, y_true.unsqueeze(1), 1)
        # print("if")
    else:
        # If already in form [B, C, H, W] (logits or one-hot)
        if y_pred.dim() == 4 and y_true.dim() == 3:
            # y_pred is [B, C, H, W] logits and y_true is [B, H, W] indices
            n_classes = y_pred.size(1)
            y_pred_one_hot = torch.nn.functional.softmax(y_pred, dim=1)
            
            y_true_one_hot = torch.zeros(
                y_true.size(0), n_classes, y_true.size(1), y_true.size(2), 
                device=y_true.device
            )
            y_true_one_hot.scatter_(1, y_true.unsqueeze(1), 1)
            # print("else if")
        else:
            # Assume both are already in proper format
            y_pred_one_hot = y_pred
            y_true_one_hot = y_true
            n_classes = y_pred.size(1)
            # print("else")
    
    # Calculate dice for each class
    dice_scores = []
    class_dice = {}
    
    for i in range(n_classes):
        pred_class = y_pred_one_hot[:, i, :, :]
        true_class = y_true_one_hot[:, i, :, :]
        
        intersection = torch.sum(pred_class * true_class)
        union = torch.sum(pred_class) + torch.sum(true_class)
        dice = (2.0 * intersection + smooth) / (union + smooth)
        dice_scores.append(dice)
        class_dice[f"dice_class_{i}"] = dice.item()
    
    mean_dice = torch.mean(torch.stack(dice_scores))
    class_dice["dice_mean"] = mean_dice.item()
    
    return mean_dice, class_dice

In [18]:
dice_prev = dice_bak(y_pred, y_true)

In [19]:
print('_'*40)
print("mean == prev:", (dice_mean==dice_prev))
print('-'*40)
print(dice_mean)
print('-'*40)
print(dice_prev)
print('_'*40)
print("loop == prev:", (dice_loop==dice_prev))
print('-'*40)
print(dice_loop)
print('-'*40)
print(dice_prev)
print('_'*40)
print("sum == prev:", (dice_sum==dice_prev))
print('-'*40)
print(dice_sum)
print('-'*40)
print(dice_prev)

________________________________________
mean == prev: False
----------------------------------------
(tensor(0.0670), {'dice_class_0': 7.62340438964058e-17, 'dice_class_1': 0.33489900827407837, 'dice_class_2': 7.665002310029137e-17, 'dice_class_3': 7.657760840074602e-17, 'dice_class_4': 7.61560043686951e-17, 'dice_mean': 0.0669798031449318})
----------------------------------------
(tensor(0.1984), {'dice_class_0': 0.20519721508026123, 'dice_class_1': 0.22346408665180206, 'dice_class_2': 0.18394768238067627, 'dice_class_3': 0.17030583322048187, 'dice_class_4': 0.20903021097183228, 'dice_mean': 0.19838900864124298})
________________________________________
loop == prev: False
----------------------------------------
(tensor(0.0670), {'dice_class_0': 1.9058147014631927e-17, 'dice_class_1': 0.33490118384361267, 'dice_class_2': 1.9162227842387024e-17, 'dice_class_3': 1.9143885939484273e-17, 'dice_class_4': 1.9038190455173472e-17, 'dice_mean': 0.06698023527860641})
------------------------

# __VS MONAI__

In [20]:
from monai.metrics import DiceMetric

In [21]:
DMM = DiceMetric(
    include_background=True,  # o False, según tu caso
    reduction="mean"         # puede ser "mean", "sum", "none"
)
DMM.reset()

In [22]:
# Simula o convierte tus predicciones y ground truth
y_pred_onehot, y_true_onehot = convert_to_one_hot(y_pred, y_true)

In [23]:
monai_dice = DMM(y_pred_onehot, y_true_onehot)
print(monai_dice.shape)
print(DMM.aggregate())
DMM.reset()

torch.Size([4, 5])
tensor([0.0670])


In [24]:
dice_sum2 = dice_with_sum(y_pred_onehot, y_true_onehot)
print(dice_sum2)

(tensor(0.0670), {'dice_class_0': 0.0, 'dice_class_1': 0.33490118384361267, 'dice_class_2': 0.0, 'dice_class_3': 0.0, 'dice_class_4': 0.0, 'dice_mean': 0.06698023527860641})


In [25]:
dice_mean2 = dice_with_mean(y_pred_onehot, y_true_onehot)
print(dice_mean2)

(tensor(0.0670), {'dice_class_0': 7.62340438964058e-17, 'dice_class_1': 0.33489900827407837, 'dice_class_2': 7.665002310029137e-17, 'dice_class_3': 7.657760840074602e-17, 'dice_class_4': 7.61560043686951e-17, 'dice_mean': 0.0669798031449318})


In [26]:
print(monai_dice.mean(0))

tensor([0.0000, 0.3349, 0.0000, 0.0000, 0.0000])


In [27]:
import torch
from monai.metrics import DiceMetric

def compare_dice_with_monai(y_pred, y_true, include_background=False, reduction="mean", smooth=1e-12):
    """
    Compara tu función `dice_with_sum` con la implementación de MONAI.

    Parámetros
    ----------
    y_pred : torch.Tensor
        Predicciones (logits, class indices, o one-hot).
    y_true : torch.Tensor
        Ground truth (class indices o one-hot).
    include_background : bool
        Si incluir la clase de fondo (índice 0) en el cálculo.
    reduction : str
        'mean', 'sum', o 'none'.
    smooth : float
        Valor de suavizado para tu función.

    Returns
    -------
    dict
        Diccionario con ambos resultados y diferencia.
    """
    
    # Primero: tu función
    my_dice, my_dice_dict = dice_with_sum(y_pred, y_true, smooth)

    # Ahora: MONAI
    from monai.metrics.utils import get_mask_edges
    from monai.networks import one_hot

    # Asegurarse de que están en formato one-hot y float
    def to_one_hot(tensor, num_classes):
        if tensor.dim() == 3:
            tensor = tensor.unsqueeze(1)  # (B, H, W) → (B, 1, H, W)
        elif tensor.dim() == 4 and tensor.size(1) != 1:
            tensor = tensor.unsqueeze(1)  # (B, D, H, W) → (B, 1, D, H, W)
        return one_hot(tensor.long(), num_classes=num_classes).float()


    # Determinar número de clases
    if y_pred.dim() >= 4 and y_pred.size(1) > 1:
        n_classes = y_pred.size(1)
        y_pred_classes = torch.argmax(y_pred, dim=1)
    else:
        y_pred_classes = y_pred
        n_classes = int(torch.max(torch.cat([y_pred, y_true])) + 1)

    y_pred_monai = to_one_hot(y_pred_classes, n_classes)
    y_true_monai = to_one_hot(y_true, n_classes)

    # MONAI DiceMetric
    monai_dice_metric = DiceMetric(
        include_background=include_background,
        reduction=reduction,
        get_not_nans=False
    )
    monai_dice = monai_dice_metric(y_pred_monai, y_true_monai)

    if isinstance(monai_dice, torch.Tensor):
        monai_dice_val = monai_dice.mean().item() if monai_dice.numel() > 1 else monai_dice.item()
    else:
        monai_dice_val = float(monai_dice)

    return {
        "your_dice": my_dice.item(),
        "monai_dice": monai_dice_val,
        "difference": abs(my_dice.item() - monai_dice_val),
        "your_per_class": my_dice_dict
    }


In [28]:
# y_pred y y_true pueden ser logits, class indices, o one-hot
result = compare_dice_with_monai(
    y_pred, 
    y_true, 
    include_background=False, 
    reduction="sum"  # o "mean", según quieras comparar
)

print(result)


{'your_dice': 0.06698023527860641, 'monai_dice': 0.08372475206851959, 'difference': 0.016744516789913177, 'your_per_class': {'dice_class_0': 0.0, 'dice_class_1': 0.33490118384361267, 'dice_class_2': 0.0, 'dice_class_3': 0.0, 'dice_class_4': 0.0, 'dice_mean': 0.06698023527860641}}


# __Test 2__

In [29]:
import torch
from monai.metrics import DiceMetric

# MONAI (por defecto, similar a dice_with_mean)
dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric(y_pred_onehot, y_true_onehot)
monai_dice = dice_metric.aggregate().item()
print(f"MONAI (por imagen): {monai_dice}")

# MONAI (suma global, similar a dice_with_sum)
dice_metric_global = DiceMetric(include_background=True, reduction="sum_batch")
dice_metric_global(y_pred_onehot, y_true_onehot)
monai_dice_global = dice_metric_global.aggregate().mean().item()
print(f"MONAI (global): {monai_dice_global}")

# Tu implementación
dice_mean, _ = dice_with_mean(y_pred, y_true)
dice_sum, _ = dice_with_sum(y_pred, y_true)
print(f"dice_with_mean: {dice_mean.item()}")
print(f"dice_with_sum: {dice_sum.item()}")

MONAI (por imagen): 0.0669798031449318
MONAI (global): 0.2679192125797272
dice_with_mean: 0.0669798031449318
dice_with_sum: 0.06698023527860641


In [31]:
def dice_with_sum_monai_style(y_pred_one_hot, y_true_one_hot, smooth=1e-12):
    intersection = torch.sum(y_pred_one_hot * y_true_one_hot, dim=(2, 3))  # Suma sobre H, W
    union = torch.sum(y_pred_one_hot, dim=(2, 3)) + torch.sum(y_true_one_hot, dim=(2, 3))
    
    # Suma global sobre el batch (como MONAI con reduction="mean_batch")
    intersection = intersection.sum(dim=0)  # [C]
    union = union.sum(dim=0)  # [C]
    
    dice_scores = (2. * intersection + smooth) / (union + smooth)  # [C]
    return dice_scores

# Calcula con tus datos one-hot
your_dice = dice_with_sum_monai_style(y_pred_onehot, y_true_onehot)
print("Tu Dice (ajustado):", your_dice)  # Debería coincidir con MONAI

Tu Dice (ajustado): tensor([1.9058e-17, 3.3490e-01, 1.9162e-17, 1.9144e-17, 1.9038e-17])


# __Using MONAI__

In [32]:
def test_sum():
    # MONAI
    dice_metric = DiceMetric(include_background=True, reduction="mean_batch")
    monai_full = dice_metric(y_pred_onehot, y_true_onehot)
    monai_dice = dice_metric.aggregate()
    print(f"MONAI (all): {monai_full}")
    print(f"MONAI (sum_batch): {monai_dice}")
    print(f"MONAI (mean sum_batch): {monai_dice.mean().item()}")
    
    # dice_with_sum
    own = dice_with_sum(y_pred_onehot, y_true_onehot)
    print(f"SUM (sum_batch): {own[1]}")
    print(f"SUM (mean sum_batch): {own[0].item()}")

In [33]:
test_sum()

MONAI (all): tensor([[0.0000, 0.3366, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3325, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3335, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3369, 0.0000, 0.0000, 0.0000]])
MONAI (sum_batch): tensor([0.0000, 0.3349, 0.0000, 0.0000, 0.0000])
MONAI (mean sum_batch): 0.0669798031449318
SUM (sum_batch): {'dice_class_0': 0.0, 'dice_class_1': 0.33490118384361267, 'dice_class_2': 0.0, 'dice_class_3': 0.0, 'dice_class_4': 0.0, 'dice_mean': 0.06698023527860641}
SUM (mean sum_batch): 0.06698023527860641


# __Another test__

In [None]:
import torch
from monai.metrics import DiceMetric
def dice_coeff(y_pred, y_true, smooth=1e-5):
    n = y_true.size(0)
    print(n)
    pred_flat = y_pred.view(n, -1)
    gt_flat = y_true.view(n, -1)

    intersection = (pred_flat * gt_flat).sum(1)
    unionset = pred_flat.sum(1) + gt_flat.sum(1)
    loss = (2. * intersection + smooth) / (unionset + smooth)

    return loss.sum() / n, loss

In [None]:
dice_sum = dice_with_sum(y_pred_onehot, y_true_onehot)
dice_mean = dice_with_mean(y_pred_onehot, y_true_onehot)
DMM = DiceMetric(include_background=True, reduction="none")
dice_monai = DMM(y_pred_onehot, y_true_onehot)
dice_new, dt = dice_coeff(y_pred_onehot, y_true_onehot)

In [None]:
print(dice_sum)
print(dice_mean)
print(dice_monai)
print(dice_new)