In [1]:
"""
ND2 -> preprocessing -> pseudocolour RGB export
+ quantitative summary (per-file)
+ object-level measurements (scikit-image regionprops_table; per-file CSV)

Pipeline
1) Read ND2 via ND2Reader and normalise to (C, Z, Y, X) (first timepoint if T exists)
2) Z-project selected channels (max/mean)
3) Per-channel preprocessing: background subtraction, normalisation, CLAHE, smoothing
4) Pseudocolour mapping and export (TIFF + JPG)
5) Quantification:
   A) Per-file distribution stats (nonzero pixels)
   B) Otsu-based signal masks: area fraction, mean signal
   C) Inter-channel correlation within union of signal masks (proxy)
   D) Object-level features from signal mask using scikit-image:
      - label, area, mean intensity, eccentricity, solidity, centroid
      Exported to CSV per file, and optionally concatenated to one CSV.

Dependencies
  pip install nd2reader numpy pandas scikit-image tifffile pillow matplotlib
"""

from __future__ import annotations

import os
from pathlib import Path
import numpy as np
import pandas as pd

from nd2reader import ND2Reader
from skimage import exposure, filters, img_as_ubyte, morphology, measure
import tifffile as tiff
from PIL import Image


# =========================
# 0. Configuration
# =========================
INPUT_DIR = Path(r"C:\XXX\input")
OUTPUT_DIR = Path(r"C:\XXX\output")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Channel indices (adjust if acquisition order differs)
CH_DAPI = 0
CH_488 = 1
CH_568 = 2

# Z projection method: "max" or "mean"
PROJECTION = "max"

# Preprocessing parameters
BG_SIGMA = 50          # background blur sigma
CLAHE_CLIP = 0.01      # CLAHE clip limit
SMOOTH_SIGMA = 1       # post-CLAHE smoothing sigma

# Export options
SAVE_TIFF = True
SAVE_JPG = True
SAVE_QC_PNG = False             # raw projections + RGB preview
SAVE_INTENSITY_HISTS = False    # per-file intensity histograms

# Quantification options
MIN_OBJECT_SIZE = 200
USE_SIGNAL_MASK_FOR_STATS = True

# Object-level quantification (scikit-image)
EXPORT_OBJECT_TABLE_PER_FILE = True
EXPORT_OBJECT_TABLE_COMBINED = True
OBJECT_FEATURES = [
    "label",
    "area",
    "mean_intensity",
    "eccentricity",
    "solidity",
    "centroid",
]


# =========================
# 1. Core image functions
# =========================
def ensure_czyx(arr: np.ndarray) -> np.ndarray | None:
    """Normalise ND2 array to (C, Z, Y, X); if (T, C, Z, Y, X) take T=0."""
    if arr.ndim == 5:
        return arr[0]
    if arr.ndim == 4:
        return arr
    return None


def project_z(stack_zyx: np.ndarray, method: str = "max") -> np.ndarray:
    """Project (Z, Y, X) -> (Y, X). If already 2D, return unchanged."""
    if stack_zyx.ndim == 2:
        return stack_zyx
    if stack_zyx.ndim != 3:
        raise ValueError(f"Expected 2D or 3D stack, got {stack_zyx.shape}")
    m = method.lower()
    if m == "max":
        return np.max(stack_zyx, axis=0)
    if m == "mean":
        return np.mean(stack_zyx, axis=0)
    raise ValueError(f"Unknown projection method: {method}")


def preprocess_channel(
    ch2d: np.ndarray,
    bg_sigma: float = 50,
    clahe_clip: float = 0.01,
    smooth_sigma: float = 1,
) -> np.ndarray:
    """
    Preprocess a 2D fluorescence image for robust visualisation and simple quantification.
    Returns float32 image in [0, 1].
    """
    ch = ch2d.astype(np.float32)

    # Background subtraction
    bg = filters.gaussian(ch, sigma=bg_sigma)
    ch = ch - bg
    ch[ch < 0] = 0

    # Normalise to [0, 1]
    ch = (ch - ch.min()) / (np.ptp(ch) + 1e-8)

    # CLAHE
    ch = exposure.equalize_adapthist(ch, clip_limit=clahe_clip)

    # Mild smoothing
    ch = filters.gaussian(ch, sigma=smooth_sigma)

    # Re-normalise
    ch = (ch - ch.min()) / (np.ptp(ch) + 1e-8)
    return ch.astype(np.float32)


def make_signal_mask(ch_norm: np.ndarray, min_size: int = 200) -> np.ndarray:
    """
    Conservative signal mask using Otsu thresholding on a [0,1] image.
    Removes small objects for stability.
    """
    th = filters.threshold_otsu(ch_norm)
    mask = ch_norm > th
    mask = morphology.remove_small_objects(mask, min_size=min_size)
    return mask


