In [1]:
!pip install nibabel monai matplotlib torch numpy

Defaulting to user installation because normal site-packages is not writeable


In [2]:
# Segmentation on secret test set
import os
import glob
import torch
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import monai
from monai.transforms import Compose, EnsureChannelFirstd, ScaleIntensityd, Resized
from monai.networks.nets import UNet
from monai.networks.utils import one_hot
from torch.utils.data import DataLoader
from scipy.ndimage import zoom

# Configuration
ensemble_model_path = "ensembleHeartUNet.pt"
save_folder = "Segmentations_secret"
data_path = r'./Secret3d'

os.makedirs(save_folder, exist_ok=True)

def load_nii(img_path):
    nimg = nib.load(img_path)
    return nimg.get_fdata(), nimg.affine

def build_test_dict(data_path):
    patient_dirs = [os.path.join(data_path, patient) for patient in os.listdir(data_path) 
                   if os.path.isdir(os.path.join(data_path, patient))]
    
    image_paths = []
    for patient_folder in patient_dirs:
        image_paths.extend(sorted(glob.glob(os.path.join(patient_folder, "*.nii.gz"))))
    image_paths = sorted(image_paths)

    dataset_dicts = []
    for img_path in image_paths:
        dataset_dicts.append({"img": img_path})
    return dataset_dicts

class LoadHeartData(monai.transforms.Transform):
    def __call__(self, sample):
        img_vol, img_affine = load_nii(sample['img'])
        original_shape = img_vol.shape
        images = np.moveaxis(img_vol, -1, 0)
        slice_list = []
        for i in range(images.shape[0]):
            slice_list.append({
                'img': images[i].astype(np.float32),
                'img_meta_dict': {'affine': img_affine, 'original_shape': original_shape},
                'img_path': sample['img']
            })
        return slice_list

def flatten_dataset(dataset_list, transform):
    flat_list = []
    for data in dataset_list:
        flat_list.extend(transform(data))
    return flat_list

def resize_or_pad_slice(pred_np, original_shape):
    pred_shape = pred_np.shape[:2]
    if pred_shape != original_shape[:2]:
        scaling_factors = (
            original_shape[0] / pred_shape[0],
            original_shape[1] / pred_shape[1]
        )
        pred_np_resized = zoom(pred_np, zoom=scaling_factors, order=0)
    else:
        pred_np_resized = pred_np
    return pred_np_resized

test_transforms = Compose([
    LoadHeartData(),
    EnsureChannelFirstd(keys=['img'], channel_dim="no_channel"),
    ScaleIntensityd(keys=['img']),
    Resized(keys=['img'], spatial_size=(256, 256), mode=['bilinear']),
])

test_dicts = build_test_dict(data_path)
img_path = test_dicts[0]['img']
test_flat = flatten_dataset(test_dicts, test_transforms)

test_dataset = monai.data.Dataset(data=test_flat)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

class EnsembleModel(torch.nn.Module):
    def __init__(self, models):
        super(EnsembleModel, self).__init__()
        self.models = torch.nn.ModuleList(models)

    def forward(self, x):
        outputs_list = [torch.softmax(model(x), dim=1) for model in self.models]
        avg_outputs = torch.mean(torch.stack(outputs_list), dim=0)
        return avg_outputs

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ensemble_model = torch.load(ensemble_model_path, map_location=device)
ensemble_model.to(device)
ensemble_model.eval()

#%% Inference and Save Segmentations as 3D Volumes
with torch.no_grad():
    # Dictionary to hold all slices for a scan
    scan_segmentations = {}

    for idx, sample in enumerate(test_dataset):
        img_path = sample["img_path"]
        img = sample["img"].to(device).unsqueeze(0)
        affine = sample["img_meta_dict"]['affine']
        original_shape = sample["img_meta_dict"]['original_shape']

        # Run inference with the ensemble model
        avg_output = ensemble_model(img)

        # Convert outputs to NumPy
        avg_output_np = avg_output.squeeze().cpu().numpy()  # Shape: (C, H, W)

        # Get the predicted class (argmax of averaged probabilities)
        pred = np.argmax(avg_output_np, axis=0).astype(np.uint8)  # Shape: (H, W)

        # Resize or pad the slice to match the original shape
        pred_resized = resize_or_pad_slice(pred, original_shape)

        # Add a singleton dimension for stacking
        if len(pred_resized.shape) == 2:
            pred_resized = np.expand_dims(pred_resized, axis=-1)

        # Extract the base name of the scan from the image path
        scan_name = os.path.basename(img_path).split('.')[0]

        # Initialize the dictionary for this scan if not already done
        if scan_name not in scan_segmentations:
            scan_segmentations[scan_name] = {
                "slices": [],
                "affine": affine
            }

        # Append the current slice and uncertainty to the scan's data
        scan_segmentations[scan_name]["slices"].append(pred_resized)

