In [None]:
import SimpleITK as sitk
import itk
import itkwidgets 
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import ipywidgets as widgets
import cv2
from tqdm.notebook import tqdm

In [None]:
class MultiImageDisplay(object):
    """
    This class provides a GUI for displaying 3D images. It supports display of
    multiple images in the same UI. The image slices are selected according to
    the axis specified by the user. Each image can have a title and a slider to
    scroll through the stack. The images can also share a single slider if they
    have the same number of slices along the given axis. Images are either
    grayscale or color. The intensity range used for display (window-level) can
    be specified by the user as input to the constructor or set via the displayed
    slider. For color images the intensity control slider will be disabled. This
    allows us to display both color and grayscale images in the same figure with
    a consistent look to the controls. The range of the intensity slider is set
    to be from top/bottom 2% of intensities (accomodating for outliers). Images
    are displayed either in horizontal or vertical layout, depending on the
    users choice.
    """

    def __init__(
        self,
        image_list,
        axis=0,
        shared_slider=False,
        title_list=None,
        window_level_list=None,
        figure_size=(10, 8),
        horizontal=True,
    ):
        self.npa_list, wl_range, wl_init = self.get_window_level_numpy_array(
            image_list, window_level_list
        )
        if title_list:
            if len(image_list) != len(title_list):
                raise ValueError("Title list and image list lengths do not match")
            self.title_list = list(title_list)
        else:
            self.title_list = [""] * len(image_list)

        # Our dynamic slice, based on the axis the user specifies
        self.slc = [slice(None)] * 3
        self.axis = axis

        ui = self.create_ui(shared_slider, wl_range, wl_init)
        display(ui)

        # Create a figure.
        col_num, row_num = (len(image_list), 1) if horizontal else (1, len(image_list))
        self.fig, self.axes = plt.subplots(row_num, col_num, figsize=figure_size)
        if len(image_list) == 1:
            self.axes = [self.axes]

        # Display the data and the controls, first time we display the image is outside the "update_display" method
        # as that method relies on the previous zoom factor which doesn't exist yet.
        for ax, npa, slider, wl_slider in zip(
            self.axes, self.npa_list, self.slider_list, self.wl_list
        ):
            self.slc[self.axis] = slice(slider.value, slider.value + 1)
            # Need to use squeeze to collapse degenerate dimension (e.g. RGB image size 124 124 1 3)
            ax.imshow(
                np.squeeze(npa[tuple(self.slc)]),
                cmap=plt.cm.Greys_r,
                vmin=wl_slider.value[0],
                vmax=wl_slider.value[1],
            )
        self.update_display()
        plt.tight_layout()

    def create_ui(self, shared_slider, wl_range, wl_init):
        # Create the active UI components. Height and width are specified in 'em' units. This is
        # a html size specification, size relative to current font size.

        if shared_slider:
            # Validate that all the images have the same size along the axis which we scroll through
            sz = self.npa_list[0].shape[self.axis]
            for npa in self.npa_list:
                if npa.shape[self.axis] != sz:
                    raise ValueError(
                        "Not all images have the same size along the specified axis, cannot share slider."
                    )

            slider = widgets.IntSlider(
                description="image slice:",
                min=0,
                max=sz - 1,
                step=1,
                value=int((sz - 1) / 2),
                width="20em",
            )
            slider.observe(self.on_slice_slider_value_change, names="value")
            self.slider_list = [slider] * len(self.npa_list)
            slicer_box = widgets.Box(padding=7, children=[slider])
        else:
            self.slider_list = []
            for npa in self.npa_list:
                slider = widgets.IntSlider(
                    description="image slice:",
                    min=0,
                    max=npa.shape[self.axis] - 1,
                    step=1,
                    value=int((npa.shape[self.axis] - 1) / 2),
                    width="20em",
                )
                slider.observe(self.on_slice_slider_value_change, names="value")
                self.slider_list.append(slider)
            slicer_box = widgets.Box(padding=7, children=self.slider_list)
        self.wl_list = []
        # Each image has a window-level slider, but it is disabled if the image
        # is a color image len(npa.shape)==4 . This allows us to display both
        # color and grayscale images in the same UI while retaining a reasonable
        # layout for the sliders.
        for r_values, i_values, npa in zip(wl_range, wl_init, self.npa_list):
            wl_range_slider = widgets.IntRangeSlider(
                description="intensity:",
                min=r_values[0],
                max=r_values[1],
                step=1,
                value=[i_values[0], i_values[1]],
                width="20em",
                disabled=len(npa.shape) == 4,
            )
            wl_range_slider.observe(self.on_wl_slider_value_change, names="value")
            self.wl_list.append(wl_range_slider)
        wl_box = widgets.Box(padding=7, children=self.wl_list)
        return widgets.VBox(children=[slicer_box, wl_box])

    def get_window_level_numpy_array(self, image_list, window_level_list):
        # Using GetArray and not GetArrayView because we don't keep references
        # to the original images. If they are deleted outside the view would become
        # invalid, so we use a copy which guarantees that the gui is consistent.
        npa_list = list(map(sitk.GetArrayFromImage, image_list))

        wl_range = []
        wl_init = []
        # We need to iterate over the images because they can be a mix of
        # grayscale and color images. If they are color we set the wl_range
        # to [0,255] and the wl_init is equal, ignoring the window_level_list
        # entry.
        for i, npa in enumerate(npa_list):
            if len(npa.shape) == 4:  # color image
                wl_range.append((0, 255))
                wl_init.append((0, 255))
                # ignore any window_level_list entry
            else:
                # We don't take the minimum/maximum values, just in case there are outliers (top/bottom 2%)
                min_max = np.percentile(npa.flatten(), [2, 98])
                wl_range.append((min_max[0], min_max[1]))
                if not window_level_list:
                    wl_init.append(wl_range[-1])
                else:
                    wl = window_level_list[i]
                    if wl:
                        wl_init.append((wl[1] - wl[0] / 2.0, wl[1] + wl[0] / 2.0))
                    else:
                        wl_init.append(wl_range[-1])
        return (npa_list, wl_range, wl_init)

    def on_slice_slider_value_change(self, change):
        self.update_display()

    def on_wl_slider_value_change(self, change):
        self.update_display()

    def update_display(self):
        # Draw the image(s)
        for ax, npa, title, slider, wl_slider in zip(
            self.axes, self.npa_list, self.title_list, self.slider_list, self.wl_list
        ):
            # We want to keep the zoom factor which was set prior to display, so we log it before
            # clearing the axes.
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()

            self.slc[self.axis] = slice(slider.value, slider.value + 1)
            ax.clear()
            # Need to use squeeze to collapse degenerate dimension (e.g. RGB image size 124 124 1 3)
            ax.imshow(
                np.squeeze(npa[tuple(self.slc)]),
                cmap=plt.cm.Greys_r,
                vmin=wl_slider.value[0],
                vmax=wl_slider.value[1],
            )
            ax.set_title(title)
            ax.set_axis_off()

            # Set the zoom factor back to what it was before we cleared the axes, and rendered our data.
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

        self.fig.canvas.draw_idle()

