# Imports y config

In [1]:
%load_ext autoreload
%autoreload 2

In [16]:
import os
import numpy as np
import torch
from monai.losses import FocalLoss as MONAIFocalLoss

In [5]:
os.chdir('C:\\Users\\Usuario\\TFG\\digipanca\\')

# `y_true` e `y_pred`

In [28]:
# Datos de prueba
batch_size, num_classes, height, width = 2, 4, 128, 128

# Simulamos predicciones (logits antes de softmax)
y_pred = torch.randn(batch_size, num_classes, height, width, requires_grad=True)

# Simulamos etiquetas verdaderas (enteros de 0 a num_classes-1)
y_true = torch.randint(0, num_classes, (batch_size, height, width))

# FocalLoss

In [54]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.losses import FocalLoss as MONAIFocalLoss

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, reduction='mean', include_background=False):
        """
        Implementación de Focal Loss basada en MONAI.

        Parámetros:
        -----------
        gamma : float
            Factor de modulación que controla el enfoque en muestras difíciles.
        reduction : str
            Método de reducción ('mean', 'sum' o 'none').
        """
        super(FocalLoss, self).__init__()
        self.monai_focal_loss = MONAIFocalLoss(gamma=gamma, reduction=reduction, include_background=include_background)

    def forward(self, y_pred, y_true):
        """
        Calcula la pérdida focal usando la implementación de MONAI.

        Parámetros:
        -----------
        y_pred : tensor de forma (B, C, H, W)
            Predicciones del modelo (logits, NO softmax aplicado).
        y_true : tensor de forma (B, H, W)
            Etiquetas verdaderas con valores en [0, C-1] (NO one-hot).

        Retorna:
        --------
        Tensor escalar con la pérdida focal.
        """
        num_classes = y_pred.shape[1]

        # Convertir y_true a one-hot si es necesario
        if y_true.dim() == 3:  # (B, H, W) → (B, C, H, W)
            y_true = F.one_hot(y_true, num_classes).permute(0, 3, 1, 2).float()

        return self.monai_focal_loss(y_pred, y_true)

# Testing

In [32]:
# Instanciamos ambas pérdidas con gamma=2.0
custom_focal_loss = FocalLoss(gamma=2.0, reduction='mean')
monai_focal_loss = MONAIFocalLoss(gamma=2.0)

# Calculamos la pérdida con la implementación propia
loss_custom = custom_focal_loss(y_pred, y_true).item()

# MONAI requiere one-hot encoding de y_true
y_true_one_hot = F.one_hot(y_true, num_classes).permute(0, 3, 1, 2).float()
loss_monai = monai_focal_loss(y_pred, y_true_one_hot).item()

# Comparación de resultados
print(f"Pérdida FocalLoss Implementación Propia: {loss_custom:.6f}")
print(f"Pérdida FocalLoss MONAI: {loss_monai:.6f}")
print(f"¿Son iguales? {'Sí' if abs(loss_custom - loss_monai) < 1e-6 else 'No'}")

Pérdida FocalLoss Implementación Propia: 0.311796
Pérdida FocalLoss MONAI: 0.347751
¿Son iguales? No


In [52]:
# Crear datos de prueba
batch_size, num_classes, height, width = 2, 4, 128, 128
y_pred = torch.randn(batch_size, num_classes, height, width, requires_grad=True)
y_true = torch.randint(0, num_classes, (batch_size, height, width))

# Instanciar ambas pérdidas
custom_focal_loss = FocalLoss(gamma=2.0, reduction='mean')
monai_focal_loss = MONAIFocalLoss(gamma=2.0, reduction='mean')

# Calcular pérdidas
loss_custom = custom_focal_loss(y_pred, y_true).item()
loss_monai = monai_focal_loss(y_pred, F.one_hot(y_true, num_classes).permute(0, 3, 1, 2).float()).item()

print(f"Pérdida FocalLoss Implementación Propia: {loss_custom:.6f}")
print(f"Pérdida FocalLoss MONAI: {loss_monai:.6f}")
print(f"¿Son iguales? {'Sí' if abs(loss_custom - loss_monai) < 1e-6 else 'No'}")


Pérdida FocalLoss Implementación Propia: 0.348555
Pérdida FocalLoss MONAI: 0.348555
¿Son iguales? Sí


In [55]:
focal_with_bg = FocalLoss(gamma=2.0, include_background=True)
focal_no_bg = FocalLoss(gamma=2.0, include_background=False)

loss_with_bg = focal_with_bg(y_pred, y_true).item()
loss_no_bg = focal_no_bg(y_pred, y_true).item()

print(f"Pérdida con fondo: {loss_with_bg:.6f}")
print(f"Pérdida sin fondo: {loss_no_bg:.6f}")

Pérdida con fondo: 0.348555
Pérdida sin fondo: 0.348450


In [56]:
unique, counts = torch.unique(y_true, return_counts=True)
print("Distribución de clases en y_true:", dict(zip(unique.tolist(), counts.tolist())))


Distribución de clases en y_true: {0: 8165, 1: 8313, 2: 8232, 3: 8058}
