# _Imports & config_

In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import os

In [3]:
os.chdir('C:\\Users\\Usuario\\TFG\\digipanca\\')

# __Class__

In [27]:
import torch
import torch.nn.functional as F
from monai.metrics import DiceMetric, MeanIoU, ConfusionMatrixMetric
from monai.transforms import AsDiscrete

class SegmentationMetrics:
    """
    Unified segmentation metrics for both 2D and 3D pancreas segmentation, using MONAI.
    """
    def __init__(self):
        self.dice_metric = DiceMetric(include_background=True, reduction="none")
        self.iou_metric = MeanIoU(include_background=True, reduction="none")
        self.confusion_matrix = ConfusionMatrixMetric(metric_name=["precision", "recall"], reduction="none")
        self.post_pred = AsDiscrete(argmax=True, to_onehot=None)  # Convert logits to one-hot
        self.post_true = AsDiscrete(to_onehot=None)  # Convert labels to one-hot

    def compute_metrics(self, y_pred, y_true):
        num_classes = y_pred.shape[1]
    
        # Update one-hot transform for the number of classes
        self.post_pred.to_onehot = num_classes
        self.post_true.to_onehot = num_classes
    
        # Convert logits to one-hot predictions
        y_pred_one_hot = self.post_pred(y_pred)
    
        # Verificar si la forma de y_true tiene un canal de tamaño 1
        if y_true.ndim == y_pred.ndim - 1:  
            # Asegúrate de que y_true tenga la forma correcta (B, 1, H, W) para 2D
            y_true = y_true.unsqueeze(1)  # Añadir el canal si falta
    
        print("y_true shape after unsqueeze:", y_true.shape)  # Depuración de la forma
    
        # Convertir a one-hot
        y_true_one_hot = self.post_true(y_true)
    
        # Compute Dice score per class
        dice_scores = self.dice_metric(y_pred_one_hot, y_true_one_hot).cpu().numpy()
        mean_dice = dice_scores.mean()
    
        # Compute IoU score per class
        iou_scores = self.iou_metric(y_pred_one_hot, y_true_one_hot).cpu().numpy()
        mean_iou = iou_scores.mean()
    
        # Compute Precision and Recall per class
        precision, recall = self.confusion_matrix(y_pred_one_hot, y_true_one_hot)
        precision = precision.cpu().numpy()
        recall = recall.cpu().numpy()
        mean_precision = precision.mean()
        mean_recall = recall.mean()
    
        # Format results
        metrics = {
            "dice_mean": mean_dice,
            "iou_mean": mean_iou,
            "precision_mean": mean_precision,
            "recall_mean": mean_recall,
        }
    
        for i in range(num_classes):
            metrics[f"dice_class_{i}"] = dice_scores[i]
            metrics[f"iou_class_{i}"] = iou_scores[i]
            metrics[f"precision_class_{i}"] = precision[i]
            metrics[f"recall_class_{i}"] = recall[i]
    
        return metrics
    
    # def compute_metrics(self, y_pred, y_true):
    #     """
    #     Compute Dice, IoU, Precision, and Recall scores for segmentation.
        
    #     Parameters
    #     ----------
    #     y_pred : torch.Tensor
    #         Predicted logits of shape (B, C, H, W) for 2D or (B, C, D, H, W) for 3D.
    #     y_true : torch.Tensor
    #         Ground truth labels of shape (B, H, W) for 2D or (B, D, H, W) for 3D.
        
    #     Returns
    #     -------
    #     dict
    #         Dictionary with Dice, IoU, Precision, and Recall scores per class and mean.
    #     """
    #     num_classes = y_pred.shape[1]
        
    #     # Update the one-hot transform for the number of classes
    #     self.post_pred.to_onehot = num_classes
    #     self.post_true.to_onehot = num_classes
        
    #     # Convert logits to one-hot predictions
    #     y_pred_one_hot = self.post_pred(y_pred)
    #     print(y_true.shape)
    #     # y_true = y_true = y_true.unsqueeze(1) if y_true.ndim == y_pred.ndim - 1 else y_true
    #     if y_true.ndim == y_pred.ndim - 1:  
    #         y_true = y_true.unsqueeze(1)
    #     print(y_true.shape)
    #     y_true_one_hot = self.post_true(y_true)
        
    #     # Compute Dice score per class
    #     dice_scores = self.dice_metric(y_pred_one_hot, y_true_one_hot).cpu().numpy()
    #     mean_dice = dice_scores.mean()
        
    #     # Compute IoU score per class
    #     iou_scores = self.iou_metric(y_pred_one_hot, y_true_one_hot).cpu().numpy()
    #     mean_iou = iou_scores.mean()
        
    #     # Compute Precision and Recall per class
    #     precision, recall = self.confusion_matrix(y_pred_one_hot, y_true_one_hot)
    #     precision = precision.cpu().numpy()
    #     recall = recall.cpu().numpy()
    #     mean_precision = precision.mean()
    #     mean_recall = recall.mean()
        
    #     # Format results
    #     metrics = {
    #         "dice_mean": mean_dice,
    #         "iou_mean": mean_iou,
    #         "precision_mean": mean_precision,
    #         "recall_mean": mean_recall,
    #     }
        
    #     for i in range(num_classes):
    #         metrics[f"dice_class_{i}"] = dice_scores[i]
    #         metrics[f"iou_class_{i}"] = iou_scores[i]
    #         metrics[f"precision_class_{i}"] = precision[i]
    #         metrics[f"recall_class_{i}"] = recall[i]
        
    #     return metrics


# __Test__

In [28]:
import numpy as np
def test_segmentation_metrics():
    metric_calculator = SegmentationMetrics()
    
    # Test 2D case
    batch_size, num_classes, height, width = 4, 5, 256, 256
    y_pred_2d = torch.randn(batch_size, num_classes, height, width)  # Logits
    y_true_2d = torch.randint(0, num_classes, (batch_size, 1, height, width))  # Ground truth with channel dim
    
    print("Testing 2D segmentation metrics...")
    metrics_2d = metric_calculator.compute_metrics(y_pred_2d, y_true_2d)
    print(metrics_2d)
    
    # Test 3D case
    batch_size, num_classes, depth, height, width = 1, 5, 64, 128, 128
    y_pred_3d = torch.randn(batch_size, num_classes, depth, height, width)  # Logits
    y_true_3d = torch.randint(0, num_classes, (batch_size, 1, depth, height, width))  # Ground truth with channel dim
    
    print("\nTesting 3D segmentation metrics...")
    metrics_3d = metric_calculator.compute_metrics(y_pred_3d, y_true_3d)
    print(metrics_3d)

In [29]:
test_segmentation_metrics()

Testing 2D segmentation metrics...
y_true shape after unsqueeze: torch.Size([4, 1, 256, 256])


AssertionError: labels should have a channel with length equal to one.

# __Class 2__

In [40]:
import torch

