# 0.2. 3D Seeded Watershed Segmentation

This script uses the `watershed` algorithm from `skimage.segmentation` to perform seeded segmentation on a label image where a cluster of objects is classified with a single label. 

The seeds for the watershed segmentation are generated from the original clustered label image. Specifically, the z-slice containing the highest number of connected components is selected, and its connected regions are used as markers (seeds). These seeds guide the watershed algorithm to separate the cluster into individual objects.


## 0.2.1. Load Python Libraries

Load the necessary Python libraries for image processing and visualization. 

The `skimage` library is used for image segmentation, while `matplotlib` is used for displaying images.

Other libraries such as `numpy` and `os` are used for numerical operations and file handling, respectively.

`warnings` is imported to suppress any low contrast image warnings that may arise during the execution of the code.

In [None]:
from skimage.io import imread, imsave, imshow
from skimage.measure import label
from skimage.segmentation import relabel_sequential, watershed
from skimage.morphology import erosion
import numpy as np
import os
import matplotlib.pyplot as plt

# Suppress warnings for low contrast images
import warnings

warnings.filterwarnings("ignore", message=".*is a low contrast image.*")

## 0.2.2. Load Functions

Load custom functions to handle image processing tasks.

In [None]:
# Define functions for image processing and visualization
def display_label_array(label_array: np.ndarray, title: str = "Label Array") -> None:
    """
    Display a 3D label array as a series of 2D slices.

    Parameters:
    - label_array (np.ndarray):  3D numpy array with labeled connected components.
    - title (str): Title for the plot.

    Returns:
    - None: Displays the label array slices in a matplotlib figure.

    """
    if label_array.ndim == 2:
        # Create subplots
        fig, axes = plt.subplots(1, 1, figsize=(15, 5))
        (imshow(label_array, cmap="Wistia"))

    elif label_array.ndim != 3:
        raise ValueError("Input label_array must be 2D or 3D.")

    elif label_array.ndim == 3:
        # Number of slices to display
        num_slices = label_array.shape[0]

        # Create subplots
        fig, axes = plt.subplots(1, num_slices, figsize=(15, 5))

        # Handle the case where there is only one slice
        if num_slices == 1:
            axes = [axes]  # Convert single Axes object to a list for consistency

        for z in range(num_slices):
            axes[z].imshow(
                label_array[z, :, :], cmap="gray", vmin=0, vmax=label_array.max()
            )
            if z == 0:
                axes[z].set_title(title)
            else:
                axes[z].set_title(f"z={z}")
            axes[z].axis("off")  # Hide axes for cleaner visualization

    plt.show()


def is_label_image(image: np.ndarray) -> bool:
    if np.issubdtype(image.dtype, np.integer) and len(np.unique(image)) < 256:
        return True
    return False


def get_label_indices(directory: str, label_image_name: str) -> tuple:
    """
    Get the indices of labeled regions in a label image.

    Parameters:
    - directory: Directory containing the label image.
    - label_image_name: Name of the label image file.

    Returns:
    - label_indices: List of unique labels in the label image.
    - original_array_shape: Shape of the original label array.
    """
    # Read the label image
    lbl_array = imread(
        os.path.join(directory, label_image_name),
    )
    # get the shape of the label array
    original_array_shape = lbl_array.shape

    # Check if label image or regular image
    if is_label_image(lbl_array):
        # If label image, use it as is
        lbl_array = lbl_array.astype(np.uint8)
    else:
        # If regular image, return None
        return None, original_array_shape

    # Get a list of unique labels in the label array
    lbl_list = np.unique(lbl_array)

    # Check if the label array is empty
    if lbl_list.size == 1:
        # If the label array is empty, return empty list
        # and the original array shape
        lbl_idx_list = []
        return lbl_idx_list, original_array_shape

    # Get the indices of the labels in the label array
    lbl_idx_list = [np.where(lbl_array == lbl) for lbl in lbl_list if lbl != 0]

    return lbl_idx_list, original_array_shape


def calculate_connected_components(label_array: np.ndarray) -> tuple:
    """
    Calculate the number of connected components in each z-slice of a 3D label array.

    Parameters:
    - label_array: 3D numpy array with labeled connected components.

    Returns:
    - connected_components_list: List of the number of connected components in each slice.
    - cc_array: 3D numpy array with connected components labeled.
    """
    # Create a list to store the number of connected components for each slice
    connected_components_list = []

    # Initialize a zero array for connected components
    cc_array = np.zeros(label_array.shape, dtype=np.uint8)

    # Compute connected components for each slice
    for z in range(label_array.shape[0]):
        # Get the connected components for the current slice
        connected_components, a, b = relabel_sequential(
            label(label_array[z, :, :], connectivity=1)
        )
        # Store the connected components in the cc_array
        cc_array[z, :, :] = connected_components
        # Get the number of connected components and append to the list
        connected_components_list.append(connected_components.max())

    return connected_components_list, cc_array


