In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from monai.networks.nets import SwinUNETR
import nibabel as nib
from src.get_data import CustomDataset # Ajusta según tu implementación
from monai.data import DataLoader
from monai import transforms
from src.custom_transforms import ConvertToMultiChannelBasedOnAnotatedInfiltration



In [2]:
# Dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Modelos
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

class Classifier(nn.Module):
    def __init__(self, input_dim=128, num_classes=3):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.fc(x)

# Función para generar mapas de probabilidad y segmentación por lotes
def generate_probability_maps(embeddings, projection_head, classifier, device, batch_size=100000):
    with torch.no_grad():
        embeddings = embeddings.squeeze(0).permute(1, 2, 3, 0)  # [128, 128, 128, 48]
        embeddings_flat = embeddings.reshape(-1, 48)  # [2097152, 48]
        total_voxels = embeddings_flat.shape[0]
        
        prob_maps = torch.zeros(total_voxels, 3, device=device)
        
        for start_idx in range(0, total_voxels, batch_size):
            end_idx = min(start_idx + batch_size, total_voxels)
            batch_embeddings = embeddings_flat[start_idx:end_idx].to(device)  # [batch_size, 48]
            
            z = projection_head(batch_embeddings)
            z = F.normalize(z, dim=1)
            logits = classifier(z)
            probs = F.softmax(logits, dim=1)
            
            prob_maps[start_idx:end_idx] = probs
            
            del batch_embeddings, z, logits, probs
            torch.cuda.empty_cache()
        
        prob_maps = prob_maps.view(128, 128, 128, 3).permute(3, 0, 1, 2)  # [3, 128, 128, 128]
        segmentation = torch.argmax(prob_maps, dim=0).to(torch.uint8)  # [128, 128, 128]
        
        return prob_maps, segmentation
    
def save_img(I_img,savename,header=None,affine=None):
    if header is None or affine is None:
        affine = np.diag([1, 1, 1, 1])
        new_img = nib.nifti1.Nifti1Image(I_img, affine, header=None)
    else:
        new_img = nib.nifti1.Nifti1Image(I_img, affine, header=header)

    nib.save(new_img, savename)


In [3]:
# Configuración del DataLoader

roi = (128, 128, 128) # (220, 220, 155) (128, 128, 64)
source_k="label"
test_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        # ConvertToMultiChannelBasedOnN_Froi(keys="label"),
        # masked(keys="image"),
        ConvertToMultiChannelBasedOnAnotatedInfiltration(keys="label"),
        transforms.CropForegroundd(
            keys=["image", "label"],
            source_key=source_k,
            k_divisible=[roi[0], roi[1], roi[2]],
        ),
        transforms.RandSpatialCropd(
            keys=["image", "label"],
            roi_size=[roi[0], roi[1], roi[2]],
            random_size=False,
        ),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        
    ]
)

dataset_path = './Dataset/Dataset_recurrence'
test_set = CustomDataset(dataset_path, section="test", transform=test_transform)  # Ajusta transform
test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=1)

# Obtener header + affine
mri = nib.load('./Dataset/Dataset_recurrence/test/images/images_structural/UPENN-GBM-00332_11/UPENN-GBM-00332_11_T1GD.nii.gz')
header = mri.header
affine = mri.affine

im_t = test_set[0]

print(im_t['image'].shape, type(im_t['image']))
print(im_t['label'].shape, type(im_t['label']))




Found 1 images and 1 labels.
torch.Size([11, 128, 128, 128]) <class 'monai.data.meta_tensor.MetaTensor'>
torch.Size([2, 128, 128, 128]) <class 'monai.data.meta_tensor.MetaTensor'>


In [None]:
# import wandb
# run = wandb.init()
# artifact = run.use_artifact('mlops-team89/Swin_UPENN_10cases/15cwmu45_best_model:v0', type='model')
# artifact_dir = artifact.download()

In [None]:
# Cargar SwinUNETR
### Hyperparameter
roi = (128, 128, 128)
swin_model = SwinUNETR(
    img_size=roi,
    in_channels=11,
    out_channels=2,  # mdificar con edema
    feature_size=48,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    use_checkpoint=True,
).to(device)

# model_path = "Dataset/model.pt"
model_path = "artifacts/15cwmu45_best_model:v0/model.pt" #"artifacts/15cwmu45_best_model:v0/model.pt"  

# Load the model on CPU
loaded_model = torch.load(model_path, map_location=torch.device(device))["state_dict"]
# model.load_state_dict(torch.load(model_path)["state_dict"])

# Load the state dictionary into the model
swin_model.load_state_dict(loaded_model)

