In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Loading functions
import os
import time
from monai.data import DataLoader, decollate_batch


import torch
import torch.nn.parallel

from src.get_data import CustomDataset, CustomDatasetSeg
import numpy as np
from scipy import ndimage
from types import SimpleNamespace
import wandb
import logging

#####
import json
import shutil
import tempfile

import matplotlib.pyplot as plt
import nibabel as nib

from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai import transforms
from monai.transforms import (
    AsDiscrete,
    Activations,
    MapTransform,
    Transform,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.utils.enums import MetricReduction
from monai.networks.nets import SwinUNETR
from monai import data

# from monai.data import decollate_batch
from functools import partial
from src.custom_transforms import ConvertToMultiChannelBasedOnN_Froi, ConvertToMultiChannelBasedOnAnotatedInfiltration, masked

## Transformaciones Swin UNETR

In [3]:
roi = (128, 128, 128) # (220, 220, 155) (128, 128, 64)
source_k="image"
train_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        # ConvertToMultiChannelBasedOnN_Froi(keys="label"),
        # masked(keys="image"),
        ConvertToMultiChannelBasedOnAnotatedInfiltration(keys="label"),
        transforms.CropForegroundd(
            keys=["image", "label"],
            source_key=source_k,
            k_divisible=[roi[0], roi[1], roi[2]],
        ),
        transforms.RandSpatialCropd(
            keys=["image", "label"],
            roi_size=[roi[0], roi[1], roi[2]],
            random_size=False,
        ),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        
    ]
)
val_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        # ConvertToMultiChannelBasedOnN_Froi(keys="label"),
        # masked(keys="image"),
        ConvertToMultiChannelBasedOnAnotatedInfiltration(keys="label"),
        transforms.RandSpatialCropd(
            keys=["image", "label"],
            roi_size=[-1, -1, -1], #[224, 224, 128],
            random_size=False,
        ),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)



In [4]:
######################
# Crear el modelo
######################

### Hyperparameter
roi = (128, 128, 128)  # (128, 128, 128)

# Create Swin transformer
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = SwinUNETR(
    img_size=roi,
    in_channels=11,
    out_channels=2,  # mdificar con edema
    feature_size=48, #48
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    use_checkpoint=True,
)

# Load the best model
model_path = "artifacts/ddfkhvym_best_model:v0/model.pt"

# Load the model on CPU
loaded_model = torch.load(model_path, map_location=torch.device('cuda:0'))["state_dict"]

# Load the state dictionary into the model
model.load_state_dict(loaded_model)

model.to(device)

# Set the model to evaluation mode
model.eval()



SwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(11, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=48, out_features=144, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=48, out_features=48, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=48, out_features=192, bias=True)
              (linear2): Linear(in_featur

In [5]:
# Create dataset data loader
dataset_path='./Dataset/Dataset_recurrence'
train_set=CustomDataset(dataset_path, section="train", transform=train_transform) # v_transform
train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)



Found 36 images and 36 labels.


In [6]:

embedding_dir = "Dataset/contrastive_voxel_wise/embeddings"
label_output_dir = "Dataset/contrastive_voxel_wise/labels"

# Crear carpetas si no existen
os.makedirs(embedding_dir, exist_ok=True)
os.makedirs(label_output_dir, exist_ok=True)

# Variable para las características del decoder
decoder_features = None

# Función hook
def decoder_hook_fn(module, input, output):
    global decoder_features
    decoder_features = output

# Registrar el hook en decoder1.conv_block
hook_handle_decoder = model.decoder1.conv_block.register_forward_hook(decoder_hook_fn)

