In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Loading functions
import os
import time
from monai.data import DataLoader, decollate_batch


import torch
import torch.nn.parallel

from src.get_data import CustomDataset, CustomDatasetSeg
import numpy as np
from scipy import ndimage
from types import SimpleNamespace
import wandb
import logging

#####
import json
import shutil
import tempfile

import matplotlib.pyplot as plt
import nibabel as nib

from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai import transforms
from monai.transforms import (
    AsDiscrete,
    Activations,
    MapTransform,
    Transform,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.utils.enums import MetricReduction
from monai.networks.nets import SwinUNETR
from monai import data

# from monai.data import decollate_batch
from functools import partial
from src.custom_transforms import ConvertToMultiChannelBasedOnN_Froi, ConvertToMultiChannelBasedOnAnotatedInfiltration, masked

## Transformaciones Swin UNETR

In [3]:
roi = (128, 128, 128) # (220, 220, 155) (128, 128, 64)
source_k="label"
train_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        # ConvertToMultiChannelBasedOnN_Froi(keys="label"),
        # masked(keys="image"),
        ConvertToMultiChannelBasedOnAnotatedInfiltration(keys="label"),
        transforms.CropForegroundd(
            keys=["image", "label"],
            source_key=source_k,
            k_divisible=[roi[0], roi[1], roi[2]],
        ),
        transforms.RandSpatialCropd(
            keys=["image", "label"],
            roi_size=[roi[0], roi[1], roi[2]],
            random_size=False,
        ),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        
    ]
)
val_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        # ConvertToMultiChannelBasedOnN_Froi(keys="label"),
        # masked(keys="image"),
        ConvertToMultiChannelBasedOnAnotatedInfiltration(keys="label"),
        transforms.RandSpatialCropd(
            keys=["image", "label"],
            roi_size=[-1, -1, -1], #[224, 224, 128],
            random_size=False,
        ),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)



In [4]:
######################
# Crear el modelo
######################

### Hyperparameter
roi = (128, 128, 128)  # (128, 128, 128)

# Create Swin transformer
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = SwinUNETR(
    img_size=roi,
    in_channels=11,
    out_channels=2,  # mdificar con edema
    feature_size=48, #48
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    use_checkpoint=True,
)

# Load the best model
model_path = "artifacts/1dhzmigz_best_model:v0/model.pt"

# Load the model on CPU
loaded_model = torch.load(model_path, map_location=torch.device('cuda:0'))["state_dict"]

# Load the state dictionary into the model
model.load_state_dict(loaded_model)

model.to(device)

# Set the model to evaluation mode
model.eval()

  loaded_model = torch.load(model_path, map_location=torch.device('cuda:0'))["state_dict"]


SwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(11, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=48, out_features=144, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=48, out_features=48, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=48, out_features=192, bias=True)
              (linear2): Linear(in_featur

In [5]:
# Create dataset data loader
dataset_path='./Dataset/Dataset_recurrence'
train_set=CustomDataset(dataset_path, section="train", transform=train_transform) # v_transform
train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)



Found 36 images and 36 labels.


In [6]:

embedding_dir = "Dataset/contrastive_voxel_wise/embeddings"
label_output_dir = "Dataset/contrastive_voxel_wise/labels"

# Crear carpetas si no existen
os.makedirs(embedding_dir, exist_ok=True)
os.makedirs(label_output_dir, exist_ok=True)

# Variable para las características del decoder
decoder_features = None

# Función hook
def decoder_hook_fn(module, input, output):
    global decoder_features
    decoder_features = output

# Registrar el hook en decoder1.conv_block
hook_handle_decoder = model.decoder1.conv_block.register_forward_hook(decoder_hook_fn)

# Extraer y guardar
with torch.no_grad():
    for idx, batch_data in enumerate(train_loader):
        image, label = batch_data["image"], batch_data["label"]
        print("Image", image.shape)  # [1, 11, 128, 128, 128]
        print("label before squeeze", label.shape)  # [1, 2, 128, 128, 128]
        
        image = image.to(device)
        label = label.squeeze(0)  # [2, 128, 128, 128]
        
        # Convertir one-hot a etiquetas únicas
        label_sum = label.sum(dim=0)  # [128, 128, 128], suma de canales
        label_class = torch.zeros_like(label_sum, dtype=torch.long)  # [128, 128, 128]
        
        # Asignar clases:
        # - Fondo (0, 0) -> 0
        # - Vasogénico (1, 0) -> 1
        # - Infiltrado (0, 1) -> 2
        label_class[label[1] == 1] = 2  # Infiltrado
        label_class[(label[0] == 1) & (label[1] == 0)] = 1  # Vasogénico
        # Donde label_sum == 0, ya es fondo (0)
        
        label = label_class.cpu().numpy()  # [128, 128, 128]
        print("label", label.shape)
        
        # Obtener embeddings
        _ = model(image)  # Ejecuta el forward para activar el hook
        
        print("decoder_features:", decoder_features.shape)  # [1, 48, 128, 128, 128]
        
        # Guardar embeddings y etiquetas
        np.save(f"{embedding_dir}/case_{idx}.npy", decoder_features.cpu().numpy())
        np.save(f"{label_output_dir}/case_{idx}.npy", label)
        
        print(f"Guardado embeddings y etiquetas para caso {idx}")

