In [None]:
import os
import glob
import pandas as pd
from aicsimageio import AICSImage
from cellpose import models
import pyclesperanto_prototype as cle
import numpy as np
from skimage.filters import threshold_otsu, gaussian
from skimage.segmentation import watershed
from skimage.morphology import disk, erosion, remove_small_objects
from skimage.measure import label
from scipy.ndimage import distance_transform_edt, find_objects
import matplotlib.pyplot as plt
from skimage.color import label2rgb
from scipy.ndimage import find_objects  # corrected import for find_objects
from regionpropsExtension import RegionPropertiesExtension, TEXTURE_FEATURE_NAMES
from numpy import linalg as LA
from scipy import ndimage as ndi
from scipy.stats import median_abs_deviation
from skimage.measure import _moments, find_contours
from skimage.measure._regionprops import RegionProperties, _cached, only2d
from skimage.feature import graycomatrix
from skimage.segmentation import find_boundaries
from pyefd import elliptic_fourier_descriptors
from functools import cached_property
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import ttest_ind


TEXTURE_PERCENTILES = (25, 50, 75)
TEXTURE_CATEGORIES = ["Contrast", "Dissimilarity", "Homogeneity", "Energy", "Correlation"]
TEXTURE_SUMM_STATISTICS = [f"{it}%" for it in TEXTURE_PERCENTILES] + ["mean", "std", "mad"]
TEXTURE_FEATURE_NAMES = [f"{it0}_{it1}" for it0 in TEXTURE_CATEGORIES for it1 in TEXTURE_SUMM_STATISTICS]


def safe_log_10_v0(value):
    """https://stackoverflow.com/questions/21610198/runtimewarning-divide-by-zero-encountered-in-log"""
    value = np.abs(value)
    result = np.where(value > 1e-12, value, -12)
    # print(result)
    res = np.log10(result, out=result, where=result > 1e-12)
    return res


def safe_log_10_v1(value):
    """Pankaj"""
    return -np.log(1+np.abs(value))


def regionprops(w0_mask, w1_mask, w2_mask, w4_mask, img, n_levels):
    N = len(np.unique(w0_mask)) - 1
    regions = np.zeros((5, N), dtype=object)
    has_nucleoli = np.zeros((N, 1), dtype=np.uint8)

    max_ = np.amax(w0_mask)
    w0_objects = ndi.find_objects(w0_mask, max_label=max_)
    w1_objects = ndi.find_objects(w1_mask, max_label=max_)
    w2_objects = ndi.find_objects(w2_mask, max_label=max_)
    w4_objects = ndi.find_objects(w4_mask, max_label=max_)
    cnt = 0
    for ii in range(max_):
        if w0_objects[ii] is None:
            continue
        label = ii + 1
        w0_props = RegionPropertiesExtension(w0_objects[ii], label, w0_mask, img[0])
        w1_props = RegionPropertiesExtension(w1_objects[ii], label, w1_mask, img[1])
        w3_props = RegionPropertiesExtension(w1_objects[ii], label, w1_mask, img[3])
        w4_props = RegionPropertiesExtension(w4_objects[ii], label, w4_mask, img[4])
        if w2_objects[ii] is not None:
            w2_props = RegionPropertiesExtension(w2_objects[ii], label, w2_mask, img[2],n_levels)
            has_nucleoli[cnt] = 1
        else:
            w2_props = None
            has_nucleoli[cnt] = 0

        regions[0, cnt] = w0_props
        regions[1, cnt] = w1_props
        regions[2, cnt] = w2_props
        regions[3, cnt] = w3_props
        regions[4, cnt] = w4_props

        cnt += 1

    return regions, has_nucleoli


