The goal of this notebook is to reassign segmentation labels based on the objects that they are contained in. 
To do so, a hierarchy of objects must first be defined.
# The hierarchy of objects is defined as follows:
- **Cell**
    - **Nucleus**
    - **Cytoplasm**

The index of a given cytoplasm should be the same as that of cell it came from.
The nucleus index should be the same as that of the cell it came from. 

There will also be rules implemented for sandwiched indexes.
This is when an object was not related properly and was assigned a different index while being surrounded (above and below in the z dimension) by the same object.
Such cases will be assigned the same index as the object that is above and below it.


In [1]:
import argparse
import pathlib
import sys

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import skimage
import skimage.io as io
import tifffile
from cellpose import core, models, utils
from rich.pretty import pprint

sys.path.append("../../utils")
import nviz
from nviz.image_meta import extract_z_slice_number_from_filename, generate_ome_xml
from segmentation_decoupling import euclidian_2D_distance

# check if in a jupyter notebook
try:
    cfg = get_ipython().config
    in_notebook = True
except NameError:
    in_notebook = False



Welcome to CellposeSAM, cellpose v
cellpose version: 	4.0.4 
platform:       	linux 
python version: 	3.11.12 
torch version:  	2.7.0+cu126! The neural network component of
CPSAM is much larger than in previous versions and CPU excution is slow. 
We encourage users to use GPU/MPS if available. 




In [2]:
if not in_notebook:
    print("Running as script")
    # set up arg parser
    parser = argparse.ArgumentParser(description="Segment the nuclei of a tiff image")

    parser.add_argument(
        "--patient",
        type=str,
        help="The patient ID",
    )

    parser.add_argument(
        "--well_fov",
        type=str,
        help="Path to the input directory containing the tiff images",
    )
    parser.add_argument(
        "--radius_constraint",
        type=int,
        default=10,
        help="The maximum radius of the x-y vector",
    )

    args = parser.parse_args()
    well_fov = args.well_fov
    patient = args.patient
else:
    print("Running in a notebook")
    well_fov = "C4-2"
    patient = "NF0014"

mask_dir = pathlib.Path(f"../../data/{patient}/processed_data/{well_fov}").resolve()

Running in a notebook


In [3]:
# get the organoid masks
# cell_mask_path = mask_dir / "cell_masks_reconstructed_corrected.tiff"
cell_mask_path = mask_dir / "cell_masks_watershed.tiff"
cell_mask_output_path = mask_dir / "cell_masks_reassigned.tiff"
nuclei_mask_path = mask_dir / "nuclei_masks_reconstructed_corrected.tiff"
cell_mask = io.imread(cell_mask_path)
nuclei_mask = io.imread(nuclei_mask_path)

In [4]:
# get the centroid and bbox of the cell mask
cell_df = pd.DataFrame.from_dict(
    skimage.measure.regionprops_table(
        cell_mask,
        properties=["centroid", "bbox"],
    )
)
cell_df["compartment"] = "cell"
cell_df["label"] = cell_mask[
    cell_df["centroid-0"].astype(int),
    cell_df["centroid-1"].astype(int),
    cell_df["centroid-2"].astype(int),
]
# remove all 0 labels
cell_df = cell_df[cell_df["label"] > 0].reset_index(drop=True)
cell_df["new_label"] = cell_df["label"]

In [5]:
nuclei_df = pd.DataFrame.from_dict(
    skimage.measure.regionprops_table(
        nuclei_mask,
        properties=["centroid", "bbox"],
    )
)
nuclei_df["compartment"] = "nuclei"
nuclei_df["label"] = nuclei_mask[
    nuclei_df["centroid-0"].astype(int),
    nuclei_df["centroid-1"].astype(int),
    nuclei_df["centroid-2"].astype(int),
]
nuclei_df = nuclei_df[nuclei_df["label"] > 0].reset_index(drop=True)

In [6]:
nuclei_df.head()

Unnamed: 0,centroid-0,centroid-1,centroid-2,bbox-0,bbox-1,bbox-2,bbox-3,bbox-4,bbox-5,compartment,label
0,10.934588,746.901225,473.777777,6,687,421,17,808,528,nuclei,14
1,7.888706,224.482524,715.338595,0,165,652,17,283,780,nuclei,24
2,4.147467,253.489444,503.487076,0,195,453,10,313,564,nuclei,34
3,19.420771,424.838544,693.588996,6,369,639,33,474,757,nuclei,39
4,5.122876,694.579159,399.909088,0,649,355,12,745,447,nuclei,46


In [7]:
cell_df.head()

Unnamed: 0,centroid-0,centroid-1,centroid-2,bbox-0,bbox-1,bbox-2,bbox-3,bbox-4,bbox-5,compartment,label,new_label
0,11.33902,761.656976,418.747893,0,608,261,33,950,564,cell,15,15
1,10.864223,207.736419,724.041699,0,86,560,33,326,874,cell,26,26
2,3.422383,242.744098,509.650772,0,98,371,15,360,730,cell,37,37
3,13.153025,438.660058,704.646145,0,333,553,33,554,933,cell,43,43
4,4.745857,695.063774,388.185504,0,593,296,22,836,461,cell,51,51