def generate_label_marker_array(
    label_array: np.ndarray,
    cc_array: np.ndarray,
    max_connected_components_idx: list,
    z_center: int,
) -> np.ndarray:
    """
    Generate a marker array for watershed segmentation.
    The array has the same shape as the label array and contains markers
    for the watershed algorithm based on the number of connected components.

    Parameters:
    - label_array: 3D numpy array with object labels.
    - cc_array: 3D numpy array with connected components labeled.
    - max_connected_components_idx: Indices of slices with the maximum number of connected components.
    - z_center: Center slice index for the label array.

    Returns:
    - marker_array: 3D numpy array with markers for watershed segmentation.
    """
    # Create a marker array with the same shape as the label array
    marker_array = np.zeros(label_array.shape, dtype=np.uint8)

    # If there is only one slice with the maximum number of connected components
    # Use that slice for the watershed algorithm
    if len(max_connected_components_idx) == 1:
        # Get the slice with the maximum number of connected components
        marker_array_raw_pre = cc_array[max_connected_components_idx[0], :, :]

        # Get amount of connected components before erosion
        cc_count_before_erosion = np.unique(marker_array_raw_pre).size

        # Erode the marker array to remove small components
        # This is done to ensure that the markers are not too close to each other
        marker_array_raw = erosion(marker_array_raw_pre, footprint=np.ones((3, 3)))

        # Get amount of connected components after erosion
        cc_count_after_erosion = np.unique(marker_array_raw).size

        # If erosion removed some components, use non-eroded marker array
        if cc_count_after_erosion < cc_count_before_erosion:
            marker_array_raw = marker_array_raw_pre

        # Check the z position for the marker array
        # If the slice with the maximum number of connected components is close to the center,
        # use that slice, otherwise use the center slice
        if abs(max_connected_components_idx[0] - z_center) < 3:
            z_position = max_connected_components_idx[0]
        else:
            z_position = z_center

        marker_array[z_position, :, :] = marker_array_raw

    else:
        # If there are multiple slices with the same maximum number of connected components
        # Use one closest to the center of the label -> z_center
        closest_slice_idx = min(
            max_connected_components_idx, key=lambda idx: abs(idx - z_center)
        )

        marker_array_raw = cc_array[closest_slice_idx, :, :]

        marker_array_raw = erosion(marker_array_raw, footprint=np.ones((3, 3)))

        # Check the z position for the marker array
        # If the closest slice to the center is close to the center,
        # use that slice, otherwise use the center slice
        if abs(closest_slice_idx - z_center) < 3:
            z_position = max_connected_components_idx[0]
        else:
            z_position = z_center

        marker_array[z_position, :, :] = marker_array_raw

    return marker_array


## 0.2.3. Pipeline

### User-Defined Variables
Define the user-defined variables for the input and output directories.

In [None]:
# Define the base directory containing label images
base_directory = "/path/to/your/directory"

# Define the directory to save watersheded label images
save_directory = "/path/to/save/directory"

# Set to True to save the empty label images in the save directory
# Set to False to skip saving empty label images
keep_blank_labels = False

### Run the code

In [None]:
# Create The save directory if it does not exist
# If it exists, it will not raise an error
os.makedirs(save_directory, exist_ok=True)

# Get the list of label images in the base directory
# The label images should be in .tif/.tiff format
lbl_image_list = [
    f for f in os.listdir(base_directory) if f.lower().endswith((".tif", ".tiff"))
]

