In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from monai.networks.nets import SwinUNETR
import nibabel as nib
from src.get_data import CustomDataset
from monai.data import DataLoader
from monai import transforms
from src.custom_transforms import ConvertToMultiChannelBasedOnAnotatedInfiltration
from monai.inferers import sliding_window_inference
from functools import partial


# Generar los modelos

In [2]:

# Dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Modelos
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

class Classifier(nn.Module):
    def __init__(self, input_dim=128, num_classes=3):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.fc(x)

# Función para generar mapas de probabilidad y segmentación por lotes
def generate_probability_maps(embeddings, projection_head, classifier, device, batch_size=100000):
    with torch.no_grad():
        embeddings = embeddings.permute(1, 2, 3, 0)  # [H, W, D, 48]
        H, W, D = embeddings.shape[:3]
        embeddings_flat = embeddings.reshape(-1, 48)  # [H*W*D, 48]
        total_voxels = embeddings_flat.shape[0]
        
        prob_maps = torch.zeros(total_voxels, 3, device=device)
        
        for start_idx in range(0, total_voxels, batch_size):
            end_idx = min(start_idx + batch_size, total_voxels)
            batch_embeddings = embeddings_flat[start_idx:end_idx].to(device)
            
            z = projection_head(batch_embeddings)
            z = F.normalize(z, dim=1)
            logits = classifier(z)
            probs = F.softmax(logits, dim=1)
            
            prob_maps[start_idx:end_idx] = probs
            
            del batch_embeddings, z, logits, probs
            torch.cuda.empty_cache()
        
        prob_maps = prob_maps.view(H, W, D, 3).permute(3, 0, 1, 2)  # [3, H, W, D]
        segmentation = torch.argmax(prob_maps, dim=0).to(torch.uint8)  # [H, W, D]
        
        return prob_maps, segmentation


def save_img(I_img,savename,header=None,affine=None):
    if header is None or affine is None:
        affine = np.diag([1, 1, 1, 1])
        new_img = nib.nifti1.Nifti1Image(I_img, affine, header=None)
    else:
        new_img = nib.nifti1.Nifti1Image(I_img, affine, header=header)

    nib.save(new_img, savename)


In [3]:
# Configuración del DataLoader

test_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]), #Leer imagenes 
        ConvertToMultiChannelBasedOnAnotatedInfiltration(keys="label"),
        # transforms.CropForegroundd(
        #     keys=["image", "label"],
        #     source_key="label",
        #     k_divisible=[128, 128, 128],
        # ),
        # transforms.RandSpatialCropd(
        #     keys=["label", "image"],
        #     roi_size=[128, 128, 128],
        #     random_size=False,
        # ),   
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), #Normalizar intensidades
    ]
)

dataset_path = './Dataset/Dataset_recurrence'
test_set = CustomDataset(dataset_path, section="test", transform=test_transform)  # Ajusta transform
test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=1)

# Obtener header + affine
mri = nib.load('./Dataset/Dataset_recurrence/test/images/images_structural/UPENN-GBM-00036_11/UPENN-GBM-00036_11_T1GD.nii.gz')
header = mri.header
affine = mri.affine

im_t = test_set[0]

print(im_t['image'].shape, type(im_t['image']))
print(im_t['label'].shape, type(im_t['label']))

Found 1 images and 1 labels.
torch.Size([11, 240, 240, 155]) <class 'monai.data.meta_tensor.MetaTensor'>
torch.Size([2, 240, 240, 155]) <class 'monai.data.meta_tensor.MetaTensor'>


In [4]:
# Cargar SwinUNETR
### Hyperparameter
roi = (128, 128, 128)
swin_model = SwinUNETR(
    img_size=roi,
    in_channels=11,
    out_channels=2,  # mdificar con edema
    feature_size=48,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    use_checkpoint=True,
).to(device)

# model_path = "Dataset/model.pt"
model_path = "artifacts/15cwmu45_best_model:v0/model.pt"  

# Load the model on CPU
loaded_model = torch.load(model_path, map_location=torch.device(device))["state_dict"]
# model.load_state_dict(torch.load(model_path)["state_dict"])

# Load the state dictionary into the model
swin_model.load_state_dict(loaded_model)

# Set the model to evaluation mode
swin_model.eval()


# Hook para capturar embeddings de decoder1.conv_block
decoder_features = None
def decoder_hook_fn(module, input, output):
    global decoder_features
    decoder_features = output

hook_handle = swin_model.decoder1.conv_block.register_forward_hook(decoder_hook_fn)

# Configurar sliding window inference
roi_size = (128, 128, 128)
model_inferer = partial(
    sliding_window_inference,
    roi_size=roi_size,
    sw_batch_size=1,  # Número de ventanas por iteración
    predictor=swin_model,
    overlap=0.5,  # Solapamiento para suavizar bordes
)

# Cargar modelos contrastivo y clasificador
projection_head = ProjectionHead(input_dim=48).to(device)
projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final.pth", map_location=device))
projection_head.eval()

