In [132]:
from tifffile import imread, imwrite
import napari
from skimage import measure
import numpy as np

In [133]:
labels = imread('./training_data/thick_brain_Nikon_dagnysd_0.5/y(512x512)/X_15.tif')

In [134]:
viewer = napari.Viewer(ndisplay=2)
viewer.add_labels(labels)



<Labels layer 'labels' at 0x23c1ec54820>

In [135]:
labels_2D = labels[0,:,:]

In [136]:
viewer.add_labels(labels_2D)

<Labels layer 'labels_2D' at 0x23c1ec54f10>

In [137]:
# Load roi from tiff
roi = imread("./test_mask.tiff")

In [138]:
# Simulate a full_image roi
# Create an array of ones with the same shape as roi
ones_array = np.ones_like(roi, dtype=np.int8)
roi = ones_array

In [139]:
viewer.add_labels(roi)

<Labels layer 'roi' at 0x23c1f45c400>

In [140]:
def extract_contour(roi: np.ndarray) -> np.ndarray:
    """
    Extracts the contour of a binary ROI image and returns a binary mask 
    where the contour pixels are set to 1.

    Parameters:
    -----------
    roi : np.ndarray
        A 2D NumPy array of dtype int8 representing the binary region of interest (ROI),
        where nonzero values indicate the foreground.

    Returns:
    --------
    np.ndarray
        A binary mask of the same shape as `roi`, where contour pixels are set to 1,
        and all other pixels are 0.

    Notes:
    ------
    - Uses `skimage.measure.find_contours` to detect continuous contour points.
    - Vectorized operations are used for efficiency.
    - The function assumes that `roi` is a binary image (values of 0 or 1).

    Example:
    --------
    >>> import numpy as np
    >>> from skimage.draw import disk
    >>> roi = np.zeros((100, 100), dtype=np.int8)
    >>> rr, cc = disk((50, 50), 20)
    >>> roi[rr, cc] = 1
    >>> contour_mask = extract_contour(roi)
    >>> print(contour_mask.sum())  # Nonzero values represent contour pixels
    """
    # Find contours, output is a list of (N, 2) arrays representing continuous (x, y) coordinates of contour points
    contours = measure.find_contours(roi, level=0.5)

    # Concatenate all contour points
    all_contours = np.vstack(contours)  # Shape: (N, 2)

    # Round to integer pixel indices
    all_contours = np.round(all_contours).astype(int)

    # Clip to ensure indices are within image bounds
    all_contours[:, 0] = np.clip(all_contours[:, 0], 0, roi.shape[0] - 1)
    all_contours[:, 1] = np.clip(all_contours[:, 1], 0, roi.shape[1] - 1)

    # Create an empty binary mask
    contour_mask = np.zeros_like(roi, dtype=np.uint8)

    # Set pixels at contour locations to 1 using NumPy advanced indexing
    contour_mask[all_contours[:, 0], all_contours[:, 1]] = 1
    
    return contour_mask

def remove_labels_touching_roi_edge(labels, roi):

    # Check if roi covers the entire image (all values stored in roi == 1):
    if np.all(roi == 1):
        #  Generate a contour that covers the border of the image
        contour_mask = np.zeros(roi.shape, dtype=np.int8)

        # Set the outer border to 1 
        contour_mask[0, :] = 1  # Top edge
        contour_mask[-1, :] = 1  # Bottom edge
        contour_mask[:, 0] = 1  # Left edge
        contour_mask[:, -1] = 1  # Right edge

    # Otherwise extract the contour of a user-defined roi:
    else:
        contour_mask = extract_contour(roi)

    # 3D segmentation logic, extend 2D mask across the entire stack volume
    if len(labels.shape) == 3:
        # Extract the number of z-slices to extend the mask into a 3D-volume
        slice_nr = labels.shape[0]

        # Extend the mask across the entire volume
        contour_mask = np.tile(contour_mask, (slice_nr, 1, 1))

    # Convert contour_mask to boolean mask
    contour_mask_bool = contour_mask.astype(bool)

    # Use NumPy's advanced indexing to identify labels that intersect with the roi contour
    intersecting_labels = np.unique(labels[contour_mask_bool])
    intersecting_labels = intersecting_labels[intersecting_labels != 0]  # Remove background label

    # Create a mask where labels are in intersecting_labels
    mask = np.isin(labels, intersecting_labels)

    # Set labels intersecting with the contour to 0
    filtered_labels = np.where(mask, 0, labels)

    # Relabel filtered labels so labels are identified by continous and not sparse integers
    filtered_labels = measure.label(filtered_labels)

    return filtered_labels


In [141]:
filtered_labels = remove_labels_touching_roi_edge(labels, roi)
viewer.add_labels(filtered_labels)

<Labels layer 'filtered_labels' at 0x23c239569a0>

In [142]:
filtered_labels = remove_labels_touching_roi_edge(labels_2D, roi)
viewer.add_labels(filtered_labels)

<Labels layer 'filtered_labels [1]' at 0x23c239bc760>