In [1]:
"""
ND2 batch processing -> Z projection -> per-channel preprocessing -> pseudocolour RGB export

Dependencies:
  pip install nd2reader numpy scikit-image tifffile pillow matplotlib

Notes:
- ND2 files are read via ND2Reader.
- This script assumes channel indices correspond to DAPI / 488 / 568 by default.
  Adjust CH_DAPI / CH_488 / CH_568 if your acquisition order differs.
"""

from __future__ import annotations

import os
from pathlib import Path
import numpy as np

from nd2reader import ND2Reader
from skimage import exposure, filters, img_as_ubyte
import tifffile as tiff
from PIL import Image


# =========================
# 0. User configuration
# =========================
INPUT_DIR = Path(r"C:\Users\23351148\OneDrive - The University of Western Australia\Desktop\bioimage_project\ND2")
OUTPUT_DIR = Path(r"C:\Users\23351148\OneDrive - The University of Western Australia\Desktop\bioimage_project\Output")

# Channel mapping (indices in ND2)
CH_DAPI = 0
CH_488 = 1
CH_568 = 2

# Projection method for Z stacks: "max" or "mean"
PROJECTION = "max"

# Preprocessing parameters
BG_SIGMA = 50          # large sigma Gaussian for background approximation
CLAHE_CLIP = 0.01      # contrast-limited adaptive histogram equalization
SMOOTH_SIGMA = 1       # light Gaussian smoothing after CLAHE

# Output options
SAVE_TIFF = True
SAVE_JPG = True
SAVE_QC_PNG = False    # Set True to save quick QC image (before/after)


# =========================
# 1. Helper functions
# =========================
def preprocess_channel(
    ch: np.ndarray,
    bg_sigma: float = 50,
    clahe_clip: float = 0.01,
    smooth_sigma: float = 1,
) -> np.ndarray:
    """
    Preprocess a 2D fluorescence image for pseudocolour rendering.

    Steps:
      1) Background subtraction via large-sigma Gaussian blur
      2) Min-max normalisation to [0, 1]
      3) CLAHE contrast enhancement
      4) Mild Gaussian smoothing
      5) Re-normalisation to [0, 1]

    Returns:
      float32 image in [0, 1]
    """
    ch = ch.astype(np.float32)

    # Background subtraction
    bg = filters.gaussian(ch, sigma=bg_sigma)
    ch_bg = ch - bg
    ch_bg[ch_bg < 0] = 0

    # Normalise to [0, 1]
    ch_bg = (ch_bg - ch_bg.min()) / (np.ptp(ch_bg) + 1e-8)

    # CLAHE
    ch_clahe = exposure.equalize_adapthist(ch_bg, clip_limit=clahe_clip)

    # Mild smoothing
    ch_smooth = filters.gaussian(ch_clahe, sigma=smooth_sigma)

    # Re-normalise
    ch_out = (ch_smooth - ch_smooth.min()) / (np.ptp(ch_smooth) + 1e-8)
    return ch_out.astype(np.float32)


def project_z(stack_zyx: np.ndarray, method: str = "max") -> np.ndarray:
    """
    Project a (Z, Y, X) stack to (Y, X). If input is already (Y, X), return as-is.
    """
    if stack_zyx.ndim == 2:
        return stack_zyx
    if stack_zyx.ndim != 3:
        raise ValueError(f"Expected (Z,Y,X) or (Y,X), got shape {stack_zyx.shape}")

    if method.lower() == "max":
        return np.max(stack_zyx, axis=0)
    if method.lower() == "mean":
        return np.mean(stack_zyx, axis=0)
    raise ValueError(f"Unknown projection method: {method}")


def ensure_czyx(data: np.ndarray) -> np.ndarray | None:
    """
    Ensure ND2 array is shaped (C, Z, Y, X).
    ND2Reader with bundle_axes=('c','z','y','x') typically yields:
      - (T, C, Z, Y, X) if time exists
      - (C, Z, Y, X) otherwise
    """
    if data.ndim == 5:
        # (T,C,Z,Y,X) -> take first timepoint
        return data[0]
    if data.ndim == 4:
        return data
    return None


# Optional QC plot (only imported if needed)
def save_qc_png(
    out_path: Path,
    dapi_raw: np.ndarray,
    g488_raw: np.ndarray,
    r568_raw: np.ndarray,
    rgb_8bit: np.ndarray,
) -> None:
    import matplotlib.pyplot as plt

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 4, 1)
    plt.imshow(dapi_raw, cmap="gray")
    plt.title("Raw DAPI (proj)")
    plt.axis("off")

    plt.subplot(1, 4, 2)
    plt.imshow(g488_raw, cmap="gray")
    plt.title("Raw 488 (proj)")
    plt.axis("off")

    plt.subplot(1, 4, 3)
    plt.imshow(r568_raw, cmap="gray")
    plt.title("Raw 568 (proj)")
    plt.axis("off")

    plt.subplot(1, 4, 4)
    plt.imshow(rgb_8bit)
    plt.title("Pseudo-colour RGB")
    plt.axis("off")

    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()


