In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
# Loading functions
import os
import time
from monai.data import DataLoader, decollate_batch


import torch
import torch.nn.parallel

from src.get_data import CustomDataset, CustomDatasetSeg
import numpy as np
from scipy import ndimage
from types import SimpleNamespace
import wandb
import logging

#####
import json
import shutil
import tempfile

import matplotlib.pyplot as plt
import nibabel as nib

from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai import transforms
from monai.transforms import (
    AsDiscrete,
    Activations,
    MapTransform,
    Transform,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.utils.enums import MetricReduction
from monai.networks.nets import SwinUNETR
from monai import data

# from monai.data import decollate_batch
from functools import partial
from src.custom_transforms import (ConvertToMultiChannelBasedOnN_Froi, 
                                   ConvertToMultiChannelBasedOnAnotatedInfiltration, 
                                   masked, ConvertToMultiChannelBasedOnBratsClassesdI)

#### Trasnformaciones

In [6]:
roi = (128, 128, 128) # (220, 220, 155) (128, 128, 64)
source_k="label"
train_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        # ConvertToMultiChannelBasedOnN_Froi(keys="label"),
        # masked(keys="image"),
        # ConvertToMultiChannelBasedOnAnotatedInfiltration(keys="label"),
        ConvertToMultiChannelBasedOnBratsClassesdI(keys="label"),
        transforms.CropForegroundd(
            keys=["image", "label"],
            source_key=source_k,
            k_divisible=[roi[0], roi[1], roi[2]],
        ),
        transforms.RandSpatialCropd(
            keys=["image", "label"],
            roi_size=[roi[0], roi[1], roi[2]],
            random_size=False,
        ),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        
    ]
)
val_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        # ConvertToMultiChannelBasedOnN_Froi(keys="label"),
        # masked(keys="image"),
        # ConvertToMultiChannelBasedOnAnotatedInfiltration(keys="label"),
        ConvertToMultiChannelBasedOnBratsClassesdI(keys="label"),
        transforms.RandSpatialCropd(
            keys=["image", "label"],
            roi_size=[-1, -1, -1], #[224, 224, 128],
            random_size=False,
        ),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

## Modelos

In [7]:
######################
# Crear el modelo
######################

### Hyperparameter
roi = (128, 128, 128)  # (128, 128, 128)

# Create Swin transformer
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def define_model(model_path):
    model = SwinUNETR(
        img_size=roi,
        in_channels=11,
        out_channels=2,  # mdificar con edema
        feature_size=48, #48
        drop_rate=0.0,
        attn_drop_rate=0.0,
        dropout_path_rate=0.0,
        use_checkpoint=True,
    )

    # # Load the best model
    # model_path = "artifacts/o9kppyr5_best_model:v0/model.pt"

    # Load the model on CPU
    loaded_model = torch.load(model_path, map_location=torch.device('cuda:0'))["state_dict"]

    # Load the state dictionary into the model
    model.load_state_dict(loaded_model)
    return model
    

In [8]:
# Modelo TC+Edema
model1=define_model("artifacts/o9kppyr5_best_model:v0/model.pt") 
model1.to(device)
model1.eval()

# Modelo Infitracion+Edema
model2=define_model("artifacts/uixfayrn_best_model:v0/model.pt") 
model2.to(device)
model2.eval()

  loaded_model = torch.load(model_path, map_location=torch.device('cuda:0'))["state_dict"]


SwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(11, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=48, out_features=144, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=48, out_features=48, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=48, out_features=192, bias=True)
              (linear2): Linear(in_featur

## Obtener embeddings

In [9]:
# Create dataset data loader
dataset_path='./Dataset/Dataset_recurrence'
train_set=CustomDataset(dataset_path, section="train", transform=train_transform) # v_transform
train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)

# Directorios para embeddings de cada modelo
embedding_dir_model1 = "Dataset/contrastive_voxel_wise/embeddings_model1"
embedding_dir_model2 = "Dataset/contrastive_voxel_wise/embeddings_model2"
label_output_dir = "Dataset/contrastive_voxel_wise/labels"

# Crear carpetas si no existen
os.makedirs(embedding_dir_model1, exist_ok=True)
os.makedirs(embedding_dir_model2, exist_ok=True)
os.makedirs(label_output_dir, exist_ok=True)

