In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Loading functions
import os
import time
from monai.data import DataLoader, decollate_batch


import torch
import torch.nn.parallel

from src.get_data import CustomDataset, CustomDatasetSeg
import numpy as np
from scipy import ndimage
from types import SimpleNamespace
import wandb
import logging

#####
import json
import shutil
import tempfile

import matplotlib.pyplot as plt
import nibabel as nib

from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai import transforms
from monai.transforms import (
    AsDiscrete,
    Activations,
    MapTransform,
    Transform,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.utils.enums import MetricReduction
from monai.networks.nets import SwinUNETR
from monai import data

# from monai.data import decollate_batch
from functools import partial
from src.custom_transforms import (ConvertToMultiChannelBasedOnN_Froi, 
                                   ConvertToMultiChannelBasedOnAnotatedInfiltration, 
                                   masked, ConvertToMultiChannelBasedOnBratsClassesdI)

#### Trasnformaciones

In [3]:
roi = (128, 128, 128) # (220, 220, 155) (128, 128, 64)
source_k="label"
train_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        # ConvertToMultiChannelBasedOnN_Froi(keys="label"),
        # masked(keys="image"),
        # ConvertToMultiChannelBasedOnAnotatedInfiltration(keys="label"),
        ConvertToMultiChannelBasedOnBratsClassesdI(keys="label"),
        transforms.CropForegroundd(
            keys=["image", "label"],
            source_key=source_k,
            k_divisible=[roi[0], roi[1], roi[2]],
        ),
        transforms.RandSpatialCropd(
            keys=["image", "label"],
            roi_size=[roi[0], roi[1], roi[2]],
            random_size=False,
        ),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        
    ]
)
val_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        # ConvertToMultiChannelBasedOnN_Froi(keys="label"),
        # masked(keys="image"),
        # ConvertToMultiChannelBasedOnAnotatedInfiltration(keys="label"),
        ConvertToMultiChannelBasedOnBratsClassesdI(keys="label"),
        transforms.RandSpatialCropd(
            keys=["image", "label"],
            roi_size=[-1, -1, -1], #[224, 224, 128],
            random_size=False,
        ),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)



## Modelos

In [4]:
######################
# Crear el modelo
######################

### Hyperparameter
roi = (128, 128, 128)  # (128, 128, 128)

# Create Swin transformer
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def define_model(model_path):
    model = SwinUNETR(
        img_size=roi,
        in_channels=11,
        out_channels=2,  # mdificar con edema
        feature_size=48, #48
        drop_rate=0.0,
        attn_drop_rate=0.0,
        dropout_path_rate=0.0,
        use_checkpoint=True,
    )

    # # Load the best model
    # model_path = "artifacts/o9kppyr5_best_model:v0/model.pt"

    # Load the model on CPU
    loaded_model = torch.load(model_path, map_location=torch.device('cuda:0'), weights_only=False)["state_dict"]

    # Load the state dictionary into the model
    model.load_state_dict(loaded_model)
    return model
    

In [5]:
# Modelo TC+Edema
model1=define_model("artifacts/vtzpbajf_best_model:v0/model.pt") # o9kppyr5
model1.to(device)
model1.eval()

# Modelo Infitracion+Edema
model2=define_model("artifacts/1dhzmigz_best_model:v0/model.pt") # uixfayrn - rvu24jip
model2.to(device)
model2.eval()



SwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(11, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=48, out_features=144, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=48, out_features=48, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=48, out_features=192, bias=True)
              (linear2): Linear(in_featur

## Obtener embeddings

In [6]:
# Create dataset data loader
# dataset_path='./Dataset/Dataset_recurrence'
dataset_path='./Dataset/Dataset_30_6'
train_set=CustomDataset(dataset_path, section="test_6", transform=train_transform) # v_transform
train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)

# Directorios para embeddings de cada modelo
embedding_dir_model1 = "Dataset/contrastive_voxel_wise/embeddings_model1"
embedding_dir_model2 = "Dataset/contrastive_voxel_wise/embeddings_model2"
label_output_dir = "Dataset/contrastive_voxel_wise/labels"

# Crear carpetas si no existen
os.makedirs(embedding_dir_model1, exist_ok=True)
os.makedirs(embedding_dir_model2, exist_ok=True)
os.makedirs(label_output_dir, exist_ok=True)

# Variables para las características de los decoders de ambos modelos
decoder_features_model1 = None
decoder_features_model2 = None

# Funciones hook para cada modelo
def decoder_hook_fn_model1(module, input, output):
    global decoder_features_model1
    decoder_features_model1 = output

def decoder_hook_fn_model2(module, input, output):
    global decoder_features_model2
    decoder_features_model2 = output

# Registrar los hooks en los decoders de ambos modelos
hook_handle_decoder1 = model1.decoder1.conv_block.register_forward_hook(decoder_hook_fn_model1)
hook_handle_decoder2 = model2.decoder1.conv_block.register_forward_hook(decoder_hook_fn_model2)

# Extraer y guardar
with torch.no_grad():
    for idx, batch_data in enumerate(train_loader):
        image, label = batch_data["image"], batch_data["label"]
        print("Image", image.shape)  # [1, 11, 128, 128, 128]
        print("label before squeeze", label.shape)  # [1, 2, 128, 128, 128]
        
        image = image.to(device)
        label = label.squeeze(0)  # [2, 128, 128, 128]
        
        # Convertir one-hot a etiquetas únicas
        label_sum = label.sum(dim=0)  # [128, 128, 128], suma de canales
        label_class = torch.zeros_like(label_sum, dtype=torch.long)  # [128, 128, 128]
        
        # Asignar clases:
        # - Fondo (0, 0) -> 0
        # - Vasogénico (1, 0) -> 1
        # - Infiltrado (0, 1) -> 2
        label_class[label[1] == 1] = 2  # Infiltrado
        label_class[(label[0] == 1) & (label[1] == 0)] = 1  # Vasogénico
        # Donde label_sum == 0, ya es fondo (0)
        
        label = label_class.cpu().numpy()  # [128, 128, 128]
        print("label", label.shape)
        
        # Obtener embeddings de ambos modelos
        _ = model1(image)  # Forward para model1
        print("decoder_features_model1:", decoder_features_model1.shape)  # [1, 48, 128, 128, 128]
        
        _ = model2(image)  # Forward para model2
        print("decoder_features_model2:", decoder_features_model2.shape)  # [1, 48, 128, 128, 128]
        
        # Guardar embeddings y etiquetas
        np.save(f"{embedding_dir_model1}/case_{idx}.npy", decoder_features_model1.cpu().numpy())
        np.save(f"{embedding_dir_model2}/case_{idx}.npy", decoder_features_model2.cpu().numpy())
        np.save(f"{label_output_dir}/case_{idx}.npy", label)
        
        print(f"Guardado embeddings y etiquetas para caso {idx}")