# =========================
# 2. Quantification helpers
# =========================
def safe_stats(x: np.ndarray) -> dict:
    """Compute robust summary stats for a 1D array."""
    if x.size == 0:
        return {"n": 0, "mean": np.nan, "median": np.nan, "p10": np.nan, "p90": np.nan, "std": np.nan}
    return {
        "n": int(x.size),
        "mean": float(np.mean(x)),
        "median": float(np.median(x)),
        "p10": float(np.percentile(x, 10)),
        "p90": float(np.percentile(x, 90)),
        "std": float(np.std(x, ddof=1)) if x.size > 1 else 0.0,
    }


def corr_in_mask(a: np.ndarray, b: np.ndarray, mask: np.ndarray) -> float:
    """Pearson correlation in a masked region (NaN if insufficient)."""
    aa = a[mask].ravel()
    bb = b[mask].ravel()
    if aa.size < 10:
        return float("nan")
    aa = aa - aa.mean()
    bb = bb - bb.mean()
    denom = (np.sqrt((aa**2).sum()) * np.sqrt((bb**2).sum())) + 1e-12
    return float((aa * bb).sum() / denom)


def save_qc_png(out_path: Path, dapi_raw: np.ndarray, g488_raw: np.ndarray, r568_raw: np.ndarray, rgb_8bit: np.ndarray) -> None:
    import matplotlib.pyplot as plt
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 4, 1); plt.imshow(dapi_raw, cmap="gray"); plt.title("Raw DAPI (proj)"); plt.axis("off")
    plt.subplot(1, 4, 2); plt.imshow(g488_raw, cmap="gray"); plt.title("Raw 488 (proj)"); plt.axis("off")
    plt.subplot(1, 4, 3); plt.imshow(r568_raw, cmap="gray"); plt.title("Raw 568 (proj)"); plt.axis("off")
    plt.subplot(1, 4, 4); plt.imshow(rgb_8bit); plt.title("Pseudo-colour RGB"); plt.axis("off")
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()


