In [1]:
# Libraries
import os
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from transformers import pipeline

In [2]:
# Paths
DATA_DIR = Path("/Users/tommy/Projects/gap-junction-segmentation/data/sem_adult_imgs_test")
OUT_DIR = Path("/Users/tommy/Projects/gap-junction-segmentation/outputs/sam3_masks")
OUT_DIR.mkdir(parents=True, exist_ok=True)

image_paths = sorted(DATA_DIR.glob("*.png"))
len(image_paths), image_paths[0]

(10,
 PosixPath('/Users/tommy/Projects/gap-junction-segmentation/data/sem_adult_imgs_test/SEM_adult_image_export_s200_NR.png'))

In [3]:
# Load SAM3 mask-generation pipeline
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

pipe = pipeline(
    "mask-generation",
    model="facebook/sam3",
    #sam2.1-hiera-large SAM2
    trust_remote_code=True,
    device=device,
 )

Loading weights:   0%|          | 0/685 [00:00<?, ?it/s]

In [4]:
# Run SAM on a single image (sanity check)
img_path = image_paths[0]
img = Image.open(img_path).convert("RGB")

output = pipe(img)

def extract_masks(output, height: int, width: int) -> np.ndarray:
    """Returns boolean masks with shape (N, H, W)."""
    if isinstance(output, dict):
        masks_obj = output.get("masks", None)
        if masks_obj is None and "mask" in output:
            masks_obj = [output["mask"]]
    elif isinstance(output, list):
        # Some pipelines return a list of dicts with a 'mask' entry
        if len(output) == 0:
            return np.zeros((0, height, width), dtype=bool)
        if isinstance(output[0], dict) and "mask" in output[0]:
            masks_obj = [o["mask"] for o in output]
        else:
            masks_obj = output
    else:
        masks_obj = None

    if masks_obj is None:
        raise ValueError(f"Unexpected pipeline output type/format: {type(output)}")

    # Normalize to a stacked numpy array
    if isinstance(masks_obj, torch.Tensor):
        masks = masks_obj.detach().cpu().numpy()
    elif isinstance(masks_obj, np.ndarray):
        masks = masks_obj
    else:
        masks_list = []
        for m in masks_obj:
            if isinstance(m, torch.Tensor):
                m = m.detach().cpu().numpy()
            elif not isinstance(m, np.ndarray):
                m = np.array(m)
            masks_list.append(m)
        masks = np.stack(masks_list, axis=0) if len(masks_list) else np.zeros((0, height, width), dtype=bool)

    # Force shape (N, H, W) and bool dtype
    if masks.ndim == 2:
        masks = masks[None, :, :]
    masks = masks.astype(bool)
    return masks

def erode3x3(mask: np.ndarray) -> np.ndarray:
    """Binary erosion with a 3x3 square structuring element (no SciPy)."""
    h, w = mask.shape
    p = np.pad(mask, 1, mode="constant", constant_values=False)
    out = np.ones((h, w), dtype=bool)
    for dy in range(3):
        for dx in range(3):
            out &= p[dy : dy + h, dx : dx + w]
    return out

masks = extract_masks(output, img.height, img.width)
print(f"{img_path.name}: {masks.shape[0]} masks, size {masks.shape[-2:]}" )

# Union mask (foreground)
union_bool = np.any(masks, axis=0)
union = union_bool.astype(np.uint8) * 255
union_path = OUT_DIR / f"{img_path.stem}_union.png"
Image.fromarray(union).save(union_path)

# Membrane-like boundary from union mask
membrane_bool = union_bool & (~erode3x3(union_bool))
membrane = membrane_bool.astype(np.uint8) * 255
membrane_path = OUT_DIR / f"{img_path.stem}_membrane.png"
Image.fromarray(membrane).save(membrane_path)

union_path, membrane_path

SEM_adult_image_export_s200_NR.png: 37 masks, size (1024, 1024)


(PosixPath('/Users/tommy/Projects/gap-junction-segmentation/outputs/sam3_masks/SEM_adult_image_export_s200_NR_union.png'),
 PosixPath('/Users/tommy/Projects/gap-junction-segmentation/outputs/sam3_masks/SEM_adult_image_export_s200_NR_membrane.png'))

In [None]:
# Batch run: save union + membrane boundary masks per image
gen_kwargs = dict(
    pred_iou_thresh=0.01,          # lower -> keep more masks
    stability_score_thresh=0.8,   # lower -> keep more masks
    crop_n_layers=1,              # enables tiled refinement
    crop_overlap_ratio=0.2,
    points_per_side=64,           # denser sampling
    min_mask_region_area=200,     # remove tiny specks (tune!)
)

for img_path in image_paths:
    img = Image.open(img_path).convert("RGB")
    output = pipe(img, **gen_kwargs)
    masks = extract_masks(output, img.height, img.width)

    union_bool = np.any(masks, axis=0)
    union = union_bool.astype(np.uint8) * 255
    Image.fromarray(union).save(OUT_DIR / f"{img_path.stem}_union.png")

    membrane_bool = union_bool & (~erode3x3(union_bool))
    membrane = membrane_bool.astype(np.uint8) * 255
    Image.fromarray(membrane).save(OUT_DIR / f"{img_path.stem}_membrane.png")

print(f"Saved {len(image_paths)} union+membrane masks to {OUT_DIR}")

Saved 10 union+membrane masks to /Users/tommy/Projects/gap-junction-segmentation/outputs/sam3_masks


In [6]:
output = pipe(img)