class RegionPropertiesExtension(RegionProperties):
    """Please refer to `skimage.measure.regionprops` for more information
    on the available region properties.
    """
    ndim = 2
    bd_val = 10
    bd_padding = [(bd_val, bd_val), (bd_val, bd_val)]

    n_levels = 8
    n_pos_pixels_lb = 10
    corr_tolerance = 1e-8
    distances = np.arange(1, 21)
    angles = np.array([0, np.pi/2])
    angles_str = [0, "pi/2"]
    intensity_percentiles = (10, 25, 75, 90)

    def __init__(self, slice, label, label_image, intensity_image, channel_name,
                 cache_active=True, ):
        super().__init__(slice, label, label_image, intensity_image, cache_active)

        
        self.channel_name = channel_name
        self.I, self.J = self.haralick_ij()
        self.corr_I, self.corr_J = self.haralick_corr_ij()



    @property
    @_cached
    def einsum_instruct_2(self):
        return "ijkm,ijkm->km"

    @property
    @_cached
    def einsum_instruct_3(self):
        return "ijkm,ijkm,ijkm->km"

    @_cached
    def haralick_ij(self):
        # create weights for specified property
        I, J = np.ogrid[0:self.n_levels, 0:self.n_levels]
        return I, J

    @_cached
    def haralick_corr_ij(self):
        I = np.array(range(0, self.n_levels)).reshape((self.n_levels, 1, 1, 1))
        J = np.array(range(0, self.n_levels)).reshape((1, self.n_levels, 1, 1))
        return I, J

    @property
    @_cached
    def weights0(self):
        weights0 = (self.I - self.J) ** 2
        weights0 = weights0.reshape((self.n_levels, self.n_levels, 1, 1))
        return weights0


    @property
    @_cached
    def weights1(self):
        weights1 = np.abs(self.I - self.J)
        weights1 = weights1.reshape((self.n_levels, self.n_levels, 1, 1))
        return weights1

    @property
    @_cached
    def weights2(self):
        return 1. / (1. + self.weights0)

    @property
    @_cached
    def bins(self):
        return np.linspace(self.intensity_min, self.intensity_max, self.n_levels)


    @property
    @_cached
    def image_intensity_discrete(self):
        return np.int32(np.digitize(self.image_intensity, self.bins, right=True))

    @property
    @_cached
    def image_int32(self):
        return np.int32(self.image)

    @property
    @only2d
    @_cached
    def moments_hu(self):
        mh = _moments.moments_hu(self.moments_normalized)
        return -1 * np.sign(mh) * safe_log_10_v0(mh)

    @property
    @only2d
    @_cached
    def moments_weighted_hu(self):
        
        mhw = _moments.moments_hu(self.moments_weighted_normalized)
        return -1 * np.sign(mhw) * safe_log_10_v0(mhw)

    @property
    @_cached
    def moments_weighted_normalized(self):
        
        mwn = _moments.moments_normalized(self.moments_weighted_central, order=3)
       
        return -1 * np.sign(mwn) * safe_log_10_v0(mwn)

    @property
    @_cached
    def image_intensity_vec(self):
        return self.image_intensity[self.image_intensity > 0]

    @property
    @_cached
    def intensity_statistics(self, ):
        if len(self.image_intensity_vec) < self.n_pos_pixels_lb:
            
            return (0, )*(len(self.intensity_percentiles)+4)
        percentiles = np.nanpercentile(self.image_intensity_vec, self.intensity_percentiles)
        intensity_median, intensity_mad, intensity_mean, intensity_std = \
            np.nanmedian(self.image_intensity_vec), median_abs_deviation(self.image_intensity_vec), \
            np.nanmean(self.image_intensity_vec), np.nanstd(self.image_intensity_vec)
        return tuple(percentiles) + (intensity_median, intensity_mad, intensity_mean, intensity_std,)

    @cached_property
    def voxel_coordinates(self):
        
        return np.array(np.where(self.image))
    @property
    @_cached
    def glcm(self, ):  
        
        P = graycomatrix(
            self.image_intensity_discrete,
            distances=self.distances,
            angles=self.angles,
            levels=self.n_levels,
            symmetric=False, normed=False)

       
        P = P.astype(np.float32)
        glcm_sums = np.sum(P, axis=(0, 1), keepdims=True)
        glcm_sums[glcm_sums == 0] = 1
        P /= glcm_sums
        return P

    @cached_property
    def glcm_features(self,):
       
       
        (num_level, num_level2, num_dist, num_angle) = self.glcm.shape


        contrast = np.sum(self.glcm * self.weights0, axis=(0, 1))
        dissimilarity = np.sum(self.glcm * self.weights1, axis=(0, 1))
        homogeneity = np.sum(self.glcm * self.weights2, axis=(0, 1))
     
        energy = LA.norm(self.glcm, ord='fro', axis=(0, 1))
   
        correlation = np.zeros((num_dist, num_angle), dtype=np.float32)
        diff_i = self.corr_I - np.sum(self.corr_I * self.glcm, axis=(0, 1))
        diff_j = self.corr_J - np.sum(self.corr_J * self.glcm, axis=(0, 1))
        std_i = np.sqrt(np.sum(self.glcm * (diff_i ** 2), axis=(0, 1)))
        std_j = np.sqrt(np.sum(self.glcm * (diff_j ** 2), axis=(0, 1)))
        cov = np.sum(self.glcm * (diff_i * diff_j), axis=(0, 1))

     
        mask_0 = std_i < self.corr_tolerance
        mask_0[std_j < self.corr_tolerance] = True
        correlation[mask_0] = 1
     
        mask_1 = ~mask_0
        correlation[mask_1] = cov[mask_1] / (std_i[mask_1] * std_j[mask_1])
        
        contrast = tuple(np.percentile(contrast, q=TEXTURE_PERCENTILES)) + \
                   (np.mean(contrast), np.std(contrast), median_abs_deviation(contrast, axis=None), )
        dissimilarity = tuple(np.percentile(dissimilarity, q=TEXTURE_PERCENTILES)) + \
                        (np.mean(dissimilarity), np.std(dissimilarity), median_abs_deviation(dissimilarity, axis=None),)
        homogeneity = tuple(np.percentile(homogeneity, q=TEXTURE_PERCENTILES)) + \
                      (np.mean(homogeneity), np.std(homogeneity), median_abs_deviation(homogeneity, axis=None),)
        energy = tuple(np.percentile(energy, q=TEXTURE_PERCENTILES)) + \
                 (np.mean(energy), np.std(energy), median_abs_deviation(energy, axis=None),)
        correlation = tuple(np.percentile(correlation, q=TEXTURE_PERCENTILES)) + \
                      (np.mean(correlation), np.std(correlation), median_abs_deviation(correlation, axis=None),)
        return contrast+dissimilarity+homogeneity+energy+correlation

    @property
    @_cached
    def efc_ratio(self, ):
        bd = find_boundaries(np.pad(self.image, self.bd_padding, 'constant', constant_values=(0, 0)))
        bd_contours = find_contours(bd, .1)[0]
        efc = elliptic_fourier_descriptors(bd_contours,
                                           normalize=True,
                                           order=15)

        efcs = np.sqrt(efc[:, 0] ** 2 + efc[:, 1] ** 2) + np.sqrt(efc[:, 2] ** 2 + efc[:, 3] ** 2)
        ratio = efcs[0] / np.sum(efcs[1:])
        return ratio

    @property
    @_cached
    def circularity(self, ):
        if self.perimeter > 1e-6:
            return (4 * np.pi * self.area) / self.perimeter ** 2
        else:
            return np.nan
        