# Remover los hooks
hook_handle_decoder1.remove()
hook_handle_decoder2.remove()

Found 6 images and 6 labels.
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features_model1: torch.Size([1, 48, 128, 128, 128])
decoder_features_model2: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 0
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features_model1: torch.Size([1, 48, 128, 128, 128])
decoder_features_model2: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 1
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features_model1: torch.Size([1, 48, 128, 128, 128])
decoder_features_model2: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 2
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder

In [15]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score, f1_score
import matplotlib.pyplot as plt

# Dispositivo
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# Dataset (ya lo tienes)
class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)  # [1, 48, 128, 128, 128]
        labels = np.load(label_path)  # [128, 128, 128]
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)  # [48, 128, 128, 128]
        labels = torch.tensor(labels, dtype=torch.long)  # [128, 128, 128]
        
        return embeddings, labels

# # Modelo de proyección
# class ProjectionHead(nn.Module):
#     def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
#         super(ProjectionHead, self).__init__()
#         self.net = nn.Sequential(
#             nn.Linear(input_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, output_dim)
#         )
    
#     def forward(self, x):
#         return self.net(x)

# Modelo de proyección (MLP más profundo)
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim1=256, hidden_dim2=128, output_dim=128, dropout_p=0.3):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim2, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

# # Clasificador supervisado
# class Classifier(nn.Module):
#     def __init__(self, input_dim=128, num_classes=3):
#         super(Classifier, self).__init__()
#         self.fc = nn.Linear(input_dim, num_classes)
    
#     def forward(self, x):
#         return self.fc(x)

# Clasificador supervisado (MLP)
class Classifier(nn.Module):
    def __init__(self, input_dim=128, hidden_dim1=256, hidden_dim2=128, num_classes=3, dropout_p=0.3):
        super(Classifier, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim2, num_classes)
        )
    
    def forward(self, x):
        return self.net(x)

# Configuración para ambos modelos
embedding_dir_model1 = "Dataset/contrastive_voxel_wise/embeddings_model1"
embedding_dir_model2 = "Dataset/contrastive_voxel_wise/embeddings_model2"
label_dir = "Dataset/contrastive_voxel_wise/labels"
batch_size = 1

# Cargar datasets y DataLoaders para ambos modelos
dataset_model1 = EmbeddingDataset(embedding_dir_model1, label_dir)
dataset_model2 = EmbeddingDataset(embedding_dir_model2, label_dir)
loader_model1 = DataLoader(dataset_model1, batch_size=batch_size, shuffle=False)
loader_model2 = DataLoader(dataset_model2, batch_size=batch_size, shuffle=False)

# Cargar modelos contrastivos preentrenados
projection_head1 = ProjectionHead(input_dim=48).to(device)
projection_head1.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_new_pipe1_v01_m1.pth", map_location=device))
projection_head1.eval()

projection_head2 = ProjectionHead(input_dim=48).to(device)
projection_head2.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_new_pipe2_m1_1dhzmigz.pth", map_location=device))
projection_head2.eval()

# Cargar clasificadores preentrenados
classifier1 = Classifier(input_dim=128, num_classes=3).to(device)
classifier1.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final_pipe1_v01_m1.pth", map_location=device))
classifier1.eval()

classifier2 = Classifier(input_dim=128, num_classes=3).to(device)
classifier2.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final_pipe2_m1_1dhzmigz.pth", map_location=device))
classifier2.eval()

# Función para generar mapas de probabilidad
def generate_probability_maps(embeddings, projection_head, classifier, device):
    with torch.no_grad():
        embeddings = embeddings.to(device).squeeze(0).permute(1, 2, 3, 0)
        embeddings_flat = embeddings.reshape(-1, 48)
        
        z = projection_head(embeddings_flat)
        z = F.normalize(z, dim=1)
        
        logits = classifier(z)
        probs = F.softmax(logits, dim=1)
        
        probs = probs.view(128, 128, 128, 3).permute(3, 0, 1, 2)
        return probs

# Funciones para calcular métricas
def calculate_metrics(pred, true,prob_maps=None, num_classes=3):
    dice_scores = []
    sensitivity_scores = []
    precision_scores = []
    auc_scores = []
    f1_scores = []
    
    # Calcular Accuracy global
    accuracy = accuracy_score(true.flatten(), pred.flatten())

    for cls in range(num_classes):
        pred_cls = (pred == cls).astype(np.uint8)
        true_cls = (true == cls).astype(np.uint8)
        
        tp = np.sum(pred_cls * true_cls)
        fp = np.sum(pred_cls * (1 - true_cls))
        fn = np.sum((1 - pred_cls) * true_cls)
        
        dice = 2 * tp / (2 * tp + fp + fn + 1e-6)
        dice_scores.append(dice)
        
        sensitivity = tp / (tp + fn + 1e-6)
        sensitivity_scores.append(sensitivity)
        
        precision = tp / (tp + fp + 1e-6)
        precision_scores.append(precision)

        # F1 Score
        f1 = f1_score(true_cls.flatten(), pred_cls.flatten(), zero_division=0)
        f1_scores.append(f1)

        # AUC-ROC (requiere mapas de probabilidad)
        if prob_maps is not None:
            try:
                auc = roc_auc_score(true_cls.flatten(), prob_maps[cls].flatten())
                auc_scores.append(auc)
            except ValueError:
                auc_scores.append(np.nan)  # Manejar casos donde AUC no se puede calcular
        else:
            auc_scores.append(np.nan)  # Si no se proporcionan mapas de probabilidad
    
    return dice_scores, sensitivity_scores, precision_scores, auc_scores, accuracy, f1_scores

# Directorio de salida
output_dir = "trained_models/mapas_combinados_pipe1_v01_pipe2_1dhzmigz_test_cube5_center"
os.makedirs(output_dir, exist_ok=True)



## Ejecutar calculo y guardar mapas

In [5]:
# Listas para métricas
all_dice = {0: [], 1: [], 2: []}
all_sensitivity = {0: [], 1: [], 2: []}
all_precision = {0: [], 1: [], 2: []}
all_auc = {0: [], 1: [], 2: []}  # Lista para AUC-ROC
all_accuracy = []  # Lista para Accuracy global
all_f1 = {0: [], 1: [], 2: []}  # Lista para F1 Score
# Listas para almacenar FPR y TPR de todos los casos
all_fpr = {0: [], 1: [], 2: []}
all_tpr = {0: [], 1: [], 2: []}


