This notebook focuses on trying to find a way to segment cells within organoids properly.
The end goals is to segment cell and extract morphology features from cellprofiler.
These masks must be imported into cellprofiler to extract features.

In [1]:
import argparse
import pathlib

import matplotlib.pyplot as plt

# Import dependencies
import numpy as np
import skimage
import tifffile
from cellpose import models
from PIL import Image
from stardist.plot import render_label

# check if in a jupyter notebook
try:
    cfg = get_ipython().config
    in_notebook = True
except NameError:
    in_notebook = False

print(in_notebook)

True


In [None]:
if not in_notebook:
    # set up arg parser
    parser = argparse.ArgumentParser(description="Segment the nuclei of a tiff image")

    parser.add_argument(
        "--input_dir",
        type=str,
        help="Path to the input directory containing the tiff images",
    )

    parser.add_argument(
        "--clip_limit",
        type=float,
        help="Clip limit for the adaptive histogram equalization",
    )

    args = parser.parse_args()
    clip_limit = args.clip_limit
    input_dir = pathlib.Path(args.input_dir).resolve(strict=True)

else:
    input_dir = pathlib.Path(
        "../../2.illumination_correction/illum_directory/W0052_F0001"
    ).resolve(strict=True)
    clip_limit = 0.4

## Set up images, paths and functions

In [3]:
image_extensions = {".tif", ".tiff"}
files = sorted(input_dir.glob("*"))
files = [str(x) for x in files if x.suffix in image_extensions]

In [4]:
image_dict = {
    "nuclei_file_paths": [],
    "nuclei": [],
    "cytoplasm1": [],
    "cytoplasm2": [],
    "cytoplasm3": [],
}

In [5]:
# split files by channel
for file in files:
    if "C4" in file.split("/")[-1]:
        image_dict["nuclei_file_paths"].append(file)
        image_dict["nuclei"].append(tifffile.imread(file).astype(np.float32))
    elif "C1" in file.split("/")[-1]:
        image_dict["cytoplasm1"].append(tifffile.imread(file).astype(np.float32))
    elif "C2" in file.split("/")[-1]:
        image_dict["cytoplasm2"].append(tifffile.imread(file).astype(np.float32))
    elif "C3" in file.split("/")[-1]:
        image_dict["cytoplasm3"].append(tifffile.imread(file).astype(np.float32))

cytoplasm_image_list = [
    np.max(np.array([cytoplasm1, cytoplasm2, cytoplasm3]), axis=0)
    for cytoplasm1, cytoplasm2, cytoplasm3 in zip(
        image_dict["cytoplasm1"], image_dict["cytoplasm2"], image_dict["cytoplasm3"]
    )
]
nuclei_image_list = [np.array(nuclei) for nuclei in image_dict["nuclei"]]

cyto = np.array(cytoplasm_image_list).astype(np.int32)
nuclei = np.array(nuclei_image_list).astype(np.int32)

cyto = skimage.exposure.equalize_adapthist(cyto, clip_limit=clip_limit)
nuclei = skimage.exposure.equalize_adapthist(nuclei, clip_limit=clip_limit)


print(cyto.shape, nuclei.shape)

dtype.py (527): Downcasting int32 to uint16 without scaling because max value 57587 fits in uint16
dtype.py (527): Downcasting int32 to uint16 without scaling because max value 64547 fits in uint16


(18, 2000, 2000) (18, 2000, 2000)


In [6]:
original_nuclei_image = nuclei.copy()
original_cyto_image = cyto.copy()

In [7]:
imgs = []
# save each z-slice as an RGB png
for z in range(nuclei.shape[0]):

    nuclei_tmp = nuclei[z, :, :]
    cyto_tmp = cyto[z, :, :]
    nuclei_tmp = (nuclei_tmp / nuclei_tmp.max() * 255).astype(np.uint8)
    cyto_tmp = (cyto_tmp / cyto_tmp.max() * 255).astype(np.uint8)
    # save the image as an RGB png with nuclei in blue and cytoplasm in red
    RGB = np.stack([cyto_tmp, np.zeros_like(cyto_tmp), nuclei_tmp], axis=-1)

    # change to 8-bit
    RGB = (RGB / RGB.max() * 255).astype(np.uint8)

    rgb_image_pil = Image.fromarray(RGB)

    imgs.append(rgb_image_pil)

## Cellpose

In [None]:
# model_type='cyto' or 'nuclei' or 'cyto2' or 'cyto3'
model_name = "cyto3"
model = models.Cellpose(model_type=model_name, gpu=True)

channels = [[1, 3]]  # channels=[red cells, blue nuclei]

masks_all_dict = {"masks": [], "imgs": []}
imgs = np.array(imgs)

# get masks for all the images
# save to a dict for later use
for img in imgs:
    # masks, flows, styles, diams = model.eval(img, diameter=diameter, channels=channels)
    masks, flows, styles, diams = model.eval(img, channels=channels)

    masks_all_dict["masks"].append(masks)
    masks_all_dict["imgs"].append(img)
print(len(masks_all_dict))
masks_all = masks_all_dict["masks"]
imgs = masks_all_dict["imgs"]

resnet_torch.py (271): You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


In [None]:
for frame_index, frame in enumerate(image_dict["nuclei_file_paths"]):
    frame
    tifffile.imwrite(
        f"{input_dir}/{str(frame).split('/')[-1].split('_C4')[0]}_cell_mask.tiff",
        nuclei[frame_index, :, :],
    )

In [None]:
if in_notebook:
    for z in range(len(masks_all)):
        plt.figure(figsize=(30, 10))
        plt.title(f"z: {z}")
        plt.axi("off")
        plt.subplot(1, 4, 1)
        plt.imshow(nuclei[z], cmap="gray")
        plt.title("Nuclei")
        plt.axis("off")

        plt.subplot(142)
        plt.imshow(cyto[z], cmap="gray")
        plt.title("Cytoplasm")
        plt.axis("off")

        plt.subplot(143)
        plt.imshow(imgs[z], cmap="gray")
        plt.title("Red: Cytoplasm, Blue: Nuclei")
        plt.axis("off")

        plt.subplot(144)
        plt.imshow(render_label(masks_all[z]))
        plt.title("Cell masks")
        plt.axis("off")
        plt.show()