# Nuclear segmentation of nuclei

Scripts to detect nuclei from 5d images (t,c,z,y,x) with a spot channel and an optional nuclear marker channel.
Mostly based on scripts from Lucien Hinderling, with some modifications and cleanup by Jennifer Semple.

Nuclear segmentation carried out with Cellpose.

**Inputs**:

1) Pandas dataframe fileList.csv (or fileList_wormMasks.csv) with absolute paths to raw .nd2 files and denoised .tif files in columns named raw_filepath and denoised_filepath.
An 'id' column has a unique id for each image. Other metadata columns can also be present.

Example column names:
*filename	date	experiment	strain	protein	id  raw_filepath    denoised_filepath*

example line:
*20240915_1268_E_bean_15um	20240915	3d	1268	DPY27	DPY27_3d_20240915_1268_E_bean_15um	/mnt/external.data/MeisterLab/Dario/Imaging/DP...	/mnt/external.data/MeisterLab/Dario/Imaging/DP...*

2) output_path is the location of the file produced by this script

**Outputs**:

segmentation masks (.tif files) in output_path/segmentation/

distance masks (.tif files) in output_path/edt/

qc plots of segmentation on original image (segmentation_XXX.pdf), individual masked nuclei (cropped_nuclei_XXX.pdf) in output_path/qc/


### Setting you might need to change

output_path - create a directory for the analysis. results will be stored in a protein/strain/date structure same as in the raw_input_path.

path_type - “server”, “mac” or “wsl” so I can switch between working on server or with izbkingston mounted on mac/pc. (for PC izbkingston needs to be mounted with sshfs as /mnt/izbkingston/ from within WSL). Since this script works best with gpu, it will almost always be "server"

Set the channel number (as it is in the orignal image) for nuclear marker (nucChannel) and for spots (spotChannel). If there is no additional nuclear marker channel, set it to the same as the spotChannel

use_worm_masks - True if you presegmented particular worm regions which you later want to filter your nuclei by (this determines whether you use fileList.csv or fileList_wormMasks.csv)

model_path - path to pretrained cellpose model



In [None]:
import torch
from skimage.color import label2rgb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cellpose
from cellpose import models
import edt
import glob
import os
import tqdm
from matplotlib_scalebar.scalebar import ScaleBar
import gc
import seaborn as sns
from bioio import BioImage
import bioio_nd2
import bioio_tifffile
from bioio.writers import OmeTiffWriter
from convert_paths import correct_path, correct_save_path, correct_loaded_path


pd.set_option('display.max_columns', None)

## Input settings

In [2]:

nucChannel = 0 # red emerin rings
spotChannel = 0 # green spots
path_type="server"
use_worm_masks = True

#output_path = correct_path('/mnt/external.data/MeisterLab/jsemple/lhinder/segmentation_Kalyan/2025-25-02_bet1-mSG_wPM1353')
#output_path = correct_path('/mnt/external.data/MeisterLab/jsemple/lhinder/segmentation_Kalyan/2025-04-03_bet1-mSG_wPM1353')
#output_path = correct_path('/mnt/external.data/MeisterLab/jsemple/lhinder/segmentation_Kalyan/2025-10-05_bet1-mSG_wPM1353')
output_path = '/mnt/external.data/MeisterLab/jsemple/demo_VIBE/results/2025-10-05_bet1-mSG_wPM1353'
output_path = correct_path(output_path, path_type)

model_path=correct_path('/mnt/external.data/MeisterLab/jsemple/demo_VIBE/cellpose_model/worms_1000epochs_v0', path_type)


if use_worm_masks:
    df = pd.read_csv(os.path.join(output_path,'fileList_wormMasks.csv'))
    df = correct_loaded_path(df, path_type)
else:
    df = pd.read_csv(os.path.join(output_path,'fileList.csv'))
    df = correct_loaded_path(df, path_type)

df.head()