class SegmentationMetrics:
    """
    Class for computing segmentation metrics that works for both 2D and 3D data.
    """
    
    @staticmethod
    def _prepare_tensors(y_pred, y_true):
        """
        Helper method to prepare tensors for metric computation.
        Handles both 2D and 3D cases, and converts to one-hot if needed.
        """
        # Determine if we're working with 2D or 3D data
        is_3d = y_pred.dim() == 4 and (y_true.dim() == 4 or y_true.dim() == 5)  # Adjusted condition
        
        # Convert to one-hot if inputs are class indices
        if y_pred.dim() == (4 if is_3d else 3):  # [B, H, W(, D)] class indices
            n_classes = torch.max(y_true).item() + 1
            shape = (y_pred.size(0), n_classes) + y_pred.size()[1:]
            y_pred_one_hot = torch.zeros(shape, device=y_pred.device)
            y_pred_one_hot.scatter_(1, y_pred.unsqueeze(1).long(), 1)  # Added .long()
            
            y_true_one_hot = torch.zeros(shape, device=y_true.device)
            y_true_one_hot.scatter_(1, y_true.unsqueeze(1).long(), 1)  # Added .long()
        else:
            # If already in form [B, C, H, W(, D)] (logits or one-hot)
            if y_pred.dim() == (5 if is_3d else 4) and y_true.dim() == (4 if is_3d else 3):
                # y_pred is logits and y_true is class indices
                n_classes = y_pred.size(1)
                y_pred_one_hot = torch.nn.functional.softmax(y_pred, dim=1)
                
                shape = (y_true.size(0), n_classes) + y_true.size()[1:]
                y_true_one_hot = torch.zeros(shape, device=y_true.device)
                y_true_one_hot.scatter_(1, y_true.unsqueeze(1).long(), 1)  # Added .long()
            else:
                # Assume both are already in proper format
                y_pred_one_hot = y_pred
                y_true_one_hot = y_true
                n_classes = y_pred.size(1)
        
        return y_pred_one_hot, y_true_one_hot, n_classes
    
    @staticmethod
    def dice_coefficient(y_pred, y_true, smooth=1e-6):
        """
        Compute Dice coefficient for 2D or 3D data.
        """
        y_pred_one_hot, y_true_one_hot, n_classes = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        
        # Calculate dice for each class
        dice_scores = []
        class_dice = {}
        
        for i in range(n_classes):
            pred_class = y_pred_one_hot[:, i, ...]
            true_class = y_true_one_hot[:, i, ...]
            
            intersection = torch.sum(pred_class * true_class)
            union = torch.sum(pred_class) + torch.sum(true_class)
            dice = (2.0 * intersection + smooth) / (union + smooth)
            dice_scores.append(dice)
            class_dice[f"dice_class_{i}"] = dice.item()
        
        mean_dice = torch.mean(torch.stack(dice_scores))
        class_dice["dice_mean"] = mean_dice.item()
        
        return mean_dice, class_dice
    
    @staticmethod
    def iou_score(y_pred, y_true, smooth=1e-6):
        """
        Compute IoU (Jaccard Index) for 2D or 3D data.
        """
        y_pred_one_hot, y_true_one_hot, n_classes = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        
        # Calculate IoU for each class
        iou_scores = []
        class_iou = {}
        
        for i in range(n_classes):
            pred_class = y_pred_one_hot[:, i, ...]
            true_class = y_true_one_hot[:, i, ...]
            
            intersection = torch.sum(pred_class * true_class)
            union = torch.sum(pred_class) + torch.sum(true_class) - intersection
            iou = (intersection + smooth) / (union + smooth)
            iou_scores.append(iou)
            class_iou[f"iou_class_{i}"] = iou.item()
        
        mean_iou = torch.mean(torch.stack(iou_scores))
        class_iou["iou_mean"] = mean_iou.item()
        
        return mean_iou, class_iou
    
    @staticmethod
    def precision_recall(y_pred, y_true, smooth=1e-6):
        """
        Compute precision and recall for 2D or 3D data.
        """
        y_pred_one_hot, y_true_one_hot, n_classes = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        
        # Calculate precision and recall for each class
        precision_scores = []
        recall_scores = []
        class_precision = {}
        class_recall = {}
        
        for i in range(n_classes):
            pred_class = y_pred_one_hot[:, i, ...]
            true_class = y_true_one_hot[:, i, ...]
            
            true_positives = torch.sum(pred_class * true_class)
            predicted_positives = torch.sum(pred_class)
            actual_positives = torch.sum(true_class)
            
            precision = (true_positives + smooth) / (predicted_positives + smooth)
            recall = (true_positives + smooth) / (actual_positives + smooth)
            
            precision_scores.append(precision)
            recall_scores.append(recall)
            
            class_precision[f"precision_class_{i}"] = precision.item()
            class_recall[f"recall_class_{i}"] = recall.item()
        
        mean_precision = torch.mean(torch.stack(precision_scores))
        mean_recall = torch.mean(torch.stack(recall_scores))
        
        class_precision["precision_mean"] = mean_precision.item()
        class_recall["recall_mean"] = mean_recall.item()
        
        return mean_precision, mean_recall, class_precision, class_recall
    
    @staticmethod
    def all_metrics(y_pred, y_true):
        """
        Compute all metrics for 2D or 3D data.
        """
        metrics = {}
        
        # Convert logits to class indices if necessary
        if y_pred.dim() in [4, 5]:  # Could be [B, C, H, W] or [B, C, H, W, D]
            y_pred_indices = torch.argmax(y_pred, dim=1)
        else:
            y_pred_indices = y_pred
            
        # Calculate all metrics
        mean_dice, class_dice = SegmentationMetrics.dice_coefficient(y_pred, y_true)
        metrics.update(class_dice)
        
        mean_iou, class_iou = SegmentationMetrics.iou_score(y_pred, y_true)
        metrics.update(class_iou)
        
        mean_precision, mean_recall, class_precision, class_recall = SegmentationMetrics.precision_recall(y_pred, y_true)
        metrics.update(class_precision)
        metrics.update(class_recall)
        
        # Add overall metrics
        metrics['dice'] = mean_dice.item()
        metrics['iou'] = mean_iou.item()
        metrics['precision'] = mean_precision.item()
        metrics['recall'] = mean_recall.item()
        
        return metrics

# __Test 2__

In [54]:
import torch
import numpy as np
from torch import Tensor

def test_metrics():
    # Configuración inicial
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nTesting on: {device}")
    
    # 1. Test 2D case (compatibility check)
    print("\n=== Testing 2D case ===")
    batch_size, classes, h, w = 2, 3, 128, 128
    
    # Create random predictions and ground truth
    y_true_2d = torch.randint(0, classes, (batch_size, h, w)).to(device)
    y_pred_logits_2d = torch.randn(batch_size, classes, h, w).to(device)
    y_pred_indices_2d = torch.argmax(y_pred_logits_2d, dim=1)
    
    # Test with logits
    print("\nTesting with logits (2D):")
    metrics_logits = SegmentationMetrics.all_metrics(y_pred_logits_2d, y_true_2d)
    for k, v in metrics_logits.items():
        print(f"{k}: {v:.4f}")
    
    # Test with class indices
    print("\nTesting with class indices (2D):")
    metrics_indices = SegmentationMetrics.all_metrics(y_pred_indices_2d, y_true_2d)
    for k, v in metrics_indices.items():
        print(f"{k}: {v:.4f}")
    
    # 2. Test 3D case
    print("\n=== Testing 3D case ===")
    batch_size, classes, h, w, d = 2, 3, 64, 64, 32
    
    # Create random predictions and ground truth
    y_true_3d = torch.randint(0, classes, (batch_size, h, w, d)).to(device)
    y_pred_logits_3d = torch.randn(batch_size, classes, h, w, d).to(device)
    y_pred_indices_3d = torch.argmax(y_pred_logits_3d, dim=1)
    
    # Test with logits
    print("\nTesting with logits (3D):")
    metrics_logits_3d = SegmentationMetrics.all_metrics(y_pred_logits_3d, y_true_3d)
    for k, v in metrics_logits_3d.items():
        print(f"{k}: {v:.4f}")
    
    # Test with class indices
    print("\nTesting with class indices (3D):")
    metrics_indices_3d = SegmentationMetrics.all_metrics(y_pred_indices_3d, y_true_3d)
    for k, v in metrics_indices_3d.items():
        print(f"{k}: {v:.4f}")
    
    # 3. Test edge cases
    print("\n=== Testing edge cases ===")
    
    # Perfect prediction test
    print("\nPerfect prediction test (3D):")
    perfect_pred = y_true_3d.clone()
    perfect_metrics = SegmentationMetrics.all_metrics(perfect_pred, y_true_3d)
    print(f"Dice: {perfect_metrics['dice']:.4f} (should be 1.0)")
    print(f"IoU: {perfect_metrics['iou']:.4f} (should be 1.0)")
    
    # Worst prediction test (no overlap)
    print("\nWorst prediction test (3D):")
    worst_pred = (y_true_3d + 1) % classes  # Guarantees no overlap
    worst_metrics = SegmentationMetrics.all_metrics(worst_pred, y_true_3d)
    print(f"Dice: {worst_metrics['dice']:.4f} (should be ~0.0)")
    print(f"IoU: {worst_metrics['iou']:.4f} (should be ~0.0)")
    
    # Empty prediction test
    print("\nEmpty prediction test (3D):")
    empty_pred = torch.zeros_like(y_true_3d)
    empty_metrics = SegmentationMetrics.all_metrics(empty_pred, y_true_3d)
    print(f"Dice: {empty_metrics['dice_class_0']:.4f} (background)")
    print(f"IoU: {empty_metrics['iou_class_0']:.4f} (background)")

