## Testing of the TotalSegmentator

In [1]:
#Imports
import nibabel as nib
from totalsegmentator.python_api import totalsegmentator
import os
import ipywidgets as widgets
from IPython.display import display
import numpy as np
from ipywidgets import interact, HBox, VBox, widgets
import cv2
from scipy.ndimage import binary_fill_holes, label, center_of_mass

### The segmentation using totalsegmentator

In [11]:
labels = None
with open("Labels.txt", 'r') as file:
    labels = [line.rstrip() for line in file]
print(labels)

In [13]:
dir = os.getcwd()
Segmentation_file_path = dir + "/Data/s0011/Segmentation_output.nii.gz"
if not os.path.exists(Segmentation_file_path):
    with open(Segmentation_file_path, 'w') as fp:
      pass
print(os.path.exists(Segmentation_file_path))
input_file_path = dir + "/Data/s0011/ct.nii.gz"
print(os.path.exists(input_file_path))


input_img = nib.load(input_file_path)
Segmentation_output = totalsegmentator(input_img, roi_subset=labels)
nib.save(Segmentation_output, Segmentation_file_path)
del Segmentation_output, input_img

100%|██████████| 12/12 [00:19<00:00,  1.64s/it]
100%|██████████| 4/4 [00:53<00:00, 13.27s/it]


# Slice rendering of segmentation
which is easier to use

In [2]:
def crop_image(out_im):
    non_zero_slices = [i for i in range(out_im.shape[0]) if np.any(out_im[i, :, :])]
    sagital_start = non_zero_slices[0] #Used for image reconstruction
    filtered = out_im[non_zero_slices, :, :]
    non_zero_slices = [i for i in range(filtered.shape[1]) if np.any(filtered[:, i, :])]
    frontal_start = non_zero_slices[0] #Used for image reconstruction
    filtered = filtered[:, non_zero_slices, :]
    non_zero_slices = [i for i in range(filtered.shape[2]) if np.any(filtered[:, :, i])]
    transversal_start = non_zero_slices[0] #Used for image reconstruction
    filtered = filtered[:, :, non_zero_slices]


    return filtered, sagital_start, frontal_start, transversal_start

In [36]:
def read_data(path = "/Data/s0011/Segmentation_output.nii.gz"):
    dir = os.getcwd()
    if os.path.exists(dir + path):
        output_img_new = nib.load(dir + path)
    else:
        print("No file found")

    # Convert to NumPy array
    output_img_data = np.array(output_img_new.dataobj)
    return output_img_data

In [37]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
import SimpleITK as sitk

#Resets the matplotlib, since totalSegmentator messes with the output
plt.close('all')
plt.switch_backend('module://ipykernel.pylab.backend_inline')

output_img_data = read_data()

filtered_data,_,_,_ = crop_image(output_img_data)

print(f"Original number of slices: {output_img_data.shape[0]}")
print(f"Number of slices after filtering: {filtered_data.shape[0]}")

# Define a function to plot sagittal slices
def plot_sagittal(slice_index):
    plt.figure(figsize=(8, 8))

    # Sagittal slice (fixing the first axis, after flipping)
    plt.imshow(filtered_data[slice_index, :, :], cmap='gray')
    plt.title(f"Sagittal Slice {slice_index}")
    plt.axis('off')

    plt.show()
    plt.close()  # Close the figure to avoid warning and memory issues

# Create an interactive slider for sagittal slices
interact(plot_sagittal, slice_index=(0, filtered_data.shape[0] - 1))


Original number of slices: 311
Number of slices after filtering: 59


