# Watershed with Distance Map
This example illustrates how to segment an image using the watershed method
and the signed Maurer distance map.

In [None]:
import itk
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


Define the types to be used

In [None]:
dimension = 3

uchar_pixel_type = itk.UC
uchar_image_type = itk.Image[uchar_pixel_type, dimension]

float_pixel_type = itk.F
float_image_type = itk.Image[float_pixel_type, dimension]

rgb_pixel_type = itk.RGBPixel[uchar_pixel_type]
RGBImageType = itk.Image[rgb_pixel_type, dimension]

Display the input image and print its shape.

In [None]:
stack_image = itk.imread("PlateauBorder.tif")
print(stack_image.shape, stack_image.dtype)

In [None]:
%matplotlib inline

# Display the input image projections
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))
for i, (cax, clabel) in enumerate(zip([ax1, ax2, ax3], ["xy", "zy", "zx"])):
    cax.imshow(np.sum(stack_image, i).squeeze(), cmap="bone_r")
    cax.set_title("{} projection".format(clabel))
    cax.set_xlabel(clabel[0])
    cax.set_ylabel(clabel[1])

## Create a bubble image
The bubble image is the reverse of the plateau border image: there cannot be
air where there is water.

In [None]:
index_radius = itk.Size[dimension]()
index_radius.Fill(args.binarizing_radius)

bubble_image = itk.voting_binary_iterative_hole_filling_image_filter(
    stack_image,
    radius=index_radius,
    background_value=0,
    foreground_value=255,
    majority_threshold=args.majority_threshold,
)

plt.imshow(bubble_image[5], cmap="bone")

# Write bubble image
itk.imwrite(bubble_image, "ReversedInputImageTest01.tif")

## Watershed on bubbles
Use the ITK watershed operation on the distance map of the input image.

In [None]:
%%time

bbl_cast_image = itk.cast_image_filter(
    bubble_image,
    ttype=(uchar_image_type, float_image_type),
)

# Normalize the image to the [0, 255] range
bubble_image_preclamp = itk.multiply_image_filter(
    bbl_cast_image,
    constant=255.0,
)
bubble_image_clamp = itk.clamp_image_filter(
    bubble_image_preclamp,
    bounds=(0, 255),
)

# Get the distance map of the input image
distance_map_image = itk.signed_maurer_distance_map_image_filter(
    bubble_image_clamp,
    inside_is_positive=False,
)
itk.imwrite(distance_map_image, args.distance_map_output_filename)

# Apply the watershed segmentation
watershed_image = itk.watershed_image_filter(
    distance_map_image,
    threshold=args.watershed_threshold,
    level=args.level,
)

# Cast to unsigned char so that it can be written as a TIFF image
# WatershedImageFilter produces itk.ULL, but CastImageFilter does not wrap
# itk.ULL: see ITK issue 2551
# ws_cast_image = itk.cast_image_filter(watershed_image, ttype=(watershed_image.__class__, uchar_image_type))
# Workaround
rgb_ws_image = itk.scalar_to_rgb_colormap_image_filter(
    watershed_image,
    colormap=itk.ScalarToRGBColormapImageFilterEnums.RGBColormapFilter_Jet,
)

# itk.imwrite(ws_cast_image, args.watershed_output_filename)
itk.imwrite(rgb_ws_image, args.watershed_output_filename)

## Cut-through
Show the values at a slice in the middle as a way to get a feeling for what the
watershed and distance map did.

In [None]:
mid_slice = watershed_image.shape[0] // 2
ws_vol_arr = itk.array_view_from_image(watershed_image)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 7))
ax1.imshow(bubble_image[mid_slice], cmap="bone")
ax1.set_title("Bubble image")
m_val = np.abs(dmap_vol[mid_slice]).std()
ax2.imshow(dmap_vol[mid_slice], cmap="RdBu", vmin=-m_val, vmax=m_val)
ax2.set_title("Distance image\nmin: {:2.2f}; max: {:2.2f}; mean: {:2.2f}".format(
    distance_map_image[mid_slice].min(),
    distance_map_image[mid_slice].max(),
    distance_map_image[mid_slice].mean())
)
ax3.imshow(ws_vol[mid_slice], cmap="nipy_spectral")
ax3.set_title("Watershed\nLabels found: {}".format(
    len(np.unique(ws_vol_arr[ws_vol_arr > 0])))
)

fig.savefig("Pipeline_images.png")

In [None]:
%matplotlib inline

# Show segmentation result projections
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))
for i, (cax, clabel) in enumerate(zip([ax1, ax2, ax3], ["xy", "zy", "zx"])):
    cax.imshow(np.max(ws_vol_arr, i).squeeze(), cmap="nipy_spectral")
    cax.set_title("{} projection".format(clabel))
    cax.set_xlabel(clabel[0])
    cax.set_ylabel(clabel[1])

fig.savefig("Segmentation_projections.png")

## Clean the segmentation and relabel
Clean the segmentation image: remove small objects by performing an
opening morphological operation and relabel in order.

