In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Loading functions
import os
import time
from monai.data import DataLoader, decollate_batch


import torch
import torch.nn.parallel

from src.get_data import CustomDataset, CustomDatasetSeg
import numpy as np
from scipy import ndimage
from types import SimpleNamespace
import wandb
import logging

#####
import json
import shutil
import tempfile

import matplotlib.pyplot as plt
import nibabel as nib

from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai import transforms
from monai.transforms import (
    AsDiscrete,
    Activations,
    MapTransform,
    Transform,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.utils.enums import MetricReduction
from monai.networks.nets import SwinUNETR
from monai import data

# from monai.data import decollate_batch
from functools import partial
from src.custom_transforms import ConvertToMultiChannelBasedOnN_Froi, ConvertToMultiChannelBasedOnAnotatedInfiltration, masked, ConvertToMultiChannelBasedOnBratsClassesdI

## Transformaciones Swin UNETR

In [3]:
roi = (128, 128, 128) # (220, 220, 155) (128, 128, 64)
source_k="label"
train_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        # ConvertToMultiChannelBasedOnN_Froi(keys="label"),
        # masked(keys="image"),
        ConvertToMultiChannelBasedOnAnotatedInfiltration(keys="label"),
        # ConvertToMultiChannelBasedOnBratsClassesdI(keys="label"),
        transforms.CropForegroundd(
            keys=["image", "label"],
            source_key=source_k,
            k_divisible=[roi[0], roi[1], roi[2]],
        ),
        transforms.RandSpatialCropd(
            keys=["image", "label"],
            roi_size=[roi[0], roi[1], roi[2]],
            random_size=False,
        ),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        
    ]
)
val_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        # ConvertToMultiChannelBasedOnN_Froi(keys="label"),
        # masked(keys="image"),
        ConvertToMultiChannelBasedOnAnotatedInfiltration(keys="label"),
        # ConvertToMultiChannelBasedOnBratsClassesdI(keys="label"),
        transforms.RandSpatialCropd(
            keys=["image", "label"],
            roi_size=[-1, -1, -1], #[224, 224, 128],
            random_size=False,
        ),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)



In [None]:
import wandb
run = wandb.init()
artifact = run.use_artifact('mlops-team89/Swin_UPENN_10cases/fhosddxt_best_model:v0', type='model')
artifact_dir = artifact.download()

In [4]:
######################
# Crear el modelo
######################

### Hyperparameter
roi = (128, 128, 128)  # (128, 128, 128)

# Create Swin transformer
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = SwinUNETR(
    img_size=roi,
    in_channels=11,
    out_channels=2,  # mdificar con edema
    feature_size=48, #48
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    use_checkpoint=True,
)

# # Load the best model Infiltrado-vasogenico
# model_path = "artifacts/fhosddxt_best_model:v0/model.pt"

# Load the best model TC-Edema
model_path = "artifacts/1dhzmigz_best_model:v0/model.pt"

# Load the model on CPU
loaded_model = torch.load(model_path, map_location=torch.device('cuda:0'), weights_only=False)["state_dict"]

# Load the state dictionary into the model
model.load_state_dict(loaded_model)

model.to(device)

# Set the model to evaluation mode
model.eval()



SwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(11, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=48, out_features=144, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=48, out_features=48, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=48, out_features=192, bias=True)
              (linear2): Linear(in_featur

In [5]:
# Create dataset data loader
# dataset_path='./Dataset/Dataset_recurrence'
dataset_path='./Dataset/Dataset_30_6'
train_set=CustomDataset(dataset_path, section="test_6", transform=train_transform) # v_transform
train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)



Found 6 images and 6 labels.


In [6]:

embedding_dir = "Dataset/contrastive_voxel_wise/temp/embeddings"
label_output_dir = "Dataset/contrastive_voxel_wise/temp/labels"

# Crear carpetas si no existen
os.makedirs(embedding_dir, exist_ok=True)
os.makedirs(label_output_dir, exist_ok=True)

# Variable para las características del decoder
decoder_features = None

# Función hook
def decoder_hook_fn(module, input, output):
    global decoder_features
    decoder_features = output

# Registrar el hook en decoder1.conv_block
hook_handle_decoder = model.decoder1.conv_block.register_forward_hook(decoder_hook_fn)

# Extraer y guardar
with torch.no_grad():
    for idx, batch_data in enumerate(train_loader):
        image, label = batch_data["image"], batch_data["label"]
        print("Image", image.shape)  # [1, 11, 128, 128, 128]
        print("label before squeeze", label.shape)  # [1, 2, 128, 128, 128]
        
        image = image.to(device)
        label = label.squeeze(0)  # [2, 128, 128, 128]
        
        # Convertir one-hot a etiquetas únicas
        label_sum = label.sum(dim=0)  # [128, 128, 128], suma de canales
        label_class = torch.zeros_like(label_sum, dtype=torch.long)  # [128, 128, 128]
        
        # Asignar clases:
        # - Fondo (0, 0) -> 0
        # - Vasogénico (1, 0) -> 1
        # - Infiltrado (0, 1) -> 2
        label_class[label[1] == 1] = 2  # Infiltrado
        label_class[(label[0] == 1) & (label[1] == 0)] = 1  # Vasogénico
        # Donde label_sum == 0, ya es fondo (0)
        
        label = label_class.cpu().numpy()  # [128, 128, 128]
        print("label", label.shape)
        
        # Obtener embeddings
        _ = model(image)  # Ejecuta el forward para activar el hook
        
        print("decoder_features:", decoder_features.shape)  # [1, 48, 128, 128, 128]
        
        # Guardar embeddings y etiquetas
        np.save(f"{embedding_dir}/case_{idx}.npy", decoder_features.cpu().numpy())
        np.save(f"{label_output_dir}/case_{idx}.npy", label)
        
        print(f"Guardado embeddings y etiquetas para caso {idx}")

# Remover el hook
hook_handle_decoder.remove()

Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 0
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 1
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 2
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 3
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torc

## Entrenar modelo contrastivo

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)
        labels = np.load(label_path)
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)
        labels = torch.tensor(labels, dtype=torch.long)
        
        return embeddings, labels

class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

def contrastive_loss(z, labels, temperature=0.5, sample_size_per_class=1024):
    N_total = z.shape[0]
    z = F.normalize(z, dim=1)
    
    classes = torch.unique(labels)
    if len(classes) < 2:
        print(f"Advertencia: Solo una clase presente ({classes.tolist()}), devolviendo pérdida 0")
        return torch.tensor(0.0, device=device, requires_grad=True)
    
    sampled_z = []
    sampled_labels = []
    
    for cls in classes:
        cls_indices = (labels == cls).nonzero(as_tuple=True)[0]
        cls_size = cls_indices.shape[0]
        if cls_size > sample_size_per_class:
            indices = torch.randperm(cls_size)[:sample_size_per_class]
            cls_indices = cls_indices[indices]
        sampled_z.append(z[cls_indices])
        sampled_labels.append(labels[cls_indices])
    
    z = torch.cat(sampled_z, dim=0)
    labels = torch.cat(sampled_labels, dim=0)
    N = z.shape[0]
    
    # print(f"Batch size: {N}, Unique labels: {torch.unique(labels).tolist()}")
    
    if N < 2:
        print("Advertencia: Batch con menos de 2 vóxeles, devolviendo pérdida 0")
        return torch.tensor(0.0, device=device, requires_grad=True)
    
    similarity = torch.mm(z, z.T) / temperature
    labels_eq = labels.unsqueeze(1) == labels.unsqueeze(0)
    labels_eq = labels_eq.float()
    eye = torch.eye(N, device=device)
    labels_eq = labels_eq * (1 - eye)
    
    exp_sim = torch.exp(similarity)
    pos_sum = (exp_sim * labels_eq).sum(dim=1)
    neg_sum = exp_sim.sum(dim=1) - exp_sim.diag()
    
    if pos_sum.sum() == 0:
        print("Advertencia: No hay pares positivos, pérdida será 0")
    
    loss = -torch.log((pos_sum + 1e-6) / (neg_sum + 1e-6))
    return loss.mean()

# Configuración
embedding_dir = "Dataset/contrastive_voxel_wise/embeddings"
label_dir = "Dataset/contrastive_voxel_wise/labels"
batch_size = 1
sample_size_per_class = 4096
temperature = 0.5
num_epochs = 100
patience = 10  # Early stopping

dataset = EmbeddingDataset(embedding_dir, label_dir)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

model = ProjectionHead(input_dim=48).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

In [2]:
# Directorio para checkpoints
output_dir = "trained_models/checkpoints_contrastive"
os.makedirs(output_dir, exist_ok=True)

# Variables para early stopping y checkpoints
best_loss = float('inf')
epochs_no_improve = 0
best_model_path = os.path.join(output_dir, "best_contrastive_projection_head.pth")

# Entrenamiento con scheduler, checkpoints y early stopping
for epoch in range(num_epochs):
    total_loss = 0
    valid_batches = 0
    
    model.train()  # Modo entrenamiento
    
    for batch_idx, (embeddings, labels) in enumerate(loader):
        embeddings = embeddings.to(device)  # [1, 48, 128, 128, 128]
        labels = labels.to(device)  # [1, 128, 128, 128]
        
        embeddings = embeddings.squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        labels = labels.squeeze(0)  # [128, 128, 128]
        
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        labels_flat = labels.reshape(-1)  # [2097152]
        
        valid_mask = labels_flat >= 0
        embeddings_valid = embeddings_flat[valid_mask]
        labels_valid = labels_flat[valid_mask]
        
        if embeddings_valid.shape[0] < 2:
            print(f"Batch {batch_idx}: Insuficientes vóxeles válidos")
            continue
        
        # Forward
        z = model(embeddings_valid)
        loss = contrastive_loss(z, labels_valid, temperature, sample_size_per_class)
        
        if loss.item() == 0:
            continue  # No contar batches con pérdida 0
        
        # Optimización
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        valid_batches += 1
        
        if batch_idx % 5 == 0:  # Imprimir cada 5 batches
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}")
    
    # Calcular pérdida promedio
    avg_loss = total_loss / max(valid_batches, 1)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Valid Batches: {valid_batches}/{len(loader)}, "
          f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Scheduler: ajustar tasa de aprendizaje
    scheduler.step(avg_loss)
    
    # Checkpoint: guardar el mejor modelo
    if avg_loss < best_loss:
        best_loss = avg_loss
        epochs_no_improve = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, best_model_path)
        print(f"Guardado checkpoint con mejor pérdida: {best_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"Épocas sin mejora: {epochs_no_improve}/{patience}")
    
    # Early stopping
    if epochs_no_improve >= patience:
        print(f"Early stopping activado tras {epoch+1} épocas. Mejor pérdida: {best_loss:.4f}")
        break

# Cargar el mejor modelo al final (opcional)
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Cargado el mejor modelo desde {best_model_path} con pérdida: {checkpoint['loss']:.4f}")

# Guardar el modelo final (opcional)
torch.save(model.state_dict(), os.path.join(output_dir, "contrastive_projection_head_final.pth"))
print("Modelo final guardado en 'trained_models/checkpoints_contrastive/contrastive_projection_head_final.pth'")

Epoch 1/100, Batch 0/30, Loss: 0.6172
Epoch 1/100, Batch 5/30, Loss: 0.5817
Epoch 1/100, Batch 10/30, Loss: 0.5196
Epoch 1/100, Batch 15/30, Loss: 0.5190
Epoch 1/100, Batch 20/30, Loss: 0.5328
Epoch 1/100, Batch 25/30, Loss: 0.5632
Epoch 1/100, Average Loss: 0.4744, Valid Batches: 30/30, Learning Rate: 0.001000
Guardado checkpoint con mejor pérdida: 0.4744
Epoch 2/100, Batch 0/30, Loss: 0.4523
Epoch 2/100, Batch 5/30, Loss: 0.3414
Epoch 2/100, Batch 10/30, Loss: 0.5353
Epoch 2/100, Batch 15/30, Loss: 0.5268
Epoch 2/100, Batch 20/30, Loss: 0.5174
Epoch 2/100, Batch 25/30, Loss: 0.5479
Epoch 2/100, Average Loss: 0.4327, Valid Batches: 30/30, Learning Rate: 0.001000
Guardado checkpoint con mejor pérdida: 0.4327
Epoch 3/100, Batch 0/30, Loss: 0.5190
Epoch 3/100, Batch 5/30, Loss: 0.4881
Epoch 3/100, Batch 10/30, Loss: 0.5118
Epoch 3/100, Batch 15/30, Loss: 0.3847
Epoch 3/100, Batch 20/30, Loss: 0.5156
Epoch 3/100, Batch 25/30, Loss: 0.4373
Epoch 3/100, Average Loss: 0.4244, Valid Batches: 

  checkpoint = torch.load(best_model_path, map_location=device)


## Modelo Contrastivo mas robusto

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Dispositivo
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Dataset
class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)
        labels = np.load(label_path)
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)
        labels = torch.tensor(labels, dtype=torch.long)
        
        return embeddings, labels

# Modelo de proyección (MLP más profundo)
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim1=256, hidden_dim2=128, output_dim=128, dropout_p=0.3):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim2, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