def test_metric_consistency():
    """Test consistency between 2D and 3D implementations by using 3D with depth=1"""
    print("\n=== Testing 2D-3D consistency ===")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size, classes, h, w = 2, 3, 128, 128
    d = 1  # Treat as 2D
    
    # Create identical test cases in 2D and "3D with depth=1"
    y_true_2d = torch.randint(0, classes, (batch_size, h, w)).to(device)
    y_pred_2d = torch.randn(batch_size, classes, h, w).to(device)
    
    y_true_3d = y_true_2d.unsqueeze(-1)  # [B, H, W, 1]
    y_pred_3d = y_pred_2d.unsqueeze(-1)  # [B, C, H, W, 1]
    
    # Compute metrics
    metrics_2d = SegmentationMetrics.all_metrics(y_pred_2d, y_true_2d)
    metrics_3d = SegmentationMetrics.all_metrics(y_pred_3d, y_true_3d)
    
    # Compare results
    print("\nComparing Dice scores:")
    for i in range(classes):
        key = f"dice_class_{i}"
        print(f"Class {i}: 2D={metrics_2d[key]:.6f} | 3D={metrics_3d[key]:.6f} | Diff={abs(metrics_2d[key]-metrics_3d[key]):.2e}")
    
    print("\nComparing IoU scores:")
    for i in range(classes):
        key = f"iou_class_{i}"
        print(f"Class {i}: 2D={metrics_2d[key]:.6f} | 3D={metrics_3d[key]:.6f} | Diff={abs(metrics_2d[key]-metrics_3d[key]):.2e}")

In [55]:
test_metrics()
test_metric_consistency()


Testing on: cpu

=== Testing 2D case ===

Testing with logits (2D):
dice_class_0: 0.3313
dice_class_1: 0.3347
dice_class_2: 0.3330
dice_mean: 0.3330
iou_class_0: 0.1985
iou_class_1: 0.2010
iou_class_2: 0.1998
iou_mean: 0.1998
precision_class_0: 0.3294
precision_class_1: 0.3355
precision_class_2: 0.3341
precision_mean: 0.3330
recall_class_0: 0.3331
recall_class_1: 0.3339
recall_class_2: 0.3320
recall_mean: 0.3330
dice: 0.3330
iou: 0.1998
precision: 0.3330
recall: 0.3330

Testing with class indices (2D):
dice_class_0: 0.3366
dice_class_1: 0.3397
dice_class_2: 0.3369
dice_mean: 0.3378
iou_class_0: 0.2024
iou_class_1: 0.2046
iou_class_2: 0.2026
iou_mean: 0.2032
precision_class_0: 0.3339
precision_class_1: 0.3405
precision_class_2: 0.3390
precision_mean: 0.3378
recall_class_0: 0.3393
recall_class_1: 0.3390
recall_class_2: 0.3349
recall_mean: 0.3378
dice: 0.3378
iou: 0.2032
precision: 0.3378
recall: 0.3378

=== Testing 3D case ===

Testing with logits (3D):
dice_class_0: 0.3337
dice_class_1

# another

In [43]:
import torch

class SegmentationMetrics:
    """
    Class for computing segmentation metrics that works for both 2D and 3D data.
    """
    
    @staticmethod
    def _prepare_tensors(y_pred, y_true):
        """
        Helper method to prepare tensors for metric computation.
        Handles both 2D and 3D cases, and converts to one-hot if needed.
        """
        # Determine if we're working with 2D or 3D data
        is_3d = y_pred.dim() == 5 or (y_pred.dim() == 4 and y_true.dim() == 4)  # Adjusted condition
        
        # Convert to one-hot if inputs are class indices
        if y_pred.dim() == (4 if is_3d else 3):  # [B, H, W(, D)] class indices
            n_classes = torch.max(y_true).item() + 1
            shape = (y_pred.size(0), n_classes) + y_pred.size()[1:]
            y_pred_one_hot = torch.zeros(shape, device=y_pred.device)
            y_pred_one_hot.scatter_(1, y_pred.unsqueeze(1).long(), 1)
            
            y_true_one_hot = torch.zeros(shape, device=y_true.device)
            y_true_one_hot.scatter_(1, y_true.unsqueeze(1).long(), 1)
        else:
            # If already in form [B, C, H, W(, D)] (logits or one-hot)
            if y_pred.dim() == (5 if is_3d else 4) and y_true.dim() == (4 if is_3d else 3):
                # y_pred is logits and y_true is class indices
                n_classes = y_pred.size(1)
                y_pred_one_hot = torch.nn.functional.softmax(y_pred, dim=1)
                
                shape = (y_true.size(0), n_classes) + y_true.size()[1:]
                y_true_one_hot = torch.zeros(shape, device=y_true.device)
                y_true_one_hot.scatter_(1, y_true.unsqueeze(1).long(), 1)
            else:
                # Assume both are already in proper format
                y_pred_one_hot = y_pred
                y_true_one_hot = y_true
                n_classes = y_pred.size(1)
        
        return y_pred_one_hot, y_true_one_hot, n_classes
    
    @staticmethod
    def dice_coefficient(y_pred, y_true, smooth=1e-6):
        """
        Compute Dice coefficient for 2D or 3D data.
        """
        y_pred_one_hot, y_true_one_hot, n_classes = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        
        # Calculate dice for each class
        dice_scores = []
        class_dice = {}
        
        for i in range(n_classes):
            pred_class = y_pred_one_hot[:, i, ...]
            true_class = y_true_one_hot[:, i, ...]
            
            intersection = torch.sum(pred_class * true_class)
            union = torch.sum(pred_class) + torch.sum(true_class)
            dice = (2.0 * intersection + smooth) / (union + smooth)
            dice_scores.append(dice)
            class_dice[f"dice_class_{i}"] = dice.item()
        
        mean_dice = torch.mean(torch.stack(dice_scores))
        class_dice["dice_mean"] = mean_dice.item()
        
        return mean_dice, class_dice
    
    @staticmethod
    def iou_score(y_pred, y_true, smooth=1e-6):
        """
        Compute IoU (Jaccard Index) for 2D or 3D data.
        """
        y_pred_one_hot, y_true_one_hot, n_classes = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        
        # Calculate IoU for each class
        iou_scores = []
        class_iou = {}
        
        for i in range(n_classes):
            pred_class = y_pred_one_hot[:, i, ...]
            true_class = y_true_one_hot[:, i, ...]
            
            intersection = torch.sum(pred_class * true_class)
            union = torch.sum(pred_class) + torch.sum(true_class) - intersection
            iou = (intersection + smooth) / (union + smooth)
            iou_scores.append(iou)
            class_iou[f"iou_class_{i}"] = iou.item()
        
        mean_iou = torch.mean(torch.stack(iou_scores))
        class_iou["iou_mean"] = mean_iou.item()
        
        return mean_iou, class_iou
    
    @staticmethod
    def precision_recall(y_pred, y_true, smooth=1e-6):
        """
        Compute precision and recall for 2D or 3D data.
        """
        y_pred_one_hot, y_true_one_hot, n_classes = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        
        # Calculate precision and recall for each class
        precision_scores = []
        recall_scores = []
        class_precision = {}
        class_recall = {}
        
        for i in range(n_classes):
            pred_class = y_pred_one_hot[:, i, ...]
            true_class = y_true_one_hot[:, i, ...]
            
            true_positives = torch.sum(pred_class * true_class)
            predicted_positives = torch.sum(pred_class)
            actual_positives = torch.sum(true_class)
            
            precision = (true_positives + smooth) / (predicted_positives + smooth)
            recall = (true_positives + smooth) / (actual_positives + smooth)
            
            precision_scores.append(precision)
            recall_scores.append(recall)
            
            class_precision[f"precision_class_{i}"] = precision.item()
            class_recall[f"recall_class_{i}"] = recall.item()
        
        mean_precision = torch.mean(torch.stack(precision_scores))
        mean_recall = torch.mean(torch.stack(recall_scores))
        
        class_precision["precision_mean"] = mean_precision.item()
        class_recall["recall_mean"] = mean_recall.item()
        
        return mean_precision, mean_recall, class_precision, class_recall
    
    @staticmethod
    def all_metrics(y_pred, y_true):
        """
        Compute all metrics for 2D or 3D data.
        """
        metrics = {}
        
        # Convert logits to class indices if necessary
        if y_pred.dim() in [4, 5]:  # Could be [B, C, H, W] or [B, C, H, W, D]
            y_pred_indices = torch.argmax(y_pred, dim=1)
        else:
            y_pred_indices = y_pred
            
        # Calculate all metrics
        mean_dice, class_dice = SegmentationMetrics.dice_coefficient(y_pred, y_true)
        metrics.update(class_dice)
        
        mean_iou, class_iou = SegmentationMetrics.iou_score(y_pred, y_true)
        metrics.update(class_iou)
        
        mean_precision, mean_recall, class_precision, class_recall = SegmentationMetrics.precision_recall(y_pred, y_true)
        metrics.update(class_precision)
        metrics.update(class_recall)
        
        # Add overall metrics
        metrics['dice'] = mean_dice.item()
        metrics['iou'] = mean_iou.item()
        metrics['precision'] = mean_precision.item()
        metrics['recall'] = mean_recall.item()
        
        return metrics