def load_czi_maxproject(czi_path):
    img = AICSImage(czi_path)
    data4d = img.get_image_data("CZYX", S=0, T=0)  
    return data4d.max(axis=1)                     

def segment_cells_with_watershed(nuc, cyto, thr=None):
    if thr is None:
        thr = threshold_otsu(cyto)
    mask = cyto > thr
    dist = distance_transform_edt(mask)
    return watershed(-dist, nuc, mask=mask).astype(np.uint16)

def segment_stage1(img_w1, img_w2,
                   alg_w1, alg_w2,
                   p1, p2):
    cp = models.CellposeModel(gpu=True)

    def run_cp(img,p):
        masks, *_ = cp.eval([img],
            diameter            = p.get("diameter", None),
            channels            = p.get("channels", [0,0]),
            flow_threshold      = p.get("flow_threshold", 0.4),
            cellprob_threshold  = p.get("cellprob_threshold", 0.0),
            normalize           = {"tile_norm_blocksize": p.get("tile_norm_blocksize",0)},
            niter               = p.get("niter", None),
        )
        return masks[0].astype(np.uint16)

    def run_pycle(img,p):
        g = cle.push(img.astype(np.float32))
        return cle.voronoi_otsu_labeling(
            g,
            spot_sigma    = p.get("spot_sigma",10),
            outline_sigma = p.get("outline_sigma",1)
        ).astype(np.uint16)


    if   alg_w1=="cellpose": w1 = run_cp(img_w1,p1)
    elif alg_w1=="pycle":    w1 = run_pycle(img_w1,p1)
    else: raise ValueError(alg_w1)


    if   alg_w2=="cellpose":  w2 = run_cp(img_w2,p2)
    elif alg_w2=="pycle":     w2 = run_pycle(img_w2,p2)
    elif alg_w2=="watershed": w2 = segment_cells_with_watershed(w1,img_w2,
                                           thr=p2.get("mask_threshold",None))
    else: raise ValueError(alg_w2)


    w_cell = w2.copy()
    w_cell[w1>0] = w1[w1>0]

    return w1, w2, w_cell



