In [7]:
import os
import glob
import pickle
from tqdm import tqdm
import numpy as np
from skimage import io
from skimage.filters import gaussian
from skimage.color import rgb2gray
from cellpose import models
import torch

# -------- user-tunable ----------
USE_GPU = torch.cuda.is_available()
DIAMETER = 10
FLOW_THRESHOLD = 0.4
CELLPROB_THRESHOLD = 0.0
BLUR = 0.0
IMG_ROOT = "content"
OUT_MASK_ROOT = "masks"
RESULTS_PKL = "results.pkl"
SAVE_EVERY = 5
# --------------------------------

# make backup of results.pkl before proceeding
if os.path.exists(RESULTS_PKL):
    import shutil
    shutil.copyfile(RESULTS_PKL, RESULTS_PKL + ".bak")
    print(f"Backup created: {RESULTS_PKL}.bak")

# load existing results if any
if os.path.exists(RESULTS_PKL):
    with open(RESULTS_PKL, "rb") as f:
        try:
            results = pickle.load(f)
        except Exception as e:
            print("Failed to load existing results.pkl:", e)
            raise
    print(f"Loaded {len(results)} existing records from {RESULTS_PKL}")
else:
    results = []
    print("No existing results.pkl found — starting fresh.")

# build set of processed image paths
processed_imgs = set()
max_run_id = 0
for r in results:
    p = r.get("image_path")
    if p:
        processed_imgs.add(os.path.abspath(p))
    rid = r.get("run", 0) or 0
    if rid > max_run_id:
        max_run_id = rid

print("Already processed images:", len(processed_imgs))
print("Starting run_id from", max_run_id + 1)

# init model (no model_type, no verbose)
model = models.CellposeModel(gpu=USE_GPU)

# helper: robust preprocess that handles RGBA, single-channel, weird shapes
def preprocess_img_safe(img, blur):
    # If multi-page TIFF, skimage may return shape (frames, h, w) or (h,w,4)
    # Handle typical cases:
    arr = np.asarray(img)
    # If image has more than 3 dims (e.g. multi-frame), take first frame
    if arr.ndim == 3 and arr.shape[0] > 1 and arr.shape[2] != 3 and arr.shape[2] != 4:
        # Could be (frames, H, W) — take first
        arr = arr[0]
    # If RGBA (4 channels) -> drop alpha (or composite)
    if arr.ndim == 3 and arr.shape[2] == 4:
        # drop alpha:
        arr = arr[..., :3]
        # alternative: composite over white:
        # alpha = img[...,3:4] / 255.0
        # rgb = img[...,:3].astype(float)
        # arr = (rgb * alpha + 255*(1-alpha)).astype(np.uint8)
    # If RGB -> convert to gray
    if arr.ndim == 3 and arr.shape[2] == 3:
        try:
            g = rgb2gray(arr)  # returns float in 0..1
            arr = (g * 255).astype(np.uint8)
        except Exception:
            # fallback: simple average
            arr = arr.mean(axis=2).astype(np.uint8)
    # If already single-channel but float — normalize
    if arr.dtype != np.uint8:
        mn, mx = float(arr.min()), float(arr.max())
        if mx > mn:
            arr = ((arr - mn) / (mx - mn) * 255.0).astype(np.uint8)
        else:
            arr = (arr * 0).astype(np.uint8)
    # blur if requested
    if blur and blur > 0:
        arrf = gaussian(arr, sigma=blur)
        mn, mx = float(arrf.min()), float(arrf.max())
        if mx > mn:
            arr = ((arrf - mn) / (mx - mn) * 255.0).astype(np.uint8)
        else:
            arr = (arrf * 0).astype(np.uint8)
    return arr

# collect all image paths under IMG_ROOT (recursive)
exts = ("*.png", "*.tif", "*.tiff", "*.jpg", "*.jpeg")
found = []
for root, dirs, files in os.walk(IMG_ROOT):
    for e in exts:
        found.extend(glob.glob(os.path.join(root, e)))
found = sorted(found)
print("Total images found under", IMG_ROOT, ":", len(found))

# filter only not-yet-processed
to_process = [p for p in found if os.path.abspath(p) not in processed_imgs]
print("Images to process:", len(to_process))

run_id = max_run_id