In [44]:
def test_metrics():
    # Configuración inicial
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nTesting on: {device}")
    
    # 1. Test 2D case (compatibility check)
    print("\n=== Testing 2D case ===")
    batch_size, classes, h, w = 2, 3, 128, 128
    
    # Create random predictions and ground truth
    y_true_2d = torch.randint(0, classes, (batch_size, h, w)).to(device)
    y_pred_logits_2d = torch.randn(batch_size, classes, h, w).to(device)
    y_pred_indices_2d = torch.argmax(y_pred_logits_2d, dim=1)
    
    # Test with logits
    print("\nTesting with logits (2D):")
    metrics_logits = SegmentationMetrics.all_metrics(y_pred_logits_2d, y_true_2d)
    for k, v in sorted(metrics_logits.items()):
        print(f"{k}: {v:.4f}")
    
    # Test with class indices
    print("\nTesting with class indices (2D):")
    metrics_indices = SegmentationMetrics.all_metrics(y_pred_indices_2d, y_true_2d)
    for k, v in sorted(metrics_indices.items()):
        print(f"{k}: {v:.4f}")
    
    # 2. Test 3D case
    print("\n=== Testing 3D case ===")
    batch_size, classes, h, w, d = 2, 3, 64, 64, 32
    
    # Create random predictions and ground truth
    y_true_3d = torch.randint(0, classes, (batch_size, h, w, d)).to(device)
    y_pred_logits_3d = torch.randn(batch_size, classes, h, w, d).to(device)
    y_pred_indices_3d = torch.argmax(y_pred_logits_3d, dim=1)
    
    # Test with logits
    print("\nTesting with logits (3D):")
    metrics_logits_3d = SegmentationMetrics.all_metrics(y_pred_logits_3d, y_true_3d)
    for k, v in sorted(metrics_logits_3d.items()):
        print(f"{k}: {v:.4f}")
    
    # Test with class indices
    print("\nTesting with class indices (3D):")
    metrics_indices_3d = SegmentationMetrics.all_metrics(y_pred_indices_3d, y_true_3d)
    for k, v in sorted(metrics_indices_3d.items()):
        print(f"{k}: {v:.4f}")

test_metrics()
print("\nAll tests completed successfully!")


Testing on: cpu

=== Testing 2D case ===

Testing with logits (2D):
dice: 0.3313
dice_class_0: 0.3320
dice_class_1: 0.3326
dice_class_2: 0.3294
dice_mean: 0.3313
iou: 0.1986
iou_class_0: 0.1991
iou_class_1: 0.1995
iou_class_2: 0.1972
iou_mean: 0.1986
precision: 0.3313
precision_class_0: 0.3310
precision_class_1: 0.3334
precision_class_2: 0.3296
precision_mean: 0.3313
recall: 0.3314
recall_class_0: 0.3330
recall_class_1: 0.3319
recall_class_2: 0.3291
recall_mean: 0.3314

Testing with class indices (2D):
dice: 0.3267
dice_class_0: 0.3298
dice_class_1: 0.3285
dice_class_2: 0.3220
dice_mean: 0.3267
iou: 0.1953
iou_class_0: 0.1974
iou_class_1: 0.1965
iou_class_2: 0.1919
iou_mean: 0.1953
precision: 0.3267
precision_class_0: 0.3285
precision_class_1: 0.3291
precision_class_2: 0.3227
precision_mean: 0.3267
recall: 0.3268
recall_class_0: 0.3311
recall_class_1: 0.3279
recall_class_2: 0.3213
recall_mean: 0.3268

=== Testing 3D case ===

Testing with logits (3D):
dice: 0.3335
dice_class_0: 0.3326

# __VS MONAI__

In [48]:
import torch
from monai.metrics import DiceMetric, MeanIoU
from monai.data import decollate_batch
from monai.transforms import AsDiscrete, EnsureType

class MonaiTester:
    def __init__(self, num_classes):
        self.num_classes = num_classes
        # MONAI metrics
        self.dice_metric = DiceMetric(include_background=True, reduction="mean")
        self.iou_metric = MeanIoU(include_background=True, reduction="mean")
        # Post-processing transforms
        self.post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)
        self.post_label = AsDiscrete(to_onehot=num_classes)
        self.ensure_type = EnsureType()

    def compute_monai_metrics(self, y_pred, y_true):
        # Add channel dimension to labels if needed
        if y_true.dim() == y_pred.dim() - 1:  # [B, H, W] vs [B, C, H, W]
            y_true = y_true.unsqueeze(1)  # Add channel dim
            
        # Convert to MONAI expected format
        y_pred_ = [self.post_pred(self.ensure_type(i)) for i in decollate_batch(y_pred)]
        y_true_ = [self.post_label(self.ensure_type(i)) for i in decollate_batch(y_true)]
        
        # Compute metrics
        self.dice_metric(y_pred=y_pred_, y=y_true_)
        self.iou_metric(y_pred=y_pred_, y=y_true_)
        
        dice = self.dice_metric.aggregate().item()
        iou = self.iou_metric.aggregate().item()
        
        # Reset for next round
        self.dice_metric.reset()
        self.iou_metric.reset()
        
        return dice, iou

def compare_with_monai():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nComparison Testing on: {device}")
    num_classes = 3
    
    # Initialize testers
    monai_tester = MonaiTester(num_classes=num_classes)
    
    # Test cases
    test_cases = [
        ("2D Random", (2, 128, 128)),
        ("3D Random", (2, 64, 64, 32)),
        ("2D Perfect", (2, 128, 128)),
        ("3D Perfect", (2, 64, 64, 32)),
        ("2D Worst", (2, 128, 128)),
        ("3D Worst", (2, 64, 64, 32))
    ]
    
    for name, shape in test_cases:
        print(f"\n=== {name} Case ===")
        
        # Generate test data
        if "Perfect" in name:
            y_true = torch.randint(0, num_classes, shape).to(device)
            y_pred = torch.nn.functional.one_hot(y_true, num_classes).permute(0, -1, *range(1, y_true.dim())).float()
        elif "Worst" in name:
            y_true = torch.randint(0, num_classes, shape).to(device)
            y_pred = torch.nn.functional.one_hot((y_true + 1) % num_classes, num_classes).permute(0, -1, *range(1, y_true.dim())).float()
        else:
            y_true = torch.randint(0, num_classes, shape).to(device)
            y_pred = torch.randn((shape[0], num_classes) + shape[1:]).to(device)
        
        # Compute metrics with both implementations
        # MONAI metrics
        monai_dice, monai_iou = monai_tester.compute_monai_metrics(y_pred, y_true)
        
        # Your metrics
        your_metrics = SegmentationMetrics.all_metrics(y_pred, y_true)
        your_dice = your_metrics['dice']
        your_iou = your_metrics['iou']
        
        # Print comparison
        print(f"Dice Score:")
        print(f"  MONAI: {monai_dice:.6f}")
        print(f"  Yours: {your_dice:.6f}")
        print(f"  Difference: {abs(monai_dice - your_dice):.2e}")
        
        print(f"\nIoU (Jaccard):")
        print(f"  MONAI: {monai_iou:.6f}")
        print(f"  Yours: {your_iou:.6f}")
        print(f"  Difference: {abs(monai_iou - your_iou):.2e}")
        
        # Verify they're almost equal
        assert torch.allclose(torch.tensor(monai_dice), torch.tensor(your_dice), atol=1e-5), \
               f"Dice scores differ too much! {monai_dice} vs {your_dice}"
        assert torch.allclose(torch.tensor(monai_iou), torch.tensor(your_iou), atol=1e-5), \
               f"IoU scores differ too much! {monai_iou} vs {your_iou}"