def segment_stage2(w1, w_cell, data2d, channel_names,
                   nucleoli_channel="RNA",
                   mito_channel="MITO",
                   min_nuc_size=10,
                   min_mito_size=20,
                   erode_radius=2):
   
    w1e = erosion(w1, disk(erode_radius))

    idx_n = channel_names.index(nucleoli_channel)
    idx_m = channel_names.index(mito_channel)
    img_n = data2d[idx_n]
    img_m = data2d[idx_m]

    
    w3 = np.zeros_like(w1, dtype=np.uint16)
    for lid, slc in enumerate(find_objects(w1e), start=1):
        if slc is None: continue
        seed = (w1e[slc]==lid)
        if not seed.any(): continue
        sub = img_n[slc]*seed
        blur = gaussian(sub, sigma=1)
        thr  = threshold_otsu(blur[seed])
        bin_ = blur>thr
        lab  = label(bin_)
        lab  = remove_small_objects(lab, min_nuc_size)
        w3[slc][lab>0] = lid

   
    w5 = np.zeros_like(w_cell, dtype=np.uint16)
    for lid, slc in enumerate(find_objects(w_cell), start=1):
        if slc is None: continue
        seed = (w_cell[slc]==lid)
        if not seed.any(): continue
        sub = img_m[slc]*seed
        thr = threshold_otsu(sub[seed])
        bin_= sub>thr
        lab = label(bin_)
        lab = remove_small_objects(lab, min_mito_size)
        w5[slc][lab>0] = lid

    return w3, w5


from scipy.ndimage import find_objects