# Función para muestreo personalizado
def custom_sample(embeddings_flat, labels_flat, background_factor=2.0, 
                  infiltrado_factor=3.0, min_background=10000, 
                  min_infiltrado=15000, max_voxels_per_class=10000, 
                  max_voxels_per_class_min=15000):
    
    classes = torch.unique(labels_flat)
    sampled_embeddings = []
    sampled_labels = []
    
    # Contar vóxeles de vasogénico (clase 1)
    vasogenico_voxels = 0
    if 1 in classes:
        cls_indices = (labels_flat == 1).nonzero(as_tuple=True)[0]
        vasogenico_voxels = min(cls_indices.shape[0], max_voxels_per_class_min)
        if vasogenico_voxels < 10:
            print(f"Muy pocos vóxeles de vasogénico ({vasogenico_voxels}), saltando")
            return torch.tensor([]), torch.tensor([])
        indices = torch.randperm(cls_indices.shape[0], device=cls_indices.device)[:vasogenico_voxels]
        sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
        sampled_labels.append(labels_flat[cls_indices[indices]])
        # print(f"Vasogénico (1): {vasogenico_voxels} vóxeles")
    
    # Muestrear infiltrado (clase 2) y fondo (clase 0)
    for cls in classes:
        cls_indices = (labels_flat == cls).nonzero(as_tuple=True)[0]
        cls_size = cls_indices.shape[0]
        
        if cls.item() == 2:  # Infiltrado
            target_voxels = min(int(vasogenico_voxels * infiltrado_factor), cls_size, max_voxels_per_class)
            # print(f"Infiltrado (2): vasogenico_voxels * infiltrado_factor = {int(vasogenico_voxels * infiltrado_factor)}, cls_size = {cls_size}, max_voxels_per_class = {max_voxels_per_class}")
            if target_voxels < min_infiltrado and cls_size >= min_infiltrado:
                target_voxels = min_infiltrado
                # print(f"Infiltrado (2): Aplicando mínimo de {min_infiltrado}")
            elif cls_size < min_infiltrado:
                target_voxels = cls_size
                # print(f"Infiltrado (2): Tomando todos los vóxeles disponibles ({cls_size})")
            # print(f"Infiltrado (2): target_voxels = {target_voxels}")
            indices = torch.randperm(cls_size, device=cls_indices.device)[:target_voxels]
            sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
            sampled_labels.append(labels_flat[cls_indices[indices]])
        elif cls.item() == 0:  # Fondo
            target_voxels = min(int(vasogenico_voxels * background_factor), cls_size, max_voxels_per_class)
            # print(f"Fondo (0): vasogenico_voxels * background_factor = {int(vasogenico_voxels * background_factor)}, cls_size = {cls_size}, max_voxels_per_class = {max_voxels_per_class}")
            if target_voxels < min_background and cls_size >= min_background:
                target_voxels = min_background
                # print(f"Fondo (0): Aplicando mínimo de {min_background}")
            elif cls_size < min_background:
                target_voxels = cls_size
                # print(f"Fondo (0): Tomando todos los vóxeles disponibles ({cls_size})")
            # print(f"Fondo (0): target_voxels = {target_voxels}")
            indices = torch.randperm(cls_size, device=cls_indices.device)[:target_voxels]
            sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
            sampled_labels.append(labels_flat[cls_indices[indices]])
    
    if not sampled_embeddings:
        return torch.tensor([]), torch.tensor([])
    
    return torch.cat(sampled_embeddings), torch.cat(sampled_labels)

# Pérdida contrastiva (adaptada para custom_sample)
def contrastive_loss(z, labels, temperature=0.5):
    N = z.shape[0]
    z = F.normalize(z, dim=1)
    
    classes = torch.unique(labels)
    if len(classes) < 2:
        print(f"Advertencia: Solo una clase presente ({classes.tolist()}), devolviendo pérdida 0")
        return torch.tensor(0.0, device=device, requires_grad=True)
    
    similarity = torch.mm(z, z.T) / temperature
    labels_eq = labels.unsqueeze(1) == labels.unsqueeze(0)
    labels_eq = labels_eq.float()
    eye = torch.eye(N, device=device)
    labels_eq = labels_eq * (1 - eye)
    
    exp_sim = torch.exp(similarity)
    pos_sum = (exp_sim * labels_eq).sum(dim=1)
    neg_sum = exp_sim.sum(dim=1) - exp_sim.diag()
    
    if pos_sum.sum() == 0:
        print("Advertencia: No hay pares positivos, pérdida será 0")
    
    loss = -torch.log((pos_sum + 1e-6) / (neg_sum + 1e-6))
    return loss.mean()

# Configuración
embedding_dir = "Dataset/contrastive_voxel_wise/train_30_1dhzmigz/embeddings"
label_dir = "Dataset/contrastive_voxel_wise/train_30_1dhzmigz/labels"
batch_size = 1
temperature = 0.5
num_epochs = 100
patience = 10


# Cargar dataset y DataLoader
dataset = EmbeddingDataset(embedding_dir, label_dir)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

# Definir modelo
model = ProjectionHead(input_dim=48, hidden_dim1=256, hidden_dim2=128, output_dim=128, dropout_p=0.3).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)



In [3]:
# Directorio para checkpoints
output_dir = "trained_models/checkpoints_contrastive"
os.makedirs(output_dir, exist_ok=True)

# Variables para early stopping y checkpoints
best_loss = float('inf')
epochs_no_improve = 0
best_model_path = os.path.join(output_dir, "best_contrastive_projection_head_new.pth")

# Entrenamiento
for epoch in range(num_epochs):
    total_loss = 0
    valid_batches = 0
    
    model.train()
    
    for batch_idx, (embeddings, labels) in enumerate(loader):
        embeddings = embeddings.to(device)  # [1, 48, 128, 128, 128]
        labels = labels.to(device)  # [1, 128, 128, 128]
        
        embeddings = embeddings.squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        labels = labels.squeeze(0)  # [128, 128, 128]
        
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        labels_flat = labels.reshape(-1)  # [2097152]
        
        valid_mask = labels_flat >= 0
        embeddings_valid = embeddings_flat[valid_mask]
        labels_valid = labels_flat[valid_mask]
        
        if embeddings_valid.shape[0] < 2:
            print(f"Batch {batch_idx}: Insuficientes vóxeles válidos")
            continue
        
        # Muestreo personalizado
        embeddings_sampled, labels_sampled = custom_sample(
            embeddings_valid, labels_valid,
            background_factor=1.0,
            infiltrado_factor=1.0,
            min_background=4096,
            min_infiltrado=4096,
            max_voxels_per_class=10000,
            max_voxels_per_class_min=20000
        )
        
        if embeddings_sampled.numel() == 0:
            print(f"Batch {batch_idx}: No se encontraron vóxeles válidos, saltando")
            continue
        
        # Sub-batching
        if embeddings_sampled.shape[0] > 20000:            
            batch_size_clf = embeddings_sampled.shape[0] // 2 # El sub-batch es la mitad del batch original
        else:
            batch_size_clf = embeddings_sampled.shape[0]

        batch_loss = 0
        batch_valid = 0
        for i in range(0, embeddings_sampled.shape[0], batch_size_clf):
            z_batch = embeddings_sampled[i:i+batch_size_clf]
            labels_batch = labels_sampled[i:i+batch_size_clf]
            
            z = model(z_batch)
            loss = contrastive_loss(z, labels_batch, temperature)
            
            if loss.item() == 0:
                continue
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            batch_loss += loss.item()
            batch_valid += 1
        
        if batch_valid > 0:
            total_loss += batch_loss
            valid_batches += batch_valid
        
        # Imprimir estadísticas
        class_counts = np.bincount(labels_sampled.cpu().numpy(), minlength=3)
        if batch_idx % 5 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(loader)}, Loss: {batch_loss/max(batch_valid, 1):.4f}, "
                  f"Sampled size: {embeddings_sampled.shape[0]}, Classes: {torch.unique(labels_sampled).tolist()}, "
                  f"Counts: {class_counts}")
    
    # Calcular pérdida promedio
    avg_loss = total_loss / max(valid_batches, 1)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Valid Batches: {valid_batches}, "
          f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Scheduler
    scheduler.step(avg_loss)
    
    # Checkpoint
    if avg_loss < best_loss:
        best_loss = avg_loss
        epochs_no_improve = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, best_model_path)
        print(f"Guardado checkpoint con mejor pérdida: {best_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"Épocas sin mejora: {epochs_no_improve}/{patience}")
    
    # Early stopping
    if epochs_no_improve >= patience:
        print(f"Early stopping activado tras {epoch+1} épocas. Mejor pérdida: {best_loss:.4f}")
        break

# Cargar el mejor modelo
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Cargado el mejor modelo desde {best_model_path} con pérdida: {checkpoint['loss']:.4f}")

# Guardar el modelo final
torch.save(model.state_dict(), os.path.join(output_dir, "contrastive_projection_head_final_new.pth"))
print("Modelo final guardado en 'trained_models/checkpoints_contrastive/contrastive_projection_head_final_new.pth'")

Epoch 1/100, Batch 0/30, Loss: 0.7793, Sampled size: 9651, Classes: [0, 1, 2], Counts: [4096 1459 4096]
Advertencia: Solo una clase presente ([1]), devolviendo pérdida 0
Advertencia: Solo una clase presente ([2]), devolviendo pérdida 0
Epoch 1/100, Batch 5/30, Loss: 0.2556, Sampled size: 8704, Classes: [0, 1, 2], Counts: [4096  512 4096]
Muy pocos vóxeles de vasogénico (2), saltando
Batch 8: No se encontraron vóxeles válidos, saltando
Advertencia: Solo una clase presente ([1]), devolviendo pérdida 0
Epoch 1/100, Batch 10/30, Loss: 0.5192, Sampled size: 18993, Classes: [0, 1, 2], Counts: [6331 6331 6331]
Advertencia: Solo una clase presente ([1]), devolviendo pérdida 0
Advertencia: Solo una clase presente ([2]), devolviendo pérdida 0
Epoch 1/100, Batch 15/30, Loss: 0.0709, Sampled size: 22803, Classes: [0, 1, 2], Counts: [7601 7601 7601]
Epoch 1/100, Batch 20/30, Loss: 0.0836, Sampled size: 8250, Classes: [0, 1, 2], Counts: [4096   58 4096]
Advertencia: Solo una clase presente ([1]), de

In [None]:
for idx, (embeddings, labels) in enumerate(loader):
    print(f"Unique labels: {torch.unique(labels).tolist()}")
    labels_flat = labels.reshape(-1)
    class_counts = torch.bincount(labels_flat)
    print(f"Case {idx}: Fondo: {class_counts[0]}, Vasogénico: {class_counts[1] if len(class_counts) > 1 else 0}, Infiltrado: {class_counts[2] if len(class_counts) > 2 else 0}")

## Entrenar modelo de clasificacion supervisado

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Dispositivo
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Dataset (ya lo tienes)
class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)  # [1, 48, 128, 128, 128]
        labels = np.load(label_path)  # [128, 128, 128]
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)  # [48, 128, 128, 128]
        labels = torch.tensor(labels, dtype=torch.long)  # [128, 128, 128]
        
        return embeddings, labels

# Modelo de proyección  simple (ya lo tienes)
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

# Clasificador supervisado
class Classifier(nn.Module):
    def __init__(self, input_dim=128, num_classes=3):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.fc(x)

# Configuración
embedding_dir = "Dataset/contrastive_voxel_wise/embeddings"
label_dir = "Dataset/contrastive_voxel_wise/labels"
batch_size = 1
sample_size_per_class = 4096 # ~10,000 vóxeles total
num_epochs = 100  # Máximo de épocas
patience = 10  # Early stopping: épocas sin mejora

# Cargar dataset y DataLoader
dataset = EmbeddingDataset(embedding_dir, label_dir)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

# Cargar modelo contrastivo preentrenado
projection_head = ProjectionHead(input_dim=48).to(device)
projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_fhosddxt.pth", map_location=device))
projection_head.eval()  # Modo evaluación, sin gradientes

# Definir clasificador
classifier = Classifier(input_dim=128, num_classes=3).to(device)
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
criterion = nn.CrossEntropyLoss()



## Muestreo equilibrado

In [None]:
# Directorio para checkpoints
output_dir = "trained_models/checkpoints"
os.makedirs(output_dir, exist_ok=True)

# Variables para early stopping y checkpoints
best_loss = float('inf')
epochs_no_improve = 0
best_model_path = os.path.join(output_dir, "best_supervised_classifier.pth")