# Variables para las características de los decoders de ambos modelos
decoder_features_model1 = None
decoder_features_model2 = None

# Funciones hook para cada modelo
def decoder_hook_fn_model1(module, input, output):
    global decoder_features_model1
    decoder_features_model1 = output

def decoder_hook_fn_model2(module, input, output):
    global decoder_features_model2
    decoder_features_model2 = output

# Registrar los hooks en los decoders de ambos modelos
hook_handle_decoder1 = model1.decoder1.conv_block.register_forward_hook(decoder_hook_fn_model1)
hook_handle_decoder2 = model2.decoder1.conv_block.register_forward_hook(decoder_hook_fn_model2)

# Extraer y guardar
with torch.no_grad():
    for idx, batch_data in enumerate(train_loader):
        image, label = batch_data["image"], batch_data["label"]
        print("Image", image.shape)  # [1, 11, 128, 128, 128]
        print("label before squeeze", label.shape)  # [1, 2, 128, 128, 128]
        
        image = image.to(device)
        label = label.squeeze(0)  # [2, 128, 128, 128]
        
        # Convertir one-hot a etiquetas únicas
        label_sum = label.sum(dim=0)  # [128, 128, 128], suma de canales
        label_class = torch.zeros_like(label_sum, dtype=torch.long)  # [128, 128, 128]
        
        # Asignar clases:
        # - Fondo (0, 0) -> 0
        # - Vasogénico (1, 0) -> 1
        # - Infiltrado (0, 1) -> 2
        label_class[label[1] == 1] = 2  # Infiltrado
        label_class[(label[0] == 1) & (label[1] == 0)] = 1  # Vasogénico
        # Donde label_sum == 0, ya es fondo (0)
        
        label = label_class.cpu().numpy()  # [128, 128, 128]
        print("label", label.shape)
        
        # Obtener embeddings de ambos modelos
        _ = model1(image)  # Forward para model1
        print("decoder_features_model1:", decoder_features_model1.shape)  # [1, 48, 128, 128, 128]
        
        _ = model2(image)  # Forward para model2
        print("decoder_features_model2:", decoder_features_model2.shape)  # [1, 48, 128, 128, 128]
        
        # Guardar embeddings y etiquetas
        np.save(f"{embedding_dir_model1}/case_{idx}.npy", decoder_features_model1.cpu().numpy())
        np.save(f"{embedding_dir_model2}/case_{idx}.npy", decoder_features_model2.cpu().numpy())
        np.save(f"{label_output_dir}/case_{idx}.npy", label)
        
        print(f"Guardado embeddings y etiquetas para caso {idx}")

# Remover los hooks
hook_handle_decoder1.remove()
hook_handle_decoder2.remove()

Found 36 images and 36 labels.
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features_model1: torch.Size([1, 48, 128, 128, 128])
decoder_features_model2: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 0
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features_model1: torch.Size([1, 48, 128, 128, 128])
decoder_features_model2: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 1
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features_model1: torch.Size([1, 48, 128, 128, 128])
decoder_features_model2: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 2
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decod

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import nibabel as nib

# Dispositivo
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# Dataset (ya lo tienes)
class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)  # [1, 48, 128, 128, 128]
        labels = np.load(label_path)  # [128, 128, 128]
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)  # [48, 128, 128, 128]
        labels = torch.tensor(labels, dtype=torch.long)  # [128, 128, 128]
        
        return embeddings, labels

# Modelo de proyección
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

# Clasificador supervisado
class Classifier(nn.Module):
    def __init__(self, input_dim=128, num_classes=3):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.fc(x)

# Configuración para ambos modelos
embedding_dir_model1 = "Dataset/contrastive_voxel_wise/embeddings_model1"
embedding_dir_model2 = "Dataset/contrastive_voxel_wise/embeddings_model2"
label_dir = "Dataset/contrastive_voxel_wise/labels"
batch_size = 1

# Cargar datasets y DataLoaders para ambos modelos
dataset_model1 = EmbeddingDataset(embedding_dir_model1, label_dir)
dataset_model2 = EmbeddingDataset(embedding_dir_model2, label_dir)
loader_model1 = DataLoader(dataset_model1, batch_size=batch_size, shuffle=False)
loader_model2 = DataLoader(dataset_model2, batch_size=batch_size, shuffle=False)