def extract_features_with_rp(data2d, channels, masks):
    
    max_label = int(masks['cell'].max())
    slices = {}
    for key, mask_img in masks.items():
        s = find_objects(mask_img)
        if len(s) < max_label:
            s = list(s) + [None] * (max_label - len(s))
        slices[key] = s

    rows = []
    for lab in range(1, max_label+1):
        if slices['cell'][lab-1] is None:
            continue
        feat = {'label': lab}
        comp_info = {
            'nucleus':  ('DAPI',   masks['nucleus'],  slices['nucleus']),
            'cyto':     ('ER',     masks['cyto'],     slices['cyto']),
            'nucleoli': ('RNA',    masks['nucleoli'], slices['nucleoli']),
            'mito':     ('MITO',   masks['mito'],     slices['mito']),
            'cell':     ('DAPI',   masks['cell'],     slices['cell']),
        }
        for comp, (ch_name, mask_img, obj_slices) in comp_info.items():
            slc = obj_slices[lab-1]
            if slc is None: 
                continue
            intensity_img = data2d[channels.index(ch_name)]
            rp = RegionPropertiesExtension(
                slice=slc,
                label=lab,
                label_image=mask_img,
                intensity_image=intensity_img,
                channel_name=comp
            )
          
            feat[f"{comp}_area"]         = rp.area
            feat[f"{comp}_eccentricity"] = rp.eccentricity
            feat[f"{comp}_solidity"]     = rp.solidity
            feat[f"{comp}_extent"]       = rp.extent
            stats = rp.intensity_statistics
            stat_names = list(rp.intensity_percentiles) + ['median','mad','mean','std']
            for name, val in zip(stat_names, stats):
                feat[f"{comp}_int_{name}"] = val
            tex_vals = rp.glcm_features
            for tex_name, tex_val in zip(TEXTURE_FEATURE_NAMES, tex_vals):
                feat[f"{comp}_tex_{tex_name}"] = tex_val

        rows.append(feat)

    return pd.DataFrame(rows)

def plot_stage2_qc(
    data2d, 
    channel_names, 
    w1, w2, w3, w_cell, w5, 
    alpha=0.6, 
    figsize=(20, 8)
):
    """
    Display a 2×5 grid: top row raw channels, bottom row segmentation masks.
    """
    fig, axes = plt.subplots(2, 5, figsize=figsize)

    # Top row: raw channels
    for i, ch in enumerate(channel_names):
        ax = axes[0, i]
        ax.imshow(data2d[i], cmap='gray')
        ax.set_title(ch)
        ax.axis('off')

    # Bottom row: masks
    masks      = [w1, w2, w3, w_cell, w5]
    mask_titles= ["Nuclei", "Cytosol", "Nucleoli", "Cell", "Mito"]

    for i, (msk, title) in enumerate(zip(masks, mask_titles)):
        ax = axes[1, i]
        ax.imshow(label2rgb(msk, bg_label=0, alpha=alpha))
        ax.set_title(title)
        ax.axis('off')

    plt.tight_layout()
    plt.show()
    
def batch_process_czi_folder_split(
    folder_path: str,
    output_folder:str,
    channel_names: list,
    p1: dict,
    p2: dict,
    min_nuc_size: int,
    min_mito_size: int,
    erode_radius: int
):
    os.makedirs(output_folder, exist_ok=True)

    for filepath in sorted(glob.glob(os.path.join(folder_path, "*.czi"))):
        fname = os.path.splitext(os.path.basename(filepath))[0]
        # Stage I & II
        data2d = load_czi_maxproject(filepath)
        img_w1 = data2d[channel_names.index("DAPI")]
        img_w2 = data2d[channel_names.index("ER")]
        w1, w2, w_cell = segment_stage1(img_w1, img_w2, "cellpose", "watershed", p1, p2)
        w3, w5 = segment_stage2(w1, w2, data2d, channel_names,
                                nucleoli_channel="RNA",
                                mito_channel="MITO",
                                min_nuc_size=min_nuc_size,
                                min_mito_size=min_mito_size,
                                erode_radius=erode_radius)

        # feature extraction
        masks = {'nucleus':w1,'cyto':w2,'nucleoli':w3,'cell':w_cell,'mito':w5}
        plot_stage2_qc(data2d, channel_names, w1, w2, w3, w_cell, w5)
        df = extract_features_with_rp(data2d, channel_names, masks)

        # add filename
        df['filename'] = fname

        # save per‐file
        out_csv = os.path.join(output_folder, f"{fname}_features.csv")
        df.to_csv(out_csv, index=False)
        print(f"✔ {fname}: {len(df)} cells → '{out_csv}'")