# Entrenamiento del clasificador con muestreo balanceado, scheduler y early stopping
for epoch in range(num_epochs):
    total_loss = 0
    valid_batches = 0
    
    classifier.train()  # Modo entrenamiento
    
    for batch_idx, (embeddings, labels) in enumerate(loader):
        embeddings = embeddings.to(device)  # [1, 48, 128, 128, 128]
        labels = labels.to(device)  # [1, 128, 128, 128]
        
        # Reorganizar para procesar vóxeles
        embeddings = embeddings.squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        labels = labels.squeeze(0)  # [128, 128, 128]
        
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        labels_flat = labels.reshape(-1)  # [2097152]
        
        # Muestreo estratificado balanceado
        classes = torch.unique(labels_flat)
        if len(classes) < 2:
            print(f"Batch {batch_idx}: Solo una clase presente ({classes.tolist()}), saltando")
            continue
        
        sampled_embeddings = []
        sampled_labels = []
        
        for cls in classes:
            cls_indices = (labels_flat == cls).nonzero(as_tuple=True)[0]
            cls_size = cls_indices.shape[0]
            if cls_size > sample_size_per_class:
                indices = torch.randperm(cls_size)[:sample_size_per_class]
                cls_indices = cls_indices[indices]
            sampled_embeddings.append(embeddings_flat[cls_indices])
            sampled_labels.append(labels_flat[cls_indices])
        
        embeddings_sampled = torch.cat(sampled_embeddings, dim=0)
        labels_sampled = torch.cat(sampled_labels, dim=0)
        
        # Obtener representaciones contrastivas
        with torch.no_grad():
            z = projection_head(embeddings_sampled)  # [N, 128]
            z = F.normalize(z, dim=1)
        
        # Clasificación
        logits = classifier(z)
        loss = criterion(logits, labels_sampled)
        
        # Optimización
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        valid_batches += 1
        
        if batch_idx % 5 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}, "
                  f"Sampled size: {embeddings_sampled.shape[0]}, Classes: {torch.unique(labels_sampled).tolist()}")
    
    # Calcular pérdida promedio
    avg_loss = total_loss / max(valid_batches, 1)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Valid Batches: {valid_batches}/{len(loader)}, "
          f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Scheduler: ajustar tasa de aprendizaje basada en la pérdida promedio
    scheduler.step(avg_loss)
    
    # Checkpoint: guardar el mejor modelo
    if avg_loss < best_loss:
        best_loss = avg_loss
        epochs_no_improve = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': classifier.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, best_model_path)
        print(f"Guardado checkpoint con mejor pérdida: {best_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"Épocas sin mejora: {epochs_no_improve}/{patience}")
    
    # Early stopping
    if epochs_no_improve >= patience:
        print(f"Early stopping activado tras {epoch+1} épocas. Mejor pérdida: {best_loss:.4f}")
        break

# Cargar el mejor modelo al final (opcional)
checkpoint = torch.load(best_model_path, map_location=device)
classifier.load_state_dict(checkpoint['model_state_dict'])
print(f"Cargado el mejor modelo desde {best_model_path} con pérdida: {checkpoint['loss']:.4f}")

# Guardar el modelo final (opcional)
torch.save(classifier.state_dict(), os.path.join(output_dir, "supervised_classifier_final.pth"))
print("Clasificador final guardado en 'trained_models/checkpoints/supervised_classifier_final.pth'")

## Muestreo custom 

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# Función para muestreo personalizado
def custom_sample(embeddings_flat, labels_flat, background_factor=2.0, infiltrado_factor=3.0, min_background=10000, min_infiltrado=15000, max_voxels_per_class=50000):
    classes = torch.unique(labels_flat)
    sampled_embeddings = []
    sampled_labels = []
    
    # Contar vóxeles de vasogénico (clase 1)
    vasogenico_voxels = 0
    if 1 in classes:
        cls_indices = (labels_flat == 1).nonzero(as_tuple=True)[0]
        vasogenico_voxels = min(cls_indices.shape[0], max_voxels_per_class)
        indices = torch.randperm(cls_indices.shape[0], device=cls_indices.device)[:vasogenico_voxels]
        sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
        sampled_labels.append(labels_flat[cls_indices[indices]])
        # print(f"Clase (1): {vasogenico_voxels} vóxeles")
    
    # Muestrear infiltrado (clase 2) y fondo (clase 0)
    for cls in classes:
        cls_indices = (labels_flat == cls).nonzero(as_tuple=True)[0]
        cls_size = cls_indices.shape[0]
        
        if cls.item() == 2:  # Infiltrado
            target_voxels = min(int(vasogenico_voxels * infiltrado_factor), cls_size, max_voxels_per_class)
            #print(f"Infiltrado (2): vasogenico_voxels * infiltrado_factor = {int(vasogenico_voxels * infiltrado_factor)}, cls_size = {cls_size}, max_voxels_per_class = {max_voxels_per_class}")
            if target_voxels < min_infiltrado and cls_size >= min_infiltrado:
                target_voxels = min_infiltrado  # Mínimo de 15,000 si es menor
                #print(f"Infiltrado (2): Aplicando mínimo de {min_infiltrado}")
            elif cls_size < min_infiltrado:
                target_voxels = cls_size  # Tomar todos si hay menos de 15,000
                #print(f"Infiltrado (2): Tomando todos los vóxeles disponibles ({cls_size})")
            # print(f"Clase (2): target_voxels = {target_voxels}")
            indices = torch.randperm(cls_size, device=cls_indices.device)[:target_voxels]
            sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
            sampled_labels.append(labels_flat[cls_indices[indices]])
        elif cls.item() == 0:  # Fondo
            target_voxels = min(int(vasogenico_voxels * background_factor), cls_size, max_voxels_per_class)
            #print(f"Fondo (0): vasogenico_voxels * background_factor = {int(vasogenico_voxels * background_factor)}, cls_size = {cls_size}, max_voxels_per_class = {max_voxels_per_class}")
            if target_voxels < min_background and cls_size >= min_background:
                target_voxels = min_background  # Mínimo de 10,000 si es menor
                #print(f"Fondo (0): Aplicando mínimo de {min_background}")
            elif cls_size < min_background:
                target_voxels = cls_size  # Tomar todos si hay menos de 10,000
                #print(f"Fondo (0): Tomando todos los vóxeles disponibles ({cls_size})")
            # print(f"Clase 0 (0): target_voxels = {target_voxels}")
            indices = torch.randperm(cls_size, device=cls_indices.device)[:target_voxels]
            sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
            sampled_labels.append(labels_flat[cls_indices[indices]])
    
    if not sampled_embeddings:
        return torch.tensor([]), torch.tensor([])
    
    return torch.cat(sampled_embeddings), torch.cat(sampled_labels)

# Directorio para checkpoints
output_dir = "trained_models/checkpoints"
os.makedirs(output_dir, exist_ok=True)

# Variables para early stopping y checkpoints
best_loss = float('inf')
epochs_no_improve = 0
best_model_path = os.path.join(output_dir, "best_supervised_classifier.pth")

# Entrenamiento del clasificador con muestreo personalizado, scheduler y early stopping
for epoch in range(num_epochs):
    total_loss = 0
    valid_batches = 0
    
    classifier.train()  # Modo entrenamiento
    
    for batch_idx, (embeddings, labels) in enumerate(loader):
        embeddings = embeddings.to(device)  # [1, 48, 128, 128, 128]
        labels = labels.to(device)  # [1, 128, 128, 128]
        
        # Reorganizar para procesar vóxeles
        embeddings = embeddings.squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        labels = labels.squeeze(0)  # [128, 128, 128]
        
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        labels_flat = labels.reshape(-1)  # [2097152]
        
        # Muestreo personalizado
        embeddings_sampled, labels_sampled = custom_sample(
            embeddings_flat, labels_flat,
            background_factor=2.0,
            infiltrado_factor=3.0,
            min_background=10000,
            min_infiltrado=15000,
            max_voxels_per_class=50000
        )
        
        if embeddings_sampled.numel() == 0:
            print(f"Batch {batch_idx}: No se encontraron vóxeles válidos, saltando")
            continue
        
        # Obtener representaciones contrastivas
        with torch.no_grad():
            z = projection_head(embeddings_sampled)  # [N, 128]
            z = F.normalize(z, dim=1)
        
        # Clasificación
        logits = classifier(z)
        loss = criterion(logits, labels_sampled)
        
        # Optimización
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        valid_batches += 1
        
        # Imprimir estadísticas del muestreo
        class_counts = np.bincount(labels_sampled.cpu().numpy(), minlength=3)
        if batch_idx % 5 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}, "
                  f"Sampled size: {embeddings_sampled.shape[0]}, Classes: {torch.unique(labels_sampled).tolist()}, "
                  f"Counts: {class_counts}")
    
    # Calcular pérdida promedio
    avg_loss = total_loss / max(valid_batches, 1)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Valid Batches: {valid_batches}/{len(loader)}, "
          f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Scheduler: ajustar tasa de aprendizaje basada en la pérdida promedio
    scheduler.step(avg_loss)
    
    # Checkpoint: guardar el mejor modelo
    if avg_loss < best_loss:
        best_loss = avg_loss
        epochs_no_improve = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': classifier.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, best_model_path)
        print(f"Guardado checkpoint con mejor pérdida: {best_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"Épocas sin mejora: {epochs_no_improve}/{patience}")
    
    # Early stopping
    if epochs_no_improve >= patience:
        print(f"Early stopping activado tras {epoch+1} épocas. Mejor pérdida: {best_loss:.4f}")
        break

# Cargar el mejor modelo al final (opcional)
checkpoint = torch.load(best_model_path, map_location=device)
classifier.load_state_dict(checkpoint['model_state_dict'])
print(f"Cargado el mejor modelo desde {best_model_path} con pérdida: {checkpoint['loss']:.4f}")

# Guardar el modelo final (opcional)
torch.save(classifier.state_dict(), os.path.join(output_dir, "supervised_classifier_final.pth"))
print("Clasificador final guardado en 'trained_models/checkpoints/supervised_classifier_final.pth'")

Vasogénico (1): 1032 vóxeles
Fondo (0): vasogenico_voxels * background_factor = 2064, cls_size = 2004144, max_voxels_per_class = 50000
Fondo (0): Aplicando mínimo de 10000
Fondo (0): target_voxels = 10000
Infiltrado (2): vasogenico_voxels * infiltrado_factor = 3096, cls_size = 91976, max_voxels_per_class = 50000
Infiltrado (2): Aplicando mínimo de 15000
Infiltrado (2): target_voxels = 15000
Epoch 1/100, Batch 0/30, Loss: 1.1219, Sampled size: 26032, Classes: [0, 1, 2], Counts: [10000  1032 15000]
Vasogénico (1): 17756 vóxeles
Fondo (0): vasogenico_voxels * background_factor = 35512, cls_size = 2068807, max_voxels_per_class = 50000
Fondo (0): target_voxels = 35512
Infiltrado (2): vasogenico_voxels * infiltrado_factor = 53268, cls_size = 10589, max_voxels_per_class = 50000
Infiltrado (2): Tomando todos los vóxeles disponibles (10589)
Infiltrado (2): target_voxels = 10589
Vasogénico (1): 3571 vóxeles
Fondo (0): vasogenico_voxels * background_factor = 7142, cls_size = 2081185, max_voxels_p

## Clasificador MLP con capas ocultas

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Dispositivo
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Dataset
class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)  # [1, 48, 128, 128, 128]
        labels = np.load(label_path)  # [128, 128, 128]
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)  # [48, 128, 128, 128]
        labels = torch.tensor(labels, dtype=torch.long)  # [128, 128, 128]
        
        return embeddings, labels

# # Modelo de proyección simple
# class ProjectionHead(nn.Module):
#     def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
#         super(ProjectionHead, self).__init__()
#         self.net = nn.Sequential(
#             nn.Linear(input_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, output_dim)
#         )
    
#     def forward(self, x):
#         return self.net(x)
    
# Modelo de proyección (MLP más profundo)
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim1=256, hidden_dim2=128, output_dim=128, dropout_p=0.3):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim2, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

# Clasificador supervisado (MLP)
class Classifier(nn.Module):
    def __init__(self, input_dim=128, hidden_dim1=256, hidden_dim2=128, num_classes=3, dropout_p=0.3):
        super(Classifier, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim2, num_classes)
        )
    
    def forward(self, x):
        return self.net(x)

