In [11]:
import logging
import os
import sys

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from matplotlib import pyplot as plt

import monai
from monai.data import ArrayDataset, create_test_image_2d, decollate_batch,list_data_collate, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric, MeanIoU
from monai.networks.nets import UNet, SegResNet, flexible_unet
from monai.transforms import (
    Activations,
    EnsureChannelFirstd,
    AsDiscrete,
    Compose,
    LoadImaged,
    RandSpatialCropd,
    RandRotate90d,
    ScaleIntensityd,
    Resized,
    ToTensor,
    Transposed,
    Lambda,
    Lambdad,
    RandScaleIntensityd,
    RandShiftIntensityd
)
from monai.visualize import plot_2d_or_3d_image
import time
from tqdm import tqdm 


from utils.image_processing_utils import overlay_prediction, overlay_images
from utils.data_utils import prepare_dataloaders_stroke_2021

monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)


from ldm.models.autoencoder import AutoencoderKL

MONAI version: 1.4.0
Numpy version: 1.26.3
Pytorch version: 2.5.0+cu124
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: /home/<username>/miniconda3/envs/dpm/lib/python3.9/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.1
scikit-image version: 0.24.0
scipy version: 1.13.1
Pillow version: 10.2.0
Tensorboard version: 2.18.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.20.0+cu124
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.3
einops version: 0.8.0
transformers version: 4.45.2
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional depen

In [None]:
lossconfig = {
    "target": "ldm.modules.losses.contperceptual.LPIPSWithDiscriminator",
    "params": {
        "disc_start": 50001,
        "kl_weight": 1.0e-06,
        "disc_weight": 0.5,
    }
}

model = AutoencoderKL(
    ddconfig={
        'double_z': True,
        'z_channels': 4,
        'resolution': 256,
        'in_channels': 3,
        'out_ch': 3,
        'ch': 128,
        'ch_mult': [1, 2, 4, 4],
        'num_res_blocks': 2,
        'attn_resolutions': [],
        'dropout': 0.0
    },
    lossconfig=lossconfig,
    embed_dim=4,
    #ckpt_path="models/first_stage_models/kl-f8/model.ckpt",
    ckpt_path="logs/2024-12-16T16-18-19_autoencoder_kl_f8_stroke_image/checkpoints/epoch=57-step=253299.ckpt",
    image_key="image_dicom",
)

making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
Restored from logs/2024-10-31T00-42-15_autoencoder_kl_f8/checkpoints/epoch=000040.ckpt


In [13]:
# Veriseti klasörü
data_dir = "/home/arms/Workspace/Dataset/Stroke2021"
roi_size = (256, 256)
spatial_size=[256,256]
num_folds=5
num_workers=0
batch_size=1
random_state=0

# define device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")


model = model.to(device)


train_loader, val_loader, class_weights = prepare_dataloaders_stroke_2021(data_dir, roi_size=roi_size, spatial_size=spatial_size, expand_img_channel=False, num_folds=1, val_ratio=0.2, num_workers=num_workers ,batch_size=batch_size, random_state=random_state, cache_rate=0.0, shuffle_valid=True)