Unnamed: 0,filename,date,protein,strain,treatment,worm_id,id,raw_filepath,denoised_filepath,worm_region
0,2025_10_05_wPM1353_HS_001,2025-10-05,bet1-mSG,wPM1353,HS,1,bet1-mSG_2025-10-05_2025_10_05_wPM1353_HS_001,/mnt/external.data/MeisterLab/jsemple/demo_VIB...,/mnt/external.data/MeisterLab/jsemple/demo_VIB...,head;body_other
1,2025_10_05_wPM1353_nHS_001,2025-10-05,bet1-mSG,wPM1353,nHS,1,bet1-mSG_2025-10-05_2025_10_05_wPM1353_nHS_001,/mnt/external.data/MeisterLab/jsemple/demo_VIB...,/mnt/external.data/MeisterLab/jsemple/demo_VIB...,head


## Create output directories and load cellpose model

In [3]:
if not os.path.exists(os.path.join(output_path,"qc")):
    os.makedirs(os.path.join(output_path,"qc"))

if not os.path.exists(os.path.join(output_path,"segmentation")):
    os.makedirs(os.path.join(output_path,"segmentation"))

if not os.path.exists(os.path.join(output_path,"edt")):
    os.makedirs(os.path.join(output_path,"edt"))


In [4]:
torch.cuda.device(0)

if torch.cuda.is_available():
    print("GPU is available")
    model = models.CellposeModel(pretrained_model=model_path, gpu=True, device =torch.device('cuda:0'))
else:
    print("Only CPU is available")
    model = models.CellposeModel(pretrained_model=model_path, gpu=False)


GPU is available


  state_dict = torch.load(filename, map_location=device)


## Functions for nuclear segmentation and qc

In [None]:
# Disable do_3D, there is a bug. 2D and stitching with overlap works much better.
# Takes around 7min for the whole image on the macbook
def segment_nuclei(img, model):
    ''' use pytorch cellpose model to segment nuclei'''
    masks,flows,styles = model.eval(img,do_3D=False,stitch_threshold=0.3,cellprob_threshold =0,diameter =36)
    return masks,flows,styles


def calc_distance_mask(masks,anisotropy):
    '''Calculate the distance map from the nuclei-edge towards the center of nucleus'''
    masks_edt = edt.edt(masks,anisotropy = anisotropy)
    return masks_edt



def plot_single_nucleus_crop(df, index, df_region_props, nuc_index, img):
    '''Plot a cropped region of a particular nucleus'''
    if spotChannel != nucChannel:
        fig, axs = plt.subplots(nrows = 1, ncols = 2, figsize = (3,1.5),dpi = 250, sharey=True)
    else:
        fig, axs = plt.subplots(nrows = 1, ncols = 1, figsize = (1.5,1.5),dpi = 250, sharey=True)
    fig.suptitle(f'{df.id.iloc[index]}', fontsize=6)

    intensity_image = df_region_props['intensity_image'][nuc_index][:,:,:,spotChannel] #show first spot channel
    image = df_region_props['image'][nuc_index]
    mx = np.ma.masked_array(intensity_image, mask = ~image)
    z_height = image.shape[0] 
    axs[0].imshow(mx[int(z_height/2)])
    axs[0].spines['top'].set_visible(False)
    axs[0].spines['right'].set_visible(False)
    axs[0].spines['bottom'].set_visible(False)
    axs[0].spines['left'].set_visible(False)
    axs[0].get_xaxis().set_ticks([])
    axs[0].get_yaxis().set_ticks([])

    if spotChannel != nucChannel:
        intensity_image = df_region_props['intensity_image'][nuc_index][:,:,:,nucChannel] #show second nuclear channel
        image = df_region_props['image'][nuc_index]
        mx = np.ma.masked_array(intensity_image, mask = ~image)
        z_height = image.shape[0]
        axs[1].imshow(mx[int(z_height/2)])
        axs[1].spines['top'].set_visible(False)
        axs[1].spines['right'].set_visible(False)
        axs[1].spines['bottom'].set_visible(False)
        axs[1].spines['left'].set_visible(False)
        axs[1].get_xaxis().set_ticks([])
        axs[1].get_yaxis().set_ticks([])


    scalebar = ScaleBar(0.065, "um", length_fraction=1, box_alpha=0.7,color='black',location='lower right',height_fraction = 0.05,border_pad =-1)
    if spotChannel != nucChannel:
        axs[1].add_artist(scalebar)
    else:
        axs[0].add_artist(scalebar) 

    plt.show()


