In [None]:
# Core Python
import os
from pathlib import Path

# Arrays / numerics
import numpy as np

# TIFF IO
from tifffile import imread, imwrite

# Curve fitting
from scipy.optimize import curve_fit

# Optional (only if you want tabular outputs, logging, or to later expand with masks/props)
import pandas as pd


In [None]:
# ========= Config =========
images_dir   = Path(r"/path/to/ZCYX_tiffs")     # input TIFFs (Z,C,Y,X)
out_dir      = Path(r"/path/to/output_all")     # output folder
image_exts   = [".tif", ".tiff"]
overwrite    = False

# Output dtype: "float32" (no scaling) or "uint16" (scaled)
output_dtype = "float32"
# If uint16, robust scaling window per channel:
p_low, p_high = 1.0, 99.9
# =========================

out_dir.mkdir(parents=True, exist_ok=True)

def linear(z, m, b): return m*z + b
def hyperbolic(z, a, b, c): return a / (z + b) + c

def _fit_profile(z, y):
    m0, b0 = np.polyfit(z, y, 1)
    p_lin, _ = curve_fit(linear, z, y, p0=[m0, b0], maxfev=10000)
    ylin = linear(z, *p_lin)

    eps = 1e-6
    b_lower = -float(np.min(z)) + eps
    a0 = (y[0] - y[-1]) * (z[-1] + 1.0) if z[-1] > 0 else 1.0
    b0 = b_lower + 1.0
    c0 = float(np.median(y[-max(3, len(y)//10):]))
    bounds = ([-np.inf, b_lower, -np.inf], [np.inf, np.inf, np.inf])

    try:
        p_hyp, _ = curve_fit(hyperbolic, z, y, p0=[a0, b0, c0], bounds=bounds, maxfev=20000)
        yhyp = hyperbolic(z, *p_hyp)
    except Exception:
        yhyp = None

    def metrics(y_true, y_pred, k):
        rss = float(np.sum((y_true - y_pred) ** 2))
        tss = float(np.sum((y_true - np.mean(y_true)) ** 2))
        r2  = 1.0 - (rss / tss) if tss > 0 else np.nan
        n   = len(y_true)
        aic = n * np.log(rss / n) + 2 * k if rss > 0 else -np.inf
        return r2, aic

    r2_lin, aic_lin = metrics(y, ylin, 2)
    if yhyp is not None:
        r2_hyp, aic_hyp = metrics(y, yhyp, 3)
        if (aic_hyp + 2 < aic_lin) or (abs(aic_hyp - aic_lin) <= 2 and r2_hyp > r2_lin):
            return "hyperbolic", yhyp
    return "linear", ylin

def _correction_factors(z, y):
    model, yfit = _fit_profile(z, y)
    q75 = np.percentile(y, 75)
    target = float(np.median(y[y >= q75])) if np.any(y >= q75) else float(np.median(y))
    denom = np.maximum(yfit, 1e-9)
    factors = target / denom
    return model, factors

def _save_corrected_all(stack_zcyx: np.ndarray, path: Path):
    if output_dtype == "float32":
        imwrite(path, stack_zcyx.astype(np.float32), imagej=False)
    elif output_dtype == "uint16":
        # Scale each channel independently with robust percentiles
        out = stack_zcyx.astype(np.float32, copy=True)
        Z, C, Y, X = out.shape
        for c in range(C):
            ch = out[:, c, :, :]
            lo = np.percentile(ch, p_low)
            hi = np.percentile(ch, p_high)
            if hi <= lo: hi = lo + 1e-6
            out[:, c, :, :] = np.clip((ch - lo) / (hi - lo), 0, 1) * 65535.0
        imwrite(path, out.astype(np.uint16), imagej=False)
    else:
        raise ValueError("output_dtype must be 'float32' or 'uint16'")

def _process_one(img_path: Path):
    out_path = out_dir / f"{img_path.stem}_corrALL.tif"
    if out_path.exists() and not overwrite:
        print(f"Skipping {img_path.name}: exists (set overwrite=True to redo).")
        return

    img = imread(img_path)  # expect Z,C,Y,X
    if img.ndim != 4:
        print(f"⚠️  {img_path.name} is not ZCYX (got {img.shape}); skipping.")
        return

    img_f = img.astype(np.float32, copy=False)
    Z, C, Y, X = img_f.shape
    corrected = img_f.copy()

    z = np.arange(Z, dtype=float)
    for c in range(C):
        ch = img_f[:, c, :, :]
        z_means = ch.mean(axis=(1, 2))
        model, factors = _correction_factors(z, z_means)
        corrected[:, c, :, :] = ch * factors[:, None, None]
        print(f"  - Channel {c}: model={model}")

    _save_corrected_all(corrected, out_path)
    print(f"✓ {img_path.name}: wrote {out_path.name}")

# Run
files = [p for p in images_dir.iterdir() if p.suffix.lower() in image_exts]
if not files:
    print(f"No images found in {images_dir}")
for p in sorted(files):
    try:
        _process_one(p)
    except Exception as e:
        print(f"❌ Error on {p.name}: {e}")