# Procesar y combinar
for idx, ((embeddings1, labels1), (embeddings2, labels2)) in enumerate(zip(loader_model1, loader_model2)):
    # Generar mapas de probabilidad para ambos modelos
    prob_maps1 = generate_probability_maps(embeddings1, projection_head1, classifier1, device)  # [3, 128, 128, 128]
    prob_maps2 = generate_probability_maps(embeddings2, projection_head2, classifier2, device)  # [3, 128, 128, 128]
    
    # Combinar mapas:
    # - Clase 0: del modelo 1
    # - Clase 1: máximo entre ambos modelos
    # - Clase 2: del modelo 2
    combined_prob_maps = torch.zeros_like(prob_maps1)  # [3, 128, 128, 128]
    combined_prob_maps[0] = prob_maps1[0]  # Clase 0 del modelo 1
    combined_prob_maps[1] = torch.max(prob_maps1[1], prob_maps2[1])  # Clase 1 máximo entre ambos
    combined_prob_maps[2] = prob_maps2[2]  # Clase 2 del modelo 2
    
    # Normalizar probabilidades para que sumen 1 en cada vóxel
    combined_prob_maps = combined_prob_maps / combined_prob_maps.sum(dim=0, keepdim=True)
    
    # Convertir a numpy
    prob_maps_np = combined_prob_maps.cpu().numpy()  # [3, 128, 128, 128]
    prob_maps_np_nifti = np.transpose(prob_maps_np, (1, 2, 3, 0))  # [128, 128, 128, 3]
    
    # Generar segmentación semántica
    segmentation = np.argmax(prob_maps_np, axis=0)  # [128, 128, 128]
    # segmentation_np = segmentation.astype(np.uint8)

    # Aplicar umbral: asignar clase 1 si la probabilidad es > 0.4
    # class_1_mask = prob_maps_np[1] > 0.3  # Máscara booleana para clase 1
    # segmentation[class_1_mask] = 1  # Asignar clase 1 a los vóxeles que cumplen el criterio
    segmentation_np = segmentation.astype(np.uint8)
    
    # Etiquetas (usamos las del modelo 2, asumiendo que son iguales)
    labels = labels2.squeeze(0)  # [128, 128, 128]
    labels_np = labels.cpu().numpy().astype(np.uint8)
    
    # Calcular métricas
    dice, sensitivity, precision, auc, accuracy, f1 = calculate_metrics(segmentation_np, labels_np, prob_maps=prob_maps_np)
    for cls in range(3):
        all_dice[cls].append(dice[cls])
        all_sensitivity[cls].append(sensitivity[cls])
        all_precision[cls].append(precision[cls])
        all_auc[cls].append(auc[cls])
        all_f1[cls].append(f1[cls])
    all_accuracy.append(accuracy)

     # Graficar curvas ROC
    plt.figure(figsize=(8, 6))
    class_names = ["Fondo", "Vasogénico", "Infiltrado"]
    colors = ['blue', 'green', 'red']
    
    for cls in range(3):
        # Etiquetas binarias para la clase actual
        true_cls = (labels_np == cls).astype(np.uint8).flatten()
        prob_cls = prob_maps_np[cls].flatten()
        
        # Calcular puntos de la curva ROC
        fpr, tpr, _ = roc_curve(true_cls, prob_cls)
        auc_value = auc[cls]  # Usar el AUC calculado previamente
        all_fpr[cls].append(fpr)
        all_tpr[cls].append(tpr)
        
        # Graficar
        plt.plot(fpr, tpr, color=colors[cls], label=f'{class_names[cls]} (AUC = {auc_value:.4f})')
    
    # Configurar el gráfico
    plt.plot([0, 1], [0, 1], 'k--')  # Línea diagonal (clasificador aleatorio)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Tasa de Falsos Positivos (FPR)')
    plt.ylabel('Tasa de Verdaderos Positivos (TPR)')
    plt.title(f'Curva ROC - Caso {idx}')
    plt.legend(loc="lower right")
    
    # Guardar el gráfico
    roc_output_path = os.path.join(output_dir, f"roc_curve_case_{idx}.png")
    plt.savefig(roc_output_path)
    plt.close()
    print(f"Guardada curva ROC en {roc_output_path}")

    print(f"Caso {idx} - Dice: {dice}, Sensitivity: {sensitivity}, Precision: {precision}, "
          f"AUC-ROC: {auc}, Accuracy: {accuracy:.4f}, F1 Score: {f1}")
    
    # Crear imágenes NIfTI
    affine = np.eye(4)
    
    # Guardar mapas de probabilidad combinados
    nifti_prob_img = nib.Nifti1Image(prob_maps_np_nifti, affine)
    prob_output_path = os.path.join(output_dir, f"probability_maps_case_{idx}.nii.gz")
    nib.save(nifti_prob_img, prob_output_path)
    print(f"Guardado mapa de probabilidad en {prob_output_path}")
    
    # Guardar etiquetas
    nifti_label_img = nib.Nifti1Image(labels_np, affine)
    label_output_path = os.path.join(output_dir, f"labels_case_{idx}.nii.gz")
    nib.save(nifti_label_img, label_output_path)
    print(f"Guardadas etiquetas en {label_output_path}")
    
    # Guardar segmentación semántica
    nifti_seg_img = nib.Nifti1Image(segmentation_np, affine)
    seg_output_path = os.path.join(output_dir, f"segmentation_case_{idx}.nii.gz")
    nib.save(nifti_seg_img, seg_output_path)
    print(f"Guardada segmentación en {seg_output_path}")

# Curva ROC Promedio
plt.figure(figsize=(8, 6))
for cls in range(3):
    # Interpolar FPR y TPR a una base común
    mean_fpr = np.linspace(0, 1, 100)
    tprs = []
    for fpr, tpr in zip(all_fpr[cls], all_tpr[cls]):
        tpr_interp = np.interp(mean_fpr, fpr, tpr)
        tpr_interp[0] = 0.0  # Asegurar que comienza en 0
        tprs.append(tpr_interp)
    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0  # Asegurar que termina en 1
    mean_auc = np.nanmean(all_auc[cls])
    
    plt.plot(mean_fpr, mean_tpr, color=colors[cls], label=f'{class_names[cls]} (AUC = {mean_auc:.4f})')

plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Tasa de Falsos Positivos (FPR)')
plt.ylabel('Tasa de Verdaderos Positivos (TPR)')
plt.title('Curva ROC Promedio')
plt.legend(loc="lower right")
roc_avg_path = os.path.join(output_dir, "roc_curve_average.png")
plt.savefig(roc_avg_path)
plt.close()
print(f"Guardada curva ROC promedio en {roc_avg_path}")