# Save each 3D segmentation
for scan_name, scan_data in scan_segmentations.items():
    # Retrieve slices, affine, and voxel sizes
    segmentation_slices = scan_data["slices"]
    affine = scan_data["affine"]

    # Stack the segmentation slices to form the 3D volume
    segmentation_3d = np.concatenate(segmentation_slices, axis=-1)

    # Update the affine matrix with the voxel sizes
    new_affine = affine.copy()

    # Save the segmentation
    output_filename = f"{scan_name}.nii.gz"
    output_path = os.path.join(save_folder, output_filename)
    pred_nii = nib.Nifti1Image(segmentation_3d, new_affine)
    nib.save(pred_nii, output_path)

    print(f"Saved segmentation for {output_filename}")

2025-04-01 09:23:59.809690: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-01 09:23:59.848857: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-01 09:23:59.848880: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-01 09:23:59.848916: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-01 09:23:59.857694: I tensorflow/core/platform/cpu_feature_g

Saved segmentation for patient151_frame01.nii.gz
Saved segmentation for patient152_frame01.nii.gz
Saved segmentation for patient153_frame01.nii.gz
Saved segmentation for patient154_frame01.nii.gz
Saved segmentation for patient155_frame01.nii.gz
Saved segmentation for patient156_frame01.nii.gz
Saved segmentation for patient157_frame01.nii.gz
Saved segmentation for patient158_frame01.nii.gz
Saved segmentation for patient159_frame01.nii.gz
Saved segmentation for patient160_frame01.nii.gz
Saved segmentation for patient161_frame01.nii.gz
Saved segmentation for patient162_frame01.nii.gz
Saved segmentation for patient163_frame01.nii.gz
Saved segmentation for patient164_frame01.nii.gz
Saved segmentation for patient165_frame01.nii.gz
Saved segmentation for patient166_frame01.nii.gz
Saved segmentation for patient167_frame01.nii.gz
Saved segmentation for patient168_frame01.nii.gz
Saved segmentation for patient169_frame01.nii.gz
Saved segmentation for patient170_frame01.nii.gz
Saved segmentation f

In [3]:
# Segmentation on original ACDC test set
import os
import glob
import torch
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import monai
from monai.transforms import Compose, EnsureChannelFirstd, ScaleIntensityd, Resized, Resize
from monai.networks.nets import UNet
from monai.networks.utils import one_hot
from torch.utils.data import DataLoader
import random

ensemble_model_path = "ensembleHeartUNet.pt"
save_folder = "Segmentations_ensemble"  # Folder to save the segmentations

# Create the save folder if it doesn't exist
os.makedirs(save_folder, exist_ok=True)

#%% Load NIfTI images
def load_nii(img_path):
    nimg = nib.load(img_path)
    return nimg.get_fdata(), nimg.affine

#%% Build a dictionary for the test dataset
def build_test_dict(data_path):
    image_dir = os.path.join(data_path, "testing", "image")
    mask_dir = os.path.join(data_path, "testing", "segmentation")

    image_paths = sorted(glob.glob(os.path.join(image_dir, "*.nii.gz")))
    mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*_gt.nii.gz")))

    mask_dict = {os.path.basename(m).replace("_gt", ""): m for m in mask_paths}
    dataset_dicts = []
    for img_path in image_paths:
        filename = os.path.basename(img_path)
        mask_path = mask_dict.get(filename, None)
        if mask_path and os.path.exists(mask_path):
            dataset_dicts.append({"img": img_path, "mask": mask_path})
    return dataset_dicts

#%% Custom Transform to Load All Slices of Data
class LoadHeartData(monai.transforms.Transform):
    def __call__(self, sample):
        img_vol, img_affine = load_nii(sample['img'])
        mask_vol, _ = load_nii(sample['mask'])
        images = np.moveaxis(img_vol, -1, 0)
        masks = np.moveaxis(mask_vol, -1, 0)
        slice_list = []
        for i in range(images.shape[0]):
            slice_list.append({
                'img': images[i].astype(np.float32),
                'mask': masks[i].astype(np.uint8),
                'img_meta_dict': {'affine': img_affine},
                'mask_meta_dict': {'affine': img_affine},
                'img_path': sample['img']  # Store the image path
            })
        return slice_list

#%% Define test transforms
test_transforms = Compose([
    LoadHeartData(),
    EnsureChannelFirstd(keys=['img', 'mask'], channel_dim="no_channel"),
    ScaleIntensityd(keys=['img']),
    Resized(keys=['img', 'mask'], spatial_size=(256, 256), mode=['bilinear', 'nearest']),
])

#%% Load and flatten the test dataset
test_dicts = build_test_dict(r'./database')
def flatten_dataset(dataset_list, transform):
    flat_list = []
    for data in dataset_list:
        flat_list.extend(transform(data))
    return flat_list

test_flat = flatten_dataset(test_dicts, test_transforms)

#%% Create test DataLoader
test_dataset = monai.data.Dataset(data=test_flat)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