# Set the model to evaluation mode
swin_model.eval()

  loaded_model = torch.load(model_path, map_location=torch.device(device))["state_dict"]


SwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(11, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=48, out_features=144, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=48, out_features=48, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=48, out_features=192, bias=True)
              (linear2): Linear(in_featur

In [6]:
# Hook para capturar embeddings
decoder_features = None
def decoder_hook_fn(module, input, output):
    global decoder_features
    decoder_features = output
    
hook_handle = swin_model.decoder1.conv_block.register_forward_hook(decoder_hook_fn)

# Cargar modelos contrastivo y clasificador
projection_head = ProjectionHead(input_dim=48).to(device)
projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final.pth", map_location=device))
projection_head.eval()

classifier = Classifier(input_dim=128, num_classes=3).to(device)
classifier.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final.pth", map_location=device))
classifier.eval()

# Directorio de salida
output_dir = "trained_models/inference_results"
os.makedirs(output_dir, exist_ok=True)

# Pipeline de inferencia
for idx, batch_data in enumerate(test_loader):
    # Extraer la imagen y convertir MetaTensor a tensor puro
    mri = batch_data["image"]
    if isinstance(mri, torch.Tensor) and hasattr(mri, 'meta'):
        mri = mri.as_tensor()  # Convertir MetaTensor a tensor puro
    mri = mri.to(device)  # [1, 11, 128, 128, 128]
    print(f"Procesando MRI {idx}, shape: {mri.shape}")
    
    # Extraer embeddings con SwinUNETR
    with torch.no_grad():
        _ = swin_model(mri)  # Activa el hook
        embeddings = decoder_features  # [1, 48, 128, 128, 128]
        if isinstance(embeddings, torch.Tensor) and hasattr(embeddings, 'meta'):
            embeddings = embeddings.as_tensor()  # Convertir MetaTensor a tensor puro
    
    # Generar mapas de probabilidad y segmentación
    prob_maps, segmentation = generate_probability_maps(embeddings, projection_head, classifier, device, batch_size=100000)
    print(f"Mapas de probabilidad generados, shape: {prob_maps.shape}")
    print(f"Segmentación generada, shape: {segmentation.shape}")
    
    # Convertir a numpy para guardar
    prob_maps_np = prob_maps.cpu().numpy()  # [3, 128, 128, 128]
    prob_maps_np_nifti = np.transpose(prob_maps_np, (1, 2, 3, 0))  # [128, 128, 128, 3]
    segmentation_np = segmentation.cpu().numpy()  # [128, 128, 128]
    
    # Crear imágenes NIfTI
    affine = np.eye(4)  # Ajusta si tienes una matriz afín real
    
    # Guardar mapas de probabilidad
    nifti_prob_img = nib.Nifti1Image(prob_maps_np_nifti, affine)
    prob_output_path = os.path.join(output_dir, f"probability_maps_mri_{idx}.nii.gz")
    # nib.save(nifti_prob_img, prob_output_path)
    save_img(
            prob_maps_np_nifti, #output_tensor.numpy(),
            prob_output_path,
            header,
            affine,
        )
    print(f"Guardado mapa de probabilidad en {prob_output_path}")
    
    # Guardar segmentación semántica
    nifti_seg_img = nib.Nifti1Image(segmentation_np, affine)
    seg_output_path = os.path.join(output_dir, f"segmentation_mri_{idx}.nii.gz")
    # nib.save(nifti_seg_img, seg_output_path)
    save_img(
            segmentation_np, #output_tensor.numpy(),
            seg_output_path,
            header,
            affine,
        )
    print(f"Guardada segmentación en {seg_output_path}")
    
    # Liberar memoria
    del mri, embeddings, prob_maps, segmentation
    torch.cuda.empty_cache()

# Remover el hook
hook_handle.remove()
print("Inferencia completada.")

  projection_head.load_state_dict(torch.load("trained_models/checkpoints_contrastive/contrastive_projection_head_final.pth", map_location=device))
  classifier.load_state_dict(torch.load("trained_models/checkpoints/supervised_classifier_final.pth", map_location=device))


Procesando MRI 0, shape: torch.Size([1, 11, 128, 128, 128])
Mapas de probabilidad generados, shape: torch.Size([3, 128, 128, 128])
Segmentación generada, shape: torch.Size([128, 128, 128])
Guardado mapa de probabilidad en trained_models/inference_results/probability_maps_mri_0.nii.gz
Guardada segmentación en trained_models/inference_results/segmentation_mri_0.nii.gz
Inferencia completada.