In [None]:
bubble_label_image = np.zeros(watershed_image.shape).astype(np.uint16)
new_idx = 1
bubble_ids = [
    (idx, np.sum(watershed_image[watershed_image==idx] > 0))
    for idx in np.unique(watershed_image[watershed_image > 0])
]

dimension = len(np.shape(bubble_image))
structuring_element_type = itk.FlatStructuringElement[dimension]
# Bubbles are round
structuring_element = structuring_element_type.Ball(
    args.cleaning_structuring_element_radius
)

from tqdm import tqdm

# Count the kept bubbles in bubble label image
for old_idx, vol in tqdm(sorted(bubble_ids, key = lambda x: x[1])):
    if 40000 < vol < 400000:
        old_img = watershed_image==old_idx
        cleaned_img = itk.binary_morphological_opening_image_filter(
            old_img,
            kernel=structuring_element,
        )
        bubble_label_image[old_img] = new_idx
        new_idx += 1

print("Total bubbles kept: {}/{}".format(new_idx, len(bubble_ids)))

# itk.imwrite(segmented_clean_image, "Clean_segmentation.tif")

In [None]:
%matplotlib inline
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))
for i, (cax, clabel) in enumerate(zip([ax1, ax2, ax3], ["xy", "zy", "zx"])):
    cax.imshow(
        np.max(bubble_label_image, i),
        cmap="jet",
        vmin=0,
        vmax=new_idx,
    )
    cax.set_title("{} projection".format(clabel))
    cax.set_xlabel(clabel[0])
    cax.set_ylabel(clabel[1])

fig.savefig("Clean_segmentation_projections.png")

## Show 3D rendering

In [None]:
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from skimage import measure
from tqdm import tqdm

# Show 3D rendering
def show_3d_mesh(image, thresholds):
    p = image[::-1].swapaxes(1, 2)
    cmap = plt.cm.get_cmap("nipy_spectral_r")
    _fig = plt.figure(figsize=(10, 10))
    ax = _fig.add_subplot(111, projection="3d")
    for _i, c_threshold in tqdm(list(enumerate(thresholds))):
        verts, faces, _, _ = measure.marching_cubes(p==c_threshold, level=0)
        mesh = Poly3DCollection(
            verts[faces],
            alpha=0.25,
            edgecolor=None,
            linewidth=0.1,
        )
        mesh.set_facecolor(cmap(_i / len(thresholds))[:3])
        mesh.set_edgecolor([1, 0, 0])
        ax.add_collection3d(mesh)

    ax.set_xlim(0, p.shape[0])
    ax.set_ylim(0, p.shape[1])
    ax.set_zlim(0, p.shape[2])

    ax.view_init(45, 45)
    return _fig

In [None]:
fig = show_3d_mesh(bubble_label_image, range(1, np.max(bubble_label_image), 10))

# Write 3D rendering of segmented image
fig.savefig("Volume_rendering.png")

## Calculate bubble centers

In [None]:
def meshgrid3d_like(in_img):
    return np.meshgrid(
        range(in_img.shape[1]),range(in_img.shape[0]), range(in_img.shape[2])
    )

zz, xx, yy = meshgrid3d_like(bubble_label_image)

out_results = []
for c_label in np.unique(bubble_label_image):  # one bubble at a time
    if c_label > 0:  # ignore background
        cur_roi = bubble_label_image == c_label
        out_results += [
            {
                "x": xx[cur_roi].mean(),
                "y": yy[cur_roi].mean(),
                "z": zz[cur_roi].mean(),
                "volume": np.sum(cur_roi),
            }
        ]

# Write the bubble volume stats
import pandas as pd
out_table = pd.DataFrame(out_results)
out_table.to_csv("bubble_volume_out.csv")

# Write the bubble volume stats sample table
volume_sample = out_table.sample(5)
volume_sample.save("Bubble_volume_stats_sample.ong")

In [None]:
# Write the bubble volume density plot
bubble_volume_density = out_table["volume"].plot.density()
bubble_volume_density.save("Bubble_volume_density_stats.png")

In [None]:
# Write the bubble center plot
bubble_centers_plot = out_table.plot.hexbin("x", "y", gridsize=(5, 5))
bubble_centers_plot.save("Bubble_centers.png")

## Compare with the training values

In [None]:
train_values = pd.read_csv("bubble_volume.csv")

In [None]:
%matplotlib inline
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
_, n_bins, _ = ax1.hist(
    np.log10(train_values["volume"]), bins=20, label="Training volumes"
)
ax1.hist(np.log10(
    out_table["volume"]),
    n_bins,
    alpha=0.5,
    label="Watershed volumes",
)
ax1.legend()
ax1.set_title("Volume comparison\n(Log10)")
ax2.plot(
    out_table["x"],
    out_table["y"],
    "r.",
    train_values["x"],
    train_values["y"],
    "b.",
)
ax2.legend(["Watershed bubbles", "Training bubbles"])

fig.savefig("Bubble_stats.png")