# # Calcular promedios y desviaciones estándar
# class_names = ["Fondo", "Vasogénico", "Infiltrado"]
# for cls in range(3):
#     dice_mean = np.mean(all_dice[cls])
#     dice_std = np.std(all_dice[cls])
#     sens_mean = np.mean(all_sensitivity[cls])
#     sens_std = np.std(all_sensitivity[cls])
#     prec_mean = np.mean(all_precision[cls])
#     prec_std = np.std(all_precision[cls])
    
#     print(f"\nClase {cls} ({class_names[cls]}):")
#     print(f"  Dice: {dice_mean:.4f} ± {dice_std:.4f}")
#     print(f"  Sensibilidad: {sens_mean:.4f} ± {sens_std:.4f}")
#     print(f"  Precisión: {prec_mean:.4f} ± {prec_std:.4f}")

# Calcular promedios y desviaciones estándar
class_names = ["Fondo", "Vasogénico", "Infiltrado"]
for cls in range(3):
    dice_mean = np.nanmean(all_dice[cls])
    dice_std = np.nanstd(all_dice[cls])
    sens_mean = np.nanmean(all_sensitivity[cls])
    sens_std = np.nanstd(all_sensitivity[cls])
    prec_mean = np.nanmean(all_precision[cls])
    prec_std = np.nanstd(all_precision[cls])
    auc_mean = np.nanmean(all_auc[cls])
    auc_std = np.nanstd(all_auc[cls])
    f1_mean = np.nanmean(all_f1[cls])
    f1_std = np.nanstd(all_f1[cls])
    
    print(f"\nClase {cls} ({class_names[cls]}):")
    print(f"  Dice: {dice_mean:.4f} ± {dice_std:.4f}")
    print(f"  Sensibilidad: {sens_mean:.4f} ± {sens_std:.4f}")
    print(f"  Precisión: {prec_mean:.4f} ± {prec_std:.4f}")
    print(f"  AUC-ROC: {auc_mean:.4f} ± {auc_std:.4f}")
    print(f"  F1 Score: {f1_mean:.4f} ± {f1_std:.4f}")

# Accuracy global
accuracy_mean = np.nanmean(all_accuracy)
accuracy_std = np.nanstd(all_accuracy)
print(f"\nAccuracy Global: {accuracy_mean:.4f} ± {accuracy_std:.4f}")

