In [1]:
# Save a positives-only BINARY mask (0/1) as a new Zarr + per-image TIFFs.
# Positivity rule: a bacterium is positive if >= COVERAGE_MIN of its band
# pixels have marker intensity >= THR (uint16). The saved mask paints the
# WHOLE bacterium for positives.

import os
from pathlib import Path
import numpy as np
import zarr
from skimage.segmentation import expand_labels
from skimage.morphology import erosion, disk
from skimage.measure import regionprops_table
from skimage.io import imsave

# --------------------------
# INPUTS (edit if needed)
# --------------------------
ZARR_IN = "/mnt/efs/aimbl_2025/student_data/S-LS/marker_data.zarr"
MASKS_KEY  = "pred_mask_bacteria/pred_mask_stack"   # (N,H,W), int labels (0=bg, 1..K)
MARKER_KEY = "marker/marker_stack"                   # (N,H,W), uint16 intensities

# --------------------------
# OUTPUTS (as you requested)
# --------------------------
TIFF_DIR   = "/mnt/efs/aimbl_2025/student_data/S-LS/binary_rings-tiff"
ZARR_OUT   = "/mnt/efs/aimbl_2025/student_data/S-LS/binary_rings.zarr"
DS_NAME    = "binary_stack"  # dataset name inside the output zarr

# --------------------------
# PARAMETERS TO TUNE
# --------------------------
R_OUT = 2            # outer band width (pixels)
R_IN  = 1            # inner band width (pixels)
THR   = 1340         # uint16 intensity threshold for marker
COVERAGE_MIN = 0.5   # fraction of band pixels that must be >= THR

# --------------------------
# HELPERS
# --------------------------
def band_two_sided(labels2d: np.ndarray, r_out: int = 2, r_in: int = 1) -> np.ndarray:
    """Label-preserving band: outer (expand) + inner (erode). Returns band_labels with instance ids."""
    labels2d = labels2d.astype(np.int32, copy=False)

    # Outer band
    r_out = int(max(0, r_out))
    expanded = expand_labels(labels2d, distance=r_out) if r_out > 0 else labels2d
    outer = expanded.copy()
    outer[labels2d != 0] = 0  # keep only the added outer ring

    # Inner band
    r_in = int(max(0, r_in))
    if r_in > 0:
        se = disk(r_in)
        inner = np.zeros_like(labels2d, dtype=np.int32)
        for oid in np.unique(labels2d):
            if oid == 0:
                continue
            obj = (labels2d == oid)
            er = erosion(obj, se)
            ring = obj & (~er)   # inner band: original minus eroded
            inner[ring] = oid
    else:
        inner = np.zeros_like(labels2d, dtype=np.int32)

    band = inner.copy()
    band[outer != 0] = outer[outer != 0]  # union inner+outer
    return band

def band_coverage_per_label(band_labels: np.ndarray, marker2d: np.ndarray, thr: int) -> dict[int, float]:
    """coverage[label] = (# band pixels with marker >= thr) / (band area)."""
    if band_labels.max() == 0:
        return {}
    pos = (marker2d >= int(thr)).astype(np.uint8)  # 1 where pixel passes; 0 otherwise
    props = regionprops_table(
        band_labels, intensity_image=pos,
        properties=("label", "mean_intensity")
    )
    return {int(l): float(m) for l, m in zip(props["label"], props["mean_intensity"])}

# --------------------------
# LOAD INPUTS
# --------------------------
zin = zarr.open(ZARR_IN, mode="r")
masks  = zin[MASKS_KEY][0:3]    # zarr arrays (lazy)
marker = zin[MARKER_KEY][0:3]
N, H, W = masks.shape
print(f"Loaded input: N={N}, H={H}, W={W}")

# --------------------------
# PREP OUTPUTS
# --------------------------
# Zarr
store = zarr.DirectoryStore(ZARR_OUT)
root  = zarr.group(store=store, overwrite=True)  # clean slate each run
chunks = (1, 552, 688)
if DS_NAME in root:
    del root[DS_NAME]
binary_ds = root.create_dataset(DS_NAME, shape=(N, H, W), chunks=chunks, dtype="uint16")

# TIFF dir
Path(TIFF_DIR).mkdir(parents=True, exist_ok=True)