In [8]:
def remove_edge_cases(
    mask: np.ndarray,
    border: int = 10,
) -> np.ndarray:
    """
    Remove masks that are image edge cases

    Parameters
    ----------
    mask : np.ndarray
        The mask to process, should be a 3D numpy array
    border : int, optional
        The number of pixels in width to create border to scan for edge cased, by default 10

    Returns
    -------
    np.ndarray
        The mask with edge cases removed
    """

    edge_pixels = np.concatenate(
        [
            #
            mask[
                :, -border:, :
            ].flatten(),  # all of z, last n rows (y), all columns (x) - bottom edge
            mask[
                :, 0:border, :
            ].flatten(),  # all of z, first n rows (y), all columns (x) - top edge
            mask[
                :, :, 0:border:
            ].flatten(),  # all of z, all rows (y), first n columns (x) - left edge
            mask[
                :, :, -border:
            ].flatten(),  # all of z, all rows (y), last n columns (x) - right edge
            # each are the edges stacked for the whole volume -> no need to specify every z slice or 3D edge
        ]
    )
    # get unique edge pixel values
    edge_pixels = np.unique(edge_pixels[edge_pixels > 0])

    for edge_pixel_case in edge_pixels:
        # make the edge cases equal to zero
        mask[mask == edge_pixel_case] = 0

    # return the mask with edge cases removed
    return mask

In [9]:
def centroid_within_bbox_detection(
    centroid: tuple,
    bbox: tuple,
) -> bool:
    """
    Check if the centroid is within the bbox

    Parameters
    ----------
    centroid : tuple
        Centroid of the object in the order of (z, y, x)
        Order of the centroid is important
    bbox : tuple
        Where the bbox is in the order of (z_min, y_min, x_min, z_max, y_max, x_max)
        Order of the bbox is important

    Returns
    -------
    bool
        True if the centroid is within the bbox, False otherwise
    """
    z_min, y_min, x_min, z_max, y_max, x_max = bbox
    z, y, x = centroid
    # check if the centroid is within the bbox
    if (
        z >= z_min
        and z <= z_max
        and y >= y_min
        and y <= y_max
        and x >= x_min
        and x <= x_max
    ):
        return True
    else:
        return False

In [10]:
def check_if_centroid_within_mask(
    centroid: tuple, mask: np.ndarray, label: int
) -> bool:
    """
    Check if the centroid is within the mask

    Parameters
    ----------
    centroid : tuple
        Centroid of the object in the order of (z, y, x)
        Order of the centroid is important
    mask : np.ndarray
        The mask to check against

    Returns
    -------
    bool
        True if the centroid is within the mask, False otherwise
    """
    z, y, x = centroid
    z = np.round(z).astype(int)
    y = np.round(y).astype(int)
    x = np.round(x).astype(int)
    # check if the centroid is within the segmentation mask
    cell_label = mask[z, y, x]
    if cell_label > 0 and cell_label == label:
        return True
    else:
        return False

In [11]:
# nuclei_df = nuclei_df.head(10)
# cell_df = cell_df.head(10)

In [12]:
print(f"Number of nuclei: {len(nuclei_df)}\nNumber of cells: {len(cell_df)}\n")

Number of nuclei: 33
Number of cells: 35



In [13]:
# if a centroid of the nuclei is inside the cell mask,
# then make the cell retain the label of the nuclei
for i, row in nuclei_df.iterrows():
    for j, row2 in cell_df.iterrows():
        # nuc_contained_in_cell_bool = check_if_centroid_within_mask(
        #     centroid=(
        #         row["centroid-0"],
        #         row["centroid-1"],
        #         row["centroid-2"],
        #     ),
        #     bbox=(
        #         row2["bbox-0"],
        #         row2["bbox-1"],
        #         row2["bbox-2"],
        #         row2["bbox-3"],
        #         row2["bbox-4"],
        #         row2["bbox-5"],
        #     ),
        # )
        nuc_contained_in_cell_bool = check_if_centroid_within_mask(
            centroid=(
                row["centroid-0"],
                row["centroid-1"],
                row["centroid-2"],
            ),
            mask=cell_mask,
            label=row2["label"],
        )
        if nuc_contained_in_cell_bool:
            # if the centroid of the nuclei is within the cell mask,
            # then make the cell retain the label of the nuclei
            cell_df.at[j, "new_label"] = row["label"]
            break
        else:
            # print(f"Cell {row2['label']} does not contain nuclei {row['label']}")
            pass

In [14]:
cell_df.head()

Unnamed: 0,centroid-0,centroid-1,centroid-2,bbox-0,bbox-1,bbox-2,bbox-3,bbox-4,bbox-5,compartment,label,new_label
0,11.33902,761.656976,418.747893,0,608,261,33,950,564,cell,15,14
1,10.864223,207.736419,724.041699,0,86,560,33,326,874,cell,26,24
2,3.422383,242.744098,509.650772,0,98,371,15,360,730,cell,37,34
3,13.153025,438.660058,704.646145,0,333,553,33,554,933,cell,43,39
4,4.745857,695.063774,388.185504,0,593,296,22,836,461,cell,51,46