new_count = 0
for img_path in tqdm(to_process, desc="Resuming processing"):
    try:
        img_raw = io.imread(img_path)
    except Exception as e:
        print("Failed to read", img_path, "->", e)
        continue

    img = preprocess_img_safe(img_raw, BLUR)

    # call model.eval
    try:
        res = model.eval([img], diameter=DIAMETER, flow_threshold=FLOW_THRESHOLD, cellprob_threshold=CELLPROB_THRESHOLD)
    except Exception as e:
        print("Cellpose eval failed for", img_path, "->", e)
        continue

    # safe unpack
    masks = None; flows = None; styles = None; diams = None
    if isinstance(res, (tuple, list)):
        if len(res) == 4:
            masks, flows, styles, diams = res
        elif len(res) == 3:
            masks, flows, styles = res
        else:
            masks = res[0] if len(res) > 0 else None
    else:
        print("Unexpected return from model.eval for", img_path)
        continue

    if masks is None:
        print("No masks returned for", img_path)
        continue

    # molecule id = first subdir under IMG_ROOT (if exists)
    rel = os.path.relpath(img_path, IMG_ROOT)
    parts = rel.split(os.sep)
    molecule_id = parts[0] if len(parts) >= 2 else os.path.splitext(os.path.basename(img_path))[0]

    out_mask_dir = os.path.join(OUT_MASK_ROOT, molecule_id)
    os.makedirs(out_mask_dir, exist_ok=True)

    for i_mask, mask in enumerate(masks):
        run_id += 1
        base = os.path.splitext(os.path.basename(img_path))[0]
        out_name = f"{molecule_id}__{base}__run{run_id:04d}.png"
        out_path = os.path.join(out_mask_dir, out_name)
        try:
            io.imsave(out_path, mask.astype(np.uint16))
        except Exception:
            np.save(out_path + ".npy", mask)

        obj_pixels = int((mask > 0).sum())
        unique_labels = np.unique(mask)
        n_objects = int((unique_labels > 0).sum()) if unique_labels.size > 0 else 0

        rec = {
            "run": run_id,
            "molecule_id": molecule_id,
            "image_path": img_path,
            "image_name": os.path.basename(img_path),
            "mask_path": out_path,
            "mask": mask,
            "dia": DIAMETER,
            "flow_threshold": FLOW_THRESHOLD,
            "cellprob_threshold": CELLPROB_THRESHOLD,
            "blur": BLUR,
            "obj_pixels": obj_pixels,
            "n_objects": n_objects
        }
        results.append(rec)
        new_count += 1

    # periodic save
    if new_count % SAVE_EVERY == 0:
        with open(RESULTS_PKL, "wb") as f:
            pickle.dump(results, f)
        print(f"Saved intermediate results ({len(results)} total) to {RESULTS_PKL}")

# final save
with open(RESULTS_PKL, "wb") as f:
    pickle.dump(results, f)
print(f"Finished. Total records: {len(results)}. Saved to {RESULTS_PKL}")


Backup created: results.pkl.bak
Loaded 0 existing records from results.pkl
Already processed images: 0
Starting run_id from 1
Total images found under content : 27
Images to process: 27


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
Resuming processing:  19%|███████████▋                                                   | 5/27 [04:32<19:59, 54.54s/it]

Saved intermediate results (5 total) to results.pkl


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
Resuming processing:  37%|██████████████████████▉                                       | 10/27 [09:06<15:28, 54.64s/it]

Saved intermediate results (10 total) to results.pkl


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
Resuming processing:  56%|██████████████████████████████████▍                           | 15/27 [13:37<10:50, 54.19s/it]

Saved intermediate results (15 total) to results.pkl


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
Resuming processing:  74%|█████████████████████████████████████████████▉                | 20/27 [18:10<06:24, 54.94s/it]

Saved intermediate results (20 total) to results.pkl


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
Resuming processing:  93%|█████████████████████████████████████████████████████████▍    | 25/27 [22:43<01:49, 54.84s/it]

Saved intermediate results (25 total) to results.pkl


  return func(*args, **kwargs)
  return func(*args, **kwargs)
Resuming processing: 100%|██████████████████████████████████████████████████████████████| 27/27 [24:33<00:00, 54.58s/it]

Finished. Total records: 27. Saved to results.pkl