# Remover el hook
hook_handle_decoder.remove()

Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 0
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 1
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 2
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 3
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torc

## Entrenar modelo contrastivo

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)
        labels = np.load(label_path)
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)
        labels = torch.tensor(labels, dtype=torch.long)
        
        return embeddings, labels

class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

def contrastive_loss(z, labels, temperature=0.5, sample_size_per_class=1024):
    N_total = z.shape[0]
    z = F.normalize(z, dim=1)
    
    classes = torch.unique(labels)
    if len(classes) < 2:
        print(f"Advertencia: Solo una clase presente ({classes.tolist()}), devolviendo pérdida 0")
        return torch.tensor(0.0, device=device, requires_grad=True)
    
    sampled_z = []
    sampled_labels = []
    
    for cls in classes:
        cls_indices = (labels == cls).nonzero(as_tuple=True)[0]
        cls_size = cls_indices.shape[0]
        if cls_size > sample_size_per_class:
            indices = torch.randperm(cls_size)[:sample_size_per_class]
            cls_indices = cls_indices[indices]
        sampled_z.append(z[cls_indices])
        sampled_labels.append(labels[cls_indices])
    
    z = torch.cat(sampled_z, dim=0)
    labels = torch.cat(sampled_labels, dim=0)
    N = z.shape[0]
    
    # print(f"Batch size: {N}, Unique labels: {torch.unique(labels).tolist()}")
    
    if N < 2:
        print("Advertencia: Batch con menos de 2 vóxeles, devolviendo pérdida 0")
        return torch.tensor(0.0, device=device, requires_grad=True)
    
    similarity = torch.mm(z, z.T) / temperature
    labels_eq = labels.unsqueeze(1) == labels.unsqueeze(0)
    labels_eq = labels_eq.float()
    eye = torch.eye(N, device=device)
    labels_eq = labels_eq * (1 - eye)
    
    exp_sim = torch.exp(similarity)
    pos_sum = (exp_sim * labels_eq).sum(dim=1)
    neg_sum = exp_sim.sum(dim=1) - exp_sim.diag()
    
    if pos_sum.sum() == 0:
        print("Advertencia: No hay pares positivos, pérdida será 0")
    
    loss = -torch.log((pos_sum + 1e-6) / (neg_sum + 1e-6))
    return loss.mean()

# Configuración
embedding_dir = "Dataset/contrastive_voxel_wise/embeddings"
label_dir = "Dataset/contrastive_voxel_wise/labels"
batch_size = 1
sample_size_per_class = 3333
temperature = 0.5
num_epochs = 100
patience = 10  # Early stopping

dataset = EmbeddingDataset(embedding_dir, label_dir)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

model = ProjectionHead(input_dim=48).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)



In [2]:
# Directorio para checkpoints
output_dir = "trained_models/checkpoints_contrastive"
os.makedirs(output_dir, exist_ok=True)

# Variables para early stopping y checkpoints
best_loss = float('inf')
epochs_no_improve = 0
best_model_path = os.path.join(output_dir, "best_contrastive_projection_head.pth")