# Extraer y guardar
with torch.no_grad():
    for idx, batch_data in enumerate(train_loader):
        image, label = batch_data["image"], batch_data["label"]
        print("Image", image.shape)  # [1, 11, 128, 128, 128]
        print("label before squeeze", label.shape)  # [1, 2, 128, 128, 128]
        
        image = image.to(device)
        label = label.squeeze(0)  # [2, 128, 128, 128]
        
        # Convertir one-hot a etiquetas únicas
        label_sum = label.sum(dim=0)  # [128, 128, 128], suma de canales
        label_class = torch.zeros_like(label_sum, dtype=torch.long)  # [128, 128, 128]
        
        # Asignar clases:
        # - Fondo (0, 0) -> 0
        # - Vasogénico (1, 0) -> 1
        # - Infiltrado (0, 1) -> 2
        label_class[label[1] == 1] = 2  # Infiltrado
        label_class[(label[0] == 1) & (label[1] == 0)] = 1  # Vasogénico
        # Donde label_sum == 0, ya es fondo (0)
        
        label = label_class.cpu().numpy()  # [128, 128, 128]
        print("label", label.shape)
        
        # Obtener embeddings
        _ = model(image)  # Ejecuta el forward para activar el hook
        
        print("decoder_features:", decoder_features.shape)  # [1, 48, 128, 128, 128]
        
        # Guardar embeddings y etiquetas
        np.save(f"{embedding_dir}/case_{idx}.npy", decoder_features.cpu().numpy())
        np.save(f"{label_output_dir}/case_{idx}.npy", label)
        
        print(f"Guardado embeddings y etiquetas para caso {idx}")

# Remover el hook
hook_handle_decoder.remove()

Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 0
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 1
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 2
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torch.Size([1, 48, 128, 128, 128])
Guardado embeddings y etiquetas para caso 3
Image torch.Size([1, 11, 128, 128, 128])
label before squeeze torch.Size([1, 2, 128, 128, 128])
label (128, 128, 128)
decoder_features: torc

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EmbeddingDataset(Dataset):
    def __init__(self, embedding_dir, label_dir):
        self.embedding_dir = embedding_dir
        self.label_dir = label_dir
        self.case_files = [f for f in os.listdir(embedding_dir) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.case_files)
    
    def __getitem__(self, idx):
        embedding_path = os.path.join(self.embedding_dir, f"case_{idx}.npy")
        label_path = os.path.join(self.label_dir, f"case_{idx}.npy")
        
        embeddings = np.load(embedding_path)
        labels = np.load(label_path)
        
        embeddings = torch.tensor(embeddings, dtype=torch.float32).squeeze(0)
        labels = torch.tensor(labels, dtype=torch.long)
        
        return embeddings, labels