# Función para muestreo personalizado
def custom_sample(embeddings_flat, labels_flat, background_factor=2.0, 
                  infiltrado_factor=3.0, min_background=10000, 
                  min_infiltrado=15000, max_voxels_per_class=10000, max_voxels_per_class_min=15000,):
    classes = torch.unique(labels_flat)
    sampled_embeddings = []
    sampled_labels = []
    
    # Contar vóxeles de vasogénico (clase 1)
    vasogenico_voxels = 0
    if 1 in classes:
        cls_indices = (labels_flat == 1).nonzero(as_tuple=True)[0]
        vasogenico_voxels = min(cls_indices.shape[0], max_voxels_per_class_min)
        if vasogenico_voxels < 10:
            print(f"Muy pocos vóxeles de vasogénico ({vasogenico_voxels}), saltando")
            return torch.tensor([]), torch.tensor([])
        indices = torch.randperm(cls_indices.shape[0], device=cls_indices.device)[:vasogenico_voxels]
        sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
        sampled_labels.append(labels_flat[cls_indices[indices]])
        # print(f"Vasogénico (1): {vasogenico_voxels} vóxeles")
    
    # Muestrear infiltrado (clase 2) y fondo (clase 0)
    for cls in classes:
        cls_indices = (labels_flat == cls).nonzero(as_tuple=True)[0]
        cls_size = cls_indices.shape[0]
        
        if cls.item() == 2:  # Infiltrado
            target_voxels = min(int(vasogenico_voxels * infiltrado_factor), cls_size, max_voxels_per_class)
            # print(f"Infiltrado (2): vasogenico_voxels * infiltrado_factor = {int(vasogenico_voxels * infiltrado_factor)}, cls_size = {cls_size}, max_voxels_per_class = {max_voxels_per_class}")
            if target_voxels < min_infiltrado and cls_size >= min_infiltrado:
                target_voxels = min_infiltrado
                # print(f"Infiltrado (2): Aplicando mínimo de {min_infiltrado}")
            elif cls_size < min_infiltrado:
                target_voxels = cls_size
                # print(f"Infiltrado (2): Tomando todos los vóxeles disponibles ({cls_size})")
            # print(f"Infiltrado (2): target_voxels = {target_voxels}")
            indices = torch.randperm(cls_size, device=cls_indices.device)[:target_voxels]
            sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
            sampled_labels.append(labels_flat[cls_indices[indices]])
        elif cls.item() == 0:  # Fondo
            target_voxels = min(int(vasogenico_voxels * background_factor), cls_size, max_voxels_per_class)
            # print(f"Fondo (0): vasogenico_voxels * background_factor = {int(vasogenico_voxels * background_factor)}, cls_size = {cls_size}, max_voxels_per_class = {max_voxels_per_class}")
            if target_voxels < min_background and cls_size >= min_background:
                target_voxels = min_background
                # print(f"Fondo (0): Aplicando mínimo de {min_background}")
            elif cls_size < min_background:
                target_voxels = cls_size
                # print(f"Fondo (0): Tomando todos los vóxeles disponibles ({cls_size})")
            # print(f"Fondo (0): target_voxels = {target_voxels}")
            indices = torch.randperm(cls_size, device=cls_indices.device)[:target_voxels]
            sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
            sampled_labels.append(labels_flat[cls_indices[indices]])
    
    if not sampled_embeddings:
        return torch.tensor([]), torch.tensor([])
    
    return torch.cat(sampled_embeddings), torch.cat(sampled_labels)

# Configuración
embedding_dir = "Dataset/contrastive_voxel_wise/train_30_1dhzmigz/embeddings"
label_dir = "Dataset/contrastive_voxel_wise/train_30_1dhzmigz/labels"
batch_size = 1
num_epochs = 100
patience = 10
max_voxels_per_class = 10000  # Usado en custom_sample

# Cargar dataset y DataLoader
dataset = EmbeddingDataset(embedding_dir, label_dir)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

# Cargar modelo contrastivo preentrenado
projection_head = ProjectionHead(input_dim=48).to(device)
# projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_fhosddxt.pth", map_location=device))
projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_new.pth", map_location=device))
projection_head.eval()  # Modo evaluación, sin gradientes

# Definir clasificador
classifier = Classifier(input_dim=128, hidden_dim1=256, hidden_dim2=128, num_classes=3, dropout_p=0.3).to(device)
optimizer = optim.Adam(classifier.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
criterion = nn.CrossEntropyLoss()

# Directorio para checkpoints
output_dir = "trained_models/checkpoints"
os.makedirs(output_dir, exist_ok=True)

# Variables para early stopping y checkpoints
best_loss = float('inf')
epochs_no_improve = 0
best_model_path = os.path.join(output_dir, "best_supervised_classifier.pth")

# Entrenamiento del clasificador con muestreo personalizado, scheduler y early stopping
for epoch in range(num_epochs):
    total_loss = 0
    valid_batches = 0
    
    classifier.train()
    
    for batch_idx, (embeddings, labels) in enumerate(loader):
        embeddings = embeddings.to(device)  # [1, 48, 128, 128, 128]
        labels = labels.to(device)  # [1, 128, 128, 128]
        
        embeddings = embeddings.squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        labels = labels.squeeze(0)  # [128, 128, 128]
        
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        labels_flat = labels.reshape(-1)  # [2097152]
        
        # Muestreo personalizado
        embeddings_sampled, labels_sampled = custom_sample(
            embeddings_flat, labels_flat,
            background_factor=1.0,
            infiltrado_factor=1.0,
            min_background=4096,
            min_infiltrado=4096,
            max_voxels_per_class=10000,
            max_voxels_per_class_min=20000,
        )
        
        if embeddings_sampled.numel() == 0:
            print(f"Batch {batch_idx}: No se encontraron vóxeles válidos, saltando")
            continue
        
        # Sub-batching para reducir el uso de memoria
        # batch_size_clf = 8192
        # Sub-batching
        if embeddings_sampled.shape[0] > 20000:            
            batch_size_clf = embeddings_sampled.shape[0] // 2 # El sub-batch es la mitad del batch original
        else:
            batch_size_clf = embeddings_sampled.shape[0]

        for i in range(0, embeddings_sampled.shape[0], batch_size_clf):
            z_batch = embeddings_sampled[i:i+batch_size_clf]
            labels_batch = labels_sampled[i:i+batch_size_clf]
            
            with torch.no_grad():
                z_batch = projection_head(z_batch)  # [N, 128]
                z_batch = F.normalize(z_batch, dim=1)
            
            logits = classifier(z_batch)
            loss = criterion(logits, labels_batch)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            valid_batches += 1
        
        # Imprimir estadísticas
        class_counts = np.bincount(labels_sampled.cpu().numpy(), minlength=3)
        if batch_idx % 5 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}, "
                  f"Sampled size: {embeddings_sampled.shape[0]}, Classes: {torch.unique(labels_sampled).tolist()}, "
                  f"Counts: {class_counts}")
    
    # Calcular pérdida promedio
    avg_loss = total_loss / max(valid_batches, 1)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Valid Batches: {valid_batches}/{len(loader)}, "
          f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Scheduler
    scheduler.step(avg_loss)
    
    # Checkpoint
    if avg_loss < best_loss:
        best_loss = avg_loss
        epochs_no_improve = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': classifier.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, best_model_path)
        print(f"Guardado checkpoint con mejor pérdida: {best_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"Épocas sin mejora: {epochs_no_improve}/{patience}")
    
    # Early stopping
    if epochs_no_improve >= patience:
        print(f"Early stopping activado tras {epoch+1} épocas. Mejor pérdida: {best_loss:.4f}")
        break

# Cargar el mejor modelo
checkpoint = torch.load(best_model_path, map_location=device)
classifier.load_state_dict(checkpoint['model_state_dict'])
print(f"Cargado el mejor modelo desde {best_model_path} con pérdida: {checkpoint['loss']:.4f}")

# Guardar el modelo final
torch.save(classifier.state_dict(), os.path.join(output_dir, "supervised_classifier_final.pth"))
print("Clasificador final guardado en 'trained_models/checkpoints/supervised_classifier_final.pth'")

Epoch 1/100, Batch 0/30, Loss: 1.1016, Sampled size: 17091, Classes: [0, 1, 2], Counts: [5697 5697 5697]
Epoch 1/100, Batch 5/30, Loss: 0.9882, Sampled size: 11763, Classes: [0, 1, 2], Counts: [4096 3571 4096]
Muy pocos vóxeles de vasogénico (2), saltando
Batch 8: No se encontraron vóxeles válidos, saltando
Epoch 1/100, Batch 10/30, Loss: 0.9229, Sampled size: 17739, Classes: [0, 1, 2], Counts: [5913 5913 5913]
Epoch 1/100, Batch 15/30, Loss: 0.6428, Sampled size: 37756, Classes: [0, 1, 2], Counts: [10000 17756 10000]
Epoch 1/100, Batch 20/30, Loss: 0.5950, Sampled size: 11396, Classes: [0, 1, 2], Counts: [4096 3204 4096]
Epoch 1/100, Batch 25/30, Loss: 0.4642, Sampled size: 9559, Classes: [0, 1, 2], Counts: [4096 1367 4096]
Epoch 1/100, Average Loss: 0.7916, Valid Batches: 40/30, Learning Rate: 0.001000
Guardado checkpoint con mejor pérdida: 0.7916
Epoch 2/100, Batch 0/30, Loss: 0.4968, Sampled size: 17091, Classes: [0, 1, 2], Counts: [5697 5697 5697]
Epoch 2/100, Batch 5/30, Loss: 0.

## Random Forest, Adaboost, xgboost

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
import xgboost as xgb
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import train_test_split
import wandb
import joblib

# Función para muestreo personalizado
def custom_sample(embeddings_flat, labels_flat, background_factor=2.0, 
                  infiltrado_factor=3.0, min_background=10000, 
                  min_infiltrado=15000, max_voxels_per_class=10000, max_voxels_per_class_min=15000,):
    classes = torch.unique(labels_flat)
    sampled_embeddings = []
    sampled_labels = []
    
    # Contar vóxeles de vasogénico (clase 1)
    vasogenico_voxels = 0
    if 1 in classes:
        cls_indices = (labels_flat == 1).nonzero(as_tuple=True)[0]
        vasogenico_voxels = min(cls_indices.shape[0], max_voxels_per_class_min)
        if vasogenico_voxels < 10:
            print(f"Muy pocos vóxeles de vasogénico ({vasogenico_voxels}), saltando")
            return torch.tensor([]), torch.tensor([])
        indices = torch.randperm(cls_indices.shape[0], device=cls_indices.device)[:vasogenico_voxels]
        sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
        sampled_labels.append(labels_flat[cls_indices[indices]])
        # print(f"Vasogénico (1): {vasogenico_voxels} vóxeles")
    
    # Muestrear infiltrado (clase 2) y fondo (clase 0)
    for cls in classes:
        cls_indices = (labels_flat == cls).nonzero(as_tuple=True)[0]
        cls_size = cls_indices.shape[0]
        
        if cls.item() == 2:  # Infiltrado
            target_voxels = min(int(vasogenico_voxels * infiltrado_factor), cls_size, max_voxels_per_class)
            # print(f"Infiltrado (2): vasogenico_voxels * infiltrado_factor = {int(vasogenico_voxels * infiltrado_factor)}, cls_size = {cls_size}, max_voxels_per_class = {max_voxels_per_class}")
            if target_voxels < min_infiltrado and cls_size >= min_infiltrado:
                target_voxels = min_infiltrado
                # print(f"Infiltrado (2): Aplicando mínimo de {min_infiltrado}")
            elif cls_size < min_infiltrado:
                target_voxels = cls_size
                # print(f"Infiltrado (2): Tomando todos los vóxeles disponibles ({cls_size})")
            # print(f"Infiltrado (2): target_voxels = {target_voxels}")
            indices = torch.randperm(cls_size, device=cls_indices.device)[:target_voxels]
            sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
            sampled_labels.append(labels_flat[cls_indices[indices]])
        elif cls.item() == 0:  # Fondo
            target_voxels = min(int(vasogenico_voxels * background_factor), cls_size, max_voxels_per_class)
            # print(f"Fondo (0): vasogenico_voxels * background_factor = {int(vasogenico_voxels * background_factor)}, cls_size = {cls_size}, max_voxels_per_class = {max_voxels_per_class}")
            if target_voxels < min_background and cls_size >= min_background:
                target_voxels = min_background
                # print(f"Fondo (0): Aplicando mínimo de {min_background}")
            elif cls_size < min_background:
                target_voxels = cls_size
                # print(f"Fondo (0): Tomando todos los vóxeles disponibles ({cls_size})")
            # print(f"Fondo (0): target_voxels = {target_voxels}")
            indices = torch.randperm(cls_size, device=cls_indices.device)[:target_voxels]
            sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
            sampled_labels.append(labels_flat[cls_indices[indices]])
    
    if not sampled_embeddings:
        return torch.tensor([]), torch.tensor([])
    
    return torch.cat(sampled_embeddings), torch.cat(sampled_labels)


In [2]:

# Dispositivo
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Dataset
class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)  # [1, 48, 128, 128, 128]
        labels = np.load(label_path)  # [128, 128, 128]
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)  # [48, 128, 128, 128]
        labels = torch.tensor(labels, dtype=torch.long)  # [128, 128, 128]
        
        return embeddings, labels

# # Modelo de proyección Simple
# class ProjectionHead(nn.Module):
#     def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
#         super(ProjectionHead, self).__init__()
#         self.net = nn.Sequential(
#             nn.Linear(input_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, output_dim)
#         )
    
#     def forward(self, x):
#         return self.net(x)

# Modelo de proyección (MLP más profundo)
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim1=256, hidden_dim2=128, output_dim=128, dropout_p=0.3):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim2, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)


# Configuración
embedding_dir = "Dataset/contrastive_voxel_wise/embeddings"
label_dir = "Dataset/contrastive_voxel_wise/labels"
batch_size = 1
random_state = 42
background_factor = 1.0  # Máximo de fondo: 2 * vasogénico
infiltrado_factor = 1.0  # Máximo de infiltrado: 3 * vasogénico
min_background = 10000 # Mínimo de fondo
min_infiltrado = 15000# Mínimo de infiltrado
max_voxels_per_class = 20000  # Límite máximo por clase
max_voxels_per_class_min=40000

