In [None]:
import numpy as np
from skimage.io import imread, imsave
import napari
from pystackreg import StackReg
from cellpose import models

import numpy as np
from skimage.measure import label, regionprops
from skimage.measure import regionprops, label

In [4]:
model = models.Cellpose(model_type='cyto2',gpu=True)

In [5]:
im1_path = 'aligned_00.tif'
im2_path = 'aligned_01.tif'
im3_path = 'aligned_02.tif'

In [7]:
im1 = imread(im1_path)
im2 = imread(im2_path)
im3 = imread(im3_path)

In [8]:
mask1,_,_,_ = model.eval(im1,diameter=50)
mask2,_,_,_ = model.eval(im2,diameter=50)
mask3,_,_,_ = model.eval(im3,diameter=50)

In [9]:
mask_stack = np.stack([mask1, mask2, mask3], axis=0)

In [89]:
import numpy as np
from skimage.measure import regionprops
from skimage.segmentation import relabel_sequential

def compute_iou(mask1, mask2):
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    return intersection / union if union > 0 else 0.0

def get_overlap_bbox(bbox1, bbox2, shape):
    minr = max(bbox1[0], bbox2[0])
    minc = max(bbox1[1], bbox2[1])
    maxr = min(bbox1[2], bbox2[2])
    maxc = min(bbox1[3], bbox2[3])
    if minr >= maxr or minc >= maxc:
        return None
    return (minr, minc, maxr, maxc)

def crop_mask(mask, bbox):
    minr, minc, maxr, maxc = bbox
    return mask[minr:maxr, minc:maxc]

def prune_ambiguous_matches(match_map, n_frames):
    # Step 1: Build reverse index: frame_id → label → [anchor_ids]
    label_usage = [dict() for _ in range(n_frames)]

    for anchor_label, matched in match_map.items():
        for frame_idx, label in enumerate(matched):
            if label not in label_usage[frame_idx]:
                label_usage[frame_idx][label] = []
            label_usage[frame_idx][label].append(anchor_label)

    # Step 2: Collect anchor_labels to remove
    anchors_to_remove = set()
    for frame_idx in range(n_frames):
        for label, anchors in label_usage[frame_idx].items():
            if len(anchors) > 1:
                anchors_to_remove.update(anchors)

    # Step 3: Filter out ambiguous anchor matches
    pruned = {
        anchor: matched
        for anchor, matched in match_map.items()
        if anchor not in anchors_to_remove
    }
    return pruned

def relabel_stack_by_anchor(stack, anchor_index=0, iou_threshold=0.7):
    # Step 1: Clean and relabel masks to sequential ints
    #stack = [relabel_sequential(mask)[0] for mask in stack]
    height, width = stack[0].shape

    # Step 2: Precompute bounding boxes per frame
    label_bboxes = []
    label_coords = []
    for frame in stack:
        props = regionprops(frame)
        bbox_dict = {p.label: p.bbox for p in props}
        label_bboxes.append(bbox_dict)
        coord_dict = {p.label: p.coords for p in props}
        label_coords.append(coord_dict)

    anchor = stack[anchor_index]
    anchor_props = regionprops(anchor)

    match_map = {}  # anchor_label -> [label_in_frame_0, ..., label_in_frame_n]

    for anchor_prop in anchor_props:
        anchor_label = anchor_prop.label
        anchor_bbox = anchor_prop.bbox

        consistent_labels = []
        valid = True

        
        for i, frame in enumerate(stack):
            if i == anchor_index:
                consistent_labels.append(anchor_label)
                continue

            best_label = []

            frame_crop = frame[anchor_bbox[0]:anchor_bbox[2], anchor_bbox[1]:anchor_bbox[3]]
            candidate_labels = np.unique(frame_crop)
            candidate_labels = candidate_labels[candidate_labels != 0]

            for cand_label in candidate_labels:
                if cand_label not in label_bboxes[i]:
                    continue
                cand_bbox = label_bboxes[i][cand_label]
                overlap_bbox = get_overlap_bbox(anchor_bbox, cand_bbox, frame.shape)
                if overlap_bbox is None:
                    continue

                anchor_part = crop_mask(anchor, overlap_bbox)
                anchor_part = (anchor_part==anchor_label)
                cand_part = crop_mask(frame, overlap_bbox)
                cand_part = (cand_part == cand_label)

                iou = compute_iou(anchor_part, cand_part)
                if iou > iou_threshold:
                    best_label.append(cand_label)

            if len(best_label) == 1:
                consistent_labels.append(best_label[0])
            else:
                valid = False
                break

        # Must match uniquely across all frames
        if valid and len(set(consistent_labels)) == len(stack):
            match_map[anchor_label] = consistent_labels

    # Step 3: Prune ambiguous matches
    match_map = prune_ambiguous_matches(match_map, n_frames=len(stack))

    # Step 4: Construct output stack
    output_stack = []
    for i, frame in enumerate(stack):
        relabeled = np.zeros_like(frame, dtype=np.uint16)
        coord_dict = label_coords[i]
        for new_label, matched in enumerate(match_map.values(), start=1):
            old_label = matched[i]
            if old_label in coord_dict:
                coords = coord_dict[old_label]
                relabeled[tuple(coords.T)] = new_label
        output_stack.append(relabeled)

    return output_stack


In [90]:
mask_stack_match = relabel_stack_by_anchor(mask_stack[:, :, :], anchor_index=2, iou_threshold=0.5)

In [91]:
viewer = napari.Viewer()
viewer.add_image(im1, name='Image 1',colormap = 'green', blending='additive')
viewer.add_image(im2, name='Image 2', colormap='yellow', blending='additive')
viewer.add_image(im3, name='Image 3', colormap='magenta', blending='additive')
viewer.add_labels(mask_stack_match[0], name='Mask 1')
viewer.add_labels(mask_stack_match[1], name='Mask 2')
viewer.add_labels(mask_stack_match[2], name='Mask 3')

<Labels layer 'Mask 3' at 0x1baa3f24050>

In [92]:
[len(np.unique(frame)) for frame in mask_stack_match]

[2047, 2047, 2047]