interactive(children=(IntSlider(value=29, description='slice_index', max=58), Output()), _dom_classes=('widget…

<function __main__.plot_sagittal(slice_index)>

# Vertebra body extraction

In [20]:

def plot_segmentation(vis_arr, value=-1):

    #There might be some other way to define the functions outside of plot_segmentation.
    #But for now I do not care and this works.
    def plot_sagittal(slice_index):
        plt.figure(figsize=(8, 8))

        # Sagittal slice (fixing the first axis, after flipping)
        plt.imshow(vis_arr[slice_index, :, :], cmap='gray')
        plt.title(f"Area of Slice {count_per_sagital_slice[slice_index]}")
        plt.axis('off')

        plt.show()
        plt.close()  # Close the figure to avoid warning and memory issues

    def plot_frontal(slice_index):
        plt.figure(figsize=(8, 8))

        # Sagittal slice (fixing the first axis, after flipping)
        plt.imshow(vis_arr[:, slice_index, :], cmap='gray')
        plt.title(f"Area of Slice {count_per_frontal_slice[slice_index]}")
        plt.axis('off')

        plt.show()
        plt.close()  # Close the figure to avoid warning and memory issues

    def plot_transversal(slice_index):
        plt.figure(figsize=(8, 8))

        # Sagittal slice (fixing the first axis, after flipping)
        plt.imshow(vis_arr[:, :, slice_index], cmap='gray')
        plt.title(f"Area of Slice {count_per_transversal_slice[slice_index]}")
        plt.axis('off')

        plt.show()
        plt.close()  # Close the figure to avoid warning and memory issues


    #Get the area in each slice
    if value != -1:
        count_per_sagital_slice = np.sum(vis_arr == value, axis=(1, 2))
        count_per_frontal_slice = np.sum(vis_arr == value, axis=(0, 2))
        count_per_transversal_slice = np.sum(vis_arr == value, axis=(0, 1))
    else:
        values = np.unique(vis_arr)
        values = values[1:]
        count_per_sagital_slice_list = {}
        count_per_frontal_slice_list = {}
        count_per_transversal_slice_list = {}
        for val in values:
            count_per_sagital_slice_list[val] = np.sum(vis_arr == values[val], axis=(1, 2))
            count_per_frontal_slice_list[val] = np.sum(vis_arr == values[val], axis=(0, 2))
            count_per_transversal_slice_list[val] = np.sum(vis_arr == values[val], axis=(0, 1))
        




    #Find the maximum point for alls lices
    max_sag = np.argmax(count_per_sagital_slice)
    max_front = np.argmax(count_per_frontal_slice)
    max_trans = np.argmax(count_per_transversal_slice)

    # Create interactive sliders
    sagittal_slider = widgets.IntSlider(min=0, max=vis_arr.shape[0] - 1, value=max_sag, description="Sagittal")
    frontal_slider = widgets.IntSlider(min=0, max=vis_arr.shape[1] - 1, value=max_front, description="Frontal")
    transversal_slider = widgets.IntSlider(min=0, max=vis_arr.shape[2] - 1, value=max_trans, description="Transversal")

    # Create interactive plots
    sagittal_plot = widgets.interactive_output(plot_sagittal, {'slice_index': sagittal_slider})
    frontal_plot = widgets.interactive_output(plot_frontal, {'slice_index': frontal_slider})
    transversal_plot = widgets.interactive_output(plot_transversal, {'slice_index': transversal_slider})

    # Arrange widgets in a horizontal box
    ui = HBox([
        VBox([sagittal_slider, sagittal_plot]),
        VBox([frontal_slider, frontal_plot]),
        VBox([transversal_slider, transversal_plot])
    ])

    display(ui)
    

In [45]:
def hole_find(binary_image_arg, invert = False, plot=False):

    if invert:
        holes = binary_fill_holes(binary_image_arg) & ~binary_image_arg
    else:
        holes = binary_fill_holes(binary_image_arg) & binary_image_arg

    # Label connected hole regions
    labeled_holes, num_holes = label(holes)

    # Find hole centers
    hole_centers = center_of_mass(holes, labeled_holes, range(1, num_holes + 1))

    if plot:
        # Plot results
        plt.imshow(holes, cmap='gray')
        plt.title("Binary Image with Holes")

        # Mark hole centers
        for y, x in hole_centers:
            plt.plot(x, y, 'ro', markersize=6)  # Red dot at the hole center

        plt.show()

        print(f"Number of holes found: {num_holes}")
        print(f"Hole centers: {hole_centers}")
    
    return hole_centers

In [43]:
def cut_body(new_arr, value, plot=False):
    #Crop the image for the vertebra
    temp, sagital_start, frontal_start, transversal_start = crop_image(new_arr)

    # Example binary image (object with holes)
    binary_image = (temp == value).astype(np.uint8) * 1
    binary_image_trans = binary_image[:,:,int(temp.shape[2]/2)]

    #For frontal they are organised (Sagital slice, Transversal)
    #For transversal they are organised (Sagital, Frontal)
    transversal_holes = hole_find(binary_image_trans, True, plot)

    temp[:,0:int(np.floor(transversal_holes[0][1])),:] = 0
    return temp, sagital_start, frontal_start, transversal_start

def restore_array(original_array, cut_array, sagital_start, frontal_start, transversal_start, value):
    return_array = np.copy(original_array)
    sagital_end = sagital_start + cut_array.shape[0]
    frontal_end = frontal_start + cut_array.shape[1]
    transversal_end = transversal_start + cut_array.shape[2]

    mask = (return_array[sagital_start:sagital_end, frontal_start:frontal_end, transversal_start:transversal_end] == value)
    return_array[sagital_start:sagital_end, frontal_start:frontal_end, transversal_start:transversal_end][mask] = cut_array[mask]

    return return_array


In [47]:
output_img_data = read_data()
value = 32
array = (output_img_data == value).astype(np.uint8) * value
tmp, sagital_start, frontal_start, transversal_start = cut_body(array, value, plot=True)
plot_segmentation(tmp, value)
tmp2 = restore_array(output_img_data, tmp, sagital_start, frontal_start, transversal_start, value)
plot_segmentation(tmp2, value)

HBox(children=(VBox(children=(IntSlider(value=24, description='Sagittal', max=34), Output())), VBox(children=(…

HBox(children=(VBox(children=(IntSlider(value=161, description='Sagittal', max=310), Output())), VBox(children…

In [None]:
output_img_data = read_data()
values = np.unique(output_img_data)
values = values[1:]
for value in values:
    print(value)
    array = (output_img_data == value).astype(np.uint8) * value
    tmp, sagital_start, frontal_start, transversal_start = cut_body(array, value, plot=True)
    output_img_data = restore_array(output_img_data, tmp, sagital_start, frontal_start, transversal_start, value)
plot_segmentation(output_img_data)

[26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44]


### Save results

In [106]:
nii_img = nib.Nifti1Image(new_arr, affine=np.eye(4))  # Identity matrix as default affine

# Save the NIfTI file
nib.save(nii_img, r'Data\\s0011\\Cropped_segmentation.nii')