# First run your original tests
# test_metrics()

# Then run the MONAI comparison
compare_with_monai()

print("\nAll tests passed! Your implementation matches MONAI's results.")


Comparison Testing on: cpu

=== 2D Random Case ===
Dice Score:
  MONAI: 0.333610
  Yours: 0.333237
  Difference: 3.73e-04

IoU (Jaccard):
  MONAI: 0.200209
  Yours: 0.199932
  Difference: 2.76e-04


AssertionError: Dice scores differ too much! 0.33361002802848816 vs 0.33323726058006287

In [49]:
def compare_with_monai():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nComparison Testing on: {device}")
    num_classes = 3
    
    # Initialize testers
    monai_tester = MonaiTester(num_classes=num_classes)
    
    # Test cases
    test_cases = [
        ("2D Random", (2, 128, 128)),
        ("3D Random", (2, 64, 64, 32)),
        ("2D Perfect", (2, 128, 128)),
        ("3D Perfect", (2, 64, 64, 32)),
        ("2D Worst", (2, 128, 128)),
        ("3D Worst", (2, 64, 64, 32))
    ]
    
    for name, shape in test_cases:
        print(f"\n=== {name} Case ===")
        
        # Generate test data
        if "Perfect" in name:
            y_true = torch.randint(0, num_classes, shape).to(device)
            y_pred = torch.nn.functional.one_hot(y_true, num_classes).permute(0, -1, *range(1, y_true.dim())).float()
        elif "Worst" in name:
            y_true = torch.randint(0, num_classes, shape).to(device)
            y_pred = torch.nn.functional.one_hot((y_true + 1) % num_classes, num_classes).permute(0, -1, *range(1, y_true.dim())).float()
        else:
            y_true = torch.randint(0, num_classes, shape).to(device)
            y_pred = torch.randn((shape[0], num_classes) + shape[1:]).to(device)
        
        # Compute metrics with both implementations
        # MONAI metrics
        monai_dice, monai_iou = monai_tester.compute_monai_metrics(y_pred, y_true)
        
        # Your metrics
        your_metrics = SegmentationMetrics.all_metrics(y_pred, y_true)
        your_dice = your_metrics['dice']
        your_iou = your_metrics['iou']
        
        # Print comparison
        print(f"Dice Score:")
        print(f"  MONAI: {monai_dice:.6f}")
        print(f"  Yours: {your_dice:.6f}")
        print(f"  Difference: {abs(monai_dice - your_dice):.2e}")
        
        print(f"\nIoU (Jaccard):")
        print(f"  MONAI: {monai_iou:.6f}")
        print(f"  Yours: {your_iou:.6f}")
        print(f"  Difference: {abs(monai_iou - your_iou):.2e}")
        
        # More tolerant verification (1e-3 instead of 1e-5)
        assert abs(monai_dice - your_dice) < 1e-3, \
               f"Dice scores differ too much! {monai_dice} vs {your_dice}"
        assert abs(monai_iou - your_iou) < 1e-3, \
               f"IoU scores differ too much! {monai_iou} vs {your_iou}"


compare_with_monai()
print("\nAll tests passed! Your implementation matches MONAI's results within tolerance.")


Comparison Testing on: cpu

=== 2D Random Case ===
Dice Score:
  MONAI: 0.333076
  Yours: 0.333989
  Difference: 9.13e-04

IoU (Jaccard):
  MONAI: 0.199828
  Yours: 0.200475
  Difference: 6.47e-04

=== 3D Random Case ===
Dice Score:
  MONAI: 0.331340
  Yours: 0.332777
  Difference: 1.44e-03

IoU (Jaccard):
  MONAI: 0.198568
  Yours: 0.199600
  Difference: 1.03e-03


AssertionError: Dice scores differ too much! 0.33133983612060547 vs 0.33277687430381775

In [50]:
def compare_with_monai():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nComparison Testing on: {device}")
    num_classes = 3
    
    # Initialize testers
    monai_tester = MonaiTester(num_classes=num_classes)
    
    # Test cases
    test_cases = [
        ("2D Random", (2, 128, 128)),
        ("3D Random", (2, 64, 64, 32)),
        ("2D Perfect", (2, 128, 128)),
        ("3D Perfect", (2, 64, 64, 32)),
        ("2D Worst", (2, 128, 128)),
        ("3D Worst", (2, 64, 64, 32))
    ]
    
    for name, shape in test_cases:
        print(f"\n=== {name} Case ===")
        
        # Generate test data
        if "Perfect" in name:
            y_true = torch.randint(0, num_classes, shape).to(device)
            y_pred = torch.nn.functional.one_hot(y_true, num_classes).permute(0, -1, *range(1, y_true.dim())).float()
        elif "Worst" in name:
            y_true = torch.randint(0, num_classes, shape).to(device)
            y_pred = torch.nn.functional.one_hot((y_true + 1) % num_classes, num_classes).permute(0, -1, *range(1, y_true.dim())).float()
        else:
            y_true = torch.randint(0, num_classes, shape).to(device)
            y_pred = torch.randn((shape[0], num_classes) + shape[1:]).to(device)
        
        # Compute metrics with both implementations
        # MONAI metrics
        monai_dice, monai_iou = monai_tester.compute_monai_metrics(y_pred, y_true)
        
        # Your metrics
        your_metrics = SegmentationMetrics.all_metrics(y_pred, y_true)
        your_dice = your_metrics['dice']
        your_iou = your_metrics['iou']
        
        # Print comparison
        print(f"Dice Score:")
        print(f"  MONAI: {monai_dice:.6f}")
        print(f"  Yours: {your_dice:.6f}")
        print(f"  Difference: {abs(monai_dice - your_dice):.2e}")
        
        print(f"\nIoU (Jaccard):")
        print(f"  MONAI: {monai_iou:.6f}")
        print(f"  Yours: {your_iou:.6f}")
        print(f"  Difference: {abs(monai_iou - your_iou):.2e}")
        
        # More tolerant verification (1e-3 instead of 1e-5)
        assert abs(monai_dice - your_dice) < 1e-3, \
               f"Dice scores differ too much! {monai_dice} vs {your_dice}"
        assert abs(monai_iou - your_iou) < 1e-3, \
               f"IoU scores differ too much! {monai_iou} vs {your_iou}"

compare_with_monai()
print("\nAll tests passed! Your implementation matches MONAI's results within tolerance.")


Comparison Testing on: cpu

=== 2D Random Case ===
Dice Score:
  MONAI: 0.331915
  Yours: 0.332830
  Difference: 9.15e-04

IoU (Jaccard):
  MONAI: 0.198990
  Yours: 0.199639
  Difference: 6.50e-04

=== 3D Random Case ===
Dice Score:
  MONAI: 0.331693
  Yours: 0.332525
  Difference: 8.32e-04

IoU (Jaccard):
  MONAI: 0.198821
  Yours: 0.199419
  Difference: 5.98e-04

=== 2D Perfect Case ===
Dice Score:
  MONAI: 1.000000
  Yours: 0.576109
  Difference: 4.24e-01

IoU (Jaccard):
  MONAI: 1.000000
  Yours: 0.404603
  Difference: 5.95e-01


AssertionError: Dice scores differ too much! 1.0 vs 0.5761087536811829

In [69]:
import torch

