In [4]:
import os
import SimpleITK as sitk
import numpy as np
from pathlib import Path
import pandas as pd
from scipy import ndimage
from skimage import measure

In [5]:
def create_binary_mask(cmr_volume, myocardium_prediction, threshold_value, min_expansion):

    # Convert the images to numpy arrays
    prediction_array = sitk.GetArrayFromImage(myocardium_prediction)

    # Initialize an empty binary mask
    binary_mask = np.zeros_like(prediction_array)

    # Iterate over slices
    structure = ndimage.generate_binary_structure(2, 2)
    for slice_index in range(myocardium_prediction.GetDepth()):
        # Get the 2D slice and the corresponding binary mask slice
        binary_mask_slice_b = sitk.Extract(myocardium_prediction, (myocardium_prediction.GetWidth(), myocardium_prediction.GetHeight(), 0), (0, 0, slice_index))

        # Threshold the original slice to create a binary mask
        binary_mask_slice = sitk.BinaryThreshold(binary_mask_slice_b, lowerThreshold=threshold_value, insideValue=1, outsideValue=0)

        # slice_mask = (myocardium_mask_array[slice_index] >= threshold_value).astype(np.uint8)
        ##### Step 2 : Morphological closing --> fill the pixels with 0 inside ROI with 1 (fill the holes inside the ROI)
        # Apply a binary closing operation
        closing_filter = sitk.BinaryMorphologicalClosingImageFilter()
        closing_filter.SetKernelRadius([6, 6])
        binary_mask_slice = closing_filter.Execute(binary_mask_slice)
        binary_mask_array = sitk.GetArrayFromImage(binary_mask_slice)

        contour_mask = np.zeros_like(binary_mask_array)

        contours = measure.find_contours(binary_mask_array) 
        # Count the number of pixels with a value of 1
        count_ones = np.count_nonzero(binary_mask_array == 1)

        if count_ones >= min_expansion:
            for i, contour in enumerate(contours):
                    if i < 1:
                        for point in contour:
                            contour_mask[int(point[0]), int(point[1])] = 1
        
        binary_mask[slice_index] = ndimage.binary_fill_holes(contour_mask).astype(int)
        binary_mask[slice_index] = ndimage.binary_dilation(binary_mask[slice_index],  iterations=1, structure=structure).astype(int)
        count_ones = np.count_nonzero(binary_mask[slice_index] == 1)

        if count_ones > 800: #prev 400
            binary_mask[slice_index] = ndimage.binary_fill_holes(binary_mask[slice_index]).astype(int)
        else:
            binary_mask[slice_index] = 0

    # Convert the binary mask back to SimpleITK image
    binary_mask_image = sitk.GetImageFromArray(binary_mask)
    binary_mask_image.CopyInformation(cmr_volume)

    return binary_mask_image

In [6]:
def keep_non_zero_slices(original_3d_image, myoseg, binary_mask_image):
    # Initialize lists to store slices with non-zero values
    non_zero_slices_original = []
    non_zero_slices_mask = []
    non_zero_slices_myoseg = []
    # Iterate over slices
    for slice_nr in range(original_3d_image.GetDepth()):
        # Extract 2D slices
        original_slice = sitk.Extract(original_3d_image, (original_3d_image.GetWidth(), original_3d_image.GetHeight(), 0), (0, 0, slice_nr))
        mask_slice = sitk.Extract(binary_mask_image, (binary_mask_image.GetWidth(), binary_mask_image.GetHeight(), 0), (0, 0, slice_nr))
        myoseg_slice = sitk.Extract(myoseg, (myoseg.GetWidth(), myoseg.GetHeight(), 0), (0, 0, slice_nr))
        # Convert slices to NumPy arrays
        # original_array = sitk.GetArrayFromImage(original_slice)
        mask_array = sitk.GetArrayFromImage(mask_slice)

        # Check if the mask slice has any non-zero values
        if np.any(mask_array != 0):
            non_zero_slices_original.append(original_slice)
            non_zero_slices_mask.append(mask_slice)
            non_zero_slices_myoseg.append(myoseg_slice)
    if len(non_zero_slices_original) > 4:
                # Merge the non-zero slices back into 3D images
        non_zero_slices_original_3d = sitk.JoinSeries(non_zero_slices_original)
        non_zero_slices_original_3d.SetDirection(original_3d_image.GetDirection())
        non_zero_slices_original_3d.SetOrigin(original_3d_image.GetOrigin())
        non_zero_slices_original_3d.SetSpacing(original_3d_image.GetSpacing())

        non_zero_slices_mask_3d = sitk.JoinSeries(non_zero_slices_mask)
        non_zero_slices_mask_3d.SetDirection(original_3d_image.GetDirection())
        non_zero_slices_mask_3d.SetOrigin(original_3d_image.GetOrigin())
        non_zero_slices_mask_3d.SetSpacing(original_3d_image.GetSpacing())
        
        non_zero_slices_myoseg_3d = sitk.JoinSeries(non_zero_slices_myoseg)
        non_zero_slices_myoseg_3d.SetDirection(original_3d_image.GetDirection())
        non_zero_slices_myoseg_3d.SetOrigin(original_3d_image.GetOrigin())
        non_zero_slices_myoseg_3d.SetSpacing(original_3d_image.GetSpacing())
    else:
        return None, None, None
    return non_zero_slices_original_3d, non_zero_slices_mask_3d,  non_zero_slices_myoseg_3d