def plot_qc_segmentation_xyz(img, masks, index, df, t=0, display_plot=False, plotContours=False):
    '''Plot a 2x3 grid of xy, xz, yz slices of the image and the corresponding segmentation'''
    nucChannel = 0
    num_z=img.shape[1]
    num_y=img.shape[2]
    num_x=img.shape[3]
    nlabel=100

    fig = plt.figure(layout='constrained',dpi=450,figsize = (10,10))
    fig.suptitle(f'Segmentation for {df.id.iloc[index]}', fontsize=10)
    subfigs = fig.subfigures(2, 1, wspace=0.1)

    axsTop = subfigs[0].subplots(2, 3,sharex=True, sharey=True)
    #xy
    axsTop[0,0].imshow(label2rgb(masks[int(num_z*0.3),:,:],bg_label=0,bg_color=(255, 255, 255),colors=np.random.random((nlabel, 3))))
    axsTop[1,0].set_title('z='+str(int(num_z*0.3)), fontsize=8)
    axsTop[0,1].imshow(label2rgb(masks[int(num_z*0.5),:,:],bg_label=0,bg_color=(255, 255, 255),colors=np.random.random((nlabel, 3))))
    axsTop[1,1].set_title('z='+str(int(num_z*0.5)), fontsize=8)
    axsTop[0,2].imshow(label2rgb(masks[int(num_z*0.7),:,:],bg_label=0,bg_color=(255, 255, 255),colors=np.random.random((nlabel, 3))))
    axsTop[1,2].set_title('z='+str(int(num_z*0.7)), fontsize=8)

    axsTop[1,0].imshow(img[nucChannel,int(num_z*0.3),:,:],cmap = 'gray_r')
    axsTop[1,1].imshow(img[nucChannel,int(num_z*0.5),:,:],cmap = 'gray_r')
    axsTop[1,2].imshow(img[nucChannel,int(num_z*0.7),:,:],cmap = 'gray_r')

    if plotContours:
        axsTop[1,0].contour(masks[int(num_z*0.3),:,:], [0.5], linewidths=0.5, colors='r')
        axsTop[1,1].contour(masks[int(num_z*0.5),:,:], [0.5], linewidths=0.5, colors='r')
        axsTop[1,2].contour(masks[int(num_z*0.7),:,:], [0.5], linewidths=0.5, colors='r')


    for axss in axsTop:
        for ax in axss:
            #ax.set_xlim(0,num_x)
            #ax.set_ylim(0,num_y)
            ax.set_xticks([])
            ax.set_yticks([])

    axsBottom = subfigs[1].subplots(4, 3,sharex=True,sharey=True)
    #xz
    axsBottom[0,0].imshow(label2rgb(masks[:,int(num_y*0.3),:],bg_label=0,bg_color=(255, 255, 255),colors=np.random.random((nlabel, 3))))
    axsBottom[1,0].set_title('y='+str(int(num_y*0.3)), fontsize=8)
    axsBottom[0,1].imshow(label2rgb(masks[:,int(num_y*0.5),:],bg_label=0,bg_color=(255, 255, 255),colors=np.random.random((nlabel, 3))))
    axsBottom[1,1].set_title('y='+str(int(num_y*0.5)), fontsize=8)
    axsBottom[0,2].imshow(label2rgb(masks[:,int(num_y*0.7),:],bg_label=0,bg_color=(255, 255, 255),colors=np.random.random((nlabel, 3))))
    axsBottom[1,2].set_title('y='+str(int(num_y*0.7)), fontsize=8)

    axsBottom[1,0].imshow(img[nucChannel,:,int(num_y*0.3),:],cmap = 'gray_r')
    axsBottom[1,1].imshow(img[nucChannel,:,int(num_y*0.5),:],cmap = 'gray_r')
    axsBottom[1,2].imshow(img[nucChannel,:,int(num_y*0.7),:],cmap = 'gray_r')

    if plotContours:
        axsBottom[1,0].contour(masks[:,int(num_y*0.3),:], [0.5], linewidths=0.5, colors='r')
        axsBottom[1,1].contour(masks[:,int(num_y*0.5),:], [0.5], linewidths=0.5, colors='r')
        axsBottom[1,2].contour(masks[:,int(num_y*0.7),:], [0.5], linewidths=0.5, colors='r')


    #yz
    axsBottom[2,0].imshow(label2rgb(masks[:,:,int(num_x*0.3)],bg_label=0,bg_color=(255, 255, 255),colors=np.random.random((nlabel, 3))))
    axsBottom[3,0].set_title('x='+str(int(num_x*0.3)), fontsize=8)
    axsBottom[2,1].imshow(label2rgb(masks[:,:,int(num_x*0.5)],bg_label=0,bg_color=(255, 255, 255),colors=np.random.random((nlabel, 3))))
    axsBottom[3,1].set_title('x='+str(int(num_x*0.5)), fontsize=8)
    axsBottom[2,2].imshow(label2rgb(masks[:,:,int(num_x*0.7)],bg_label=0,bg_color=(255, 255, 255),colors=np.random.random((nlabel, 3))))
    axsBottom[3,2].set_title('x='+str(int(num_x*0.7)), fontsize=8)

    axsBottom[3,0].imshow(img[nucChannel,:,:,int(num_x*0.3)],cmap = 'gray_r')
    axsBottom[3,1].imshow(img[nucChannel,:,:,int(num_x*0.5)],cmap = 'gray_r')
    axsBottom[3,2].imshow(img[nucChannel,:,:,int(num_x*0.7)],cmap = 'gray_r')

    if plotContours:
        axsBottom[3,0].contour(masks[:,:,int(num_x*0.3)], [0.5], linewidths=0.5, colors='r')
        axsBottom[3,1].contour(masks[:,:,int(num_x*0.5)], [0.5], linewidths=0.5, colors='r')
        axsBottom[3,2].contour(masks[:,:,int(num_x*0.7)], [0.5], linewidths=0.5, colors='r')

    for axss in axsBottom:
        for ax in axss:
            #ax.set_ylim(0,num_z)
            ax.set_xticks([])
            ax.set_yticks([])

    plt.tight_layout()
    if display_plot:
        plt.show()
    else:
        fig.savefig(os.path.join(output_path,'qc','segmentation_'+df.id.iloc[index]+'_t'+'{:02d}'.format(t)+'.png'))
        plt.close()