# --------------------------
# PROCESS & SAVE
# --------------------------
total_pos = total_inst = 0

for i in range(N):
    labels = masks[i][...]           # (H,W) int
    mrk    = marker[i][...]          # (H,W) uint16

    ids = np.unique(labels)
    ids = ids[ids != 0]
    if ids.size == 0:
        bin_img = np.zeros((H, W), dtype=np.uint8)
        binary_ds[i, :, :] = bin_img
        imsave(os.path.join(TIFF_DIR, f"img_{i:03d}_binary_positive.tif"), bin_img)
        continue

    band = band_two_sided(labels, r_out=R_OUT, r_in=R_IN)
    coverage = band_coverage_per_label(band, mrk, thr=THR)

    pos_ids = [int(oid) for oid in ids if coverage.get(int(oid), 0.0) >= COVERAGE_MIN]
    mask_pos = np.isin(labels, pos_ids).astype(np.uint8)  # 1 for positive bacterium pixels, else 0

    binary_ds[i, :, :] = mask_pos
    imsave(os.path.join(TIFF_DIR, f"img_{i:03d}_binary_positive.tif"), mask_pos)

    total_pos  += len(pos_ids)
    total_inst += len(ids)

print(f"Saved Zarr: {ZARR_OUT}  dataset: '{DS_NAME}'  shape={binary_ds.shape} chunks={binary_ds.chunks}")
print(f"Saved TIFFs to: {TIFF_DIR}")
print(f"Positive instances: {total_pos}/{total_inst} ({100*total_pos/max(total_inst,1):.1f}%)")

Loaded input: N=3, H=2208, W=2752


  return func(*args, **kwargs)
  return func(*args, **kwargs)


Saved Zarr: /mnt/efs/aimbl_2025/student_data/S-LS/binary_rings.zarr  dataset: 'binary_stack'  shape=(3, 2208, 2752) chunks=(1, 552, 688)
Saved TIFFs to: /mnt/efs/aimbl_2025/student_data/S-LS/binary_rings-tiff
Positive instances: 140/529 (26.5%)


  return func(*args, **kwargs)


Why these metrics fit your goal (“positive vs negative overlap”)

Precision (PPV): Among pixels you label positive, how many are truly positive? If you’re worried about false positive marker calls along borders/background, this is key.

Recall (Sensitivity): Among truly positive pixels, how many did you catch? If missing positive bacteria is costly, you’ll watch this.

F1 / Dice: Harmonic mean of precision and recall; standard for binary segmentation overlap when you care about both.

IoU (Jaccard): Another overlap score (stricter than Dice); commonly reported.

Specificity & Balanced Accuracy: Background usually dominates these images; balanced accuracy prevents plain accuracy from looking great just because most pixels are negative.

In [None]:
import os, re, csv
from pathlib import Path
import numpy as np
import imageio.v2 as imageio

# ---- EDIT THESE ----
PRED_DIR = "/mnt/efs/aimbl_2025/student_data/S-LS/binary_rings-predicted-tiff"
GT_DIR   = "/mnt/efs/aimbl_2025/student_data/S-LS/binary_rings-gt-tiff"
WRITE_CSV = True
CSV_PATH  = "/mnt/efs/aimbl_2025/student_data/S-LS/binary_rings_metrics.csv"

# ---------- pairing helpers ----------
id_pat = re.compile(r"img_(\d+)")  # grabs the number in filenames like img_002_...

def index_by_id(folder):
    d = {}
    for f in sorted(Path(folder).glob("*.tif")):
        m = id_pat.search(f.name)
        if not m: 
            continue
        idx = int(m.group(1))
        d[idx] = f
    return d

pred_map = index_by_id(PRED_DIR)
gt_map   = index_by_id(GT_DIR)
common_ids = sorted(set(pred_map) & set(gt_map))
missing_pred = sorted(set(gt_map) - set(pred_map))
missing_gt   = sorted(set(pred_map) - set(gt_map))

print(f"Found {len(common_ids)} matched image IDs.")
if missing_pred:
    print(f"IDs missing in PRED: {missing_pred}")
if missing_gt:
    print(f"IDs missing in GT:   {missing_gt}")

assert common_ids, "No matching image IDs between predicted and GT folders."