class SegmentationMetrics:
    """
    Class for computing segmentation metrics that works for both 2D and 3D data.
    """
    
    @staticmethod
    def _prepare_tensors(y_pred, y_true):
        """
        Helper method to prepare tensors for metric computation.
        Handles both 2D and 3D cases, and converts to one-hot if needed.
        """
        # Case 1: y_pred is class indices [B, H, W(, D)]
        if y_pred.dim() == y_true.dim():
            n_classes = torch.max(y_true).item() + 1
            y_pred_one_hot = torch.zeros((y_pred.size(0), n_classes) + y_pred.size()[1:], 
                                      device=y_pred.device)
            y_pred_one_hot.scatter_(1, y_pred.unsqueeze(1).long(), 1)
            
            y_true_one_hot = torch.zeros_like(y_pred_one_hot)
            y_true_one_hot.scatter_(1, y_true.unsqueeze(1).long(), 1)
        
        # Case 2: y_pred is logits [B, C, H, W(, D)]
        else:
            n_classes = y_pred.size(1)
            y_pred_one_hot = torch.softmax(y_pred, dim=1)
            
            y_true_one_hot = torch.zeros_like(y_pred_one_hot)
            y_true_one_hot.scatter_(1, y_true.unsqueeze(1).long(), 1)
            
        return y_pred_one_hot, y_true_one_hot, n_classes
    
    @staticmethod
    def dice_coefficient(y_pred, y_true, smooth=1e-8):
        """
        Compute Dice coefficient for 2D or 3D data.
        """
        y_pred_one_hot, y_true_one_hot, n_classes = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        
        # Calculate dice for each class
        dice_scores = []
        class_dice = {}
        
        for i in range(n_classes):
            pred_class = y_pred_one_hot[:, i, ...]
            true_class = y_true_one_hot[:, i, ...]
            
            intersection = torch.sum(pred_class * true_class)
            union = torch.sum(pred_class) + torch.sum(true_class)
            dice = (2.0 * intersection + smooth) / (union + smooth)
            dice_scores.append(dice)
            class_dice[f"dice_class_{i}"] = dice.item()
        
        mean_dice = torch.mean(torch.stack(dice_scores))
        class_dice["dice_mean"] = mean_dice.item()
        
        return mean_dice, class_dice
    
    @staticmethod
    def iou_score(y_pred, y_true, smooth=1e-8):
        """
        Compute IoU (Jaccard Index) for 2D or 3D data.
        """
        y_pred_one_hot, y_true_one_hot, n_classes = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        
        # Calculate IoU for each class
        iou_scores = []
        class_iou = {}
        
        for i in range(n_classes):
            pred_class = y_pred_one_hot[:, i, ...]
            true_class = y_true_one_hot[:, i, ...]
            
            intersection = torch.sum(pred_class * true_class)
            union = torch.sum(pred_class) + torch.sum(true_class) - intersection
            iou = (intersection + smooth) / (union + smooth)
            iou_scores.append(iou)
            class_iou[f"iou_class_{i}"] = iou.item()
        
        mean_iou = torch.mean(torch.stack(iou_scores))
        class_iou["iou_mean"] = mean_iou.item()
        
        return mean_iou, class_iou
    
    @staticmethod
    def precision_recall(y_pred, y_true, smooth=1e-8):
        """
        Compute precision and recall for 2D or 3D data.
        """
        y_pred_one_hot, y_true_one_hot, n_classes = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        
        # Calculate precision and recall for each class
        precision_scores = []
        recall_scores = []
        class_precision = {}
        class_recall = {}
        
        for i in range(n_classes):
            pred_class = y_pred_one_hot[:, i, ...]
            true_class = y_true_one_hot[:, i, ...]
            
            true_positives = torch.sum(pred_class * true_class)
            predicted_positives = torch.sum(pred_class)
            actual_positives = torch.sum(true_class)
            
            precision = (true_positives + smooth) / (predicted_positives + smooth)
            recall = (true_positives + smooth) / (actual_positives + smooth)
            
            precision_scores.append(precision)
            recall_scores.append(recall)
            
            class_precision[f"precision_class_{i}"] = precision.item()
            class_recall[f"recall_class_{i}"] = recall.item()
        
        mean_precision = torch.mean(torch.stack(precision_scores))
        mean_recall = torch.mean(torch.stack(recall_scores))
        
        class_precision["precision_mean"] = mean_precision.item()
        class_recall["recall_mean"] = mean_recall.item()
        
        return mean_precision, mean_recall, class_precision, class_recall
    
    @staticmethod
    def all_metrics(y_pred, y_true):
        """
        Compute all metrics for 2D or 3D data.
        """
        metrics = {}
        
        # Convert logits to class indices if necessary
        if y_pred.dim() in [4, 5]:  # Could be [B, C, H, W] or [B, C, H, W, D]
            y_pred_indices = torch.argmax(y_pred, dim=1)
        else:
            y_pred_indices = y_pred
            
        # Calculate all metrics
        mean_dice, class_dice = SegmentationMetrics.dice_coefficient(y_pred, y_true)
        metrics.update(class_dice)
        
        mean_iou, class_iou = SegmentationMetrics.iou_score(y_pred, y_true)
        metrics.update(class_iou)
        
        mean_precision, mean_recall, class_precision, class_recall = SegmentationMetrics.precision_recall(y_pred, y_true)
        metrics.update(class_precision)
        metrics.update(class_recall)
        
        # Add overall metrics
        metrics['dice'] = mean_dice.item()
        metrics['iou'] = mean_iou.item()
        metrics['precision'] = mean_precision.item()
        metrics['recall'] = mean_recall.item()
        
        return metrics

In [70]:
def compare_with_monai():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nComparison Testing on: {device}")
    num_classes = 3
    
    # Initialize testers
    monai_tester = MonaiTester(num_classes=num_classes)
    
    # Test cases
    test_cases = [
        # ("2D Random", (2, 128, 128)),
        ("3D Random", (2, 64, 64, 32)),
        # ("2D Perfect", (2, 128, 128)),
        ("3D Perfect", (2, 64, 64, 32)),
        # ("2D Worst", (2, 128, 128)),
        ("3D Worst", (2, 64, 64, 32))
    ]
    
    for name, shape in test_cases:
        print(f"\n=== {name} Case ===")
        
        # Generate test data
        if "Perfect" in name:
            y_true = torch.randint(0, num_classes, shape).to(device)
            y_pred = torch.nn.functional.one_hot(y_true, num_classes).permute(0, -1, *range(1, y_true.dim())).float()
        elif "Worst" in name:
            y_true = torch.randint(0, num_classes, shape).to(device)
            y_pred = torch.nn.functional.one_hot((y_true + 1) % num_classes, num_classes).permute(0, -1, *range(1, y_true.dim())).float()
        else:
            y_true = torch.randint(0, num_classes, shape).to(device)
            y_pred = torch.randn((shape[0], num_classes) + shape[1:]).to(device)
        
        # Compute metrics with both implementations
        # MONAI metrics
        monai_dice, monai_iou = monai_tester.compute_monai_metrics(y_pred, y_true)
        
        # Your metrics
        your_metrics = SegmentationMetrics.all_metrics(y_pred, y_true)
        your_dice = your_metrics['dice']
        your_iou = your_metrics['iou']
        
        # Print comparison
        print(f"Dice Score:")
        print(f"  MONAI: {monai_dice:.6f}")
        print(f"  Yours: {your_dice:.6f}")
        print(f"  Difference: {abs(monai_dice - your_dice):.2e}")
        
        print(f"\nIoU (Jaccard):")
        print(f"  MONAI: {monai_iou:.6f}")
        print(f"  Yours: {your_iou:.6f}")
        print(f"  Difference: {abs(monai_iou - your_iou):.2e}")
        
        # More tolerant verification (1e-3 instead of 1e-5)
        assert abs(monai_dice - your_dice) < 1e-3, \
               f"Dice scores differ too much! {monai_dice} vs {your_dice}"
        assert abs(monai_iou - your_iou) < 1e-3, \
               f"IoU scores differ too much! {monai_iou} vs {your_iou}"

compare_with_monai()
print("\nAll tests passed! Your implementation matches MONAI's results within tolerance.")


Comparison Testing on: cpu

=== 3D Random Case ===
Dice Score:
  MONAI: 0.332715
  Yours: 0.332840
  Difference: 1.25e-04

IoU (Jaccard):
  MONAI: 0.199557
  Yours: 0.199645
  Difference: 8.84e-05

=== 3D Perfect Case ===
Dice Score:
  MONAI: 1.000000
  Yours: 0.576117
  Difference: 4.24e-01

IoU (Jaccard):
  MONAI: 1.000000
  Yours: 0.404610
  Difference: 5.95e-01


AssertionError: Dice scores differ too much! 1.0 vs 0.5761168599128723

In [56]:
import torch
import numpy as np