In [6]:
## Run the segmentation script on all images (reserve more than 24GB!)
# this produces segmentation, segmentation_qc and edt files
def run_nuclear_segmentation(indices, df, rerun=False, use_denoised=True):
    '''Run the segmentation on all images in the dataframe'''
    for index in tqdm.tqdm(indices):
        if rerun or not os.path.exists(os.path.join(output_path,'edt',df.id.iloc[index]+'_t0.tif')):
            # get anisotropy from raw image metadata
            img_5d = BioImage(df.raw_filepath.iloc[index], reader=bioio_nd2.Reader)

            ZvX = np.round(img_5d.physical_pixel_sizes.Z/img_5d.physical_pixel_sizes.X,0)
            anisotropy = (ZvX,1,1)
            # Load the denoised data
            if use_denoised:
                img_5d = BioImage(df.denoised_filepath.iloc[index], reader=bioio_tifffile.Reader)
            for t in range(img_5d.dims.T):
                img = img_5d.get_image_data("CZYX", T=t)

                # Segment nuclei 
                masks,flows,styles = segment_nuclei(img[nucChannel,:,:,:],model) # Run the segmentation
                plot_qc_segmentation_xyz(img,masks,index, df, t, display_plot = False)                         # Create qc plot
                OmeTiffWriter.save(masks, os.path.join(output_path,'segmentation',df.id.iloc[index]+'_t'+'{:02d}'.format(t)+'.tif'))

                del flows
                del styles
                gc.collect()
                
                # Calculate edt 
                masks_edt = calc_distance_mask(masks,anisotropy)
                OmeTiffWriter.save(masks_edt, os.path.join(output_path,'edt',df.id.iloc[index]+'_t'+'{:02d}'.format(t)+'.tif'))

                del masks
                del masks_edt
                gc.collect()
                continue


## Running the analysis for nuclear segmentation

In [7]:
# run analysis to segment nuclei
indices=range(0,len(df))

#indices=[7,50]

run_nuclear_segmentation(indices, df, rerun=True, use_denoised=True) 



100%|██████████| 80/80 [00:04<00:00, 19.59it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0030314768628217914..255.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.003044939462909957..255.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0022763987603476865..255.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [255.0..255.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [255.0..255.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.004826562097455023..255.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] fo