In [None]:
batch_process_czi_folder_split(
    folder_path   = r"folder of multicolor czi file",
    output_folder = r"folder to store the feature extract from each image",
    channel_names = ["DAPI","ER","RNA","AG","MITO"],
    p1            = {"diameter":15},
    p2            = {"mask_threshold":None},
    min_nuc_size  = 5,
    min_mito_size = 20,
    erode_radius  = 10
)

In [None]:

input_folder = r"folder to store the feature extract from each image"
output_csv   = r"combined multi-well (replicates) to one csv file"


csv_paths = sorted(glob.glob(os.path.join(input_folder, "*_features.csv")))


dfs = []
for fp in csv_paths:
    df = pd.read_csv(fp)
    dfs.append(df)

all_df = pd.concat(dfs, ignore_index=True)


all_df.to_csv(output_csv, index=False)
print(f"Merged {len(dfs)} files → {all_df.shape[0]} cells saved to:\n  {output_csv}")

In [None]:

def significance_stars(p):
    if p < 0.001: return '***'
    elif p < 0.01: return '**'
    elif p < 0.05: return '*'
    else: return 'ns'


def add_stat_annotation(ax, x1, x2, y, h, text):
    ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1.5, c='k')
    ax.text((x1+x2)/2, y+h, text, ha='center', va='bottom', fontsize=12)


group_order = ['0 h', '12 h', '24 h', '48 h']
comparisons = [('0 h','12 h'), ('0 h','24 h'), ('0 h','48 h')]
agg = pd.read_csv('featruesofinterest.csv')

agg['Group'] = pd.Categorical(agg['Group'], categories=group_order, ordered=True)

agg_features = ["median_nucleus_area", "median_cyto_area", 
                "median_mito_area", "median_nucleoli_area"]
cells = agg["Cell"].unique()

fig, axes = plt.subplots(
    nrows=len(cells), 
    ncols=len(agg_features),
    figsize=(len(agg_features)*4, len(cells)*4),
    sharex='col'
)

for i, cell in enumerate(cells):
    sub = agg[agg["Cell"] == cell]
    
    for j, feat in enumerate(agg_features):
        ax = axes[i, j]

        sns.barplot(
            x="Group", y=feat, data=sub,
            errorbar="sd", edgecolor="k", fill=False, color="black",ax=ax
        )
        sns.stripplot(
            x="Group", y=feat, data=sub,
            color="black", alpha=0.6, size=4,
            jitter=True, ax=ax
        )
        
 
        y_min, y_max = ax.get_ylim()
        h = (y_max - y_min) * 0.05
        for k, (g1, g2) in enumerate(comparisons):
            d1 = sub.loc[sub['Group'] == g1, feat]
            d2 = sub.loc[sub['Group'] == g2, feat]
            stat, p = ttest_ind(d1, d2, equal_var=False)
            stars = significance_stars(p)
            x1 = group_order.index(g1)
            x2 = group_order.index(g2)
            y = y_max + k * h * 2
            add_stat_annotation(ax, x1, x2, y, h, stars)
        

        if i == 0:
            ax.set_title(feat.replace("_", " ").title(), pad=10)
        if j == 0:
            ax.set_ylabel(f"{cell}\n", fontsize=12)
        else:
            ax.set_ylabel("")
        if i == len(cells) - 1:
            ax.set_xlabel("Group")
        else:
            ax.set_xlabel("")

        ax.tick_params(axis='x', rotation=45, direction='in')
        ax.tick_params(axis='y', direction='in')

plt.tight_layout()

plt.show()