class SegmentationMetrics:
    @staticmethod
    def _ensure_onehot(tensor, num_classes):
        """Convert tensor to one-hot format consistently with MONAI"""
        if tensor.dim() == 4:  # [B, H, W, D] -> [B, C, H, W, D]
            tensor = tensor.unsqueeze(1)
        if tensor.shape[1] != num_classes:
            tensor = torch.nn.functional.one_hot(tensor.long(), num_classes)
            tensor = tensor.permute(0, -1, *range(1, tensor.dim()-1)).float()
        return tensor

    @staticmethod
    def _prepare_tensors(y_pred, y_true):
        """Unified tensor preparation matching MONAI's expectations"""
        # Get number of classes from y_true if not provided
        num_classes = y_pred.shape[1] if y_pred.dim() > y_true.dim() else torch.max(y_true).item() + 1
        
        # Convert both tensors to one-hot format
        y_pred_oh = SegmentationMetrics._ensure_onehot(y_pred, num_classes)
        y_true_oh = SegmentationMetrics._ensure_onehot(y_true, num_classes)
        
        return y_pred_oh, y_true_oh, num_classes

    @staticmethod
    def dice_coefficient(y_pred, y_true, smooth=1e-5):  # Adjusted smoothing to match MONAI
        y_pred_oh, y_true_oh, _ = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        
        intersection = torch.sum(y_pred_oh * y_true_oh, dim=(2,3,4))
        union = torch.sum(y_pred_oh, dim=(2,3,4)) + torch.sum(y_true_oh, dim=(2,3,4))
        
        dice = (2. * intersection + smooth) / (union + smooth)
        return torch.mean(dice), {f"dice_class_{i}": dice[:,i].mean().item() 
                                for i in range(dice.shape[1])}

    @staticmethod
    def iou_score(y_pred, y_true, smooth=1e-5):
        y_pred_oh, y_true_oh, _ = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        
        intersection = torch.sum(y_pred_oh * y_true_oh, dim=(2,3,4))
        union = torch.sum(y_pred_oh, dim=(2,3,4)) + torch.sum(y_true_oh, dim=(2,3,4)) - intersection
        
        iou = (intersection + smooth) / (union + smooth)
        return torch.mean(iou), {f"iou_class_{i}": iou[:,i].mean().item() 
                               for i in range(iou.shape[1])}

    @staticmethod
    def all_metrics(y_pred, y_true):
        metrics = {}
        mean_dice, dice_classes = SegmentationMetrics.dice_coefficient(y_pred, y_true)
        mean_iou, iou_classes = SegmentationMetrics.iou_score(y_pred, y_true)
        
        metrics.update(dice_classes)
        metrics.update(iou_classes)
        metrics.update({
            'dice': mean_dice.item(),
            'iou': mean_iou.item()
        })
        return metrics

In [60]:
def compare_with_monai():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nComparison Testing on: {device}")
    num_classes = 3
    
    monai_tester = MonaiTester(num_classes=num_classes)
    
    # Test cases - now using proper one-hot format
    test_cases = [
        ("2D Random", (2, num_classes, 128, 128)),
        ("3D Random", (2, num_classes, 64, 64, 32)),
        ("2D Perfect", (2, num_classes, 128, 128)),
        ("3D Perfect", (2, num_classes, 64, 64, 32))
    ]
    
    for name, shape in test_cases:
        print(f"\n=== {name} Case ===")
        
        # Generate proper one-hot test data
        if "Perfect" in name:
            y_true = torch.zeros(shape, device=device)
            y_true[:, 0] = 1  # All background
            y_pred = y_true.clone()
        else:
            y_true = torch.rand(shape, device=device)
            y_true = y_true / y_true.sum(dim=1, keepdim=True)  # Simulate softmax
            y_pred = torch.rand_like(y_true)
        
        # MONAI metrics
        monai_dice, monai_iou = monai_tester.compute_monai_metrics(y_pred, torch.argmax(y_true, dim=1))
        
        # Your metrics
        your_metrics = SegmentationMetrics.all_metrics(y_pred, torch.argmax(y_true, dim=1))
        
        # Print comparisons
        print(f"Dice Score:")
        print(f"  MONAI: {monai_dice:.6f}")
        print(f"  Yours: {your_metrics['dice']:.6f}")
        print(f"  Difference: {abs(monai_dice - your_metrics['dice']):.2e}")
        
        print(f"\nIoU (Jaccard):")
        print(f"  MONAI: {monai_iou:.6f}")
        print(f"  Yours: {your_metrics['iou']:.6f}")
        print(f"  Difference: {abs(monai_iou - your_metrics['iou']):.2e}")
        
        # Verify perfect cases match exactly
        if "Perfect" in name:
            assert abs(monai_dice - 1.0) < 1e-6, "MONAI perfect prediction failed"
            assert abs(your_metrics['dice'] - 1.0) < 1e-6, "Your perfect prediction failed"
        else:
            assert abs(monai_dice - your_metrics['dice']) < 1e-3, "Dice mismatch"
            assert abs(monai_iou - your_metrics['iou']) < 1e-3, "IoU mismatch"


compare_with_monai()
print("\nAll tests passed! Implementation matches MONAI.")


Comparison Testing on: cpu

=== 2D Random Case ===
Dice Score:
  MONAI: 0.333918
  Yours: 0.400248
  Difference: 6.63e-02

IoU (Jaccard):
  MONAI: 0.200429
  Yours: 0.250199
  Difference: 4.98e-02


AssertionError: Dice mismatch

In [59]:
class SegmentationMetrics:
    @staticmethod
    def _ensure_onehot(tensor, num_classes):
        """Convert tensor to one-hot format consistently with MONAI"""
        if tensor.dim() == 4 and tensor.shape[1] != num_classes:  # [B, H, W, D] case
            tensor = tensor.unsqueeze(1)
        if tensor.shape[1] != num_classes:
            tensor = torch.nn.functional.one_hot(tensor.long(), num_classes)
            tensor = tensor.permute(0, -1, *range(1, tensor.dim()-1)).float()
        return tensor

    @staticmethod
    def _prepare_tensors(y_pred, y_true):
        """Unified tensor preparation matching MONAI's expectations"""
        # Get number of classes from y_true if not provided
        num_classes = y_pred.shape[1] if y_pred.dim() > y_true.dim() else torch.max(y_true).item() + 1
        
        # Convert both tensors to one-hot format
        y_pred_oh = SegmentationMetrics._ensure_onehot(y_pred, num_classes)
        y_true_oh = SegmentationMetrics._ensure_onehot(y_true, num_classes)
        
        return y_pred_oh, y_true_oh, num_classes

    @staticmethod
    def _calculate_spatial_dims(tensor):
        """Determine spatial dimensions based on tensor shape"""
        if tensor.dim() == 4:  # 2D case [B, C, H, W]
            return (2, 3)
        elif tensor.dim() == 5:  # 3D case [B, C, H, W, D]
            return (2, 3, 4)
        else:
            raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}")

    @staticmethod
    def dice_coefficient(y_pred, y_true, smooth=1e-5):
        y_pred_oh, y_true_oh, _ = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        spatial_dims = SegmentationMetrics._calculate_spatial_dims(y_pred_oh)
        
        intersection = torch.sum(y_pred_oh * y_true_oh, dim=spatial_dims)
        union = torch.sum(y_pred_oh, dim=spatial_dims) + torch.sum(y_true_oh, dim=spatial_dims)
        
        dice = (2. * intersection + smooth) / (union + smooth)
        return torch.mean(dice), {f"dice_class_{i}": dice[:,i].mean().item() 
                                for i in range(dice.shape[1])}

    @staticmethod
    def iou_score(y_pred, y_true, smooth=1e-5):
        y_pred_oh, y_true_oh, _ = SegmentationMetrics._prepare_tensors(y_pred, y_true)
        spatial_dims = SegmentationMetrics._calculate_spatial_dims(y_pred_oh)
        
        intersection = torch.sum(y_pred_oh * y_true_oh, dim=spatial_dims)
        union = torch.sum(y_pred_oh, dim=spatial_dims) + torch.sum(y_true_oh, dim=spatial_dims) - intersection
        
        iou = (intersection + smooth) / (union + smooth)
        return torch.mean(iou), {f"iou_class_{i}": iou[:,i].mean().item() 
                               for i in range(iou.shape[1])}

    @staticmethod
    def all_metrics(y_pred, y_true):
        metrics = {}
        mean_dice, dice_classes = SegmentationMetrics.dice_coefficient(y_pred, y_true)
        mean_iou, iou_classes = SegmentationMetrics.iou_score(y_pred, y_true)
        
        metrics.update(dice_classes)
        metrics.update(iou_classes)
        metrics.update({
            'dice': mean_dice.item(),
            'iou': mean_iou.item()
        })
        return metrics