# Entrenamiento con scheduler, checkpoints y early stopping
for epoch in range(num_epochs):
    total_loss = 0
    valid_batches = 0
    
    model.train()  # Modo entrenamiento
    
    for batch_idx, (embeddings, labels) in enumerate(loader):
        embeddings = embeddings.to(device)  # [1, 48, 128, 128, 128]
        labels = labels.to(device)  # [1, 128, 128, 128]
        
        embeddings = embeddings.squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        labels = labels.squeeze(0)  # [128, 128, 128]
        
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        labels_flat = labels.reshape(-1)  # [2097152]
        
        valid_mask = labels_flat >= 0
        embeddings_valid = embeddings_flat[valid_mask]
        labels_valid = labels_flat[valid_mask]
        
        if embeddings_valid.shape[0] < 2:
            print(f"Batch {batch_idx}: Insuficientes vóxeles válidos")
            continue
        
        # Forward
        z = model(embeddings_valid)
        loss = contrastive_loss(z, labels_valid, temperature, sample_size_per_class)
        
        if loss.item() == 0:
            continue  # No contar batches con pérdida 0
        
        # Optimización
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        valid_batches += 1
        
        if batch_idx % 5 == 0:  # Imprimir cada 5 batches
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}")
    
    # Calcular pérdida promedio
    avg_loss = total_loss / max(valid_batches, 1)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Valid Batches: {valid_batches}/{len(loader)}, "
          f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Scheduler: ajustar tasa de aprendizaje
    scheduler.step(avg_loss)
    
    # Checkpoint: guardar el mejor modelo
    if avg_loss < best_loss:
        best_loss = avg_loss
        epochs_no_improve = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, best_model_path)
        print(f"Guardado checkpoint con mejor pérdida: {best_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"Épocas sin mejora: {epochs_no_improve}/{patience}")
    
    # Early stopping
    if epochs_no_improve >= patience:
        print(f"Early stopping activado tras {epoch+1} épocas. Mejor pérdida: {best_loss:.4f}")
        break

# Cargar el mejor modelo al final (opcional)
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Cargado el mejor modelo desde {best_model_path} con pérdida: {checkpoint['loss']:.4f}")

# Guardar el modelo final (opcional)
torch.save(model.state_dict(), os.path.join(output_dir, "contrastive_projection_head_final.pth"))
print("Modelo final guardado en 'trained_models/checkpoints_contrastive/contrastive_projection_head_final.pth'")

Epoch 1/100, Batch 0/36, Loss: 0.8038
Epoch 1/100, Batch 5/36, Loss: 0.4591
Epoch 1/100, Batch 10/36, Loss: 0.2092
Epoch 1/100, Batch 15/36, Loss: 0.4047
Epoch 1/100, Batch 20/36, Loss: 0.3057
Epoch 1/100, Batch 25/36, Loss: 0.3589
Epoch 1/100, Batch 30/36, Loss: 0.5311
Epoch 1/100, Batch 35/36, Loss: 0.3748
Epoch 1/100, Average Loss: 0.3716, Valid Batches: 36/36, Learning Rate: 0.001000
Guardado checkpoint con mejor pérdida: 0.3716
Epoch 2/100, Batch 0/36, Loss: 0.5252
Epoch 2/100, Batch 5/36, Loss: 0.3832
Epoch 2/100, Batch 10/36, Loss: 0.4601
Epoch 2/100, Batch 15/36, Loss: 0.4649
Epoch 2/100, Batch 20/36, Loss: 0.3860
Epoch 2/100, Batch 25/36, Loss: 0.1903
Epoch 2/100, Batch 30/36, Loss: 0.4668
Epoch 2/100, Batch 35/36, Loss: 0.4273
Epoch 2/100, Average Loss: 0.3182, Valid Batches: 36/36, Learning Rate: 0.001000
Guardado checkpoint con mejor pérdida: 0.3182
Epoch 3/100, Batch 0/36, Loss: 0.3430
Epoch 3/100, Batch 5/36, Loss: 0.2393
Epoch 3/100, Batch 10/36, Loss: 0.1927
Epoch 3/100

  checkpoint = torch.load(best_model_path, map_location=device)


In [None]:


for epoch in range(num_epochs):
    total_loss = 0
    valid_batches = 0
    for batch_idx, (embeddings, labels) in enumerate(loader):
        embeddings = embeddings.to(device)
        labels = labels.to(device)
        
        embeddings = embeddings.squeeze(0).permute(1, 2, 3, 0)
        labels = labels.squeeze(0)
        
        embeddings_flat = embeddings.reshape(-1, 48)
        labels_flat = labels.reshape(-1)
        
        valid_mask = labels_flat >= 0
        embeddings_valid = embeddings_flat[valid_mask]
        labels_valid = labels_flat[valid_mask]
        
        if embeddings_valid.shape[0] < 2:
            print(f"Batch {batch_idx}: Insuficientes vóxeles válidos")
            continue
        
        z = model(embeddings_valid)
        loss = contrastive_loss(z, labels_valid, temperature, sample_size_per_class)
        
        if loss.item() == 0:
            continue  # No contar batches con pérdida 0 en el promedio
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        valid_batches += 1
        
        if batch_idx % 1 == 0:  # Reducir frecuencia de impresión
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / max(valid_batches, 1)  # Evitar división por 0
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Valid Batches: {valid_batches}/{len(loader)}")

torch.save(model.state_dict(), "trained_models/contrastive_projection_head.pth")

Batch size: 5729, Unique labels: [0, 1, 2]
Epoch 1/100, Batch 0/36, Loss: 0.8219
Batch size: 6144, Unique labels: [0, 1, 2]
Epoch 1/100, Batch 1/36, Loss: 0.7822
Batch size: 6144, Unique labels: [0, 1, 2]
Epoch 1/100, Batch 2/36, Loss: 0.7484
Batch size: 6026, Unique labels: [0, 1, 2]
Epoch 1/100, Batch 3/36, Loss: 0.7338
Batch size: 6144, Unique labels: [0, 1, 2]
Epoch 1/100, Batch 4/36, Loss: 0.6744
Batch size: 6144, Unique labels: [0, 1, 2]
Epoch 1/100, Batch 5/36, Loss: 0.6506
Batch size: 6144, Unique labels: [0, 1, 2]
Epoch 1/100, Batch 6/36, Loss: 0.5996
Batch size: 4550, Unique labels: [0, 1, 2]
Epoch 1/100, Batch 7/36, Loss: 0.3945
Batch size: 5463, Unique labels: [0, 1, 2]
Epoch 1/100, Batch 8/36, Loss: 0.6113
Batch size: 4608, Unique labels: [0, 1, 2]
Epoch 1/100, Batch 9/36, Loss: 0.3738
Batch size: 6144, Unique labels: [0, 1, 2]
Epoch 1/100, Batch 10/36, Loss: 0.4886
Batch size: 6144, Unique labels: [0, 1, 2]
Epoch 1/100, Batch 11/36, Loss: 0.5332
Batch size: 6144, Unique l

In [2]:
for idx, (embeddings, labels) in enumerate(loader):
    print(f"Unique labels: {torch.unique(labels).tolist()}")
    labels_flat = labels.reshape(-1)
    class_counts = torch.bincount(labels_flat)
    print(f"Case {idx}: Fondo: {class_counts[0]}, Vasogénico: {class_counts[1] if len(class_counts) > 1 else 0}, Infiltrado: {class_counts[2] if len(class_counts) > 2 else 0}")

Unique labels: [0, 1, 2]
Case 0: Fondo: 2004144, Vasogénico: 1032, Infiltrado: 91976
Unique labels: [0, 1, 2]
Case 1: Fondo: 2056648, Vasogénico: 18101, Infiltrado: 22403
Unique labels: [0, 1, 2]
Case 2: Fondo: 2079347, Vasogénico: 512, Infiltrado: 17293
Unique labels: [0, 1, 2]
Case 3: Fondo: 2084680, Vasogénico: 6398, Infiltrado: 6074
Unique labels: [0, 1, 2]
Case 4: Fondo: 1994106, Vasogénico: 3204, Infiltrado: 99842
Unique labels: [0, 1, 2]
Case 5: Fondo: 2059831, Vasogénico: 2, Infiltrado: 37319
Unique labels: [0, 1, 2]
Case 6: Fondo: 2065553, Vasogénico: 1930, Infiltrado: 29669
Unique labels: [0, 1, 2]
Case 7: Fondo: 2038238, Vasogénico: 5913, Infiltrado: 53001
Unique labels: [0, 1, 2]
Case 8: Fondo: 2068807, Vasogénico: 17756, Infiltrado: 10589
Unique labels: [0, 1, 2]
Case 9: Fondo: 2054422, Vasogénico: 58, Infiltrado: 42672
Unique labels: [0, 1, 2]
Case 10: Fondo: 1964971, Vasogénico: 23627, Infiltrado: 108554
Unique labels: [0, 1, 2]
Case 11: Fondo: 2081185, Vasogénico: 3571,

## Entrenar modelo de clasificacion supervisado

In [3]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Dispositivo
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# Dataset (ya lo tienes)
class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)  # [1, 48, 128, 128, 128]
        labels = np.load(label_path)  # [128, 128, 128]
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)  # [48, 128, 128, 128]
        labels = torch.tensor(labels, dtype=torch.long)  # [128, 128, 128]
        
        return embeddings, labels