# Cargar dataset y DataLoader
dataset = EmbeddingDataset(embedding_dir, label_dir)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

# Cargar modelo contrastivo preentrenado
projection_head = ProjectionHead(input_dim=48).to(device)
# projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_fhosddxt.pth", map_location=device))
projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_new.pth", map_location=device))
projection_head.eval()

# Definir clasificadores
classifiers = {
    "RandomForest": RandomForestClassifier(
        n_estimators=100, max_depth=10, n_jobs=-1, random_state=random_state
    ),
    # "AdaBoost": AdaBoostClassifier(
    #     base_estimator=DecisionTreeClassifier(max_depth=3),
    #     n_estimators=50, random_state=random_state
    # ),
    "XGBoost": xgb.XGBClassifier(
        n_estimators=100, max_depth=6, learning_rate=0.1, n_jobs=-1, random_state=random_state
    )
}

# Directorio para resultados
output_dir = "trained_models/classifier_experiments"
os.makedirs(output_dir, exist_ok=True)

# # Función para muestreo personalizado
# def custom_sample(embeddings_flat, labels_flat, background_factor=2.0, infiltrado_factor=3.0, min_background=10000, min_infiltrado=15000, max_voxels_per_class=50000):
#     classes = torch.unique(labels_flat)
#     sampled_embeddings = []
#     sampled_labels = []
    
#     # Contar vóxeles de vasogénico (clase 1)
#     vasogenico_voxels = 0
#     if 1 in classes:
#         cls_indices = (labels_flat == 1).nonzero(as_tuple=True)[0]
#         vasogenico_voxels = min(cls_indices.shape[0], max_voxels_per_class)
#         indices = torch.randperm(cls_indices.shape[0], device=cls_indices.device)[:vasogenico_voxels]
#         sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
#         sampled_labels.append(labels_flat[cls_indices[indices]])
    
#     # Muestrear infiltrado (clase 2) y fondo (clase 0)
#     for cls in classes:
#         cls_indices = (labels_flat == cls).nonzero(as_tuple=True)[0]
#         cls_size = cls_indices.shape[0]
        
#         if cls.item() == 2:  # Infiltrado
#             target_voxels = min(int(vasogenico_voxels * infiltrado_factor), cls_size, max_voxels_per_class)
#             if target_voxels < min_infiltrado and cls_size >= min_infiltrado:
#                 target_voxels = min_infiltrado  # Mínimo de 15,000 si es menor
#             elif cls_size < min_infiltrado:
#                 target_voxels = cls_size  # Tomar todos si hay menos de 15,000
#             indices = torch.randperm(cls_size, device=cls_indices.device)[:target_voxels]
#             sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
#             sampled_labels.append(labels_flat[cls_indices[indices]])
#         elif cls.item() == 0:  # Fondo
#             target_voxels = min(int(vasogenico_voxels * background_factor), cls_size, max_voxels_per_class)
#             if target_voxels < min_background and cls_size >= min_background:
#                 target_voxels = min_background  # Mínimo de 10,000 si es menor
#             elif cls_size < min_background:
#                 target_voxels = cls_size  # Tomar todos si hay menos de 10,000
#             indices = torch.randperm(cls_size, device=cls_indices.device)[:target_voxels]
#             sampled_embeddings.append(embeddings_flat[cls_indices[indices]])
#             sampled_labels.append(labels_flat[cls_indices[indices]])
    
#     if not sampled_embeddings:
#         return torch.tensor([]), torch.tensor([])
    
#     return torch.cat(sampled_embeddings), torch.cat(sampled_labels)

# Configurar WandB
os.environ["WANDB_NOTEBOOK_NAME"] = "use_contrastive.ipynb"

# Entrenamiento y evaluación
for clf_name, clf in classifiers.items():
    print(f"\nEntrenando {clf_name}...")
    
    # # Iniciar experimento en WandB
    # wandb.init(project="voxel_classifier", name=f"{clf_name}_05", config={
    #     "classifier": clf_name,
    #     "background_factor": background_factor,
    #     "min_background": min_background,
    #     "random_state": random_state,
    #     **clf.get_params()
    # })
    
    # Listas para datos de entrenamiento y validación
    all_embeddings = []
    all_labels = []
    
    # Recolectar datos balanceados
    for batch_idx, (embeddings, labels) in enumerate(loader):
        embeddings = embeddings.to(device)  # [1, 48, 128, 128, 128]
        labels = labels.to(device)  # [1, 128, 128, 128]
        
        embeddings = embeddings.squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        labels = labels.squeeze(0)  # [128, 128, 128]
        
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        labels_flat = labels.reshape(-1)  # [2097152]
        
        # Muestreo personalizado
        embeddings_sampled, labels_sampled = custom_sample(embeddings_flat, 
                                                           labels_flat, 
                                                           background_factor, 
                                                           infiltrado_factor, 
                                                           min_background, min_infiltrado, 
                                                           max_voxels_per_class,
                                                           max_voxels_per_class_min)
        if embeddings_sampled.numel() == 0:
            print(f"Batch {batch_idx}: No se encontraron vóxeles válidos, saltando")
            continue
        
        # Obtener representaciones contrastivas
        with torch.no_grad():
            z = projection_head(embeddings_sampled)  # [N, 128]
            z = F.normalize(z, dim=1)
        
        all_embeddings.append(z.cpu().numpy())
        all_labels.append(labels_sampled.cpu().numpy())
        
        # Imprimir estadísticas del muestreo
        class_counts = np.bincount(labels_sampled.cpu().numpy(), minlength=3)
        print(f"Batch {batch_idx}/{len(loader)}, Sampled size: {len(labels_sampled)}, "
              f"Classes: {np.unique(labels_sampled.cpu())}, Counts: {class_counts}")
    
    # Concatenar todos los datos
    X = np.concatenate(all_embeddings, axis=0)  # [N_total, 128]
    y = np.concatenate(all_labels, axis=0)  # [N_total]
    
    # Dividir en entrenamiento y validación
    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=random_state
    )
    
    # Entrenar clasificador
    clf.fit(X_train, y_train)
    
    # Evaluar en entrenamiento
    y_train_pred = clf.predict(X_train)
    train_accuracy = accuracy_score(y_train, y_train_pred)
    train_f1 = f1_score(y_train, y_train_pred, average='weighted')
    
    # Evaluar en validación
    y_val_pred = clf.predict(X_val)
    val_accuracy = accuracy_score(y_val, y_val_pred)
    val_f1 = f1_score(y_val, y_val_pred, average='weighted')
    
    # Reporte detallado
    val_report = classification_report(y_val, y_val_pred, target_names=["Fondo", "Vasogénico", "Infiltrado"])
    
    # Registrar métricas en WandB
    # wandb.log({
    #     "train_accuracy": train_accuracy,
    #     "train_f1": train_f1,
    #     "val_accuracy": val_accuracy,
    #     "val_f1": val_f1
    # })
    
    # Registrar reporte detallado
    # wandb.log({"classification_report": wandb.Html(val_report.replace('\n', '<br>'))})
    
    # Guardar modelo
    model_path = os.path.join(output_dir, f"{clf_name}_model.pkl")
    joblib.dump(clf, model_path)
    print(f"Modelo {clf_name} guardado en {model_path}")
    
    # Finalizar experimento
    # wandb.finish()
    
    print(f"\nResultados para {clf_name}:")
    print(f"Train Accuracy: {train_accuracy:.4f}, Train F1: {train_f1:.4f}")
    print(f"Val Accuracy: {val_accuracy:.4f}, Val F1: {val_f1:.4f}")
    print("Classification Report:\n", val_report)

print("Entrenamiento y evaluación completados.")


Entrenando RandomForest...
Batch 0/30, Sampled size: 26367, Classes: [0 1 2], Counts: [10000  1367 15000]
Batch 1/30, Sampled size: 27325, Classes: [0 1 2], Counts: [10000  2325 15000]
Batch 2/30, Sampled size: 28204, Classes: [0 1 2], Counts: [10000  3204 15000]
Batch 3/30, Sampled size: 31331, Classes: [0 1 2], Counts: [10000  6331 15000]
Batch 4/30, Sampled size: 30913, Classes: [0 1 2], Counts: [10000  5913 15000]
Batch 5/30, Sampled size: 26032, Classes: [0 1 2], Counts: [10000  1032 15000]
Batch 6/30, Sampled size: 25454, Classes: [0 1 2], Counts: [10000   454 15000]
Batch 7/30, Sampled size: 73491, Classes: [0 1 2], Counts: [20000 33491 20000]
Batch 8/30, Sampled size: 39888, Classes: [0 1 2], Counts: [12444 12444 15000]
Batch 9/30, Sampled size: 33917, Classes: [0 1 2], Counts: [10000  8917 15000]
Batch 10/30, Sampled size: 26930, Classes: [0 1 2], Counts: [10000  1930 15000]
Batch 11/30, Sampled size: 25339, Classes: [0 1 2], Counts: [10000   339 15000]
Batch 12/30, Sampled s

## Hacer inferencia guardar mapas y calcular metricas

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import nibabel as nib

# Dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset
class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)  # [1, 48, 128, 128, 128]
        labels = np.load(label_path)  # [128, 128, 128]
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)  # [48, 128, 128, 128]
        labels = torch.tensor(labels, dtype=torch.long)  # [128, 128, 128]
        
        return embeddings, labels

# # Modelo de proyección Simple
# class ProjectionHead(nn.Module):
#     def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
#         super(ProjectionHead, self).__init__()
#         self.net = nn.Sequential(
#             nn.Linear(input_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, output_dim)
#         )
    
#     def forward(self, x):
#         return self.net(x)   

# Modelo de proyección (MLP más profundo)
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim1=256, hidden_dim2=128, output_dim=128, dropout_p=0.3):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim2, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)
    

# # Clasificador supervisado Lineal
# class Classifier(nn.Module):
#     def __init__(self, input_dim=128, num_classes=3):
#         super(Classifier, self).__init__()
#         self.fc = nn.Linear(input_dim, num_classes)
    
#     def forward(self, x):
#         return self.fc(x)
    
# Clasificador supervisado (MLP)
class Classifier(nn.Module):
    def __init__(self, input_dim=128, hidden_dim1=256, hidden_dim2=128, num_classes=3, dropout_p=0.3):
        super(Classifier, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            # nn.Dropout(dropout_p),
            nn.Linear(hidden_dim2, num_classes)
        )
    
    def forward(self, x):
        return self.net(x)
       
# Función para generar mapas de probabilidad
def generate_probability_maps(embeddings, projection_head, classifier, device):
    """
    embeddings: tensor [1, 48, 128, 128, 128] - Características de SwinUNETR
    Retorna: mapas de probabilidad [3, 128, 128, 128]
    """
    with torch.no_grad():
        embeddings = embeddings.to(device).squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        
        z = projection_head(embeddings_flat)  # [2097152, 128]
        z = F.normalize(z, dim=1)
        
        logits = classifier(z)  # [2097152, 3]
        probs = F.softmax(logits, dim=1)  # [2097152, 3]
        
        probs = probs.view(128, 128, 128, 3).permute(3, 0, 1, 2)  # [3, 128, 128, 128]
        return probs

dataset = EmbeddingDataset(embedding_dir="Dataset/contrastive_voxel_wise/valid_1dhzmigz/embeddings", 
                          label_dir="Dataset/contrastive_voxel_wise/valid_1dhzmigz/labels")
loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

# Cargar modelos (asumiendo que ya los tienes cargados)
projection_head = ProjectionHead(input_dim=48).to(device)
projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_new_pipe2_m1_1dhzmigz.pth", map_location=device))
projection_head.eval()

classifier = Classifier(input_dim=128, num_classes=3).to(device)
classifier.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final_pipe2_m1_1dhzmigz.pth", map_location=device))
classifier.eval()

# Directorio de salida
output_dir = "trained_models/mapas_valid_pipe2_m1_1dhzmigz"
os.makedirs(output_dir, exist_ok=True)

# Función para generar mapas de probabilidad
def generate_probability_maps(embeddings, projection_head, classifier, device):
    with torch.no_grad():
        embeddings = embeddings.to(device).squeeze(0).permute(1, 2, 3, 0)
        embeddings_flat = embeddings.reshape(-1, 48)
        
        z = projection_head(embeddings_flat)
        z = F.normalize(z, dim=1)
        
        logits = classifier(z)
        probs = F.softmax(logits, dim=1)
        
        probs = probs.view(128, 128, 128, 3).permute(3, 0, 1, 2)
        return probs

# Funciones para calcular métricas
def calculate_metrics(pred, true, num_classes=3):
    dice_scores = []
    sensitivity_scores = []
    precision_scores = []
    
    for cls in range(num_classes):
        pred_cls = (pred == cls).astype(np.uint8)
        true_cls = (true == cls).astype(np.uint8)
        
        # True Positives (TP), False Positives (FP), False Negatives (FN)
        tp = np.sum(pred_cls * true_cls)
        fp = np.sum(pred_cls * (1 - true_cls))
        fn = np.sum((1 - pred_cls) * true_cls)
        
        # Dice
        dice = 2 * tp / (2 * tp + fp + fn + 1e-6)  # Evitar división por 0
        dice_scores.append(dice)
        
        # Sensibilidad (Recall)
        sensitivity = tp / (tp + fn + 1e-6)
        sensitivity_scores.append(sensitivity)
        
        # Precisión
        precision = tp / (tp + fp + 1e-6)
        precision_scores.append(precision)
    
    return dice_scores, sensitivity_scores, precision_scores