In [61]:
from monai.metrics import DiceMetric, MeanIoU
from monai.data import decollate_batch
from monai.transforms import AsDiscrete, EnsureType

class SegmentationMetrics:
    def __init__(self, num_classes, device="cuda" if torch.cuda.is_available() else "cpu"):
        self.num_classes = num_classes
        self.device = device
        
        # MONAI metrics setup
        self.dice_metric = DiceMetric(include_background=True, reduction="mean")
        self.iou_metric = MeanIoU(include_background=True, reduction="mean")
        
        # Post-processing transforms
        self.post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)
        self.post_label = AsDiscrete(to_onehot=num_classes)
        self.ensure_type = EnsureType(device=device)

    def _prepare_monai_inputs(self, y_pred, y_true):
        """Convert inputs to MONAI expected format"""
        y_pred = self.ensure_type(y_pred)
        y_true = self.ensure_type(y_true)
        
        # Add channel dim if needed (for class indices)
        if y_true.ndim == y_pred.ndim - 1:
            y_true = y_true.unsqueeze(1)
            
        return [self.post_pred(i) for i in decollate_batch(y_pred)], \
               [self.post_label(i) for i in decollate_batch(y_true)]

    def compute_metrics(self, y_pred, y_true):
        """Main method that matches your original API"""
        y_pred_pp, y_true_pp = self._prepare_monai_inputs(y_pred, y_true)
        
        # Compute metrics
        self.dice_metric(y_pred=y_pred_pp, y=y_true_pp)
        self.iou_metric(y_pred=y_pred_pp, y=y_true_pp)
        
        dice = self.dice_metric.aggregate().item()
        iou = self.iou_metric.aggregate().item()
        
        # Reset for next batch
        self.dice_metric.reset()
        self.iou_metric.reset()
        
        return {
            'dice': dice,
            'iou': iou,
            'dice_mean': dice,
            'iou_mean': iou
        }

# Usage example:
if __name__ == "__main__":
    # Initialize
    metrics = SegmentationMetrics(num_classes=3)
    
    # Test case - batch of 2 128x128 2D images
    y_true = torch.randint(0, 3, (2, 128, 128))  # Class indices
    y_pred = torch.randn(2, 3, 128, 128)  # Logits
    
    # Compute
    results = metrics.compute_metrics(y_pred, y_true)
    print(results)

{'dice': 0.33498871326446533, 'iou': 0.20119662582874298, 'dice_mean': 0.33498871326446533, 'iou_mean': 0.20119662582874298}


In [65]:
from monai.metrics import DiceMetric, MeanIoU
from monai.data import decollate_batch
from monai.transforms import AsDiscrete, EnsureType
import torch

class SegmentationMetrics:
    """
    Preserves original interface while using MONAI internally.
    Output format matches exactly what you had before.
    """
    
    def __init__(self, num_classes):
        self.num_classes = num_classes
        
        # MONAI metrics setup
        self.dice_metric = DiceMetric(include_background=True, 
                                    reduction="none",
                                    get_not_nans=False)
        self.iou_metric = MeanIoU(include_background=True,
                                 reduction="none",
                                 get_not_nans=False)
        
        # Post-processing transforms
        self.post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)
        self.post_label = AsDiscrete(to_onehot=num_classes)
        self.ensure_type = EnsureType()

    def _prepare_inputs(self, y_pred, y_true):
        """Convert inputs to MONAI expected format"""
        y_pred = self.ensure_type(y_pred)
        y_true = self.ensure_type(y_true)
        
        # Add channel dim if needed (for class indices)
        if y_true.dim() == y_pred.dim() - 1:
            y_true = y_true.unsqueeze(1)
            
        return [self.post_pred(i) for i in decollate_batch(y_pred)], \
               [self.post_label(i) for i in decollate_batch(y_true)]

    def _create_metrics_dict(self, mean_value, class_values, prefix):
        """Creates the exact output format you had before"""
        metrics = {
            f"{prefix}_mean": mean_value.item()
        }
        
        # Handle both tensor and list/array inputs
        if torch.is_tensor(class_values):
            if class_values.dim() == 0:  # Single value
                metrics.update({f"{prefix}_class_{i}": class_values.item() 
                              for i in range(self.num_classes)})
            else:  # Per-class values
                metrics.update({f"{prefix}_class_{i}": class_values[i].item() 
                              for i in range(len(class_values))})
        else:  # List or array
            metrics.update({f"{prefix}_class_{i}": class_values[i] 
                          for i in range(len(class_values))})
            
        metrics[prefix] = metrics[f"{prefix}_mean"]  # Add the short version
        return metrics

    def _compute_precision_recall(self, y_pred_pp, y_true_pp):
        """Proper precision/recall calculation using TP/FP/FN"""
        # Convert to tensors if they're lists
        y_pred = torch.stack(y_pred_pp) if isinstance(y_pred_pp, list) else y_pred_pp
        y_true = torch.stack(y_true_pp) if isinstance(y_true_pp, list) else y_true_pp
        
        # Flatten all predictions and labels
        y_pred = y_pred.flatten(1)  # [N, C, H*W*D]
        y_true = y_true.flatten(1)
        
        # Calculate TP, FP, FN per class
        tp = (y_pred * y_true).sum(1)
        fp = (y_pred * (1 - y_true)).sum(1)
        fn = ((1 - y_pred) * y_true).sum(1)
        
        precision = (tp + 1e-7) / (tp + fp + 1e-7)
        recall = (tp + 1e-7) / (tp + fn + 1e-7)
        
        return precision.mean(0), recall.mean(0)  # Mean across batch

    def all_metrics(self, y_pred, y_true):
        """
        Returns metrics in the EXACT original format.
        """
        y_pred_pp, y_true_pp = self._prepare_inputs(y_pred, y_true)
        
        # Compute Dice and IoU (MONAI)
        self.dice_metric(y_pred=y_pred_pp, y=y_true_pp)
        dice_scores = self.dice_metric.aggregate()  # Shape: [batch, classes]
        dice_mean = torch.mean(dice_scores)
        self.dice_metric.reset()
        
        self.iou_metric(y_pred=y_pred_pp, y=y_true_pp)
        iou_scores = self.iou_metric.aggregate()
        iou_mean = torch.mean(iou_scores)
        self.iou_metric.reset()
        
        # Compute precision and recall properly
        precision, recall = self._compute_precision_recall(y_pred_pp, y_true_pp)
        
        # Build output
        metrics = {}
        metrics.update(self._create_metrics_dict(dice_mean, dice_scores.mean(0), 'dice'))
        metrics.update(self._create_metrics_dict(iou_mean, iou_scores.mean(0), 'iou'))
        metrics.update(self._create_metrics_dict(
            torch.mean(precision), 
            precision, 
            'precision'))
        metrics.update(self._create_metrics_dict(
            torch.mean(recall),
            recall,
            'recall'))
        
        return metrics

# Usage example:
if __name__ == "__main__":
    metrics = SegmentationMetrics(num_classes=3)
    
    # Test case - batch of 2 128x128 2D images
    y_true = torch.randint(0, 3, (2, 128, 128))  # Class indices
    y_pred = torch.randn(2, 3, 128, 128)  # Logits
    
    # Compute - same call signature as before
    results = metrics.all_metrics(y_pred, y_true)
    
    # Print in the exact original format
    for k, v in results.items():
        print(f"{k}: {v:.4f}")

dice_mean: 0.3302
dice_class_0: 0.3312
dice_class_1: 0.3336
dice_class_2: 0.3258
dice: 0.3302
iou_mean: 0.1978
iou_class_0: 0.1985
iou_class_1: 0.2002
iou_class_2: 0.1946
iou: 0.1978
precision_mean: 0.3302
precision_class_0: 0.3302
precision_class_1: 0.3302
precision_class_2: 0.3302
precision: 0.3302
recall_mean: 0.3302
recall_class_0: 0.3302
recall_class_1: 0.3302
recall_class_2: 0.3302
recall: 0.3302


In [66]:
from src.metrics import SegmentationMetrics as sm

In [67]:
metrics = sm.all_metrics(y_true, y_pred)
print(metrics)

TypeError: zeros(): argument 'size' failed to unpack the object at pos 2 with error "type must be tuple of ints,but got float"