# Modelo de proyección (ya lo tienes)
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

# Clasificador supervisado
class Classifier(nn.Module):
    def __init__(self, input_dim=128, num_classes=3):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.fc(x)

# Configuración
embedding_dir = "Dataset/contrastive_voxel_wise/embeddings"
label_dir = "Dataset/contrastive_voxel_wise/labels"
batch_size = 1
sample_size_per_class = 3333  # ~10,000 vóxeles total
num_epochs = 100  # Máximo de épocas
patience = 10  # Early stopping: épocas sin mejora

# Cargar dataset y DataLoader
dataset = EmbeddingDataset(embedding_dir, label_dir)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

# Cargar modelo contrastivo preentrenado
projection_head = ProjectionHead(input_dim=48).to(device)
projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final.pth", map_location=device))
projection_head.eval()  # Modo evaluación, sin gradientes

# Definir clasificador
classifier = Classifier(input_dim=128, num_classes=3).to(device)
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
criterion = nn.CrossEntropyLoss()



  projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final.pth", map_location=device))


In [4]:
# Directorio para checkpoints
output_dir = "trained_models/checkpoints"
os.makedirs(output_dir, exist_ok=True)

# Variables para early stopping y checkpoints
best_loss = float('inf')
epochs_no_improve = 0
best_model_path = os.path.join(output_dir, "best_supervised_classifier.pth")