In [None]:
mr_image = sitk.ReadImage(Path("data/anonymise_patient1.mhd"))
npa = sitk.GetArrayViewFromImage(mr_image)

# Display the image slice from the middle of the stack, z axis
z = int(mr_image.GetDepth()/2)
npa_zslice = sitk.GetArrayViewFromImage(mr_image)[z,:,:]

itkwidgets.view(mr_image)

In [None]:
selected_image = mr_image
print("Before modification:")
print("origin: " + str(selected_image.GetOrigin()))
print("size: " + str(selected_image.GetSize()))
print("spacing: " + str(selected_image.GetSpacing()))
print("direction: " + str(selected_image.GetDirection()))
print("pixel type: " + str(selected_image.GetPixelIDTypeAsString()))
print(
    "number of pixel components: " + str(selected_image.GetNumberOfComponentsPerPixel())
)


In [None]:
image_array = sitk.GetArrayFromImage(mr_image)
histogram, bin_edges = np.histogram(image_array.flatten(),bins=255)
plt.plot(bin_edges[0:-1], histogram, color='blue')
plt.title("Histogramme de l'image médicale")
plt.xlabel("Valeur des pixels")
plt.ylabel("Fréquence")
plt.show()


In [None]:

thresholdFilter = sitk.IntensityWindowingImageFilter()
thresholdFilter.SetWindowMinimum(150)
thresholdFilter.SetWindowMaximum(1000)
thresholdFilter.SetOutputMinimum(0)
thresholdFilter.SetOutputMaximum(255)
resultTreshold = thresholdFilter.Execute(mr_image)