# Cargar modelos contrastivos preentrenados
projection_head1 = ProjectionHead(input_dim=48).to(device)
projection_head1.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_o9kppyr5.pth", map_location=device))
projection_head1.eval()

projection_head2 = ProjectionHead(input_dim=48).to(device)
projection_head2.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_uixfayrn.pth", map_location=device))
projection_head2.eval()

# Cargar clasificadores preentrenados
classifier1 = Classifier(input_dim=128, num_classes=3).to(device)
classifier1.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final_o9kppyr5.pth", map_location=device))
classifier1.eval()

classifier2 = Classifier(input_dim=128, num_classes=3).to(device)
classifier2.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final_uixfayrn.pth", map_location=device))
classifier2.eval()

# Función para generar mapas de probabilidad
def generate_probability_maps(embeddings, projection_head, classifier, device):
    with torch.no_grad():
        embeddings = embeddings.to(device).squeeze(0).permute(1, 2, 3, 0)
        embeddings_flat = embeddings.reshape(-1, 48)
        
        z = projection_head(embeddings_flat)
        z = F.normalize(z, dim=1)
        
        logits = classifier(z)
        probs = F.softmax(logits, dim=1)
        
        probs = probs.view(128, 128, 128, 3).permute(3, 0, 1, 2)
        return probs

# Funciones para calcular métricas
def calculate_metrics(pred, true, num_classes=3):
    dice_scores = []
    sensitivity_scores = []
    precision_scores = []
    
    for cls in range(num_classes):
        pred_cls = (pred == cls).astype(np.uint8)
        true_cls = (true == cls).astype(np.uint8)
        
        tp = np.sum(pred_cls * true_cls)
        fp = np.sum(pred_cls * (1 - true_cls))
        fn = np.sum((1 - pred_cls) * true_cls)
        
        dice = 2 * tp / (2 * tp + fp + fn + 1e-6)
        dice_scores.append(dice)
        
        sensitivity = tp / (tp + fn + 1e-6)
        sensitivity_scores.append(sensitivity)
        
        precision = tp / (tp + fp + 1e-6)
        precision_scores.append(precision)
    
    return dice_scores, sensitivity_scores, precision_scores

# Directorio de salida
output_dir = "trained_models/mapas_combinados"
os.makedirs(output_dir, exist_ok=True)



  projection_head1.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_o9kppyr5.pth", map_location=device))
  projection_head2.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final_uixfayrn.pth", map_location=device))
  classifier1.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final_o9kppyr5.pth", map_location=device))
  classifier2.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final_uixfayrn.pth", map_location=device))


## Ejecutar calculo y guardar mapas

In [2]:
# Listas para métricas
all_dice = {0: [], 1: [], 2: []}
all_sensitivity = {0: [], 1: [], 2: []}
all_precision = {0: [], 1: [], 2: []}