# Directorio de salida
# output_dir = "trained_models/mapas"
# os.makedirs(output_dir, exist_ok=True)

# Listas para almacenar métricas por caso
all_dice = {0: [], 1: [], 2: []}  # Fondo, Vasogénico, Infiltrado
all_sensitivity = {0: [], 1: [], 2: []}
all_precision = {0: [], 1: [], 2: []}

# Procesar y guardar como NIfTI
for idx, (embeddings, labels) in enumerate(loader):
    # Generar mapas de probabilidad
    prob_maps = generate_probability_maps(embeddings, projection_head, classifier, device)
    print(f"Mapas de probabilidad para caso {idx}, shape: {prob_maps.shape}")
    
    # Convertir mapas de probabilidad a numpy
    prob_maps_np = prob_maps.cpu().numpy()  # [3, 128, 128, 128]
    prob_maps_np_nifti = np.transpose(prob_maps_np, (1, 2, 3, 0))  # [128, 128, 128, 3]
    
    # Generar segmentación semántica
    segmentation = np.argmax(prob_maps_np, axis=0)  # [128, 128, 128]
    segmentation_np = segmentation.astype(np.uint8)
    
    # Convertir etiquetas a numpy
    labels = labels.squeeze(0)  # [128, 128, 128]
    labels_np = labels.cpu().numpy().astype(np.uint8)

    # hacer cero segmentation_np en donde labels_np es cero
    # segmentation_np[labels_np == 0] = 0
    
    # Calcular métricas
    dice, sensitivity, precision = calculate_metrics(segmentation_np, labels_np)
    for cls in range(3):
        all_dice[cls].append(dice[cls])
        all_sensitivity[cls].append(sensitivity[cls])
        all_precision[cls].append(precision[cls])
    
    print(f"Caso {idx} - Dice: {dice}, Sensitivity: {sensitivity}, Precision: {precision}")
    
    # Crear imágenes NIfTI
    affine = np.eye(4)
    
    # Guardar mapas de probabilidad
    nifti_prob_img = nib.Nifti1Image(prob_maps_np_nifti, affine)
    prob_output_path = os.path.join(output_dir, f"probability_maps_case_{idx}.nii.gz")
    nib.save(nifti_prob_img, prob_output_path)
    print(f"Guardado mapa de probabilidad en {prob_output_path}")
    
    # Guardar etiquetas
    nifti_label_img = nib.Nifti1Image(labels_np, affine)
    label_output_path = os.path.join(output_dir, f"labels_case_{idx}.nii.gz")
    nib.save(nifti_label_img, label_output_path)
    print(f"Guardadas etiquetas en {label_output_path}")
    
    # Guardar segmentación semántica
    nifti_seg_img = nib.Nifti1Image(segmentation_np, affine)
    seg_output_path = os.path.join(output_dir, f"segmentation_case_{idx}.nii.gz")
    nib.save(nifti_seg_img, seg_output_path)
    print(f"Guardada segmentación en {seg_output_path}")

# Calcular promedios y desviaciones estándar
class_names = ["Fondo", "Infiltrado", "Vasogénico"]
for cls in range(3):
    dice_mean = np.mean(all_dice[cls])
    dice_std = np.std(all_dice[cls])
    sens_mean = np.mean(all_sensitivity[cls])
    sens_std = np.std(all_sensitivity[cls])
    prec_mean = np.mean(all_precision[cls])
    prec_std = np.std(all_precision[cls])
    
    print(f"\nClase {cls} ({class_names[cls]}):")
    print(f"  Dice: {dice_mean:.4f} ± {dice_std:.4f}")
    print(f"  Sensibilidad: {sens_mean:.4f} ± {sens_std:.4f}")
    print(f"  Precisión: {prec_mean:.4f} ± {prec_std:.4f}")

Mapas de probabilidad para caso 0, shape: torch.Size([3, 128, 128, 128])
Caso 0 - Dice: [0.9975996073793137, 0.40547157974856896, 0.693651942576132], Sensitivity: [0.9952967624783551, 0.33052916335673893, 0.8597227076579349], Precision: [0.9999131333011678, 0.5243627989467544, 0.5813532081088277]
Guardado mapa de probabilidad en trained_models/mapas_valid_pipe2_m1_1dhzmigz/probability_maps_case_0.nii.gz
Guardadas etiquetas en trained_models/mapas_valid_pipe2_m1_1dhzmigz/labels_case_0.nii.gz
Guardada segmentación en trained_models/mapas_valid_pipe2_m1_1dhzmigz/segmentation_case_0.nii.gz
Mapas de probabilidad para caso 1, shape: torch.Size([3, 128, 128, 128])
Caso 1 - Dice: [0.9935457241878816, 0.3192793234269347, 0.652329860884344], Sensitivity: [0.9872634080825368, 0.7268415176543411, 0.8881858331540604], Precision: [0.9999085056472076, 0.20457044132208493, 0.5154524842144259]
Guardado mapa de probabilidad en trained_models/mapas_valid_pipe2_m1_1dhzmigz/probability_maps_case_1.nii.gz
G

## Evaluar Voxel-wise y Region-wise

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
from sklearn.metrics import roc_curve, roc_auc_score, accuracy_score, f1_score
from scipy import stats
import matplotlib.pyplot as plt

# Dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset
class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)  # [1, 48, 128, 128, 128]
        labels = np.load(label_path)  # [128, 128, 128]
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)  # [48, 128, 128, 128]
        labels = torch.tensor(labels, dtype=torch.long)  # [128, 128, 128]
        
        return embeddings, labels