#itkwidgets.view(resultTreshold)


In [None]:
# thresholdMaxCompFilter = sitk.ThresholdMaximumConnectedComponentsImageFilter()
# thresholdMaxCompFilter.SetMinimumObjectSizeInPixels(1500000)
# thresholdMaxCompFilter.SetInsideValue(255)
# thresholdMaxCompFilter.SetOutsideValue(0)
# resultThresholdMaxComp = thresholdMaxCompFilter.Execute(resultTreshold)
#itkwidgets.view(resultThresholdMaxComp)

In [None]:
#Convert sitk to itk image
itk_image = itk.GetImageFromArray(sitk.GetArrayFromImage(resultTreshold), is_vector = resultTreshold.GetNumberOfComponentsPerPixel()>1)
itk_image.SetOrigin(resultTreshold.GetOrigin())
itk_image.SetSpacing(resultTreshold.GetSpacing())   
itk_image.SetDirection(itk.GetMatrixFromArray(np.reshape(np.array(resultTreshold.GetDirection()), [3]*2)))


In [None]:
cleaned_thresh_img = sitk.BinaryOpeningByReconstruction(resultTreshold, [10, 10, 10])
cleaned_thresh_img = sitk.BinaryClosingByReconstruction(
    cleaned_thresh_img, [10, 10, 10]
)
cleaned_thresh_img = sitk.BinaryClosingByReconstruction(
    cleaned_thresh_img, [10, 10, 10]
)
# MultiImageDisplay(
#     image_list=[cleaned_thresh_img],
#     title_list=["Cleaned Binary Segmentation"],
#     figure_size=(8, 4),
# );


In [None]:
itkwidgets.view(sitk.Cast(sitk.ConnectedComponent(cleaned_thresh_img),sitk.sitkUInt8))

In [None]:
stats = sitk.LabelShapeStatisticsImageFilter()
stats.Execute(sitk.ConnectedComponent(cleaned_thresh_img))

# Look at the distribution of sizes of connected components (bacteria).
label_sizes = [stats.GetNumberOfPixels(l) for l in stats.GetLabels() if l != 1]

plt.figure()
plt.hist(label_sizes, bins=200,log=True)
plt.title("Distribution of Object Sizes")
plt.xlabel("size in pixels")
plt.ylabel("number of objects")
plt.show()
print(label_sizes)

In [None]:
num_labels = stats.GetNumberOfLabels()
area_list = dict()
# Afficher les aires de chaque région détectée
print("Max des aires des régions détectées:")
for label in range(1, num_labels + 1):  # Labels commencent à 1
    area = stats.GetPhysicalSize(label)
    box = stats.GetBoundingBox(label)
    if area > 5000 and not any(size == 0 for size in box):
        #print(f"Région {label}: {area} pixels")
        area_list[label] = area
print(max(area_list, key=area_list.get))

In [None]:
print(stats.GetRegion(192))
roi = sitk.RegionOfInterest(cleaned_thresh_img, stats.GetBoundingBox(69685))
#itkwidgets.view(roi)

In [None]:
thresholdFilter = sitk.IntensityWindowingImageFilter()
thresholdFilter.SetWindowMinimum(175)
thresholdFilter.SetWindowMaximum(255)
thresholdFilter.SetOutputMinimum(0)
thresholdFilter.SetOutputMaximum(255)
resultTargetBonesAndVessels = thresholdFilter.Execute(cleaned_thresh_img)
image_array = sitk.GetArrayFromImage(cleaned_thresh_img)
plt.title("Coupe de l'image")
plt.plot([255],[284], marker='o', ls='none', ms=5,color="y")
plt.imshow(image_array[-1,:,:])
np.save(Path("temp/preprocessed_image.npy"), image_array)