# Procesar y combinar
for idx, ((embeddings1, labels1), (embeddings2, labels2)) in enumerate(zip(loader_model1, loader_model2)):
    # Generar mapas de probabilidad para ambos modelos
    prob_maps1 = generate_probability_maps(embeddings1, projection_head1, classifier1, device)  # [3, 128, 128, 128]
    prob_maps2 = generate_probability_maps(embeddings2, projection_head2, classifier2, device)  # [3, 128, 128, 128]
    
    # Combinar mapas:
    # - Clase 0: del modelo 1
    # - Clase 1: máximo entre ambos modelos
    # - Clase 2: del modelo 2
    combined_prob_maps = torch.zeros_like(prob_maps1)  # [3, 128, 128, 128]
    combined_prob_maps[0] = prob_maps1[0]  # Clase 0 del modelo 1
    combined_prob_maps[1] = torch.max(prob_maps1[1], prob_maps2[1])  # Clase 1 máximo entre ambos
    combined_prob_maps[2] = prob_maps2[2]  # Clase 2 del modelo 2
    
    # Normalizar probabilidades para que sumen 1 en cada vóxel
    combined_prob_maps = combined_prob_maps / combined_prob_maps.sum(dim=0, keepdim=True)
    
    # Convertir a numpy
    prob_maps_np = combined_prob_maps.cpu().numpy()  # [3, 128, 128, 128]
    prob_maps_np_nifti = np.transpose(prob_maps_np, (1, 2, 3, 0))  # [128, 128, 128, 3]
    
    # Generar segmentación semántica
    segmentation = np.argmax(prob_maps_np, axis=0)  # [128, 128, 128]
    segmentation_np = segmentation.astype(np.uint8)
    
    # Etiquetas (usamos las del modelo 2, asumiendo que son iguales)
    labels = labels2.squeeze(0)  # [128, 128, 128]
    labels_np = labels.cpu().numpy().astype(np.uint8)
    
    # Calcular métricas
    dice, sensitivity, precision = calculate_metrics(segmentation_np, labels_np)
    for cls in range(3):
        all_dice[cls].append(dice[cls])
        all_sensitivity[cls].append(sensitivity[cls])
        all_precision[cls].append(precision[cls])
    
    print(f"Caso {idx} - Dice: {dice}, Sensitivity: {sensitivity}, Precision: {precision}")
    
    # Crear imágenes NIfTI
    affine = np.eye(4)
    
    # Guardar mapas de probabilidad combinados
    nifti_prob_img = nib.Nifti1Image(prob_maps_np_nifti, affine)
    prob_output_path = os.path.join(output_dir, f"probability_maps_case_{idx}.nii.gz")
    nib.save(nifti_prob_img, prob_output_path)
    print(f"Guardado mapa de probabilidad en {prob_output_path}")
    
    # Guardar etiquetas
    nifti_label_img = nib.Nifti1Image(labels_np, affine)
    label_output_path = os.path.join(output_dir, f"labels_case_{idx}.nii.gz")
    nib.save(nifti_label_img, label_output_path)
    print(f"Guardadas etiquetas en {label_output_path}")
    
    # Guardar segmentación semántica
    nifti_seg_img = nib.Nifti1Image(segmentation_np, affine)
    seg_output_path = os.path.join(output_dir, f"segmentation_case_{idx}.nii.gz")
    nib.save(nifti_seg_img, seg_output_path)
    print(f"Guardada segmentación en {seg_output_path}")

# Calcular promedios y desviaciones estándar
class_names = ["Fondo", "Vasogénico", "Infiltrado"]
for cls in range(3):
    dice_mean = np.mean(all_dice[cls])
    dice_std = np.std(all_dice[cls])
    sens_mean = np.mean(all_sensitivity[cls])
    sens_std = np.std(all_sensitivity[cls])
    prec_mean = np.mean(all_precision[cls])
    prec_std = np.std(all_precision[cls])
    
    print(f"\nClase {cls} ({class_names[cls]}):")
    print(f"  Dice: {dice_mean:.4f} ± {dice_std:.4f}")
    print(f"  Sensibilidad: {sens_mean:.4f} ± {sens_std:.4f}")
    print(f"  Precisión: {prec_mean:.4f} ± {prec_std:.4f}")

Caso 0 - Dice: [0.9984164779656527, 0.421750040611237, 0.686223303396002], Sensitivity: [0.9970624177938241, 0.30777540641562523, 0.8939531632088092], Precision: [0.9997742208999835, 0.6697821100533382, 0.5568312118008502]
Guardado mapa de probabilidad en trained_models/mapas_combinados/probability_maps_case_0.nii.gz
Guardadas etiquetas en trained_models/mapas_combinados/labels_case_0.nii.gz
Guardada segmentación en trained_models/mapas_combinados/segmentation_case_0.nii.gz
Caso 1 - Dice: [0.9952324265824696, 0.6344339622552444, 0.47209805837979996], Sensitivity: [0.990528524581974, 0.5808441725211855, 0.861302242554708], Precision: [0.9999812182992083, 0.6989174141447907, 0.32516355588981233]
Guardado mapa de probabilidad en trained_models/mapas_combinados/probability_maps_case_1.nii.gz
Guardadas etiquetas en trained_models/mapas_combinados/labels_case_1.nii.gz
Guardada segmentación en trained_models/mapas_combinados/segmentation_case_1.nii.gz
Caso 2 - Dice: [0.9960911779713028, 0.779