# Modelo de proyección
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim1=256, hidden_dim2=128, output_dim=128, dropout_p=0.3):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            nn.Linear(hidden_dim2, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

# Clasificador supervisado
class Classifier(nn.Module):
    def __init__(self, input_dim=128, hidden_dim1=256, hidden_dim2=128, num_classes=3, dropout_p=0.3):
        super(Classifier, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            nn.Linear(hidden_dim2, num_classes)
        )
    
    def forward(self, x):
        return self.net(x)

# Función para generar mapas de probabilidad
def generate_probability_maps(embeddings, projection_head, classifier, device):
    with torch.no_grad():
        embeddings = embeddings.to(device).squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        
        z = projection_head(embeddings_flat)  # [2097152, 128]
        z = F.normalize(z, dim=1)
        
        logits = classifier(z)  # [2097152, 3]
        probs = F.softmax(logits, dim=1)  # [2097152, 3]
        
        probs = probs.view(128, 128, 128, 3).permute(3, 0, 1, 2)  # [3, 128, 128, 128]
        return probs

# Función para calcular métricas
def calculate_metrics(pred, true, prob_maps=None, num_classes=3):
    dice_scores = []
    sensitivity_scores = []
    precision_scores = []
    auc_scores = []
    f1_scores = []
    
    accuracy = accuracy_score(true.flatten(), pred.flatten())
    
    for cls in range(num_classes):
        pred_cls = (pred == cls).astype(np.uint8)
        true_cls = (true == cls).astype(np.uint8)
        
        # True Positives (TP), False Positives (FP), False Negatives (FN)
        tp = np.sum(pred_cls * true_cls)
        fp = np.sum(pred_cls * (1 - true_cls))
        fn = np.sum((1 - pred_cls) * true_cls)
        
        # Dice
        dice = 2 * tp / (2 * tp + fp + fn + 1e-6)
        dice_scores.append(dice)
        
        # Sensibilidad (Recall)
        sensitivity = tp / (tp + fn + 1e-6)
        sensitivity_scores.append(sensitivity)
        
        # Precisión
        precision = tp / (tp + fp + 1e-6)
        precision_scores.append(precision)
        
        # F1 Score
        f1 = f1_score(true_cls.flatten(), pred_cls.flatten(), zero_division=0)
        f1_scores.append(f1)
        
        # AUC-ROC
        if prob_maps is not None:
            try:
                auc = roc_auc_score(true_cls.flatten(), prob_maps[cls].flatten())
                auc_scores.append(auc)
            except ValueError:
                auc_scores.append(np.nan)
        else:
            auc_scores.append(np.nan)
    
    return dice_scores, sensitivity_scores, precision_scores, auc_scores, accuracy, f1_scores

# Función para dividir en cubos y obtener clases predominantes
def get_cube_labels(volume, cube_size, num_classes=3):
    dims = volume.shape
    assert dims[0] % cube_size == 0, "El tamaño del cubo debe dividir exactamente el tamaño del volumen"
    num_cubes = dims[0] // cube_size
    
    cube_labels = np.zeros((num_cubes, num_cubes, num_cubes), dtype=np.uint8)
    cube_probs = np.zeros((num_classes, num_cubes, num_cubes, num_cubes))
    
    for i in range(num_cubes):
        for j in range(num_cubes):
            for k in range(num_cubes):
                cube = volume[i*cube_size:(i+1)*cube_size, 
                             j*cube_size:(j+1)*cube_size, 
                             k*cube_size:(k+1)*cube_size]
                # Clase predominante (modo)
                mode_value = stats.mode(cube.flatten(), keepdims=True)[0][0]
                cube_labels[i, j, k] = mode_value
                # Proporción de cada clase como "probabilidad" suavizada
                for cls in range(num_classes):
                    cube_probs[cls, i, j, k] = np.mean(cube == cls)
    
    return cube_labels, cube_probs

# Cargar dataset y modelos
dataset = EmbeddingDataset(embedding_dir="Dataset/contrastive_voxel_wise/valid_1dhzmigz/embeddings", 
                          label_dir="Dataset/contrastive_voxel_wise/valid_1dhzmigz/labels")
loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

projection_head = ProjectionHead(input_dim=48).to(device)
projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_new_pipe2_m1_1dhzmigz.pth", map_location=device))
projection_head.eval()

classifier = Classifier(input_dim=128, num_classes=3).to(device)
classifier.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final_pipe2_m1_1dhzmigz.pth", map_location=device))
classifier.eval()

# Directorio de salida
output_dir = "trained_models/mapas_valid_pipe2_m1_1dhzmigz"
os.makedirs(output_dir, exist_ok=True)

# Listas para métricas voxel-wise
all_dice = {0: [], 1: [], 2: []}
all_sensitivity = {0: [], 1: [], 2: []}
all_precision = {0: [], 1: [], 2: []}
all_auc = {0: [], 1: [], 2: []}
all_f1 = {0: [], 1: [], 2: []}
all_accuracy = []
all_fpr = {0: [], 1: [], 2: []}
all_tpr = {0: [], 1: [], 2: []}
all_center_distance_voxel = []

# Listas para métricas cube-wise
all_dice_cube = {0: [], 1: [], 2: []}
all_sensitivity_cube = {0: [], 1: [], 2: []}
all_precision_cube = {0: [], 1: [], 2: []}
all_auc_cube = {0: [], 1: [], 2: []}
all_f1_cube = {0: [], 1: [], 2: []}
all_accuracy_cube = []
all_fpr_cube = {0: [], 1: [], 2: []}
all_tpr_cube = {0: [], 1: [], 2: []}
all_center_distance = []

# Tamaño del cubo
cube_size = 8  # 128 / 8 = 16 cubos por dimensión

# Procesar y combinar
for idx, (embeddings, labels) in enumerate(loader):
    # Generar mapas de probabilidad
    prob_maps = generate_probability_maps(embeddings, projection_head, classifier, device)
    print(f"Mapas de probabilidad para caso {idx}, shape: {prob_maps.shape}")
    
    # Convertir a numpy
    prob_maps_np = prob_maps.cpu().numpy()  # [3, 128, 128, 128]
    prob_maps_np_nifti = np.transpose(prob_maps_np, (1, 2, 3, 0))  # [128, 128, 128, 3]
    
    # Generar segmentación semántica
    segmentation = np.argmax(prob_maps_np, axis=0)
    segmentation_np = segmentation.astype(np.uint8)
    
    # Etiquetas
    labels_np = labels.squeeze(0).cpu().numpy().astype(np.uint8)
    
    # Calcular métricas voxel-wise
    dice, sensitivity, precision, auc, accuracy, f1 = calculate_metrics(segmentation_np, labels_np, prob_maps=prob_maps_np)
    
    # Almacenar métricas voxel-wise
    for cls in range(3):
        all_dice[cls].append(dice[cls])
        all_sensitivity[cls].append(sensitivity[cls])
        all_precision[cls].append(precision[cls])
        all_auc[cls].append(auc[cls])
        all_f1[cls].append(f1[cls])
    all_accuracy.append(accuracy)

    # ***** NUEVO: Análisis espacial voxel-wise para la clase Infiltrado (Clase 1) *****
    infiltrado_pred_voxel = (segmentation_np == 1).astype(np.uint8)
    infiltrado_true_voxel = (labels_np == 1).astype(np.uint8)
    if np.sum(infiltrado_pred_voxel) > 0:
        pred_center_voxel = np.mean(np.where(infiltrado_pred_voxel), axis=1)
    else:
        pred_center_voxel = np.array([np.nan, np.nan, np.nan])
    if np.sum(infiltrado_true_voxel) > 0:
        true_center_voxel = np.mean(np.where(infiltrado_true_voxel), axis=1)
    else:
        true_center_voxel = np.array([np.nan, np.nan, np.nan])
    distance_voxel = np.linalg.norm(pred_center_voxel - true_center_voxel) if not np.any(np.isnan(pred_center_voxel)) and not np.any(np.isnan(true_center_voxel)) else np.nan
    all_center_distance_voxel.append(distance_voxel)
    
    # Graficar curvas ROC voxel-wise
    plt.figure(figsize=(8, 6))
    class_names = ["Fondo", "Infiltrado", "Vasogénico"]
    colors = ['blue', 'green', 'red']
    
    for cls in range(3):
        true_cls = (labels_np == cls).astype(np.uint8).flatten()
        prob_cls = prob_maps_np[cls].flatten()
        fpr, tpr, _ = roc_curve(true_cls, prob_cls)
        auc_value = auc[cls]
        all_fpr[cls].append(fpr)
        all_tpr[cls].append(tpr)
        plt.plot(fpr, tpr, color=colors[cls], label=f'{class_names[cls]} (AUC = {auc_value:.4f})')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Tasa de Falsos Positivos (FPR)')
    plt.ylabel('Tasa de Verdaderos Positivos (TPR)')
    plt.title(f'Curva ROC Voxel-wise - Caso {idx}')
    plt.legend(loc="lower right")
    roc_output_path = os.path.join(output_dir, f"roc_curve_case_{idx}.png")
    plt.savefig(roc_output_path)
    plt.close()
    print(f"Guardada curva ROC voxel-wise en {roc_output_path}")
    
    # Evaluación basada en cubos
    pred_cube_labels, pred_cube_probs = get_cube_labels(segmentation_np, cube_size)
    true_cube_labels, true_cube_probs = get_cube_labels(labels_np, cube_size)
    
    # Calcular métricas cube-wise
    dice_cube, sensitivity_cube, precision_cube, auc_cube, accuracy_cube, f1_cube = calculate_metrics(
        pred_cube_labels, true_cube_labels, prob_maps=pred_cube_probs
    )
    
    # Almacenar métricas cube-wise
    for cls in range(3):
        all_dice_cube[cls].append(dice_cube[cls])
        all_sensitivity_cube[cls].append(sensitivity_cube[cls])
        all_precision_cube[cls].append(precision_cube[cls])
        all_auc_cube[cls].append(auc_cube[cls])
        all_f1_cube[cls].append(f1_cube[cls])
    all_accuracy_cube.append(accuracy_cube)
    
    # Graficar curvas ROC cube-wise
    plt.figure(figsize=(8, 6))
    for cls in range(3):
        true_cls_cube = (true_cube_labels == cls).astype(np.uint8).flatten()
        prob_cls_cube = pred_cube_probs[cls].flatten()
        fpr_cube, tpr_cube, _ = roc_curve(true_cls_cube, prob_cls_cube)
        auc_value_cube = auc_cube[cls]
        all_fpr_cube[cls].append(fpr_cube)
        all_tpr_cube[cls].append(tpr_cube)
        plt.plot(fpr_cube, tpr_cube, color=colors[cls], label=f'{class_names[cls]} (AUC = {auc_value_cube:.4f})')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Tasa de Falsos Positivos (FPR)')
    plt.ylabel('Tasa de Verdaderos Positivos (TPR)')
    plt.title(f'Curva ROC Cube-wise - Caso {idx}')
    plt.legend(loc="lower right")
    roc_cube_output_path = os.path.join(output_dir, f"roc_curve_cube_case_{idx}.png")
    plt.savefig(roc_cube_output_path)
    plt.close()
    print(f"Guardada curva ROC cube-wise en {roc_cube_output_path}")
    
    # Mapa de coincidencias/discrepancias
    match_map = (pred_cube_labels == true_cube_labels).astype(np.uint8)
    mismatch_map = (pred_cube_labels != true_cube_labels).astype(np.uint8)
    
    # Guardar mapas de cubos
    affine_cube = np.eye(4) * cube_size
    nib.save(nib.Nifti1Image(pred_cube_labels, affine_cube), os.path.join(output_dir, f"pred_cube_labels_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(true_cube_labels, affine_cube), os.path.join(output_dir, f"true_cube_labels_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(match_map, affine_cube), os.path.join(output_dir, f"match_map_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(mismatch_map, affine_cube), os.path.join(output_dir, f"mismatch_map_case_{idx}.nii.gz"))
    
    # Análisis espacial (centro de masa de la clase Infiltrado, clase 1)
    infiltrado_pred = (pred_cube_labels == 1).astype(np.uint8)
    infiltrado_true = (true_cube_labels == 1).astype(np.uint8)
    if np.sum(infiltrado_pred) > 0:
        pred_center = np.mean(np.where(infiltrado_pred), axis=1)
    else:
        pred_center = np.array([np.nan, np.nan, np.nan])
    if np.sum(infiltrado_true) > 0:
        true_center = np.mean(np.where(infiltrado_true), axis=1)
    else:
        true_center = np.array([np.nan, np.nan, np.nan])
    distance = np.linalg.norm(pred_center - true_center) if not np.any(np.isnan(pred_center)) and not np.any(np.isnan(true_center)) else np.nan
    all_center_distance.append(distance)
    
    # Imprimir métricas por caso
    print(f"Caso {idx} - Voxel-wise:")
    print(f"  Dice: {dice}, Sensitivity: {sensitivity}, Precision: {precision}, AUC-ROC: {auc}, "
          f"Accuracy: {accuracy:.4f}, F1 Score: {f1}")
    print(f"  Centro de masa Infiltrado (Pred, Voxel-wise): {pred_center_voxel}")
    print(f"  Centro de masa Infiltrado (True, Voxel-wise): {true_center_voxel}")
    print(f"  Distancia entre centros (Voxel-wise): {distance_voxel}")

    print(f"Caso {idx} - Cube-wise (tamaño {cube_size}):")
    print(f"  Dice: {dice_cube}, Sensitivity: {sensitivity_cube}, Precision: {precision_cube}, "
          f"AUC-ROC: {auc_cube}, Accuracy: {accuracy_cube:.4f}, F1 Score: {f1_cube}")
    print(f"  Centro de masa Infiltrado (Pred): {pred_center}")
    print(f"  Centro de masa Infiltrado (True): {true_center}")
    print(f"  Distancia entre centros: {distance}")
    
    # Guardar mapas de probabilidad y segmentaciones voxel-wise
    nib.save(nib.Nifti1Image(prob_maps_np_nifti, np.eye(4)), os.path.join(output_dir, f"probability_maps_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(labels_np, np.eye(4)), os.path.join(output_dir, f"labels_case_{idx}.nii.gz"))
    nib.save(nib.Nifti1Image(segmentation_np, np.eye(4)), os.path.join(output_dir, f"segmentation_case_{idx}.nii.gz"))

# Calcular promedios y desviaciones estándar (voxel-wise)
print("\nResultados Voxel-wise:")
for cls in range(3):
    dice_mean = np.nanmean(all_dice[cls])
    dice_std = np.nanstd(all_dice[cls])
    sens_mean = np.nanmean(all_sensitivity[cls])
    sens_std = np.nanstd(all_sensitivity[cls])
    prec_mean = np.nanmean(all_precision[cls])
    prec_std = np.nanstd(all_precision[cls])
    auc_mean = np.nanmean(all_auc[cls])
    auc_std = np.nanstd(all_auc[cls])
    f1_mean = np.nanmean(all_f1[cls])
    f1_std = np.nanstd(all_f1[cls])
    
    print(f"\nClase {cls} ({class_names[cls]}):")
    print(f"  Dice: {dice_mean:.4f} ± {dice_std:.4f}")
    print(f"  Sensibilidad: {sens_mean:.4f} ± {sens_std:.4f}")
    print(f"  Precisión: {prec_mean:.4f} ± {prec_std:.4f}")
    print(f"  AUC-ROC: {auc_mean:.4f} ± {auc_std:.4f}")
    print(f"  F1 Score: {f1_mean:.4f} ± {f1_std:.4f}")

accuracy_mean = np.nanmean(all_accuracy)
accuracy_std = np.nanstd(all_accuracy)
print(f"\nAccuracy Global Voxel-wise: {accuracy_mean:.4f} ± {accuracy_std:.4f}")

# ***** NUEVO: Calcular promedio y desviación estándar para distancia voxel-wise *****
dist_mean_voxel = np.nanmean(all_center_distance_voxel)
dist_std_voxel = np.nanstd(all_center_distance_voxel)
print(f"\nDistancia entre centros Global (Voxel-wise): {dist_mean_voxel:.4f} ± {dist_std_voxel:.4f}")

# Curva ROC promedio voxel-wise
plt.figure(figsize=(8, 6))
for cls in range(3):
    mean_fpr = np.linspace(0, 1, 100)
    tprs = []
    for fpr, tpr in zip(all_fpr[cls], all_tpr[cls]):
        tpr_interp = np.interp(mean_fpr, fpr, tpr)
        tpr_interp[0] = 0.0
        tprs.append(tpr_interp)
    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = np.nanmean(all_auc[cls])
    
    plt.plot(mean_fpr, mean_tpr, color=colors[cls], label=f'{class_names[cls]} (AUC = {mean_auc:.4f})')

plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Tasa de Falsos Positivos (FPR)')
plt.ylabel('Tasa de Verdaderos Positivos (TPR)')
plt.title('Curva ROC Promedio Voxel-wise')
plt.legend(loc="lower right")
roc_avg_path = os.path.join(output_dir, "roc_curve_average.png")
plt.savefig(roc_avg_path)
plt.close()
print(f"Guardada curva ROC promedio voxel-wise en {roc_avg_path}")

# Calcular promedios y desviaciones estándar (cube-wise)
print("\nResultados Cube-wise:")
for cls in range(3):
    dice_mean = np.nanmean(all_dice_cube[cls])
    dice_std = np.nanstd(all_dice_cube[cls])
    sens_mean = np.nanmean(all_sensitivity_cube[cls])
    sens_std = np.nanstd(all_sensitivity_cube[cls])
    prec_mean = np.nanmean(all_precision_cube[cls])
    prec_std = np.nanstd(all_precision_cube[cls])
    auc_mean = np.nanmean(all_auc_cube[cls])
    auc_std = np.nanstd(all_auc_cube[cls])
    f1_mean = np.nanmean(all_f1_cube[cls])
    f1_std = np.nanstd(all_f1_cube[cls])
    
    print(f"\nClase {cls} ({class_names[cls]}):")
    print(f"  Dice: {dice_mean:.4f} ± {dice_std:.4f}")
    print(f"  Sensibilidad: {sens_mean:.4f} ± {sens_std:.4f}")
    print(f"  Precisión: {prec_mean:.4f} ± {prec_std:.4f}")
    print(f"  AUC-ROC: {auc_mean:.4f} ± {auc_std:.4f}")
    print(f"  F1 Score: {f1_mean:.4f} ± {f1_std:.4f}")

accuracy_mean_cube = np.nanmean(all_accuracy_cube)
accuracy_std_cube = np.nanstd(all_accuracy_cube)
print(f"\nAccuracy Global Cube-wise: {accuracy_mean_cube:.4f} ± {accuracy_std_cube:.4f}")

dist_mean = np.nanmean(all_center_distance)
dist_std = np.nanstd(all_center_distance)
print(f"\nDistancia entre centros Global: {dist_mean:.4f} ± {dist_std:.4f}")

# Curva ROC promedio cube-wise
print("\nGenerando curva ROC promedio para métricas cube-wise...")
plt.figure(figsize=(8, 6))
for cls in range(3):
    mean_fpr_cube = np.linspace(0, 1, 100)
    tprs_cube = []
    for fpr, tpr in zip(all_fpr_cube[cls], all_tpr_cube[cls]):
        tpr_interp = np.interp(mean_fpr_cube, fpr, tpr)
        tpr_interp[0] = 0.0
        tprs_cube.append(tpr_interp)
    mean_tpr_cube = np.mean(tprs_cube, axis=0)
    mean_tpr_cube[-1] = 1.0
    mean_auc_cube = np.nanmean(all_auc_cube[cls])
    
    plt.plot(mean_fpr_cube, mean_tpr_cube, color=colors[cls], label=f'{class_names[cls]} (AUC = {mean_auc_cube:.4f})')

plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Tasa de Falsos Positivos (FPR)')
plt.ylabel('Tasa de Verdaderos Positivos (TPR)')
plt.title('Curva ROC Promedio Cube-wise')
plt.legend(loc="lower right")
roc_avg_cube_path = os.path.join(output_dir, "roc_curve_average_cube.png")
plt.savefig(roc_avg_cube_path)
plt.close()
print(f"Guardada curva ROC promedio cube-wise en {roc_avg_cube_path}")

Mapas de probabilidad para caso 0, shape: torch.Size([3, 128, 128, 128])
Guardada curva ROC voxel-wise en trained_models/mapas_valid_pipe2_m1_1dhzmigz/roc_curve_case_0.png
Guardada curva ROC cube-wise en trained_models/mapas_valid_pipe2_m1_1dhzmigz/roc_curve_cube_case_0.png
Caso 0 - Voxel-wise:
  Dice: [0.9975996073793137, 0.40547157974856896, 0.693651942576132], Sensitivity: [0.9952967624783551, 0.33052916335673893, 0.8597227076579349], Precision: [0.9999131333011678, 0.5243627989467544, 0.5813532081088277], AUC-ROC: [0.9998564277133127, 0.9906417912065751, 0.9926824584823732], Accuracy: 0.9829, F1 Score: [0.9975996073795605, 0.40547157975677367, 0.6936519425826526]
  Centro de masa Infiltrado (Pred, Voxel-wise): [59.00476265 71.05636678 67.55492751]
  Centro de masa Infiltrado (True, Voxel-wise): [64.85411718 61.54701108 57.60398522]
  Distancia entre centros (Voxel-wise): 14.955401920357406
Caso 0 - Cube-wise (tamaño 8):
  Dice: [0.998103426349968, 0.43010752225690835, 0.66315789124

## Evaluar varios clasificadores

In [7]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
import joblib
import wandb
from sklearn.metrics import classification_report

# Dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset (asumiendo que ya lo tienes definido)
class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)  # [1, 48, 128, 128, 128]
        labels = np.load(label_path)  # [128, 128, 128]
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)  # [48, 128, 128, 128]
        labels = torch.tensor(labels, dtype=torch.long)  # [128, 128, 128]
        
        return embeddings, labels

# Modelo de proyección Simple
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

# # Modelo de proyección (MLP más profundo)
# class ProjectionHead(nn.Module):
#     def __init__(self, input_dim=48, hidden_dim1=256, hidden_dim2=128, output_dim=128, dropout_p=0.3):
#         super(ProjectionHead, self).__init__()
#         self.net = nn.Sequential(
#             nn.Linear(input_dim, hidden_dim1),
#             nn.ReLU(),
#             # nn.Dropout(dropout_p),
#             nn.Linear(hidden_dim1, hidden_dim2),
#             nn.ReLU(),
#             # nn.Dropout(dropout_p),
#             nn.Linear(hidden_dim2, output_dim)
#         )
    
#     def forward(self, x):
#         return self.net(x)
    
# Clasificador lineal original
class Classifier(nn.Module):
    def __init__(self, input_dim=128, num_classes=3):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.fc(x)

# Función para generar mapas de probabilidad (adaptada para modelos PyTorch y scikit-learn)
def generate_probability_maps(embeddings, projection_head, classifier, device, is_pytorch_model=True):
    """
    embeddings: tensor [1, 48, 128, 128, 128] - Características de SwinUNETR
    projection_head: modelo de proyección (PyTorch)
    classifier: modelo clasificador (PyTorch o scikit-learn/XGBoost)
    is_pytorch_model: True para PyTorch, False para scikit-learn/XGBoost
    Retorna: mapas de probabilidad [3, 128, 128, 128]
    """
    with torch.no_grad():
        embeddings = embeddings.to(device).squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        
        z = projection_head(embeddings_flat)  # [2097152, 128]
        z = F.normalize(z, dim=1)
        
        if is_pytorch_model:
            logits = classifier(z)  # [2097152, 3]
            probs = F.softmax(logits, dim=1)  # [2097152, 3]
        else:
            # scikit-learn/XGBoost espera arrays de NumPy
            z_np = z.cpu().numpy()
            probs = classifier.predict_proba(z_np)  # [2097152, 3]
            probs = torch.tensor(probs, dtype=torch.float32, device=device)
        
        probs = probs.view(128, 128, 128, 3).permute(3, 0, 1, 2)  # [3, 128, 128, 128]
        return probs

# Funciones para calcular métricas
def calculate_metrics(pred, true, num_classes=3):
    dice_scores = []
    sensitivity_scores = []
    precision_scores = []
    
    for cls in range(num_classes):
        pred_cls = (pred == cls).astype(np.uint8)
        true_cls = (true == cls).astype(np.uint8)
        
        # True Positives (TP), False Positives (FP), False Negatives (FN)
        tp = np.sum(pred_cls * true_cls)
        fp = np.sum(pred_cls * (1 - true_cls))
        fn = np.sum((1 - pred_cls) * true_cls)
        
        # Dice
        dice = 2 * tp / (2 * tp + fp + fn + 1e-6)
        dice_scores.append(dice)
        
        # Sensibilidad (Recall)
        sensitivity = tp / (tp + fn + 1e-6)
        sensitivity_scores.append(sensitivity)
        
        # Precisión
        precision = tp / (tp + fp + 1e-6)
        precision_scores.append(precision)
    
    return dice_scores, sensitivity_scores, precision_scores

# Configuración
embedding_dir = "Dataset/contrastive_voxel_wise/valid/embeddings"
label_dir = "Dataset/contrastive_voxel_wise/valid/labels"
output_dir = "trained_models/mapas"
os.makedirs(output_dir, exist_ok=True)

# Cargar dataset
dataset = EmbeddingDataset(embedding_dir=embedding_dir, label_dir=label_dir)
loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

# Cargar modelo de proyección
projection_head = ProjectionHead(input_dim=48).to(device)
projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_fhosddxt.pth", map_location=device))
# projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_new.pth", map_location=device))
projection_head.eval()

# Definir clasificadores
classifiers = {
    # "Linear": (Classifier(input_dim=128, num_classes=3).to(device), True),
    "RandomForest": (joblib.load("trained_models/classifier_experiments-v05/RandomForest_model.pkl"), False),
    # "AdaBoost": (joblib.load("trained_models/classifier_experiments/AdaBoost_model.pkl"), False),
    "XGBoost": (joblib.load("trained_models/classifier_experiments-v05/XGBoost_model.pkl"), False)
}

# # Cargar estado para el clasificador lineal
# classifiers["Linear"][0].load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final_fhosddxt.pth", map_location=device))
# classifiers["Linear"][0].eval()

# # Configurar WandB
# os.environ["WANDB_NOTEBOOK_NAME"] = "use_contrastive.ipynb"  # Ajusta si es necesario
# wandb.login()

# Evaluar cada clasificador
for clf_name, (classifier, is_pytorch_model) in classifiers.items():
    print(f"\nEvaluando {clf_name}...")
    
    # # Iniciar experimento en WandB
    # wandb.init(project="voxel_classifier_evaluation", name=f"eval_{clf_name}", config={
    #     "classifier": clf_name,
    #     "dataset": "contrastive_voxel_wise"
    # })
    
    # Listas para almacenar métricas por caso
    all_dice = {0: [], 1: [], 2: []}  # Fondo, Vasogénico, Infiltrado
    all_sensitivity = {0: [], 1: [], 2: []}
    all_precision = {0: [], 1: [], 2: []}
    
    # Procesar y guardar como NIfTI
    for idx, (embeddings, labels) in enumerate(loader):
        # Generar mapas de probabilidad
        prob_maps = generate_probability_maps(embeddings, projection_head, classifier, device, is_pytorch_model)
        print(f"Mapas de probabilidad para caso {idx}, shape: {prob_maps.shape}")
        
        # Convertir mapas de probabilidad a numpy
        prob_maps_np = prob_maps.cpu().numpy()  # [3, 128, 128, 128]
        prob_maps_np_nifti = np.transpose(prob_maps_np, (1, 2, 3, 0))  # [128, 128, 128, 3]
        
        # Generar segmentación semántica
        segmentation = np.argmax(prob_maps_np, axis=0)  # [128, 128, 128]
        segmentation_np = segmentation.astype(np.uint8)
        
        # Convertir etiquetas a numpy
        labels = labels.squeeze(0)  # [128, 128, 128]
        labels_np = labels.cpu().numpy().astype(np.uint8)
        
        # Calcular métricas
        dice, sensitivity, precision = calculate_metrics(segmentation_np, labels_np)
        for cls in range(3):
            all_dice[cls].append(dice[cls])
            all_sensitivity[cls].append(sensitivity[cls])
            all_precision[cls].append(precision[cls])
        
        # Generar reporte detallado por caso
        report = classification_report(labels_np.flatten(), segmentation_np.flatten(), 
                                     target_names=["Fondo", "Vasogénico", "Infiltrado"], output_dict=True)
        
        # # Registrar métricas por caso en WandB
        # wandb.log({
        #     f"case_{idx}/dice_fondo": dice[0],
        #     f"case_{idx}/dice_vasogenico": dice[1],
        #     f"case_{idx}/dice_infiltrado": dice[2],
        #     f"case_{idx}/sensitivity_fondo": sensitivity[0],
        #     f"case_{idx}/sensitivity_vasogenico": sensitivity[1],
        #     f"case_{idx}/sensitivity_infiltrado": sensitivity[2],
        #     f"case_{idx}/precision_fondo": precision[0],
        #     f"case_{idx}/precision_vasogenico": precision[1],
        #     f"case_{idx}/precision_infiltrado": precision[2],
        # })
        
        # print(f"Caso {idx} - Dice: {dice}, Sensitivity: {sensitivity}, Precision: {precision}")
        
        # Crear imágenes NIfTI
        affine = np.eye(4)
        
        # Guardar mapas de probabilidad
        nifti_prob_img = nib.Nifti1Image(prob_maps_np_nifti, affine)
        prob_output_path = os.path.join(output_dir, f"{clf_name}_probability_maps_case_{idx}.nii.gz")
        nib.save(nifti_prob_img, prob_output_path)
        print(f"Guardado mapa de probabilidad en {prob_output_path}")
        
        # Guardar etiquetas
        nifti_label_img = nib.Nifti1Image(labels_np, affine)
        label_output_path = os.path.join(output_dir, f"{clf_name}_labels_case_{idx}.nii.gz")
        nib.save(nifti_label_img, label_output_path)
        print(f"Guardadas etiquetas en {label_output_path}")
        
        # Guardar segmentación semántica
        nifti_seg_img = nib.Nifti1Image(segmentation_np, affine)
        seg_output_path = os.path.join(output_dir, f"{clf_name}_segmentation_case_{idx}.nii.gz")
        nib.save(nifti_seg_img, seg_output_path)
        print(f"Guardada segmentación en {seg_output_path}")
    
    # Calcular promedios y desviaciones estándar
    class_names = ["Fondo", "Vasogénico", "Infiltrado"]
    metrics_summary = {}
    for cls in range(3):
        dice_mean = np.mean(all_dice[cls])
        dice_std = np.std(all_dice[cls])
        sens_mean = np.mean(all_sensitivity[cls])
        sens_std = np.std(all_sensitivity[cls])
        prec_mean = np.mean(all_precision[cls])
        prec_std = np.std(all_precision[cls])
        
        # metrics_summary.update({
        #     f"mean_dice_{class_names[cls]}": dice_mean,
        #     f"std_dice_{class_names[cls]}": dice_std,
        #     f"mean_sensitivity_{class_names[cls]}": sens_mean,
        #     f"std_sensitivity_{class_names[cls]}": sens_std,
        #     f"mean_precision_{class_names[cls]}": prec_mean,
        #     f"std_precision_{class_names[cls]}": prec_std
        # })
        
        print(f"\nClase {cls} ({class_names[cls]}):")
        print(f"  Dice: {dice_mean:.4f} ± {dice_std:.4f}")
        print(f"  Sensibilidad: {sens_mean:.4f} ± {sens_std:.4f}")
        print(f"  Precisión: {prec_mean:.4f} ± {prec_std:.4f}")
    
    # # Registrar métricas agregadas en WandB
    # wandb.log(metrics_summary)
    
    # # Finalizar experimento
    # wandb.finish()

print("Evaluación completada.")


Evaluando RandomForest...
Mapas de probabilidad para caso 0, shape: torch.Size([3, 128, 128, 128])
Guardado mapa de probabilidad en trained_models/mapas/RandomForest_probability_maps_case_0.nii.gz
Guardadas etiquetas en trained_models/mapas/RandomForest_labels_case_0.nii.gz
Guardada segmentación en trained_models/mapas/RandomForest_segmentation_case_0.nii.gz
Mapas de probabilidad para caso 1, shape: torch.Size([3, 128, 128, 128])
Guardado mapa de probabilidad en trained_models/mapas/RandomForest_probability_maps_case_1.nii.gz
Guardadas etiquetas en trained_models/mapas/RandomForest_labels_case_1.nii.gz
Guardada segmentación en trained_models/mapas/RandomForest_segmentation_case_1.nii.gz
Mapas de probabilidad para caso 2, shape: torch.Size([3, 128, 128, 128])
Guardado mapa de probabilidad en trained_models/mapas/RandomForest_probability_maps_case_2.nii.gz
Guardadas etiquetas en trained_models/mapas/RandomForest_labels_case_2.nii.gz
Guardada segmentación en trained_models/mapas/RandomFo