class ProjectionHead(nn.Module):
    def __init__(self, input_dim=48, hidden_dim=128, output_dim=128):
        super(ProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

def contrastive_loss(z, labels, temperature=0.5, sample_size_per_class=1024):
    N_total = z.shape[0]
    z = F.normalize(z, dim=1)
    
    classes = torch.unique(labels)
    if len(classes) < 2:
        print(f"Advertencia: Solo una clase presente ({classes.tolist()}), devolviendo pérdida 0")
        return torch.tensor(0.0, device=device, requires_grad=True)
    
    sampled_z = []
    sampled_labels = []
    
    for cls in classes:
        cls_indices = (labels == cls).nonzero(as_tuple=True)[0]
        cls_size = cls_indices.shape[0]
        if cls_size > sample_size_per_class:
            indices = torch.randperm(cls_size)[:sample_size_per_class]
            cls_indices = cls_indices[indices]
        sampled_z.append(z[cls_indices])
        sampled_labels.append(labels[cls_indices])
    
    z = torch.cat(sampled_z, dim=0)
    labels = torch.cat(sampled_labels, dim=0)
    N = z.shape[0]
    
    print(f"Batch size: {N}, Unique labels: {torch.unique(labels).tolist()}")
    
    if N < 2:
        print("Advertencia: Batch con menos de 2 vóxeles, devolviendo pérdida 0")
        return torch.tensor(0.0, device=device, requires_grad=True)
    
    similarity = torch.mm(z, z.T) / temperature
    labels_eq = labels.unsqueeze(1) == labels.unsqueeze(0)
    labels_eq = labels_eq.float()
    eye = torch.eye(N, device=device)
    labels_eq = labels_eq * (1 - eye)
    
    exp_sim = torch.exp(similarity)
    pos_sum = (exp_sim * labels_eq).sum(dim=1)
    neg_sum = exp_sim.sum(dim=1) - exp_sim.diag()
    
    if pos_sum.sum() == 0:
        print("Advertencia: No hay pares positivos, pérdida será 0")
    
    loss = -torch.log((pos_sum + 1e-6) / (neg_sum + 1e-6))
    return loss.mean()

# Configuración
embedding_dir = "Dataset/contrastive_voxel_wise/embeddings"
label_dir = "Dataset/contrastive_voxel_wise/labels"
batch_size = 1
sample_size_per_class = 1024
temperature = 0.5
num_epochs = 10

dataset = EmbeddingDataset(embedding_dir, label_dir)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

model = ProjectionHead(input_dim=48).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
    total_loss = 0
    valid_batches = 0
    for batch_idx, (embeddings, labels) in enumerate(loader):
        embeddings = embeddings.to(device)
        labels = labels.to(device)
        
        embeddings = embeddings.squeeze(0).permute(1, 2, 3, 0)
        labels = labels.squeeze(0)
        
        embeddings_flat = embeddings.reshape(-1, 48)
        labels_flat = labels.reshape(-1)
        
        valid_mask = labels_flat >= 0
        embeddings_valid = embeddings_flat[valid_mask]
        labels_valid = labels_flat[valid_mask]
        
        if embeddings_valid.shape[0] < 2:
            print(f"Batch {batch_idx}: Insuficientes vóxeles válidos")
            continue
        
        z = model(embeddings_valid)
        loss = contrastive_loss(z, labels_valid, temperature, sample_size_per_class)
        
        if loss.item() == 0:
            continue  # No contar batches con pérdida 0 en el promedio
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        valid_batches += 1
        
        if batch_idx % 1 == 0:  # Reducir frecuencia de impresión
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / max(valid_batches, 1)  # Evitar división por 0
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Valid Batches: {valid_batches}/{len(loader)}")

torch.save(model.state_dict(), "contrastive_projection_head.pth")

Batch size: 2546, Unique labels: [0, 1, 2]
Epoch 1/10, Batch 0/36, Loss: 0.8071
Batch size: 3072, Unique labels: [0, 1, 2]
Epoch 1/10, Batch 1/36, Loss: 0.8541
Batch size: 3072, Unique labels: [0, 1, 2]
Epoch 1/10, Batch 2/36, Loss: 0.8532
Batch size: 3072, Unique labels: [0, 1, 2]
Epoch 1/10, Batch 3/36, Loss: 0.7374
Batch size: 3072, Unique labels: [0, 1, 2]
Epoch 1/10, Batch 4/36, Loss: 0.6246
Batch size: 2111, Unique labels: [0, 1, 2]
Epoch 1/10, Batch 5/36, Loss: 0.3689
Batch size: 2896, Unique labels: [0, 1, 2]
Epoch 1/10, Batch 6/36, Loss: 0.5647
Batch size: 3072, Unique labels: [0, 1, 2]
Epoch 1/10, Batch 7/36, Loss: 0.5819
Batch size: 3072, Unique labels: [0, 1, 2]
Epoch 1/10, Batch 8/36, Loss: 0.5266
Advertencia: Solo una clase presente ([0]), devolviendo pérdida 0
Batch size: 3072, Unique labels: [0, 1, 2]
Epoch 1/10, Batch 10/36, Loss: 0.5104
Batch size: 3072, Unique labels: [0, 1, 2]
Epoch 1/10, Batch 11/36, Loss: 0.5714
Batch size: 2559, Unique labels: [0, 1, 2]
Epoch 1/1

In [None]:
for idx, (embeddings, labels) in enumerate(loader):
    print(f"Unique labels: {torch.unique(labels).tolist()}")
    labels_flat = labels.reshape(-1)
    class_counts = torch.bincount(labels_flat)
    print(f"Case {idx}: Fondo: {class_counts[0]}, Vasogénico: {class_counts[1] if len(class_counts) > 1 else 0}, Infiltrado: {class_counts[2] if len(class_counts) > 2 else 0}")

In [None]:
import os
import numpy as np

label_dir = "Dataset/contrastive_voxel_wise/labels"
for idx in range(36):  # Ajusta según el número de casos
    label_path = os.path.join(label_dir, f"case_{idx}.npy")
    if os.path.exists(label_path):
        label = np.load(label_path)
        unique, counts = np.unique(label, return_counts=True)
        print(f"Case {idx}: {dict(zip(unique, counts))}")