In [7]:
path = '../LGE_clin_myo_mask_1_iter_highertresh'
for data_set in os.listdir(path):
    for patient in os.listdir(os.path.join(path, data_set)):
        # for patient in os.listdir(os.path.join(path, data_set)):
        id = patient.split("_")[0]
        src_patient_folder = Path(os.path.join(path, data_set, patient))
        image_path = os.path.join(src_patient_folder, f"LGE_{id}.nii.gz")
        myoseg_path = os.path.join(src_patient_folder, f"MYO_{id}.nrrd")

        img = sitk.ReadImage(image_path)
        myoseg = sitk.ReadImage(str(myoseg_path))

        # Example usage
        threshold = 0.05  # Adjust the threshold as needed
        min_expansion = 150  # Adjust the minimum expansion as needed
        binary_mask_image = create_binary_mask(img, myoseg, threshold, min_expansion)
        # binary_mask_image = sitk.Cast(binary_mask_image, sitk.sitkFloat64)

        non_zero_slices_original_3d, non_zero_slices_mask_3d, non_zero_slices_myoseg_3d  = keep_non_zero_slices(img, myoseg, binary_mask_image)
        if non_zero_slices_original_3d is not None:
            # Save preprocessed image as NIfTI
            output_path_image = os.path.join(src_patient_folder, f'LGE_{id}.nii.gz')
            output_path_mask = os.path.join(src_patient_folder, f'MASK_{id}.nii.gz')
            output_path_myo = os.path.join(src_patient_folder, f'MYO_{id}.nrrd')

            # Save the resulting images
            sitk.WriteImage(non_zero_slices_original_3d, output_path_image)
            sitk.WriteImage(non_zero_slices_mask_3d, output_path_mask)
            sitk.WriteImage(non_zero_slices_myoseg_3d, output_path_myo)

In [5]:
""""
    prev method!
""""
# import os
# import SimpleITK as sitk
# import numpy as np
# import mclahe as mc
# from pathlib import Path
# import pandas as pd

# threshold_value = 0.05
# kernel_size=(10, 25, 25)

# # def discard_slices(volume, mask, pixel_threshold):
# #     # Calculate the sum of non-masked pixels for each slice
# #     non_masked_pixel_sums = np.sum(volume * (1 - mask), axis=(1, 2))

# #     # Identify slices that meet the pixel threshold
# #     selected_slices = np.where(non_masked_pixel_sums >= pixel_threshold)[0]

# #     # Discard slices that do not meet the threshold
# #     filtered_volume = volume[selected_slices]

# #     return filtered_volume

# path = '../LGE_CINE_clin_myo_masked'
# for data_set in os.listdir(path):
#     for patient in os.listdir(os.path.join(path, data_set)):
#         # for patient in os.listdir(os.path.join(path, data_set)):
#         id = patient.split("_")[0]
#         src_patient_folder = Path(os.path.join(path, data_set, patient))
#         image_path = os.path.join(src_patient_folder, f"LGE_{id}.nii.gz")
#         myoseg_f = os.path.join(src_patient_folder, f"MYO_{id}.nrrd")
#         img = sitk.ReadImage(image_path)
#         myoseg = sitk.ReadImage(str(myoseg_f))

#         binary_mask = sitk.BinaryThreshold(myoseg, lowerThreshold=threshold_value, upperThreshold=1.0, insideValue=1, outsideValue=0)

#         # Create a SimpleITK structuring element
#         structuring_element = sitk.sitkBox
#         closing_filter = sitk.BinaryMorphologicalClosingImageFilter()
#         closing_filter.SetKernelType(structuring_element)
#         closing_filter.SetKernelRadius(kernel_size)
        
#         # Perform 3D morphological closing on the integer image
#         result_stack = closing_filter.Execute(binary_mask)

#         # # Convert the result back to the original floating-point type if needed
#         # result_stack.SetSpacing(img.GetSpacing())
#         # result_stack.SetOrigin(img.GetOrigin())
#         # result_stack.SetDirection(img.GetDirection())
#         result_stack.CopyInformation(img)

#         masked_image = sitk.Mask(img, result_stack)
#         # # Set the pixel threshold
#         # pixel_threshold = 1000  # Adjust as needed

#         # # Discard slices based on the pixel threshold
#         # filtered_volume = discard_slices(your_volume_data, your_mask_data, pixel_threshold)

#         # # Print the number of remaining slices
#         # print(f"Number of remaining slices: {filtered_volume.shape[0]}")

#         # Save preprocessed image as NIfTI
#         # output_path_image = os.path.join(src_patient_folder, f'LGE_{id}.nii.gz')
#         output_path_mask = os.path.join(src_patient_folder, f'MASK_{id}.nii.gz')
#         # sitk.WriteImage(masked_image, output_path_image)
#         sitk.WriteImage(result_stack, output_path_mask)