In [14]:
def show_images(image, reconstructed, idx=0):
    fig, axs = plt.subplots(2, 3, figsize=(20, 10))

    # Show original image channels
    axs[0, 0].imshow(image[idx, 0].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
    axs[0, 0].set_title("Input Image - Channel 1")
    axs[0, 0].axis('off')

    axs[0, 1].imshow(image[idx, 1].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
    axs[0, 1].set_title("Input Image - Channel 2")
    axs[0, 1].axis('off')

    axs[0, 2].imshow(image[idx, 2].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
    axs[0, 2].set_title("Input Image - Channel 3")
    axs[0, 2].axis('off')

    # Show reconstructed image channels
    axs[1, 0].imshow(reconstructed[idx, 0].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
    axs[1, 0].set_title("Reconstructed - Channel 1")
    axs[1, 0].axis('off')

    axs[1, 1].imshow(reconstructed[idx, 1].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
    axs[1, 1].set_title("Reconstructed - Channel 2")
    axs[1, 1].axis('off')

    axs[1, 2].imshow(reconstructed[idx, 2].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
    axs[1, 2].set_title("Reconstructed - Channel 3")
    axs[1, 2].axis('off')

    plt.show()


dice_metric_all = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
dice_metric_wob = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
iou_metric_all = MeanIoU(include_background=True, reduction="mean", get_not_nans=False)
iou_metric_wob = MeanIoU(include_background=False, reduction="mean", get_not_nans=False)
iou_metric_classwise = MeanIoU(include_background=True, reduction="none", get_not_nans=False)

In [15]:
def one_hot_encode_masks(masks, num_classes, device):
    """
    Given a batch of masks, this function converts them to one-hot encoded format.
    
    Args:
    - masks (torch.Tensor): Input masks of shape [batch_size, 1, height, width].
    - num_classes (int): The number of segmentation classes.
    - device (torch.device): The device on which the operations should be performed (e.g., GPU or CPU).
    
    Returns:
    - masks_onehot (torch.Tensor): One-hot encoded masks with shape [batch_size, num_classes, height, width].
    """
    # 1. Kanal boyutunu sıkıştır (1. boyutu kaldır)
    masks = masks.squeeze(1)  # [batch_size, height, width]

    # 2. One-hot kodlama
    masks_onehot = F.one_hot(masks.long(), num_classes=num_classes)  # [batch_size, height, width, num_classes]

    # 3. Boyutları ayarla (kanalı öne taşı)
    masks_onehot = masks_onehot.permute(0, 3, 1, 2).float().to(device)  # [batch_size, num_classes, height, width]
    #masks_onehot = (masks_onehot*2.0)-1.0

    return masks_onehot

def apply_threshold_to_channels(batch_images, thresholds=[0.5, 0.5, 0.5]):
    """
    Çoklu batch ve 3 kanallı bir görüntüde her bir kanala verilen threshold'u uygular ve binary hale getirir.
    
    Args:
        batch_images (torch.Tensor): [B, C, H, W] şeklinde, B batch boyutu, C kanal sayısı, H ve W ise görüntü boyutları.
        thresholds (list): Her kanal için uygulanacak threshold değerleri (örneğin, [0.5, 0.3, 0.7]).
    
    Returns:
        torch.Tensor: Her kanala threshold uygulandıktan sonra binary hale getirilmiş tensor.
    """
    
    device = batch_images.device  # batch_images hangi cihazda ise onu al
    binary_images = (batch_images >= torch.tensor(thresholds).view(1, -1, 1, 1).to(device)).float()
    
    return binary_images

def one_hot_encode_nearest(reconstructed_image, num_classes=3):
    """
    Converts a single-channel image to a one-hot encoded tensor based on the nearest value.
    
    Args:
        reconstructed_image (torch.Tensor): [B, 1, H, W] shaped tensor (B: batch, H: height, W: width).
        num_classes (int): The number of classes for one-hot encoding (default: 3 for values 0, 1, 2).
    
    Returns:
        torch.Tensor: One-hot encoded tensor, shape [B, num_classes, H, W].
    """
    closest_values = torch.argmin(torch.abs(reconstructed_image - torch.arange(num_classes, device=reconstructed_image.device).view(1, num_classes, 1, 1)), dim=1)
    one_hot_encoded = torch.nn.functional.one_hot(closest_values, num_classes=num_classes).permute(0, 3, 1, 2).float()
    return one_hot_encoded


In [16]:
model.eval() 

post_pred = AsDiscrete(argmax=True, to_onehot=3)
post_mask = AsDiscrete(to_onehot=3)
dice_metric_all.reset()
dice_metric_wob.reset()
iou_metric_all.reset()
iou_metric_wob.reset()
for test_data in tqdm(val_loader, total=len(val_loader), desc=f"Test"):
        val_images, val_masks = test_data["image_dicom"].to(device), test_data["mask"].to(device)
        val_masks_hot_encoded=one_hot_encode_masks(val_masks,3,device)
        with torch.no_grad():
            posterior = model.encode(val_masks_hot_encoded)
            latent = posterior.sample()
            reconstructed_image = model.decode(latent)
            reconstructed_image_binary = apply_threshold_to_channels(reconstructed_image)
            dice_metric_all(y_pred=reconstructed_image_binary, y=val_masks_hot_encoded)
            dice_metric_wob(y_pred=reconstructed_image_binary, y=val_masks_hot_encoded)
            iou_metric_all(y_pred=reconstructed_image_binary, y=val_masks_hot_encoded)
            iou_metric_wob(y_pred=reconstructed_image_binary, y=val_masks_hot_encoded)
            #show_images(val_masks_hot_encoded, reconstructed_image_binary)
dice_metric_all_mean = dice_metric_all.aggregate().item()
dice_metric_wob_mean = dice_metric_wob.aggregate().item()
iou_metric_all_mean = iou_metric_all.aggregate().item()
iou_metric_wob_mean = iou_metric_wob.aggregate().item()
print(f"Mean dice (all): {dice_metric_all_mean:.4f}, Mean dice (wo backgorund) {dice_metric_wob_mean:.4f}")
print(f"Mean iou (all): {iou_metric_all_mean:.4f}, Mean iou (wo backgorund) {iou_metric_wob_mean:.4f}")

Test: 100%|██████████| 1331/1331 [04:37<00:00,  4.79it/s]

Mean dice (all): 0.9988, Mean dice (wo backgorund) 0.9928
Mean iou (all): 0.9976, Mean iou (wo backgorund) 0.9859





In [None]:
# model.eval() 
# post_pred = AsDiscrete(argmax=True, to_onehot=3)
# post_mask = AsDiscrete(to_onehot=3)
# dice_metric.reset()
# for test_data in tqdm(val_loader[0], total=len(val_loader[0]), desc=f"Test"):
#         val_images, val_masks = test_data["image_dicom"].to(device), test_data["mask"].to(device)
#         val_masks_repated=val_masks.repeat(1, 3, 1, 1).to(device)
#         val_masks_hot_encoded=one_hot_encode_masks(val_masks,3,device)
#         with torch.no_grad():
#             posterior = model.encode(val_masks_repated)
#             latent = posterior.sample()
#             reconstructed_image = model.decode(latent)
#             reconstructed_image_binary = one_hot_encode_nearest(reconstructed_image)
#             dice_metric(y_pred=reconstructed_image_binary, y=val_masks_hot_encoded)
#             #show_images(val_masks_hot_encoded, reconstructed_image_binary)
# dice_metric_mean = dice_metric.aggregate().item()
# print(f"Mean dice: {dice_metric_mean:.4f}")

In [None]:
model.eval() 
post_pred = AsDiscrete(argmax=True, to_onehot=3)
post_mask = AsDiscrete(to_onehot=3)
dice_metric.reset()
for test_data in tqdm(val_loader[0], total=len(val_loader[0]), desc=f"Test"):
        val_images, val_masks = test_data["image_dicom"].to(device), test_data["mask"].to(device)
        val_masks_repated=val_masks.repeat(1, 3, 1, 1).to(device)
        val_masks_repated = torch.clamp(val_masks_repated, max=1)
        #val_masks_hot_encoded=one_hot_encode_masks(val_masks,3,device)
        with torch.no_grad():
            posterior = model.encode(val_masks_repated)
            latent = posterior.sample()
            reconstructed_image = model.decode(latent)
            reconstructed_image_binary = apply_threshold_to_channels(reconstructed_image)
            dice_metric(y_pred=reconstructed_image_binary[:,0,:,:], y=val_masks_repated[:,0,:,:])
            #show_images(val_masks_repated, reconstructed_image_binary)
dice_metric_mean = dice_metric.aggregate().item()
print(f"Mean dice: {dice_metric_mean:.4f}")