# Imports y config

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
os.chdir('C:\\Users\\Usuario\\TFG\\digipanca\\')

# __Dice Loss__

In [4]:
from monai.losses import DiceLoss as MONAIDiceLoss
from src.losses import MulticlassDiceLoss

In [5]:
# ======================
# Data
# ======================

batch_size = 4
n_classes = 5
height, width = 512, 512

# Crear tensores aleatorios de logits (NO etiquetas discretas)
y_pred = torch.randn(batch_size, n_classes, height, width)  # Logits de la red
y_true = torch.randint(0, n_classes, (batch_size, height, width))  # Ground truth en etiquetas

# Instancias de las pérdidas
my_dice_loss = MulticlassDiceLoss(ignore_background=True)
monai_dice_loss = MONAIDiceLoss(include_background=False, reduction="mean", softmax=True)

# Calcular pérdidas
loss_my = my_dice_loss(y_pred, y_true)
loss_monai = monai_dice_loss(y_pred, F.one_hot(y_true, num_classes=n_classes).permute(0, 3, 1, 2).float())

# Mostrar resultados
print(f"Dice Loss (Mi implementación): {loss_my.item():.6f}")
print(f"Dice Loss (MONAI): {loss_monai.item():.6f}")

Dice Loss (Mi implementación): 0.799942
Dice Loss (MONAI): 0.799943


## __DiceFocalLoss with 2D and 3D__

In [9]:
from src.losses.dice import DiceFocalLoss

In [19]:
# 2D data (Batch, Channels, Height, Width)
h, w = 512, 512
y_pred_2d = torch.randn(2, 3, h, w)  # 2D logits
y_true_2d = torch.randint(0, 3, (2, h, w))  # 2D labels

loss_fn = DiceFocalLoss()
loss_2d = loss_fn(y_pred_2d, y_true_2d)

print(f"Loss 2D: {loss_2d.item()}")

# 3D data (Batch, Channels, Depth, Height, Width)
y_pred_3d = torch.randn(2, 3, 32, h, w)  # simulating 3D volume with 32 slices
y_true_3d = torch.randint(0, 3, (2, 32, h, w))  # 3D labels

loss_3d = loss_fn(y_pred_3d, y_true_3d)

print(f"Loss 3D: {loss_3d.item()}")

Loss 2D: 1.0128384828567505
Loss 3D: 1.013000726699829


In [12]:
import torch
import torch.nn.functional as F

# Simulamos un tensor de etiquetas con valores de clase (antes de one-hot)
batch_size, depth, height, width = 2, 16, 128, 128  # Ejemplo 3D
num_classes = 4  # Supongamos que hay 4 clases

# Simulamos un batch de etiquetas de clase (sin one-hot, valores entre 0 y num_classes-1)
y_true = torch.randint(0, num_classes, (batch_size, depth, height, width))  # (B, D, H, W)

# Convertimos a one-hot
y_true_one_hot = F.one_hot(y_true, num_classes=num_classes).permute(0, 4, 1, 2, 3).float()  # (B, C, D, H, W)

print(f"Forma original: {y_true.shape}")  # (B, D, H, W)
print(f"Forma one-hot: {y_true_one_hot.shape}")  # (B, C, D, H, W) → debe incluir C


Forma original: torch.Size([2, 16, 128, 128])
Forma one-hot: torch.Size([2, 4, 16, 128, 128])


In [13]:
print("Valores únicos en el tensor convertido:", torch.unique(y_true_one_hot))

Valores únicos en el tensor convertido: tensor([0., 1.])


In [15]:
for i in range(num_classes):
    print(f"Voxeles de clase {i} en one-hot:", (y_true_one_hot[:, i] == 1).sum().item(), "vs original:", (y_true == i).sum().item())

Voxeles de clase 0 en one-hot: 131531 vs original: 131531
Voxeles de clase 1 en one-hot: 131152 vs original: 131152
Voxeles de clase 2 en one-hot: 130723 vs original: 130723
Voxeles de clase 3 en one-hot: 130882 vs original: 130882