# =========================
# 2. Main batch routine
# =========================
def main() -> None:
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    files = [f for f in os.listdir(INPUT_DIR) if f.lower().endswith(".nd2")]
    print(f"Found {len(files)} ND2 files in: {INPUT_DIR}")

    if not files:
        print("No ND2 files found. Please check INPUT_DIR.")
        return

    for fname in files:
        fpath = INPUT_DIR / fname
        print(f"\nProcessing: {fpath.name}")

        # ---- Read ND2 ----
        nd2 = ND2Reader(str(fpath))
        nd2.bundle_axes = ("c", "z", "y", "x")
        nd2.iter_axes = "t" if "t" in nd2.axes else ""

        data = np.array(nd2)
        channel_names = nd2.metadata.get("channels", [])
        print("  Raw data shape:", data.shape)
        if channel_names:
            print("  Channel names:", channel_names)

        data_czyx = ensure_czyx(data)
        if data_czyx is None:
            print("  Skip: unexpected ND2 dimensions:", data.shape)
            continue

        C, Z, Y, X = data_czyx.shape
        max_ch = max(CH_DAPI, CH_488, CH_568)
        if C <= max_ch:
            print(f"  Skip: insufficient channels (C={C}). Need at least {max_ch+1}.")
            continue

        # ---- Z projection for each channel ----
        dapi_raw = project_z(data_czyx[CH_DAPI], method=PROJECTION)
        g488_raw = project_z(data_czyx[CH_488], method=PROJECTION)
        r568_raw = project_z(data_czyx[CH_568], method=PROJECTION)

        # ---- Preprocess each channel ----
        dapi = preprocess_channel(dapi_raw, BG_SIGMA, CLAHE_CLIP, SMOOTH_SIGMA)
        g488 = preprocess_channel(g488_raw, BG_SIGMA, CLAHE_CLIP, SMOOTH_SIGMA)
        r568 = preprocess_channel(r568_raw, BG_SIGMA, CLAHE_CLIP, SMOOTH_SIGMA)

        # ---- Pseudocolour mapping (swap R/G as your earlier script) ----
        # R <- 488, G <- 568, B <- DAPI
        rgb = np.dstack([g488, r568, dapi]).astype(np.float32)
        rgb_8bit = img_as_ubyte(np.clip(rgb, 0, 1))

        base = Path(fname).stem
        tif_path = OUTPUT_DIR / f"{base}_rgb.tif"
        jpg_path = OUTPUT_DIR / f"{base}_rgb.jpg"
        qc_path = OUTPUT_DIR / f"{base}_qc.png"

        # ---- Save ----
        if SAVE_TIFF:
            tiff.imwrite(str(tif_path), rgb_8bit)
            print("  Saved TIFF:", tif_path.name)

        if SAVE_JPG:
            Image.fromarray(rgb_8bit).save(str(jpg_path), format="JPEG", quality=95, subsampling=0)
            print("  Saved JPG :", jpg_path.name)

        if SAVE_QC_PNG:
            save_qc_png(qc_path, dapi_raw, g488_raw, r568_raw, rgb_8bit)
            print("  Saved QC  :", qc_path.name)

    print("\nDone. All ND2 files processed.")


if __name__ == "__main__":
    main()


Found 43 ND2 files in: C:\Users\23351148\OneDrive - The University of Western Australia\Desktop\bioimage_project\ND2

Processing: ccz1 RANK15 surface  X DC h pha x60035.nd2
  Raw data shape: (1, 3, 5, 1024, 1024)
  Channel names: ['DAPI', 'Alexa 488 antibody', 'Alx647']
  Saved TIFF: ccz1 RANK15 surface  X DC h pha x60035_rgb.tif
  Saved JPG : ccz1 RANK15 surface  X DC h pha x60035_rgb.jpg

Processing: ccz1 RANK15 surface  X DC h pha x60036.nd2
  Raw data shape: (1, 3, 5, 1024, 1024)
  Channel names: ['DAPI', 'Alexa 488 antibody', 'Alx647']
  Saved TIFF: ccz1 RANK15 surface  X DC h pha x60036_rgb.tif
  Saved JPG : ccz1 RANK15 surface  X DC h pha x60036_rgb.jpg

Processing: ccz1 RANK15 surface  X DC h pha x60037.nd2
  Raw data shape: (1, 3, 5, 1024, 1024)
  Channel names: ['DAPI', 'Alexa 488 antibody', 'Alx647']
  Saved TIFF: ccz1 RANK15 surface  X DC h pha x60037_rgb.tif
  Saved JPG : ccz1 RANK15 surface  X DC h pha x60037_rgb.jpg

Processing: ccz1 RANK15 surface  X DC h pha x60038.nd2