#%% Define the EnsembleModel globally to avoid pickle issues
class EnsembleModel(torch.nn.Module):
    def __init__(self, models):
        super(EnsembleModel, self).__init__()
        self.models = torch.nn.ModuleList(models)

    def forward(self, x):
        # Collect and average the predictions (soft voting)
        outputs_list = [torch.softmax(model(x), dim=1) for model in self.models]
        avg_outputs = torch.mean(torch.stack(outputs_list), dim=0)
        return avg_outputs

#%% Load the saved ensemble model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ensemble_model = torch.load(ensemble_model_path, map_location=device)
ensemble_model.to(device)
ensemble_model.eval()  # Ensure the model is in evaluation mode

#%% Inference and Save Segmentations as 3D Volumes
with torch.no_grad():
    # Dictionary to hold all slices for a scan
    scan_segmentations = {}

    for idx, sample in enumerate(test_dataset):
        img_path = sample["img_path"]  # Get the image path directly from the sample
        img = sample["img"].to(device).unsqueeze(0)  # Add batch dimension
        mask = sample["mask"].to(device).unsqueeze(0)  # Keep 1-channel
        affine = sample["img_meta_dict"]['affine']  # Retrieve the affine matrix from metadata

        # Run inference with the ensemble model
        output = ensemble_model(img)
        
        # Apply softmax
        output_prob = torch.softmax(output, dim=1)

        # Apply argmax to get the predicted class
        pred = torch.argmax(output_prob, dim=1, keepdim=True)

        # Convert tensors to numpy for saving
        pred_np = pred.squeeze().cpu().numpy()

        # Ensure the prediction is of type uint8, which is commonly used for segmentation masks
        pred_np = pred_np.astype(np.uint8)

        # Check if the prediction is 2D or 3D
        if len(pred_np.shape) == 2:
            # If it's a 2D array (e.g., for a single slice), add a singleton dimension for the third axis
            pred_np = np.expand_dims(pred_np, axis=-1)

        # Extract the base name of the scan from the image path (without extension)
        scan_name = os.path.basename(img_path).split('.')[0]

        # If we haven't processed this scan yet, initialize the list for the 3D segmentation
        if scan_name not in scan_segmentations:
            scan_segmentations[scan_name] = []

        # Append the current slice to the list for this scan
        scan_segmentations[scan_name].append(pred_np)

    # After processing all slices, save each 3D segmentation for the scans
    for scan_name, segmentation_slices in scan_segmentations.items():
        # Stack the segmentation slices to form the 3D volume (along the third axis)
        segmentation_3d = np.concatenate(segmentation_slices, axis=-1)

        # Search for the image path in test_dicts based on the scan_name
        img_path = None
        for entry in test_dicts:
            if os.path.basename(entry['img']).replace(".nii.gz", "") == scan_name:
                img_path = entry['img']
                break

        if img_path is None:
            print(f"Image path for {scan_name} not found!")
            continue

        # Load the original image using the LoadHeartData transform
        img_vol, _ = load_nii(img_path)
        original_shape = img_vol.shape

        # Resize segmentation to match the original 3D image shape
        resize_transform = Resize(spatial_size=original_shape, mode='nearest')
        resized_segmentation = resize_transform(segmentation_3d[None])[0].astype(np.uint8)

        # Create a unique filename for the entire scan
        output_filename = f"{scan_name}_gt.nii.gz"

        output_path = os.path.join(save_folder, output_filename)

        # Save the resized 3D segmentation as a NIfTI file
        pred_nii = nib.Nifti1Image(resized_segmentation, affine)

        # Save the NIfTI file
        nib.save(pred_nii, output_path)

        print(f"Saved segmentation for {output_filename}")

Saved segmentation for patient101_frame01_gt.nii.gz
Saved segmentation for patient101_frame14_gt.nii.gz
Saved segmentation for patient102_frame01_gt.nii.gz
Saved segmentation for patient102_frame13_gt.nii.gz
Saved segmentation for patient103_frame01_gt.nii.gz
Saved segmentation for patient103_frame11_gt.nii.gz
Saved segmentation for patient104_frame01_gt.nii.gz
Saved segmentation for patient104_frame11_gt.nii.gz
Saved segmentation for patient105_frame01_gt.nii.gz
Saved segmentation for patient105_frame10_gt.nii.gz
Saved segmentation for patient106_frame01_gt.nii.gz
Saved segmentation for patient106_frame13_gt.nii.gz
Saved segmentation for patient107_frame01_gt.nii.gz
Saved segmentation for patient107_frame10_gt.nii.gz
Saved segmentation for patient108_frame01_gt.nii.gz
Saved segmentation for patient108_frame09_gt.nii.gz
Saved segmentation for patient109_frame01_gt.nii.gz
Saved segmentation for patient109_frame10_gt.nii.gz
Saved segmentation for patient110_frame01_gt.nii.gz
Saved segmen