Guardada curva ROC en trained_models/mapas_combinados_pipe1_v01_pipe2_1dhzmigz_test/roc_curve_case_0.png
Caso 0 - Dice: [0.9970035703989371, 0.520193908666654, 0.6986080252122276], Sensitivity: [0.9941729438903477, 0.4651015729859273, 0.8572294069472858], Precision: [0.9998503617499696, 0.5900915958885441, 0.5895229395928157], AUC-ROC: [0.9997574644423223, 0.9910906713472637, 0.9931217498754038], Accuracy: 0.9818, F1 Score: [0.9970035703991849, 0.5201939086743189, 0.698608025218861]
Guardado mapa de probabilidad en trained_models/mapas_combinados_pipe1_v01_pipe2_1dhzmigz_test/probability_maps_case_0.nii.gz
Guardadas etiquetas en trained_models/mapas_combinados_pipe1_v01_pipe2_1dhzmigz_test/labels_case_0.nii.gz
Guardada segmentación en trained_models/mapas_combinados_pipe1_v01_pipe2_1dhzmigz_test/segmentation_case_0.nii.gz
Guardada curva ROC en trained_models/mapas_combinados_pipe1_v01_pipe2_1dhzmigz_test/roc_curve_case_1.png
Caso 1 - Dice: [0.9961940335339866, 0.7573532532253574, 0.709

# Metricas basdas en regions (supervoxeles)

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score, accuracy_score, f1_score
from scipy import stats
import os
import numpy as np
import nibabel as nib
import torch

# Función para calcular métricas (ya definida previamente)
def calculate_metrics(pred, true, prob_maps=None, num_classes=3):
    dice_scores = []
    sensitivity_scores = []
    precision_scores = []
    auc_scores = []
    f1_scores = []
    
    accuracy = accuracy_score(true.flatten(), pred.flatten())
    
    for cls in range(num_classes):
        pred_cls = (pred == cls).astype(np.uint8)
        true_cls = (true == cls).astype(np.uint8)
        
        tp = np.sum(pred_cls * true_cls)
        fp = np.sum(pred_cls * (1 - true_cls))
        fn = np.sum((1 - pred_cls) * true_cls)
        
        dice = 2 * tp / (2 * tp + fp + fn + 1e-6)
        dice_scores.append(dice)
        
        sensitivity = tp / (tp + fn + 1e-6)
        sensitivity_scores.append(sensitivity)
        
        precision = tp / (tp + fp + 1e-6)
        precision_scores.append(precision)
        
        f1 = f1_score(true_cls.flatten(), pred_cls.flatten(), zero_division=0)
        f1_scores.append(f1)
        
        if prob_maps is not None:
            try:
                auc = roc_auc_score(true_cls.flatten(), prob_maps[cls].flatten())
                auc_scores.append(auc)
            except ValueError:
                auc_scores.append(np.nan)
        else:
            auc_scores.append(np.nan)
    
    return dice_scores, sensitivity_scores, precision_scores, auc_scores, accuracy, f1_scores

# Función para dividir en cubos y obtener clases predominantes
def get_cube_labels(volume, cube_size, num_classes=3):
    dims = volume.shape
    assert dims[0] % cube_size == 0, "El tamaño del cubo debe dividir exactamente el tamaño del volumen"
    num_cubes = dims[0] // cube_size
    
    cube_labels = np.zeros((num_cubes, num_cubes, num_cubes), dtype=np.uint8)
    cube_probs = np.zeros((num_classes, num_cubes, num_cubes, num_cubes))
    
    for i in range(num_cubes):
        for j in range(num_cubes):
            for k in range(num_cubes):
                cube = volume[i*cube_size:(i+1)*cube_size, 
                             j*cube_size:(j+1)*cube_size, 
                             k*cube_size:(k+1)*cube_size]
                # Clase predominante (modo)
                mode_value = stats.mode(cube.flatten(), keepdims=True)[0][0]
                cube_labels[i, j, k] = mode_value
                # Proporción de cada clase como "probabilidad" suavizada
                for cls in range(num_classes):
                    cube_probs[cls, i, j, k] = np.mean(cube == cls)
    
    return cube_labels, cube_probs

# Listas para métricas voxel-wise
all_dice = {0: [], 1: [], 2: []}
all_sensitivity = {0: [], 1: [], 2: []}
all_precision = {0: [], 1: [], 2: []}
all_auc = {0: [], 1: [], 2: []}
all_accuracy = []
all_f1 = {0: [], 1: [], 2: []}

# Listas para métricas cube-wise
all_dice_cube = {0: [], 1: [], 2: []}
all_sensitivity_cube = {0: [], 1: [], 2: []}
all_precision_cube = {0: [], 1: [], 2: []}
all_auc_cube = {0: [], 1: [], 2: []}
all_accuracy_cube = []
all_f1_cube = {0: [], 1: [], 2: []}

# Tamaño del cubo (ajusta según necesites)
cube_size = 8  # 128 / 16 = 8 cubos por dimensión

# Procesar y combinar
for idx, ((embeddings1, labels1), (embeddings2, labels2)) in enumerate(zip(loader_model1, loader_model2)):
    # Generar mapas de probabilidad para ambos modelos
    prob_maps1 = generate_probability_maps(embeddings1, projection_head1, classifier1, device)  # [3, 128, 128, 128]
    prob_maps2 = generate_probability_maps(embeddings2, projection_head2, classifier2, device)  # [3, 128, 128, 128]
    
    # Combinar mapas
    combined_prob_maps = torch.zeros_like(prob_maps1)
    combined_prob_maps[0] = prob_maps1[0]
    combined_prob_maps[1] = torch.max(prob_maps1[1], prob_maps2[1])
    combined_prob_maps[2] = prob_maps2[2]
    
    # Convertir a numpy
    prob_maps_np = combined_prob_maps.cpu().numpy()  # [3, 128, 128, 128]
    prob_maps_np_nifti = np.transpose(prob_maps_np, (1, 2, 3, 0))  # [128, 128, 128, 3]
    
    # Generar segmentación semántica
    segmentation = np.argmax(prob_maps_np, axis=0)
    # Colocar nuevo umbral para infiltracion
    # class_1_mask = prob_maps_np[1] > 0.4
    # segmentation[class_1_mask] = 1
    segmentation_np = segmentation.astype(np.uint8)
    
    # Etiquetas
    labels = labels2.squeeze(0)
    labels_np = labels.cpu().numpy().astype(np.uint8)
    
    # Calcular métricas voxel-wise
    dice, sensitivity, precision, auc, accuracy, f1 = calculate_metrics(segmentation_np, labels_np, prob_maps=prob_maps_np)
    
    # Almacenar métricas voxel-wise
    for cls in range(3):
        all_dice[cls].append(dice[cls])
        all_sensitivity[cls].append(sensitivity[cls])
        all_precision[cls].append(precision[cls])
        all_auc[cls].append(auc[cls])
        all_f1[cls].append(f1[cls])
    all_accuracy.append(accuracy)
    
    # Evaluación basada en cubos
    pred_cube_labels, pred_cube_probs = get_cube_labels(segmentation_np, cube_size)
    true_cube_labels, true_cube_probs = get_cube_labels(labels_np, cube_size)
    
    # Calcular métricas cube-wise
    dice_cube, sensitivity_cube, precision_cube, auc_cube, accuracy_cube, f1_cube = calculate_metrics(
        pred_cube_labels, true_cube_labels, prob_maps=pred_cube_probs
    )
    
    # Almacenar métricas cube-wise
    for cls in range(3):
        all_dice_cube[cls].append(dice_cube[cls])
        all_sensitivity_cube[cls].append(sensitivity_cube[cls])
        all_precision_cube[cls].append(precision_cube[cls])
        all_auc_cube[cls].append(auc_cube[cls])
        all_f1_cube[cls].append(f1_cube[cls])
    all_accuracy_cube.append(accuracy_cube)
    
    # Mapa de coincidencias/discrepancias
    match_map = (pred_cube_labels == true_cube_labels).astype(np.uint8)
    mismatch_map = (pred_cube_labels != true_cube_labels).astype(np.uint8)
    
    # Guardar mapas de cubos y coincidencias
    affine = np.eye(4) * cube_size  # Ajustar el affine para reflejar el tamaño del cubo
    nib.save(nib.Nifti1Image(pred_cube_labels, affine), os.path.join(output_dir, f"pred_cube_labels_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(true_cube_labels, affine), os.path.join(output_dir, f"true_cube_labels_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(match_map, affine), os.path.join(output_dir, f"match_map_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(mismatch_map, affine), os.path.join(output_dir, f"mismatch_map_case_{idx}.nii.gz"))
    
    # Análisis espacial (centro de masa de la clase Infiltrado)
    infiltrado_pred = (pred_cube_labels == 1).astype(np.uint8)
    infiltrado_true = (true_cube_labels == 1).astype(np.uint8)
    if np.sum(infiltrado_pred) > 0:
        pred_center = np.mean(np.where(infiltrado_pred), axis=1)
    else:
        pred_center = np.array([np.nan, np.nan, np.nan])
    if np.sum(infiltrado_true) > 0:
        true_center = np.mean(np.where(infiltrado_true), axis=1)
    else:
        true_center = np.array([np.nan, np.nan, np.nan])
    
    print(f"Caso {idx} - Voxel-wise:")
    print(f"  Dice: {dice}, Sensitivity: {sensitivity}, Precision: {precision}, AUC-ROC: {auc}, "
          f"Accuracy: {accuracy:.4f}, F1 Score: {f1}")
    print(f"Caso {idx} - Cube-wise (tamaño {cube_size}):")
    print(f"  Dice: {dice_cube}, Sensitivity: {sensitivity_cube}, Precision: {precision_cube}, "
          f"AUC-ROC: {auc_cube}, Accuracy: {accuracy_cube:.4f}, F1 Score: {f1_cube}")
    print(f"  Centro de masa Infiltrado (Pred): {pred_center}")
    print(f"  Centro de masa Infiltrado (True): {true_center}")
    
    # Guardar mapas de probabilidad y segmentaciones voxel-wise
    nib.save(nib.Nifti1Image(prob_maps_np_nifti, np.eye(4)), os.path.join(output_dir, f"probability_maps_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(labels_np, np.eye(4)), os.path.join(output_dir, f"labels_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(segmentation_np, np.eye(4)), os.path.join(output_dir, f"segmentation_case_{idx}.nii.gz"))

# Calcular promedios y desviaciones estándar (voxel-wise)
class_names = ["Fondo", "Vasogénico", "Infiltrado"]
print("\nResultados Voxel-wise:")
for cls in range(3):
    dice_mean = np.nanmean(all_dice[cls])
    dice_std = np.nanstd(all_dice[cls])
    sens_mean = np.nanmean(all_sensitivity[cls])
    sens_std = np.nanstd(all_sensitivity[cls])
    prec_mean = np.nanmean(all_precision[cls])
    prec_std = np.nanstd(all_precision[cls])
    auc_mean = np.nanmean(all_auc[cls])
    auc_std = np.nanstd(all_auc[cls])
    f1_mean = np.nanmean(all_f1[cls])
    f1_std = np.nanstd(all_f1[cls])
    
    print(f"\nClase {cls} ({class_names[cls]}):")
    print(f"  Dice: {dice_mean:.4f} ± {dice_std:.4f}")
    print(f"  Sensibilidad: {sens_mean:.4f} ± {sens_std:.4f}")
    print(f"  Precisión: {prec_mean:.4f} ± {prec_std:.4f}")
    print(f"  AUC-ROC: {auc_mean:.4f} ± {auc_std:.4f}")
    print(f"  F1 Score: {f1_mean:.4f} ± {f1_std:.4f}")

accuracy_mean = np.nanmean(all_accuracy)
accuracy_std = np.nanstd(all_accuracy)
print(f"\nAccuracy Global: {accuracy_mean:.4f} ± {accuracy_std:.4f}")

# Calcular promedios y desviaciones estándar (cube-wise)
print("\nResultados Cube-wise:")
for cls in range(3):
    dice_mean = np.nanmean(all_dice_cube[cls])
    dice_std = np.nanstd(all_dice_cube[cls])
    sens_mean = np.nanmean(all_sensitivity_cube[cls])
    sens_std = np.nanstd(all_sensitivity_cube[cls])
    prec_mean = np.nanmean(all_precision_cube[cls])
    prec_std = np.nanstd(all_precision_cube[cls])
    auc_mean = np.nanmean(all_auc_cube[cls])
    auc_std = np.nanstd(all_auc_cube[cls])
    f1_mean = np.nanmean(all_f1_cube[cls])
    f1_std = np.nanstd(all_f1_cube[cls])
    
    print(f"\nClase {cls} ({class_names[cls]}):")
    print(f"  Dice: {dice_mean:.4f} ± {dice_std:.4f}")
    print(f"  Sensibilidad: {sens_mean:.4f} ± {sens_std:.4f}")
    print(f"  Precisión: {prec_mean:.4f} ± {prec_std:.4f}")
    print(f"  AUC-ROC: {auc_mean:.4f} ± {auc_std:.4f}")
    print(f"  F1 Score: {f1_mean:.4f} ± {f1_std:.4f}")

accuracy_mean_cube = np.nanmean(all_accuracy_cube)
accuracy_std_cube = np.nanstd(all_accuracy_cube)
print(f"\nAccuracy Global Cube-wise: {accuracy_mean_cube:.4f} ± {accuracy_std_cube:.4f}")

# Metricas basadas en regiones centrada en tumor

In [18]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score, accuracy_score, f1_score
from scipy import stats
import os
import numpy as np
import nibabel as nib
import torch
from mpl_toolkits.mplot3d import Axes3D

# Función para calcular métricas (para todas las clases o un subconjunto)
def calculate_metrics(pred, true, prob_maps=None, num_classes=3, class_offset=0):
    dice_scores = []
    sensitivity_scores = []
    precision_scores = []
    auc_scores = []
    f1_scores = []
    
    accuracy = accuracy_score(true.flatten(), pred.flatten())
    
    for cls in range(class_offset, class_offset + num_classes):
        pred_cls = (pred == cls).astype(np.uint8)
        true_cls = (true == cls).astype(np.uint8)
        
        tp = np.sum(pred_cls * true_cls)
        fp = np.sum(pred_cls * (1 - true_cls))
        fn = np.sum((1 - pred_cls) * true_cls)
        
        dice = 2 * tp / (2 * tp + fp + fn + 1e-6)
        dice_scores.append(dice)
        
        sensitivity = tp / (tp + fn + 1e-6)
        sensitivity_scores.append(sensitivity)
        
        precision = tp / (tp + fp + 1e-6)
        precision_scores.append(precision)
        
        f1 = f1_score(true_cls.flatten(), pred_cls.flatten(), zero_division=0)
        f1_scores.append(f1)
        
        if prob_maps is not None:
            try:
                auc = roc_auc_score(true_cls.flatten(), prob_maps[cls - class_offset].flatten())
                auc_scores.append(auc)
            except ValueError:
                auc_scores.append(np.nan)
        else:
            auc_scores.append(np.nan)
    
    return dice_scores, sensitivity_scores, precision_scores, auc_scores, accuracy, f1_scores

# Función para obtener el cuadro delimitador y centro del tumor
def get_tumor_bbox_and_center(labels_np):
    tumor_mask = (labels_np > 0).astype(np.uint8)  # Clases 1 (Infiltrado) y 2 (Vasogénico)
    if np.sum(tumor_mask) == 0:
        return None, None
    indices = np.where(tumor_mask)
    bbox = {
        'x_min': np.min(indices[0]), 'x_max': np.max(indices[0]),
        'y_min': np.min(indices[1]), 'y_max': np.max(indices[1]),
        'z_min': np.min(indices[2]), 'z_max': np.max(indices[2])
    }
    center = np.mean(indices, axis=1)
    return bbox, center

# Función para dividir la región del tumor en cubos
def get_tumor_cube_labels(segmentation_np, labels_np, bbox, center, grid_size=5):
    # Dimensiones del cuadro delimitador
    dims = [bbox['x_max'] - bbox['x_min'] + 1, 
            bbox['y_max'] - bbox['y_min'] + 1, 
            bbox['z_max'] - bbox['z_min'] + 1]
    
    # Tamaño del cubo (dinámico)
    cube_size = [max(1, d // grid_size) for d in dims]
    
    # Ajustar el cuadro delimitador para centrarlo y cubrir la grilla
    half_grid = np.array([grid_size * cs / 2 for cs in cube_size])
    bbox_min = np.floor(center - half_grid).astype(int)
    bbox_max = np.ceil(center + half_grid).astype(int)
    
    # Asegurar que no exceda los límites del volumen
    bbox_min = np.maximum(bbox_min, 0)
    bbox_max = np.minimum(bbox_max, np.array(segmentation_np.shape))
    
    # Extraer la región del tumor
    tumor_pred = segmentation_np[bbox_min[0]:bbox_max[0], bbox_min[1]:bbox_max[1], bbox_min[2]:bbox_max[2]]
    tumor_true = labels_np[bbox_min[0]:bbox_max[0], bbox_min[1]:bbox_max[1], bbox_min[2]:bbox_max[2]]
    
    # Ajustar para que coincida con la grilla
    cube_labels_pred = np.zeros((grid_size, grid_size, grid_size), dtype=np.uint8)
    cube_labels_true = np.zeros((grid_size, grid_size, grid_size), dtype=np.uint8)
    cube_probs_pred = np.zeros((3, grid_size, grid_size, grid_size))  # Clases 0, 1, 2
    
    for i in range(grid_size):
        for j in range(grid_size):
            for k in range(grid_size):
                x_start = i * cube_size[0]
                y_start = j * cube_size[1]
                z_start = k * cube_size[2]
                x_end = min(x_start + cube_size[0], tumor_pred.shape[0])
                y_end = min(y_start + cube_size[1], tumor_pred.shape[1])
                z_end = min(z_start + cube_size[2], tumor_pred.shape[2])
                
                if x_end <= x_start or y_end <= y_start or z_end <= z_start:
                    continue
                
                cube_pred = tumor_pred[x_start:x_end, y_start:y_end, z_start:z_end]
                cube_true = tumor_true[x_start:x_end, y_start:y_end, z_start:z_end]
                
                # Clase predominante (incluye clase 0)
                cube_labels_pred[i, j, k] = stats.mode(cube_pred.flatten(), keepdims=True)[0][0]
                cube_labels_true[i, j, k] = stats.mode(cube_true.flatten(), keepdims=True)[0][0]
                
                # Proporciones para clases 0, 1, 2
                cube_probs_pred[0, i, j, k] = np.mean(cube_pred == 0)
                cube_probs_pred[1, i, j, k] = np.mean(cube_pred == 1)
                cube_probs_pred[2, i, j, k] = np.mean(cube_pred == 2)
    
    return cube_labels_pred, cube_labels_true, cube_probs_pred, bbox_min, bbox_max

# Listas para métricas voxel-wise (todas las clases)
all_dice = {0: [], 1: [], 2: []}
all_sensitivity = {0: [], 1: [], 2: []}
all_precision = {0: [], 1: [], 2: []}
all_auc = {0: [], 1: [], 2: []}
all_accuracy = []
all_f1 = {0: [], 1: [], 2: []}

# Listas para métricas cube-wise (todas las clases)
all_dice_cube = {0: [], 1: [], 2: []}
all_sensitivity_cube = {0: [], 1: [], 2: []}
all_precision_cube = {0: [], 1: [], 2: []}
all_auc_cube = {0: [], 1: [], 2: []}
all_accuracy_cube = []
all_f1_cube = {0: [], 1: [], 2: []}

# Configuración de la grilla (5x5x5 o 3x3x3)
grid_size = 5  # Cambia a 3 para una grilla 3x3x3

# Procesar y combinar
for idx, ((embeddings1, labels1), (embeddings2, labels2)) in enumerate(zip(loader_model1, loader_model2)):
    # Generar mapas de probabilidad
    prob_maps1 = generate_probability_maps(embeddings1, projection_head1, classifier1, device)
    prob_maps2 = generate_probability_maps(embeddings2, projection_head2, classifier2, device)
    
    # Combinar mapas
    combined_prob_maps = torch.zeros_like(prob_maps1)
    combined_prob_maps[0] = prob_maps1[0]  # Fondo
    combined_prob_maps[1] = torch.max(prob_maps1[1], prob_maps2[1])  # Infiltrado
    combined_prob_maps[2] = prob_maps2[2]  # Vasogénico
    
    # Convertir a numpy
    prob_maps_np = combined_prob_maps.cpu().numpy()
    prob_maps_np_nifti = np.transpose(prob_maps_np, (1, 2, 3, 0))
    
    # Generar segmentación
    segmentation = np.argmax(prob_maps_np, axis=0)
    class_1_mask = prob_maps_np[1] > 0.4  # Umbral para Infiltrado
    segmentation[class_1_mask] = 1
    segmentation_np = segmentation.astype(np.uint8)
    
    # Etiquetas
    labels = labels2.squeeze(0)
    labels_np = labels.cpu().numpy().astype(np.uint8)
    
    # Calcular métricas voxel-wise
    dice, sensitivity, precision, auc, accuracy, f1 = calculate_metrics(
        segmentation_np, labels_np, prob_maps=prob_maps_np, num_classes=3, class_offset=0
    )
    
    # Almacenar métricas voxel-wise
    for cls in range(3):
        all_dice[cls].append(dice[cls])
        all_sensitivity[cls].append(sensitivity[cls])
        all_precision[cls].append(precision[cls])
        all_auc[cls].append(auc[cls])
        all_f1[cls].append(f1[cls])
    all_accuracy.append(accuracy)
    
    # Obtener cuadro delimitador y centro del tumor
    bbox, center = get_tumor_bbox_and_center(labels_np)
    if bbox is None:
        print(f"Caso {idx}: No se detectó tumor en el ground truth. Saltando evaluación cube-wise.")
        continue
    
    # Evaluación basada en cubos
    cube_labels_pred, cube_labels_true, cube_probs_pred, bbox_min, bbox_max = get_tumor_cube_labels(
        segmentation_np, labels_np, bbox, center, grid_size=grid_size
    )
    
    # Calcular métricas cube-wise (clases 0, 1, 2)
    dice_cube, sensitivity_cube, precision_cube, auc_cube, accuracy_cube, f1_cube = calculate_metrics(
        cube_labels_pred, cube_labels_true, prob_maps=cube_probs_pred, num_classes=3, class_offset=0
    )
    
    # Almacenar métricas cube-wise
    for cls in range(3):
        all_dice_cube[cls].append(dice_cube[cls])
        all_sensitivity_cube[cls].append(sensitivity_cube[cls])
        all_precision_cube[cls].append(precision_cube[cls])
        all_auc_cube[cls].append(auc_cube[cls])
        all_f1_cube[cls].append(f1_cube[cls])
    all_accuracy_cube.append(accuracy_cube)
    
    # Mapa de coincidencias/discrepancias
    match_map = (cube_labels_pred == cube_labels_true).astype(np.uint8)
    mismatch_map = (cube_labels_pred != cube_labels_true).astype(np.uint8)
    
    # Guardar mapas de cubos
    affine = np.eye(4)
    affine[:3, 3] = bbox_min  # Ajustar origen al cuadro delimitador
    nib.save(nib.Nifti1Image(cube_labels_pred, affine), 
             os.path.join(output_dir, f"pred_cube_labels_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(cube_labels_true, affine), 
             os.path.join(output_dir, f"true_cube_labels_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(match_map, affine), 
             os.path.join(output_dir, f"match_map_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(mismatch_map, affine), 
             os.path.join(output_dir, f"mismatch_map_case_{idx}.nii.gz"))
    
    # Análisis espacial (centro de masa de Infiltrado)
    infiltrado_pred = (cube_labels_pred == 1).astype(np.uint8)  # Clase 1: Infiltrado
    infiltrado_true = (cube_labels_true == 1).astype(np.uint8)
    pred_center = np.mean(np.where(infiltrado_pred), axis=1) if np.sum(infiltrado_pred) > 0 else np.array([np.nan] * 3)
    true_center = np.mean(np.where(infiltrado_true), axis=1) if np.sum(infiltrado_true) > 0 else np.array([np.nan] * 3)
    distance = np.linalg.norm(pred_center - true_center) if not np.any(np.isnan(pred_center)) and not np.any(np.isnan(true_center)) else np.nan
    
    # Visualización 3D
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    infiltrado_pred_pos = np.where(infiltrado_pred)
    infiltrado_true_pos = np.where(infiltrado_true)
    ax.scatter(infiltrado_pred_pos[0], infiltrado_pred_pos[1], infiltrado_pred_pos[2], 
               c='red', label='Infiltrado Predicho', alpha=0.5)
    ax.scatter(infiltrado_true_pos[0], infiltrado_true_pos[1], infiltrado_true_pos[2], 
               c='blue', label='Infiltrado Verdadero', alpha=0.5)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title(f'Infiltrado - Caso {idx} (Grilla {grid_size}x{grid_size}x{grid_size})')
    ax.legend()
    plt.savefig(os.path.join(output_dir, f"infiltrado_scatter_case_{idx}.png"))
    plt.close()
    
    # Imprimir resultados
    print(f"Caso {idx} - Voxel-wise:")
    print(f"  Dice: {dice}, Sensitivity: {sensitivity}, Precision: {precision}, AUC-ROC: {auc}, "
          f"Accuracy: {accuracy:.4f}, F1 Score: {f1}")
    print(f"Caso {idx} - Cube-wise (Grilla {grid_size}x{grid_size}x{grid_size}):")
    print(f"  Dice: {dice_cube}, Sensitivity: {sensitivity_cube}, Precision: {precision_cube}, "
          f"AUC-ROC: {auc_cube}, Accuracy: {accuracy_cube:.4f}, F1 Score: {f1_cube}")
    print(f"  Centro de masa Infiltrado (Pred): {pred_center}")
    print(f"  Centro de masa Infiltrado (True): {true_center}")
    print(f"  Distancia entre centros: {distance:.2f}")
    
    # Guardar mapas voxel-wise
    nib.save(nib.Nifti1Image(prob_maps_np_nifti, np.eye(4)), 
             os.path.join(output_dir, f"probability_maps_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(labels_np, np.eye(4)), 
             os.path.join(output_dir, f"labels_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(segmentation_np, np.eye(4)), 
             os.path.join(output_dir, f"segmentation_case_{idx}.nii.gz"))

# Calcular promedios y desviaciones estándar (voxel-wise)
class_names = ["Fondo", "Infiltrado", "Vasogénico"]
print("\nResultados Voxel-wise:")
for cls in range(3):
    dice_mean = np.nanmean(all_dice[cls])
    dice_std = np.nanstd(all_dice[cls])
    sens_mean = np.nanmean(all_sensitivity[cls])
    sens_std = np.nanstd(all_sensitivity[cls])
    prec_mean = np.nanmean(all_precision[cls])
    prec_std = np.nanstd(all_precision[cls])
    auc_mean = np.nanmean(all_auc[cls])
    auc_std = np.nanstd(all_auc[cls])
    f1_mean = np.nanmean(all_f1[cls])
    f1_std = np.nanstd(all_f1[cls])
    
    print(f"\nClase {cls} ({class_names[cls]}):")
    print(f"  Dice: {dice_mean:.4f} ± {dice_std:.4f}")
    print(f"  Sensibilidad: {sens_mean:.4f} ± {sens_std:.4f}")
    print(f"  Precisión: {prec_mean:.4f} ± {prec_std:.4f}")
    print(f"  AUC-ROC: {auc_mean:.4f} ± {auc_std:.4f}")
    print(f"  F1 Score: {f1_mean:.4f} ± {f1_std:.4f}")

accuracy_mean = np.nanmean(all_accuracy)
accuracy_std = np.nanstd(all_accuracy)
print(f"\nAccuracy Global: {accuracy_mean:.4f} ± {accuracy_std:.4f}")

# Calcular promedios y desviaciones estándar (cube-wise, clases 0, 1, 2)
print(f"\nResultados Cube-wise (Grilla {grid_size}x{grid_size}x{grid_size}):")
for cls in range(3):
    dice_mean = np.nanmean(all_dice_cube[cls])
    dice_std = np.nanstd(all_dice_cube[cls])
    sens_mean = np.nanmean(all_sensitivity_cube[cls])
    sens_std = np.nanstd(all_sensitivity_cube[cls])
    prec_mean = np.nanmean(all_precision_cube[cls])
    prec_std = np.nanstd(all_precision_cube[cls])
    auc_mean = np.nanmean(all_auc_cube[cls])
    auc_std = np.nanstd(all_auc_cube[cls])
    f1_mean = np.nanmean(all_f1_cube[cls])
    f1_std = np.nanstd(all_f1_cube[cls])
    
    print(f"\nClase {cls} ({class_names[cls]}):")
    print(f"  Dice: {dice_mean:.4f} ± {dice_std:.4f}")
    print(f"  Sensibilidad: {sens_mean:.4f} ± {sens_std:.4f}")
    print(f"  Precisión: {prec_mean:.4f} ± {prec_std:.4f}")
    print(f"  AUC-ROC: {auc_mean:.4f} ± {auc_std:.4f}")
    print(f"  F1 Score: {f1_mean:.4f} ± {f1_std:.4f}")

accuracy_mean_cube = np.nanmean(all_accuracy_cube)
accuracy_std_cube = np.nanstd(all_accuracy_cube)
print(f"\nAccuracy Global Cube-wise: {accuracy_mean_cube:.4f} ± {accuracy_std_cube:.4f}")

Caso 0 - Voxel-wise:
  Dice: [0.9968297778503837, 0.5601522842568509, 0.6596068471084519], Sensitivity: [0.9937781582351525, 0.5815087081236923, 0.7322381451536237], Precision: [0.9999001965623705, 0.540308957804972, 0.6000840239735691], AUC-ROC: [0.9997442857524647, 0.9908721550503753, 0.9926824584823732], Accuracy: 0.9810, F1 Score: [0.9968297778506316, 0.5601522842639594, 0.6596068471153744]
Caso 0 - Cube-wise (Grilla 5x5x5):
  Dice: [0.9350649289930849, 0.5217391190926278, 0.5999999880000003], Sensitivity: [0.8780487697798931, 0.5454545206611581, 0.7142856802721105], Precision: [0.9999999861111113, 0.4999999791666675, 0.5172413614744358], AUC-ROC: [0.9937606352807713, 0.8239187996469549, 0.8413461538461539], Accuracy: 0.7920, F1 Score: [0.9350649350649352, 0.5217391304347826, 0.6000000000000001]
  Centro de masa Infiltrado (Pred): [1.58333333 2.75       2.125     ]
  Centro de masa Infiltrado (True): [2.40909091 2.         1.54545455]
  Distancia entre centros: 1.26
Caso 1 - Voxel-