def save_intensity_hist(out_path: Path, g488: np.ndarray, r568: np.ndarray, dapi: np.ndarray) -> None:
    import matplotlib.pyplot as plt

    def vals(x):
        v = x[x > 0].ravel()
        return v if v.size else np.array([0.0], dtype=np.float32)

    v488 = vals(g488)
    v568 = vals(r568)
    vdapi = vals(dapi)

    plt.figure(figsize=(6, 4))
    plt.hist(v488, bins=60, alpha=0.6, label="488", density=True)
    plt.hist(v568, bins=60, alpha=0.6, label="568", density=True)
    plt.hist(vdapi, bins=60, alpha=0.6, label="DAPI", density=True)
    plt.xlabel("Normalised intensity")
    plt.ylabel("Density")
    plt.title("Intensity distributions (post-preprocessing)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()


# =========================
# 3. Main batch processing
# =========================
def main() -> None:
    files = [f for f in os.listdir(INPUT_DIR) if f.lower().endswith(".nd2")]
    print(f"Found {len(files)} ND2 file(s) in: {INPUT_DIR}")

    if not files:
        print("No ND2 files found. Check INPUT_DIR.")
        return

    summary_rows: list[dict] = []
    all_object_tables: list[pd.DataFrame] = []

    for fname in files:
        fpath = INPUT_DIR / fname
        print(f"\nProcessing: {fpath.name}")

        # ---- Read ND2 ----
        nd2 = ND2Reader(str(fpath))
        nd2.bundle_axes = ("c", "z", "y", "x")
        nd2.iter_axes = "t" if "t" in nd2.axes else ""

        data = np.array(nd2)
        data_czyx = ensure_czyx(data)

        if data_czyx is None:
            print("  Skip: unexpected ND2 dimensions:", data.shape)
            continue

        C, Z, Y, X = data_czyx.shape
        required = max(CH_DAPI, CH_488, CH_568) + 1
        if C < required:
            print(f"  Skip: insufficient channels (C={C}, need >= {required})")
            continue

        # ---- Z projection ----
        dapi_raw = project_z(data_czyx[CH_DAPI], method=PROJECTION)
        g488_raw = project_z(data_czyx[CH_488], method=PROJECTION)
        r568_raw = project_z(data_czyx[CH_568], method=PROJECTION)

        # ---- Preprocessing ----
        dapi = preprocess_channel(dapi_raw, BG_SIGMA, CLAHE_CLIP, SMOOTH_SIGMA)
        g488 = preprocess_channel(g488_raw, BG_SIGMA, CLAHE_CLIP, SMOOTH_SIGMA)
        r568 = preprocess_channel(r568_raw, BG_SIGMA, CLAHE_CLIP, SMOOTH_SIGMA)

        # ---- Pseudocolour mapping (R<-488, G<-568, B<-DAPI) ----
        rgb = np.dstack([g488, r568, dapi]).astype(np.float32)
        rgb_8bit = img_as_ubyte(np.clip(rgb, 0, 1))

        base = Path(fname).stem
        tif_path = OUTPUT_DIR / f"{base}_rgb.tif"
        jpg_path = OUTPUT_DIR / f"{base}_rgb.jpg"
        qc_path = OUTPUT_DIR / f"{base}_qc.png"
        hist_path = OUTPUT_DIR / f"{base}_intensity_hist.png"

        # ---- Save images ----
        if SAVE_TIFF:
            tiff.imwrite(str(tif_path), rgb_8bit)
        if SAVE_JPG:
            Image.fromarray(rgb_8bit).save(str(jpg_path), format="JPEG", quality=95, subsampling=0)
        if SAVE_QC_PNG:
            save_qc_png(qc_path, dapi_raw, g488_raw, r568_raw, rgb_8bit)
        if SAVE_INTENSITY_HISTS:
            save_intensity_hist(hist_path, g488, r568, dapi)

        # =========================
        # 4. Quantification (per-file)
        # =========================
        row = {
            "file": fname,
            "shape_CZYX": f"{C}x{Z}x{Y}x{X}",
            "projection": PROJECTION,
            "bg_sigma": BG_SIGMA,
            "clahe_clip": CLAHE_CLIP,
            "smooth_sigma": SMOOTH_SIGMA,
        }

        # Global (nonzero) intensity distributions
        v488 = g488[g488 > 0].ravel()
        v568 = r568[r568 > 0].ravel()
        vdapi = dapi[dapi > 0].ravel()

        for k, v in safe_stats(v488).items():
            row[f"g488_{k}"] = v
        for k, v in safe_stats(v568).items():
            row[f"r568_{k}"] = v
        for k, v in safe_stats(vdapi).items():
            row[f"dapi_{k}"] = v

        # Signal-region metrics (Otsu masks)
        if USE_SIGNAL_MASK_FOR_STATS:
            mask_488 = make_signal_mask(g488, min_size=MIN_OBJECT_SIZE)
            mask_568 = make_signal_mask(r568, min_size=MIN_OBJECT_SIZE)

            row["mask488_area_fraction"] = float(mask_488.mean())
            row["mask568_area_fraction"] = float(mask_568.mean())

            row["g488_mean_in_mask488"] = float(g488[mask_488].mean()) if mask_488.any() else float("nan")
            row["r568_mean_in_mask568"] = float(r568[mask_568].mean()) if mask_568.any() else float("nan")

            # Correlation in union mask (proxy; not a colocalisation metric)
            union = mask_488 | mask_568
            row["corr_488_568_in_union_mask"] = corr_in_mask(g488, r568, union)
        else:
            mask_488 = None
            mask_568 = None
            row["mask488_area_fraction"] = float("nan")
            row["mask568_area_fraction"] = float("nan")
            row["g488_mean_in_mask488"] = float("nan")
            row["r568_mean_in_mask568"] = float("nan")
            row["corr_488_568_in_union_mask"] = float("nan")

        summary_rows.append(row)

        # =========================
        # 5. Object-level quantification (scikit-image)
        # =========================
        if EXPORT_OBJECT_TABLE_PER_FILE and (mask_488 is not None) and mask_488.any():
            labels_488 = measure.label(mask_488)

            props = measure.regionprops_table(
                labels_488,
                intensity_image=g488,
                properties=OBJECT_FEATURES,
            )
            df_obj = pd.DataFrame(props)
            df_obj.insert(0, "file", fname)
            df_obj.insert(1, "base", base)
            df_obj.insert(2, "channel", "488")
            df_obj.insert(3, "projection", PROJECTION)

            # Save per-file object table
            obj_csv = OUTPUT_DIR / f"{base}_objects_488.csv"
            df_obj.to_csv(obj_csv, index=False)

            # Keep for combined export
            all_object_tables.append(df_obj)

        print(f"  Saved outputs for: {base}")

    # ---- Save per-file summary table ----
    if summary_rows:
        df_sum = pd.DataFrame(summary_rows)
        out_csv = OUTPUT_DIR / "quant_summary.csv"
        df_sum.to_csv(out_csv, index=False)
        print(f"\nQuantification summary saved: {out_csv}")

    # ---- Save combined object table ----
    if EXPORT_OBJECT_TABLE_COMBINED and all_object_tables:
        df_all = pd.concat(all_object_tables, ignore_index=True)
        out_obj = OUTPUT_DIR / "objects_488_all_files.csv"
        df_all.to_csv(out_obj, index=False)
        print(f"Combined object table saved: {out_obj}")

    print("\nDone.")


if __name__ == "__main__":
    main()


Found 1 ND2 file(s) in: C:\Users\yixiao Zhou\Desktop\bioimage\input

Processing: RMC1 L152WT MOC EEA CD63 H PH 60OIL 0000.032.nd2
  Saved outputs for: RMC1 L152WT MOC EEA CD63 H PH 60OIL 0000.032

Quantification summary saved: C:\Users\yixiao Zhou\Desktop\bioimage\output\quant_summary.csv
Combined object table saved: C:\Users\yixiao Zhou\Desktop\bioimage\output\objects_488_all_files.csv

Done.