In [15]:
nuclei_df.head()

Unnamed: 0,centroid-0,centroid-1,centroid-2,bbox-0,bbox-1,bbox-2,bbox-3,bbox-4,bbox-5,compartment,label
0,10.934588,746.901225,473.777777,6,687,421,17,808,528,nuclei,14
1,7.888706,224.482524,715.338595,0,165,652,17,283,780,nuclei,24
2,4.147467,253.489444,503.487076,0,195,453,10,313,564,nuclei,34
3,19.420771,424.838544,693.588996,6,369,639,33,474,757,nuclei,39
4,5.122876,694.579159,399.909088,0,649,355,12,745,447,nuclei,46


In [16]:
print(nuclei_df["label"].unique())
print(cell_df["new_label"].unique())

[ 14  24  34  39  46  52  62  63  71  78  96 102 114 117 130 141 147 149
 155 156 167 168 179 188 189 200 203 204 215 218 219 228 229]
[ 14  24  34  39  46  62  63  71  78 149  96  15 102 114 117 130 141 147
 165 155 156 167 168 193 179 188 189 200 203 204 215 218 219 228 229]


In [17]:
# merge the dataframes
nuclei_and_cell_df = pd.merge(
    nuclei_df,
    cell_df,
    left_on="label",
    right_on="new_label",
    suffixes=("_nuclei", "_cell"),
)
pd.options.display.max_columns = None
nuclei_and_cell_df

Unnamed: 0,centroid-0_nuclei,centroid-1_nuclei,centroid-2_nuclei,bbox-0_nuclei,bbox-1_nuclei,bbox-2_nuclei,bbox-3_nuclei,bbox-4_nuclei,bbox-5_nuclei,compartment_nuclei,label_nuclei,centroid-0_cell,centroid-1_cell,centroid-2_cell,bbox-0_cell,bbox-1_cell,bbox-2_cell,bbox-3_cell,bbox-4_cell,bbox-5_cell,compartment_cell,label_cell,new_label
0,10.934588,746.901225,473.777777,6,687,421,17,808,528,nuclei,14,11.33902,761.656976,418.747893,0,608,261,33,950,564,cell,15,14
1,7.888706,224.482524,715.338595,0,165,652,17,283,780,nuclei,24,10.864223,207.736419,724.041699,0,86,560,33,326,874,cell,26,24
2,4.147467,253.489444,503.487076,0,195,453,10,313,564,nuclei,34,3.422383,242.744098,509.650772,0,98,371,15,360,730,cell,37,34
3,19.420771,424.838544,693.588996,6,369,639,33,474,757,nuclei,39,13.153025,438.660058,704.646145,0,333,553,33,554,933,cell,43,39
4,5.122876,694.579159,399.909088,0,649,355,12,745,447,nuclei,46,4.745857,695.063774,388.185504,0,593,296,22,836,461,cell,51,46
5,5.289912,386.904649,742.61498,1,343,671,11,436,811,nuclei,62,5.600186,405.137394,778.324064,0,332,648,21,511,926,cell,69,62
6,15.276792,419.787447,913.823796,9,358,860,22,477,964,nuclei,63,18.250629,423.970203,915.309609,1,327,836,33,567,981,cell,70,63
7,4.353928,557.908332,506.448514,1,488,458,9,630,552,nuclei,71,4.393998,553.833433,527.829793,0,433,369,20,688,714,cell,79,71
8,3.934457,673.236804,1104.533547,1,605,1074,8,737,1136,nuclei,78,5.383401,662.726887,1095.956705,0,520,1045,21,779,1145,cell,86,78
9,1.992485,1430.440448,1239.7706,1,1354,1174,4,1508,1305,nuclei,96,1.349833,1423.265032,1231.128856,0,1308,1134,7,1509,1311,cell,106,96


In [18]:
print(
    f"Number of nuclei: {len(nuclei_df)}\n"
    f"Number of cells: {len(cell_df)}\n"
    f"Number of cells with nuclei: {len(nuclei_and_cell_df)}"
)

Number of nuclei: 33
Number of cells: 35
Number of cells with nuclei: 32


In [19]:
def mask_label_reassignment(
    mask_df: pd.DataFrame,
    mask_input: np.ndarray,
) -> np.ndarray:
    """
    Reassign the labels of the mask based on the mask_df

    Parameters
    ----------
    mask_df : pd.DataFrame
        DataFrame containing the labels and centroids of the mask
    mask_input : np.ndarray
        The input mask to reassign the labels to

    Returns
    -------
    np.ndarray
        The mask with reassigned labels
    """
    for i, row in mask_df.iterrows():
        if row["label"] == row["new_label"]:
            # if the label is already the new label, skip
            continue
        mask_input[mask_input == row["label"]] = row["new_label"]
    return mask_input

In [20]:
cell_mask = remove_edge_cases(
    mask=cell_mask,
    border=10,
)

In [21]:
cell_mask = mask_label_reassignment(
    mask_df=cell_df,
    mask_input=cell_mask,
)
tifffile.imwrite(
    cell_mask_output_path,
    cell_mask,
)