# Iterate over each label image in the directory
for lbl_image in lbl_image_list:
    print(f"Processing {lbl_image}...")

    # Get the label indices from the label image
    # The get_label_indices function returns a list of indices for each label in the label image
    # and the shape of the original label array
    # This is used to create a new label array with the same shape as the original label
    # image
    # If the label image is not a label image, it will return None
    # If the label image is empty, it will return an empty list
    lbl_idx_list, original_array_shape = get_label_indices(base_directory, lbl_image)

    if lbl_idx_list is None:
        # If the label image is likely not a label image, skip it
        print(f"{lbl_image} is likely not a label image. Skipping...")
        continue

    elif len(lbl_idx_list) == 0:
        # If no labels are found, skip the image
        print("No labels found. Skipping...")

        if keep_blank_labels:
            # If keep_blank_labels is True, save the empty label image
            imsave(
                os.path.join(save_directory, lbl_image),
                np.zeros(original_array_shape, dtype=np.uint16),
            )
            print("Saved empty label image to save directory.")

        continue

    # New array to store the relabeled labels
    # The shape is the same as the original label array
    lbl_array = np.zeros(original_array_shape, dtype=np.uint8)

    # Variable to count the number of relabeled labels
    # This is used to assign unique labels to the components
    re_lbl_count = 0

    # Iterate over each label in the label index list
    for i, lbl_idx in enumerate(lbl_idx_list):
        print(f"Processing label {i + 1}...")

        # Calculate the dimensions of the label bounding rectangle from the indices
        # lbl_idx is a tuple of arrays (z, x, y) for the label
        # Adding 1 to the max value to include the last index
        z_dim = lbl_idx[0].max() - lbl_idx[0].min() + 1
        x_dim = lbl_idx[1].max() - lbl_idx[1].min() + 1
        y_dim = lbl_idx[2].max() - lbl_idx[2].min() + 1

        # Calculate the coordinates of the center of the label
        z_center = z_dim // 2
        x_center = x_dim // 2 + 1
        y_center = y_dim // 2 + 1

        # Create a zero array with the same shape as the label, with padding on xy dimensions
        # The padding is done to ensure that the label array has a border of zeros in x and y dimensions
        # This is done to avoid issues with the watershed algorithm
        label_shape = (z_dim, x_dim + 2, y_dim + 2)
        label_array = np.zeros(label_shape, dtype=np.uint8)

        # Adjust the original indexes to fit the new label_array shape
        # By resetting the coordinates to start from 0
        # Add 1 to the x and y coordinates to ensure that the label array does not touch the edges
        adjusted_z = lbl_idx[0] - lbl_idx[0].min()
        adjusted_x = lbl_idx[1] - lbl_idx[1].min() + 1
        adjusted_y = lbl_idx[2] - lbl_idx[2].min() + 1

        # Fill the empty label_array with the label shape from the adjusted indices
        # This is done to ensure that the label array has the correct shape and the label is
        # placed in the correct position
        label_array[adjusted_z, adjusted_x, adjusted_y] = 1

        # Calculate the connected components for each z slice in the label array
        # This is done to find the markers for the watershed algorithm
        connected_components_list, cc_array = calculate_connected_components(
            label_array
        )

        # If the maximum number of connected components along the z dimension is 1,
        # this means that there is only one connected component in the label
        # In this case, we skip the watershed step and continue to the next label
        # This is done to avoid unnecessary processing and to ensure that the watershed algorithm
        # has enough markers to work with
        if max(connected_components_list) == 1:
            print(
                f"Only one connected component found for label {i + 1}, no markers for watershed applied.\nTransfering to relabeled array"
            )
            # Add the label to the new label array
            # The label is adjusted by adding the relabel count to ensure unique labels
            # This is done to avoid label conflicts in the relabeled array
            lbl_array[lbl_idx] = (
                label_array[adjusted_z, adjusted_x, adjusted_y] + re_lbl_count
            )
            # Increment the label count for the next label
            re_lbl_count += 1
            continue

        # Find the slice with the maximum number of connected components
        # This is done to ensure that the watershed algorithm has enough markers
        max_connected_components_idx = [
            i
            for i, x in enumerate(connected_components_list)
            if x == max(connected_components_list)
        ]

        # If no connected components are found, skip the watershed step
        # and continue to the next label
        if len(max_connected_components_idx) == 0:
            print(f"No connected components found for label {lbl_idx}.")
            continue

        else:
            # Generate a marker array for the watershed algorithm
            # The marker array is generated based on the connected components
            # and the maximum number of connected components found in the label
            marker_array = generate_label_marker_array(
                label_array,
                cc_array,
                max_connected_components_idx,
                z_center,
            )

        # Apply the watershed algorithm to the label array using the marker array as the seed points
        watersheded_array = watershed(
            image=label_array,
            markers=marker_array,
            mask=label_array,
            connectivity=26,
            compactness=0.001,
        )

        # Uncomment the function below to visualize the watersheded array
        # This will display the watersheded label array in a matplotlib figure
        # display_label_array(watersheded_array, title="watershed")

        # Add the watersheded label to the new full sized label array
        lbl_array[lbl_idx] = (
            watersheded_array[adjusted_z, adjusted_x, adjusted_y] + re_lbl_count
        )
        # Increment the relabel count for the next label
        # This takes into account the number of labels after the watershed step
        re_lbl_count += max(connected_components_list)

    # Save the new label array with the watersheded labels
    # The label array is saved as a 16-bit unsigned integer image
    imsave(
        os.path.join(save_directory, lbl_image),
        lbl_array.astype(np.uint16),
    )


Processing C4-16122021_Label45_367L_w3_1076_100x_0p21_01_scaled_oriScale.tif...
Processing label 1...
Only one conected component found for label 1, no markers for watershed applied.
Transfering to relabeled array
Processing label 2...
Only one conected component found for label 2, no markers for watershed applied.
Transfering to relabeled array
Processing label 3...
Processing label 4...
Only one conected component found for label 4, no markers for watershed applied.
Transfering to relabeled array
Processing label 5...
Only one conected component found for label 5, no markers for watershed applied.
Transfering to relabeled array
Processing label 6...
Processing label 7...
Processing label 8...
Only one conected component found for label 8, no markers for watershed applied.
Transfering to relabeled array
Processing label 9...
Only one conected component found for label 9, no markers for watershed applied.
Transfering to relabeled array
Processing C4-12012022_Label46_367L_Cd16_100x_0p21