classifier = Classifier(input_dim=128, num_classes=3).to(device)
classifier.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final.pth", map_location=device))
classifier.eval()





Classifier(
  (fc): Linear(in_features=128, out_features=3, bias=True)
)

In [5]:
# Directorio de salida
output_dir = "trained_models/inference_results"
os.makedirs(output_dir, exist_ok=True)

# Pipeline de inferencia con sliding window
for idx, batch_data in enumerate(test_loader):
    mri = batch_data["image"].to(device)  # [1, 11, 240, 240, 155]
    print(f"Procesando MRI {idx}, shape: {mri.shape}")
    
    # Dimensiones del volumen completo
    _, _, H, W, D = mri.shape
    
    # Buffer para embeddings completos y conteo de contribuciones (para solapamiento)
    embeddings_full = torch.zeros(48, H, W, D, device=device)
    count_map = torch.zeros(H, W, D, device=device)
    
    # Calcular posiciones de las ventanas
    step_h = int(roi_size[0] * (1 - 0.25))
    step_w = int(roi_size[1] * (1 - 0.25))
    step_d = int(roi_size[2] * (1 - 0.25))
    
    with torch.no_grad():
        for h in range(0, H - roi_size[0] + 1, step_h):
            for w in range(0, W - roi_size[1] + 1, step_w):
                for d in range(0, D - roi_size[2] + 1, step_d):
                    h_end = min(h + roi_size[0], H)
                    w_end = min(w + roi_size[1], W)
                    d_end = min(d + roi_size[2], D)
                    
                    # Extraer parche
                    patch = mri[:, :, h:h_end, w:w_end, d:d_end]
                    if patch.shape[2:] != torch.Size(roi_size):
                        # Rellenar si el parche es más pequeño (bordes)
                        pad_h = roi_size[0] - patch.shape[2]
                        pad_w = roi_size[1] - patch.shape[3]
                        pad_d = roi_size[2] - patch.shape[4]
                        patch = F.pad(patch, (0, pad_d, 0, pad_w, 0, pad_h))
                    
                    # Obtener embeddings del parche
                    _ = swin_model(patch)
                    embeddings_patch = decoder_features.squeeze(0)  # [48, 128, 128, 128]
                    
                    # Acumular en el volumen completo
                    embeddings_full[:, h:h_end, w:w_end, d:d_end] += embeddings_patch[:, :h_end-h, :w_end-w, :d_end-d]
                    count_map[h:h_end, w:w_end, d:d_end] += 1
                    
                    del patch, embeddings_patch
                    torch.cuda.empty_cache()
    
    # Normalizar por el conteo de solapamientos
    embeddings_full /= count_map.clamp(min=1).unsqueeze(0)  # [48, 240, 240, 155]
    
    # Generar mapas de probabilidad y segmentación
    prob_maps, segmentation = generate_probability_maps(embeddings_full, projection_head, classifier, device, batch_size=100000)
    print(f"Mapas de probabilidad generados, shape: {prob_maps.shape}")
    print(f"Segmentación generada, shape: {segmentation.shape}")
    
    # Convertir a numpy para guardar
    prob_maps_np = prob_maps.cpu().numpy()  # [3, 240, 240, 155]
    prob_maps_np_nifti = np.transpose(prob_maps_np, (1, 2, 3, 0))  # [240, 240, 155, 3]
    segmentation_np = segmentation.cpu().numpy()  # [240, 240, 155]
    
    # Crear imágenes NIfTI
    # affine = np.eye(4)  # Ajusta si tienes una matriz afín real
    
    # Guardar mapas de probabilidad
    nifti_prob_img = nib.Nifti1Image(prob_maps_np_nifti, affine)
    prob_output_path = os.path.join(output_dir, f"probability_maps_mri_{idx}.nii.gz")
    # nib.save(nifti_prob_img, prob_output_path)
    save_img(
            prob_maps_np_nifti, #output_tensor.numpy(),
            prob_output_path,
            header,
            affine,
        )
    print(f"Guardado mapa de probabilidad en {prob_output_path}")
    
    # Guardar segmentación semántica
    nifti_seg_img = nib.Nifti1Image(segmentation_np, affine)
    seg_output_path = os.path.join(output_dir, f"segmentation_mri_{idx}.nii.gz")
    save_img(
            segmentation_np, #output_tensor.numpy(),
            seg_output_path,
            header,
            affine,
        )
    print(f"Guardada segmentación en {seg_output_path}")
    
    # Liberar memoria
    del mri, embeddings_full, count_map, prob_maps, segmentation
    torch.cuda.empty_cache()

# Remover el hook
hook_handle.remove()
print("Inferencia completada.")

Procesando MRI 0, shape: torch.Size([1, 11, 240, 240, 155])
Mapas de probabilidad generados, shape: torch.Size([3, 240, 240, 155])
Segmentación generada, shape: torch.Size([240, 240, 155])
Guardado mapa de probabilidad en trained_models/inference_results/probability_maps_mri_0.nii.gz
Guardada segmentación en trained_models/inference_results/segmentation_mri_0.nii.gz
Inferencia completada.