# Entrenamiento del clasificador con muestreo balanceado, scheduler y early stopping
for epoch in range(num_epochs):
    total_loss = 0
    valid_batches = 0
    
    classifier.train()  # Modo entrenamiento
    
    for batch_idx, (embeddings, labels) in enumerate(loader):
        embeddings = embeddings.to(device)  # [1, 48, 128, 128, 128]
        labels = labels.to(device)  # [1, 128, 128, 128]
        
        # Reorganizar para procesar vóxeles
        embeddings = embeddings.squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        labels = labels.squeeze(0)  # [128, 128, 128]
        
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        labels_flat = labels.reshape(-1)  # [2097152]
        
        # Muestreo estratificado balanceado
        classes = torch.unique(labels_flat)
        if len(classes) < 2:
            print(f"Batch {batch_idx}: Solo una clase presente ({classes.tolist()}), saltando")
            continue
        
        sampled_embeddings = []
        sampled_labels = []
        
        for cls in classes:
            cls_indices = (labels_flat == cls).nonzero(as_tuple=True)[0]
            cls_size = cls_indices.shape[0]
            if cls_size > sample_size_per_class:
                indices = torch.randperm(cls_size)[:sample_size_per_class]
                cls_indices = cls_indices[indices]
            sampled_embeddings.append(embeddings_flat[cls_indices])
            sampled_labels.append(labels_flat[cls_indices])
        
        embeddings_sampled = torch.cat(sampled_embeddings, dim=0)
        labels_sampled = torch.cat(sampled_labels, dim=0)
        
        # Obtener representaciones contrastivas
        with torch.no_grad():
            z = projection_head(embeddings_sampled)  # [N, 128]
            z = F.normalize(z, dim=1)
        
        # Clasificación
        logits = classifier(z)
        loss = criterion(logits, labels_sampled)
        
        # Optimización
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        valid_batches += 1
        
        if batch_idx % 5 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}, "
                  f"Sampled size: {embeddings_sampled.shape[0]}, Classes: {torch.unique(labels_sampled).tolist()}")
    
    # Calcular pérdida promedio
    avg_loss = total_loss / max(valid_batches, 1)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Valid Batches: {valid_batches}/{len(loader)}, "
          f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Scheduler: ajustar tasa de aprendizaje basada en la pérdida promedio
    scheduler.step(avg_loss)
    
    # Checkpoint: guardar el mejor modelo
    if avg_loss < best_loss:
        best_loss = avg_loss
        epochs_no_improve = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': classifier.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, best_model_path)
        print(f"Guardado checkpoint con mejor pérdida: {best_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"Épocas sin mejora: {epochs_no_improve}/{patience}")
    
    # Early stopping
    if epochs_no_improve >= patience:
        print(f"Early stopping activado tras {epoch+1} épocas. Mejor pérdida: {best_loss:.4f}")
        break

# Cargar el mejor modelo al final (opcional)
checkpoint = torch.load(best_model_path, map_location=device)
classifier.load_state_dict(checkpoint['model_state_dict'])
print(f"Cargado el mejor modelo desde {best_model_path} con pérdida: {checkpoint['loss']:.4f}")

# Guardar el modelo final (opcional)
torch.save(classifier.state_dict(), os.path.join(output_dir, "supervised_classifier_final.pth"))
print("Clasificador final guardado en 'trained_models/checkpoints/supervised_classifier_final.pth'")

Epoch 1/100, Batch 0/36, Loss: 1.1146, Sampled size: 9999, Classes: [0, 1, 2]
Epoch 1/100, Batch 5/36, Loss: 1.0719, Sampled size: 9999, Classes: [0, 1, 2]
Epoch 1/100, Batch 10/36, Loss: 1.0363, Sampled size: 9999, Classes: [0, 1, 2]
Epoch 1/100, Batch 15/36, Loss: 0.9843, Sampled size: 9999, Classes: [0, 1, 2]
Epoch 1/100, Batch 20/36, Loss: 0.9499, Sampled size: 9999, Classes: [0, 1, 2]
Epoch 1/100, Batch 25/36, Loss: 0.9107, Sampled size: 7935, Classes: [0, 1, 2]
Epoch 1/100, Batch 30/36, Loss: 0.9555, Sampled size: 9999, Classes: [0, 1, 2]
Epoch 1/100, Batch 35/36, Loss: 0.8333, Sampled size: 9999, Classes: [0, 1, 2]
Epoch 1/100, Average Loss: 0.9847, Valid Batches: 36/36, Learning Rate: 0.001000
Guardado checkpoint con mejor pérdida: 0.9847
Epoch 2/100, Batch 0/36, Loss: 0.8926, Sampled size: 9999, Classes: [0, 1, 2]
Epoch 2/100, Batch 5/36, Loss: 0.7909, Sampled size: 9999, Classes: [0, 1, 2]
Epoch 2/100, Batch 10/36, Loss: 0.7685, Sampled size: 9999, Classes: [0, 1, 2]
Epoch 2/

  checkpoint = torch.load(best_model_path, map_location=device)


In [5]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import nibabel as nib

# Dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Función para generar mapas de probabilidad
def generate_probability_maps(embeddings, projection_head, classifier, device):
    """
    embeddings: tensor [1, 48, 128, 128, 128] - Características de SwinUNETR
    Retorna: mapas de probabilidad [3, 128, 128, 128]
    """
    with torch.no_grad():
        embeddings = embeddings.to(device).squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        
        z = projection_head(embeddings_flat)  # [2097152, 128]
        z = F.normalize(z, dim=1)
        
        logits = classifier(z)  # [2097152, 3]
        probs = F.softmax(logits, dim=1)  # [2097152, 3]
        
        probs = probs.view(128, 128, 128, 3).permute(3, 0, 1, 2)  # [3, 128, 128, 128]
        return probs

dataset = EmbeddingDataset(embedding_dir="Dataset/contrastive_voxel_wise/embeddings", 
                          label_dir="Dataset/contrastive_voxel_wise/labels")
loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

# Cargar modelos (asumiendo que ya los tienes cargados)
projection_head = ProjectionHead(input_dim=48).to(device)
projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final.pth", map_location=device))
projection_head.eval()

classifier = Classifier(input_dim=128, num_classes=3).to(device)
classifier.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final.pth", map_location=device))
classifier.eval()

# Directorio de salida
output_dir = "trained_models/mapas"
os.makedirs(output_dir, exist_ok=True)

# Procesar y guardar como NIfTI
for idx, (embeddings, labels) in enumerate(loader):
    # Generar mapas de probabilidad
    prob_maps = generate_probability_maps(embeddings, projection_head, classifier, device)
    print(f"Mapas de probabilidad para caso {idx}, shape: {prob_maps.shape}")
    
    # Convertir mapas de probabilidad a numpy y ajustar formato para NIfTI
    prob_maps_np = prob_maps.cpu().numpy()  # [3, 128, 128, 128]
    prob_maps_np_nifti = np.transpose(prob_maps_np, (1, 2, 3, 0))  # [128, 128, 128, 3] para NIfTI
    
    # Generar segmentación semántica (clase más probable por vóxel)
    segmentation = np.argmax(prob_maps_np, axis=0)  # [128, 128, 128], valores 0, 1, 2
    segmentation_np = segmentation.astype(np.uint8)  # Convertir a uint8 para NIfTI
    
    # Convertir etiquetas a numpy
    labels = labels.squeeze(0)  # [128, 128, 128]
    labels_np = labels.cpu().numpy().astype(np.uint8)  # Convertir a uint8
    
    # Crear imágenes NIfTI con matriz afín identidad
    affine = np.eye(4)  # Ajusta si tienes una matriz afín real
    
    # Guardar mapas de probabilidad
    nifti_prob_img = nib.Nifti1Image(prob_maps_np_nifti, affine)
    prob_output_path = os.path.join(output_dir, f"probability_maps_case_{idx}.nii.gz")
    nib.save(nifti_prob_img, prob_output_path)
    print(f"Guardado mapa de probabilidad como NIfTI en {prob_output_path}")
    
    # Guardar etiquetas
    nifti_label_img = nib.Nifti1Image(labels_np, affine)
    label_output_path = os.path.join(output_dir, f"labels_case_{idx}.nii.gz")
    nib.save(nifti_label_img, label_output_path)
    print(f"Guardadas etiquetas como NIfTI en {label_output_path}")
    
    # Guardar segmentación semántica
    nifti_seg_img = nib.Nifti1Image(segmentation_np, affine)
    seg_output_path = os.path.join(output_dir, f"segmentation_case_{idx}.nii.gz")
    nib.save(nifti_seg_img, seg_output_path)
    print(f"Guardada segmentación semántica como NIfTI en {seg_output_path}")
    
    # break  # Descomenta si solo quieres procesar un caso

  projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final.pth", map_location=device))
  classifier.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final.pth", map_location=device))


Mapas de probabilidad para caso 0, shape: torch.Size([3, 128, 128, 128])
Guardado mapa de probabilidad como NIfTI en trained_models/mapas/probability_maps_case_0.nii.gz
Guardadas etiquetas como NIfTI en trained_models/mapas/labels_case_0.nii.gz
Guardada segmentación semántica como NIfTI en trained_models/mapas/segmentation_case_0.nii.gz
Mapas de probabilidad para caso 1, shape: torch.Size([3, 128, 128, 128])
Guardado mapa de probabilidad como NIfTI en trained_models/mapas/probability_maps_case_1.nii.gz
Guardadas etiquetas como NIfTI en trained_models/mapas/labels_case_1.nii.gz
Guardada segmentación semántica como NIfTI en trained_models/mapas/segmentation_case_1.nii.gz
Mapas de probabilidad para caso 2, shape: torch.Size([3, 128, 128, 128])
Guardado mapa de probabilidad como NIfTI en trained_models/mapas/probability_maps_case_2.nii.gz
Guardadas etiquetas como NIfTI en trained_models/mapas/labels_case_2.nii.gz
Guardada segmentación semántica como NIfTI en trained_models/mapas/segmentati

In [6]:
# Función para generar mapas de probabilidad
def generate_probability_maps(embeddings, projection_head, classifier, device):
    with torch.no_grad():
        embeddings = embeddings.to(device).squeeze(0).permute(1, 2, 3, 0)
        embeddings_flat = embeddings.reshape(-1, 48)
        
        z = projection_head(embeddings_flat)
        z = F.normalize(z, dim=1)
        
        logits = classifier(z)
        probs = F.softmax(logits, dim=1)
        
        probs = probs.view(128, 128, 128, 3).permute(3, 0, 1, 2)
        return probs

# Funciones para calcular métricas
def calculate_metrics(pred, true, num_classes=3):
    dice_scores = []
    sensitivity_scores = []
    precision_scores = []
    
    for cls in range(num_classes):
        pred_cls = (pred == cls).astype(np.uint8)
        true_cls = (true == cls).astype(np.uint8)
        
        # True Positives (TP), False Positives (FP), False Negatives (FN)
        tp = np.sum(pred_cls * true_cls)
        fp = np.sum(pred_cls * (1 - true_cls))
        fn = np.sum((1 - pred_cls) * true_cls)
        
        # Dice
        dice = 2 * tp / (2 * tp + fp + fn + 1e-6)  # Evitar división por 0
        dice_scores.append(dice)
        
        # Sensibilidad (Recall)
        sensitivity = tp / (tp + fn + 1e-6)
        sensitivity_scores.append(sensitivity)
        
        # Precisión
        precision = tp / (tp + fp + 1e-6)
        precision_scores.append(precision)
    
    return dice_scores, sensitivity_scores, precision_scores
# Directorio de salida
output_dir = "trained_models/mapas"
os.makedirs(output_dir, exist_ok=True)

# Listas para almacenar métricas por caso
all_dice = {0: [], 1: [], 2: []}  # Fondo, Vasogénico, Infiltrado
all_sensitivity = {0: [], 1: [], 2: []}
all_precision = {0: [], 1: [], 2: []}

# Procesar y guardar como NIfTI
for idx, (embeddings, labels) in enumerate(loader):
    # Generar mapas de probabilidad
    prob_maps = generate_probability_maps(embeddings, projection_head, classifier, device)
    print(f"Mapas de probabilidad para caso {idx}, shape: {prob_maps.shape}")
    
    # Convertir mapas de probabilidad a numpy
    prob_maps_np = prob_maps.cpu().numpy()  # [3, 128, 128, 128]
    prob_maps_np_nifti = np.transpose(prob_maps_np, (1, 2, 3, 0))  # [128, 128, 128, 3]
    
    # Generar segmentación semántica
    segmentation = np.argmax(prob_maps_np, axis=0)  # [128, 128, 128]
    segmentation_np = segmentation.astype(np.uint8)
    
    # Convertir etiquetas a numpy
    labels = labels.squeeze(0)  # [128, 128, 128]
    labels_np = labels.cpu().numpy().astype(np.uint8)

    # hacer cero segmentation_np en donde labels_np es cero
    # segmentation_np[labels_np == 0] = 0
    
    # Calcular métricas
    dice, sensitivity, precision = calculate_metrics(segmentation_np, labels_np)
    for cls in range(3):
        all_dice[cls].append(dice[cls])
        all_sensitivity[cls].append(sensitivity[cls])
        all_precision[cls].append(precision[cls])
    
    print(f"Caso {idx} - Dice: {dice}, Sensitivity: {sensitivity}, Precision: {precision}")
    
    # Crear imágenes NIfTI
    affine = np.eye(4)
    
    # Guardar mapas de probabilidad
    nifti_prob_img = nib.Nifti1Image(prob_maps_np_nifti, affine)
    prob_output_path = os.path.join(output_dir, f"probability_maps_case_{idx}.nii.gz")
    nib.save(nifti_prob_img, prob_output_path)
    print(f"Guardado mapa de probabilidad en {prob_output_path}")
    
    # Guardar etiquetas
    nifti_label_img = nib.Nifti1Image(labels_np, affine)
    label_output_path = os.path.join(output_dir, f"labels_case_{idx}.nii.gz")
    nib.save(nifti_label_img, label_output_path)
    print(f"Guardadas etiquetas en {label_output_path}")
    
    # Guardar segmentación semántica
    nifti_seg_img = nib.Nifti1Image(segmentation_np, affine)
    seg_output_path = os.path.join(output_dir, f"segmentation_case_{idx}.nii.gz")
    nib.save(nifti_seg_img, seg_output_path)
    print(f"Guardada segmentación en {seg_output_path}")

# Calcular promedios y desviaciones estándar
class_names = ["Fondo", "Vasogénico", "Infiltrado"]
for cls in range(3):
    dice_mean = np.mean(all_dice[cls])
    dice_std = np.std(all_dice[cls])
    sens_mean = np.mean(all_sensitivity[cls])
    sens_std = np.std(all_sensitivity[cls])
    prec_mean = np.mean(all_precision[cls])
    prec_std = np.std(all_precision[cls])
    
    print(f"\nClase {cls} ({class_names[cls]}):")
    print(f"  Dice: {dice_mean:.4f} ± {dice_std:.4f}")
    print(f"  Sensibilidad: {sens_mean:.4f} ± {sens_std:.4f}")
    print(f"  Precisión: {prec_mean:.4f} ± {prec_std:.4f}")

Mapas de probabilidad para caso 0, shape: torch.Size([3, 128, 128, 128])
Caso 0 - Dice: [0.9980268587698456, 0.43403419138467214, 0.711034133147085], Sensitivity: [0.9961672066033376, 0.2909914894661557, 0.9805196318075143], Precision: [0.9998934671460645, 0.8536755043013314, 0.5577440519509876]
Guardado mapa de probabilidad en trained_models/mapas/probability_maps_case_0.nii.gz
Guardadas etiquetas en trained_models/mapas/labels_case_0.nii.gz
Guardada segmentación en trained_models/mapas/segmentation_case_0.nii.gz
Mapas de probabilidad para caso 1, shape: torch.Size([3, 128, 128, 128])
Caso 1 - Dice: [0.9948751915126847, 0.7238126352660077, 0.49861937127409695], Sensitivity: [0.9898191186127809, 0.683409593319536, 0.8973623852639295], Precision: [0.9999831832432895, 0.7692930929491798, 0.3452205882268328]
Guardado mapa de probabilidad en trained_models/mapas/probability_maps_case_1.nii.gz
Guardadas etiquetas en trained_models/mapas/labels_case_1.nii.gz
Guardada segmentación en trained_