# ---------- metrics helpers ----------
def confusion_counts(y_true: np.ndarray, y_pred: np.ndarray):
    # y_* are boolean arrays (True=positive). Flatten to count pixels.
    y_true = y_true.ravel()
    y_pred = y_pred.ravel()
    tp = int(np.sum( y_pred &  y_true))
    fp = int(np.sum( y_pred & ~y_true))
    fn = int(np.sum(~y_pred &  y_true))
    tn = int(np.sum(~y_pred & ~y_true))
    return tp, fp, fn, tn

def metrics_from_counts(tp, fp, fn, tn):
    eps = 1e-12
    precision = tp / (tp + fp + eps)
    recall    = tp / (tp + fn + eps)  # sensitivity
    f1        = 2 * precision * recall / (precision + recall + eps)  # == Dice for binary
    dice      = f1  # identical for binary segmentation
    iou       = tp / (tp + fp + fn + eps)
    specificity = tn / (tn + fp + eps)
    accuracy    = (tp + tn) / (tp + fp + fn + tn + eps)
    balanced_acc = 0.5 * (recall + specificity)
    return dict(
        precision=precision, recall=recall, f1=f1, dice=dice, iou=iou,
        specificity=specificity, accuracy=accuracy, balanced_accuracy=balanced_acc,
        tp=tp, fp=fp, fn=fn, tn=tn
    )

# ---------- evaluate per-image and aggregate ----------
rows = []
agg_tp=agg_fp=agg_fn=agg_tn = 0

for idx in common_ids:
    pred_path = pred_map[idx]
    gt_path   = gt_map[idx]

    pred = imageio.imread(pred_path)
    gt   = imageio.imread(gt_path)

    # binarize: any nonzero is positive
    y_pred = pred.astype(bool)
    y_true = gt.astype(bool)

    if y_pred.shape != y_true.shape:
        raise ValueError(f"Shape mismatch for ID {idx}: pred {y_pred.shape} vs gt {y_true.shape}")

    tp, fp, fn, tn = confusion_counts(y_true, y_pred)
    m = metrics_from_counts(tp, fp, fn, tn)

    rows.append({
        "img_id": idx,
        "pred_file": pred_path.name,
        "gt_file": gt_path.name,
        **{k: (float(v) if isinstance(v, (float, np.floating)) else v) for k,v in m.items()}
    })

    agg_tp += tp; agg_fp += fp; agg_fn += fn; agg_tn += tn

# dataset-level (micro-averaged over all pixels)
M = metrics_from_counts(agg_tp, agg_fp, agg_fn, agg_tn)

print("\n=== Dataset-level (micro over pixels) ===")
print(f"Precision:        {M['precision']:.4f}")
print(f"Recall (sens.):   {M['recall']:.4f}")
print(f"F1 / Dice:        {M['f1']:.4f}")
print(f"IoU (Jaccard):    {M['iou']:.4f}")
print(f"Specificity:      {M['specificity']:.4f}")
print(f"Balanced Acc.:    {M['balanced_accuracy']:.4f}")
print(f"Accuracy:         {M['accuracy']:.4f}  (can be misleading with lots of background)")
print(f"TP/FP/FN/TN:      {M['tp']} / {M['fp']} / {M['fn']} / {M['tn']}")

# per-image summary (optional)
print("\n=== Per-image (first few) ===")
for r in rows[:min(5, len(rows))]:
    print(f"img_{r['img_id']:03d}: F1={r['f1']:.4f}  IoU={r['iou']:.4f}  "
          f"Prec={r['precision']:.4f}  Rec={r['recall']:.4f}")

# write CSV (optional)
if WRITE_CSV:
    fieldnames = ["img_id", "pred_file", "gt_file",
                  "precision", "recall", "f1", "dice", "iou",
                  "specificity", "accuracy", "balanced_accuracy", "tp", "fp", "fn", "tn"]
    with open(CSV_PATH, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        for r in rows:
            w.writerow({k: r.get(k, "") for k in fieldnames})
        # also append dataset-level row
        w.writerow({
            "img_id": "DATASET",
            "pred_file": f"{len(rows)} images",
            "gt_file": "",
            **{k: (float(v) if isinstance(v, (float, np.floating)) else v) for k,v in M.items()}
        })
    print(f"\nWrote per-image metrics and dataset summary to: {CSV_PATH}")