In [1]:
import os
import pandas as pd
from pathlib import Path

import numpy as np
import nibabel as nib
from skimage.measure import label as connected_components

import torch
from torch.utils.data import Dataset, DataLoader

import scipy.ndimage as ndi

import micro_sam.training as sam_training
from micro_sam.automatic_segmentation import (
    get_predictor_and_segmenter,
    automatic_instance_segmentation,
)

# Root directory with your finetune patches
ROOT = Path("/midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches")
print("ROOT exists:", ROOT.exists())

# NEW: where to save finetuned checkpoints and predictions
CHECKPOINT_ROOT = Path("/midtier/paetzollab/scratch/ads4015/compare_methods/micro_sam/finetuned_checkpoints")
PRED_ROOT = Path("/midtier/paetzollab/scratch/ads4015/compare_methods/micro_sam/finetuned_preds")

CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
PRED_ROOT.mkdir(parents=True, exist_ok=True)



  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


ROOT exists: True


In [2]:
def compute_distance_to_center(inst_mask: np.ndarray) -> np.ndarray:
    """
    For each instance ID > 0, compute distance to its center of mass.
    Normalize to [0,1] per instance and invert so that center=1, edges=0.
    Outside objects: 0.
    """
    H, W = inst_mask.shape
    out = np.zeros((H, W), dtype=np.float32)

    instance_ids = np.unique(inst_mask)
    instance_ids = instance_ids[instance_ids != 0]

    yy, xx = np.indices((H, W), dtype=np.float32)

    for iid in instance_ids:
        mask = (inst_mask == iid)
        if not np.any(mask):
            continue

        cy, cx = ndi.center_of_mass(mask.astype(np.float32))
        dist = np.sqrt((yy - cy) ** 2 + (xx - cx) ** 2)

        maxd = dist[mask].max() + 1e-6
        dist_norm = dist / maxd  # 0..1 inside object
        out[mask] = 1.0 - dist_norm[mask]  # 1 at center, 0 at furthest

    return out


def compute_distance_to_boundary(inst_mask: np.ndarray) -> np.ndarray:
    """
    For each instance ID > 0, compute distance transform inside the object.
    Normalize to [0,1] per instance. Outside objects: 0.
    """
    H, W = inst_mask.shape
    out = np.zeros((H, W), dtype=np.float32)

    instance_ids = np.unique(inst_mask)
    instance_ids = instance_ids[instance_ids != 0]

    for iid in instance_ids:
        mask = (inst_mask == iid)
        if not np.any(mask):
            continue

        dist = ndi.distance_transform_edt(mask)
        maxd = dist[mask].max() + 1e-6
        dist_norm = dist / maxd  # 0..1 inside object
        out[mask] = dist_norm[mask]

    return out


In [3]:
class Selma2DSliceDataset(Dataset):
    def __init__(self, raw_label_pairs):
        """
        raw_label_pairs: list of (raw_path, label_path)
        Each raw_path is a single-channel NIfTI (e.g. *_ch0.nii.gz or *_ch1.nii.gz).
        """
        self.pairs = raw_label_pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        raw_path, label_path = self.pairs[idx]
        raw_path = Path(raw_path)
        label_path = Path(label_path)

        # --- load raw and label volumes ---
        raw_vol = nib.load(str(raw_path)).get_fdata().astype(np.float32)
        lab_vol = nib.load(str(label_path)).get_fdata().astype(np.float32)

        assert raw_vol.shape == lab_vol.shape, f"shape mismatch: {raw_vol.shape} vs {lab_vol.shape}"
        Z = raw_vol.shape[0]

        # --- find a slice that contains at least one object ---
        for _ in range(20):
            z = np.random.randint(0, Z)
            lab_slice = lab_vol[z]
            if lab_slice.max() > 0:
                break
        else:
            # fallback: deterministic scan
            found = False
            for z in range(Z):
                lab_slice = lab_vol[z]
                if lab_slice.max() > 0:
                    found = True
                    break
            if not found:
                # no instances in this volume at all; pick a different pair
                return self.__getitem__((idx + 1) % len(self.pairs))

        img_slice = raw_vol[z]
        lab_slice = lab_vol[z]

        # --- instance mask (connected components) ---
        bin_mask = (lab_slice > 0).astype(np.uint8)
        inst_mask = connected_components(bin_mask, connectivity=1).astype(np.int32)

        if inst_mask.max() == 0:
            # no instance after CC, rare corner case
            return self.__getitem__((idx + 1) % len(self.pairs))

        # --- robust image normalization to [0,255] for SAM ---
        p1 = np.percentile(img_slice, 1)
        p99 = np.percentile(img_slice, 99)
        if p99 > p1:
            img_slice_norm = (img_slice - p1) / (p99 - p1)
        else:
            img_slice_norm = np.zeros_like(img_slice, dtype=np.float32)

        img_slice_norm = np.clip(img_slice_norm, 0.0, 1.0)
        img_slice_uint = (img_slice_norm * 255.0).astype(np.float32)

        # --- label channels ---
        # channel 0: instance ids
        instance_channel = inst_mask.astype(np.float32)

        # foreground mask
        fg = (inst_mask > 0).astype(np.float32)

        # distances
        dist_center = compute_distance_to_center(inst_mask)
        dist_boundary = compute_distance_to_boundary(inst_mask)

        # mask distances and clamp to [0,1]
        dist_center = np.clip(dist_center * fg, 0.0, 1.0)
        dist_boundary = np.clip(dist_boundary * fg, 0.0, 1.0)

        # assemble y: (4, H, W)
        #   y[0] = instance ids (integer-valued)
        #   y[1] = foreground (0..1)
        #   y[2] = center distance (0..1)
        #   y[3] = boundary distance (0..1)
        y_np = np.stack(
            [instance_channel, fg, dist_center, dist_boundary],
            axis=0
        ).astype(np.float32)

        # to torch
        x = torch.from_numpy(img_slice_uint[None, ...])  # (1,H,W), float32 in [0,255]
        y = torch.from_numpy(y_np)                        # (4,H,W)

        return x, y


In [4]:
from collections import defaultdict

rng = np.random.default_rng(42)

# --------------------------------------------------------
# 1) COLLECT (class_name, raw_path, label_path) FOR ALL CHANNELS
# --------------------------------------------------------
records = []  # (class_name, raw_path, label_path)

for class_dir in ROOT.iterdir():
    if not class_dir.is_dir():
        continue

    class_name = class_dir.name  # e.g. "amyloid_plaque_patches", "vessels_patches"

    # any channel: *_ch0.nii.gz, *_ch1.nii.gz, ... but exclude *_label.nii.gz
    for raw_path in sorted(class_dir.glob("*_ch*.nii.gz")):
        if "_label" in raw_path.name:
            continue
        # label for this specific channel file
        label_path = raw_path.with_name(raw_path.name.replace(".nii.gz", "_label.nii.gz"))
        if label_path.exists():
            records.append((class_name, raw_path, label_path))

print("Total labeled volumes:", len(records))

# --------------------------------------------------------
# 2) GROUP BY DATATYPE (class_name)
# --------------------------------------------------------
by_class = defaultdict(list)  # class_name -> list[(raw,label)]

for cls, raw, lab in records:
    by_class[cls].append((raw, lab))

for cls, items in by_class.items():
    print(cls, ":", len(items), "channel-files (samples)")

# --------------------------------------------------------
# 3) STRATIFIED SPLIT:
#    - 2 TEST samples per datatype
#    - Remaining -> 80/20 TRAIN/VAL (per datatype)
#    - Ensure train & val each get >=1 sample per datatype where possible
# --------------------------------------------------------
train_pairs = []
val_pairs = []
test_pairs = []

for cls, items in by_class.items():
    items = items.copy()
    rng.shuffle(items)

    if len(items) < 4:
        raise ValueError(
            f"Not enough samples in class '{cls}' to hold out 2 test and still have train+val."
        )

    # 2 test samples for this datatype
    test_cls = items[:2]
    rest = items[2:]

    n_rest = len(rest)
    # initial 80/20 split
    n_train = int(round(0.8 * n_rest))
    # ensure at least 1 val and 1 train
    if n_train < 1:
        n_train = 1
    if n_train > n_rest - 1:
        n_train = n_rest - 1
    n_val = n_rest - n_train

    train_cls = rest[:n_train]
    val_cls = rest[n_train:]

    print(
        f"{cls}: total={len(items)} -> test={len(test_cls)}, "
        f"train={len(train_cls)}, val={len(val_cls)}"
    )

    test_pairs.extend(test_cls)
    train_pairs.extend(train_cls)
    val_pairs.extend(val_cls)

# Shuffle globally
rng.shuffle(train_pairs)
rng.shuffle(val_pairs)
rng.shuffle(test_pairs)

print("Final sizes:")
print("  train:", len(train_pairs))
print("  val:  ", len(val_pairs))
print("  test: ", len(test_pairs))


Total labeled volumes: 88
vessels_patches : 40 channel-files (samples)
cell_nucleus_patches : 25 channel-files (samples)
c_fos_positive_patches : 4 channel-files (samples)
amyloid_plaque_patches : 19 channel-files (samples)
vessels_patches: total=40 -> test=2, train=30, val=8
cell_nucleus_patches: total=25 -> test=2, train=18, val=5
c_fos_positive_patches: total=4 -> test=2, train=1, val=1
amyloid_plaque_patches: total=19 -> test=2, train=14, val=3
Final sizes:
  train: 63
  val:   17
  test:  8


In [5]:
# Map each raw_path to its datatype (class_name)
cls_by_raw = {Path(raw): cls for cls, raw, lab in records}

print("Example mapping entries (first 5):")
for i, (raw, cls) in enumerate(cls_by_raw.items()):
    if i >= 5:
        break
    print(cls, "->", raw)


Example mapping entries (first 5):
vessels_patches -> /midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches/vessels_patches/patch_000_vol019_ch0.nii.gz
vessels_patches -> /midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches/vessels_patches/patch_000_vol019_ch1.nii.gz
vessels_patches -> /midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches/vessels_patches/patch_001_vol021_ch0.nii.gz
vessels_patches -> /midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches/vessels_patches/patch_001_vol021_ch1.nii.gz
vessels_patches -> /midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches/vessels_patches/patch_002_vol018_ch0.nii.gz


In [6]:
train_dataset = Selma2DSliceDataset(train_pairs)
val_dataset = Selma2DSliceDataset(val_pairs)

batch_size = 1  # keep small for microSAM

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

# microSAM's JointSamTrainer expects these attributes
train_loader.shuffle = True
val_loader.shuffle = False

print("Train samples:", len(train_dataset), "Val samples:", len(val_dataset))
print("Held-out test samples:", len(test_pairs))




Train samples: 63 Val samples: 17
Held-out test samples: 8


In [7]:
n_epochs = 10
n_objects_per_batch = 5
train_instance_segmentation = True
model_type = "vit_b_lm"  # must match when loading later

checkpoint_name = "selma3d_microsam_ais"
# root_dir is no longer needed for checkpoints

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

sam_training.train_sam(
    name=checkpoint_name,
    save_root=str(CHECKPOINT_ROOT),  # <-- save into your finetuned_checkpoints dir
    model_type=model_type,
    train_loader=train_loader,
    val_loader=val_loader,
    n_epochs=n_epochs,
    n_objects_per_batch=n_objects_per_batch,
    with_segmentation_decoder=train_instance_segmentation,
    device=device,
)



Using device: cuda


Verifying labels in 'train' dataloader: 100%|██████████| 50/50 [00:00<00:00, 61.66it/s]
Verifying labels in 'val' dataloader:  34%|███▍      | 17/50 [00:00<00:00, 49.74it/s]


Start fitting for 630 iterations /  10 epochs
with 63 iterations per epoch
Training with mixed precision


Epoch 10: average [s/it]: 0.441256, current metric: 0.238749, best metric: 0.238749: 100%|█████████▉| 629/630 [05:51<00:00,  1.79it/s]

Finished training after 10 epochs / 630 iterations.
The best epoch is number 9.
Training took 354.7764194011688 seconds (= 00:05:55 hours)





In [8]:
checkpoint_name = "selma3d_microsam_ais"
best_checkpoint = str(
    CHECKPOINT_ROOT / "checkpoints" / checkpoint_name / "best.pt"
)
print("Best checkpoint:", best_checkpoint, "exists:", os.path.exists(best_checkpoint))


predictor, segmenter = get_predictor_and_segmenter(
    model_type=model_type,
    checkpoint=best_checkpoint,
    device=device,
    is_tiled=False,  # we will handle tiling for large images ourselves
)


Best checkpoint: /midtier/paetzollab/scratch/ads4015/compare_methods/micro_sam/finetuned_checkpoints/checkpoints/selma3d_microsam_ais/best.pt exists: True


In [9]:
def segment_slice_2d(predictor, segmenter, img2d, tile_shape=None, halo=None, verbose=False):
    """Run AIS on a single 2D slice."""
    img2d = img2d.astype(np.float32)

    # normalize to [0,255] similarly to training
    p1 = np.percentile(img2d, 1)
    p99 = np.percentile(img2d, 99)
    if p99 > p1:
        img2d = (img2d - p1) / (p99 - p1)
    else:
        img2d = np.zeros_like(img2d)
    img2d = np.clip(img2d, 0.0, 1.0)
    img2d = (img2d * 255.0).astype(np.float32)

    instances = automatic_instance_segmentation(
        predictor=predictor,
        segmenter=segmenter,
        input_path=img2d,
        ndim=2,
        tile_shape=tile_shape,
        halo=halo,
        verbose=verbose,
    )

    return instances


In [10]:
def segment_volume_slices(predictor, segmenter, vol, tile_shape=None, halo=None, verbose=False):
    """
    vol: numpy array (Z, Y, X)
    returns: instances (Z, Y, X) with per-slice instance labels.
    """
    Z, Y, X = vol.shape
    instances_vol = np.zeros((Z, Y, X), dtype=np.int32)

    for z in range(Z):
        if verbose:
            print(f"Segmenting slice {z+1}/{Z}...")
        img2d = vol[z]
        seg2d = segment_slice_2d(
            predictor,
            segmenter,
            img2d,
            tile_shape=tile_shape,
            halo=halo,
            verbose=False,
        )
        instances_vol[z] = seg2d.astype(np.int32)

    return instances_vol


In [11]:
def compute_binary_metrics(gt, pred):
    """Compute Dice and IoU for binary masks (numpy arrays of 0/1)."""
    gt = gt.astype(bool)
    pred = pred.astype(bool)

    intersection = np.logical_and(gt, pred).sum()
    union = np.logical_or(gt, pred).sum()
    gt_sum = gt.sum()
    pred_sum = pred.sum()

    # Dice
    denom = gt_sum + pred_sum
    dice = 2.0 * intersection / denom if denom > 0 else np.nan

    # IoU
    iou = intersection / union if union > 0 else np.nan

    return dice, iou




In [12]:
# ---------------------------------------------------------------------
# EVALUATE ON HELD-OUT TEST SET AND SAVE PREDICTIONS
# ---------------------------------------------------------------------
results = []  # per-volume metrics

eval_tile_shape = None  # fine for 96x96 patches
eval_halo = None

for raw_path, label_path in test_pairs:
    raw_path = Path(raw_path)
    label_path = Path(label_path)

    cls = cls_by_raw[raw_path]

    print(f"Evaluating {raw_path.name} (class={cls})")

    # load full volumes
    raw_vol = nib.load(str(raw_path)).get_fdata().astype(np.float32)
    lab_vol = nib.load(str(label_path)).get_fdata().astype(np.float32)

    # run AIS slice-wise over the full volume
    instances = segment_volume_slices(
        predictor,
        segmenter,
        raw_vol,
        tile_shape=eval_tile_shape,
        halo=eval_halo,
        verbose=False,
    )

    # --- SAVE PREDICTION ---
    cls_dir = PRED_ROOT / cls
    cls_dir.mkdir(parents=True, exist_ok=True)

    pred_fname = raw_path.name.replace(".nii.gz", "_microsam_ais2dstack.nii.gz")
    pred_path = cls_dir / pred_fname

    pred_nii = nib.Nifti1Image(instances.astype(np.int32), affine=nib.load(str(raw_path)).affine)
    nib.save(pred_nii, pred_path)
    print(f"  Saved prediction to: {pred_path}")

    # --- COMPUTE SLICE-WISE DICE/IOU (BINARY FOREGROUND) ---
    Z = raw_vol.shape[0]
    dice_list = []
    iou_list = []

    for z in range(Z):
        gt_slice = lab_vol[z]
        if gt_slice.max() == 0:
            # skip slices without any foreground in GT
            continue

        pred_slice = instances[z]

        gt_bin = (gt_slice > 0).astype(np.uint8)
        pred_bin = (pred_slice > 0).astype(np.uint8)

        dice, iou = compute_binary_metrics(gt_bin, pred_bin)
        if not np.isnan(dice):
            dice_list.append(dice)
        if not np.isnan(iou):
            iou_list.append(iou)

    if len(dice_list) == 0:
        mean_dice = np.nan
        mean_iou = np.nan
    else:
        mean_dice = float(np.mean(dice_list))
        mean_iou = float(np.mean(iou_list))

    results.append(
        {
            "class": cls,
            "raw_path": str(raw_path),
            "label_path": str(label_path),
            "pred_path": str(pred_path),
            "mean_dice": mean_dice,
            "mean_iou": mean_iou,
            "n_slices_with_fg": len(dice_list),
        }
    )

# ---------------------------------------------------------------------
# Aggregate into a table: per-volume + per-class + overall
# ---------------------------------------------------------------------
df = pd.DataFrame(results)
display(df)

per_class = (
    df.groupby("class")[["mean_dice", "mean_iou"]]
    .mean()
    .sort_values("mean_dice", ascending=False)
)
print("\n=== Per-class mean metrics on held-out test set ===")
display(per_class)

overall = df[["mean_dice", "mean_iou"]].mean()
print("\n=== Overall macro-average metrics over all test volumes ===")
print(overall)


Evaluating patch_016_vol008_ch0.nii.gz (class=vessels_patches)
  Saved prediction to: /midtier/paetzollab/scratch/ads4015/compare_methods/micro_sam/finetuned_preds/vessels_patches/patch_016_vol008_ch0_microsam_ais2dstack.nii.gz
Evaluating patch_010_vol000_ch0.nii.gz (class=cell_nucleus_patches)
  Saved prediction to: /midtier/paetzollab/scratch/ads4015/compare_methods/micro_sam/finetuned_preds/cell_nucleus_patches/patch_010_vol000_ch0_microsam_ais2dstack.nii.gz
Evaluating patch_013_vol001_ch0.nii.gz (class=cell_nucleus_patches)
  Saved prediction to: /midtier/paetzollab/scratch/ads4015/compare_methods/micro_sam/finetuned_preds/cell_nucleus_patches/patch_013_vol001_ch0_microsam_ais2dstack.nii.gz
Evaluating patch_011_vol019_ch0.nii.gz (class=amyloid_plaque_patches)
  Saved prediction to: /midtier/paetzollab/scratch/ads4015/compare_methods/micro_sam/finetuned_preds/amyloid_plaque_patches/patch_011_vol019_ch0_microsam_ais2dstack.nii.gz
Evaluating patch_000_vol009_ch0.nii.gz (class=c_fos_po

Unnamed: 0,class,raw_path,label_path,pred_path,mean_dice,mean_iou,n_slices_with_fg
0,vessels_patches,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/compare_me...,0.812819,0.686667,96
1,cell_nucleus_patches,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/compare_me...,0.060692,0.034644,96
2,cell_nucleus_patches,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/compare_me...,0.236245,0.140939,96
3,amyloid_plaque_patches,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/compare_me...,0.04,0.022222,5
4,c_fos_positive_patches,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/compare_me...,0.129663,0.077383,96
5,vessels_patches,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/compare_me...,0.564281,0.398104,96
6,c_fos_positive_patches,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/compare_me...,0.186658,0.109517,96
7,amyloid_plaque_patches,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/data_selma...,/midtier/paetzollab/scratch/ads4015/compare_me...,0.595893,0.494006,60



=== Per-class mean metrics on held-out test set ===


Unnamed: 0_level_0,mean_dice,mean_iou
class,Unnamed: 1_level_1,Unnamed: 2_level_1
vessels_patches,0.68855,0.542386
amyloid_plaque_patches,0.317947,0.258114
c_fos_positive_patches,0.15816,0.09345
cell_nucleus_patches,0.148468,0.087792



=== Overall macro-average metrics over all test volumes ===
mean_dice    0.328281
mean_iou     0.245435
dtype: float64


In [13]:
# test_path = str(
#     "/midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches/vessels_patches/patch_001_vol021_ch1.nii.gz"
# )

# nii = nib.load(test_path)
# vol = nii.get_fdata().astype(np.float32)  # (Z, Y, X)

# instances = segment_volume_slices(
#     predictor,
#     segmenter,
#     vol,
#     tile_shape=None,  # no tiling needed for 96^3 patches
#     halo=None,
#     verbose=True,
# )

# print("Instances shape:", instances.shape)
# print("Unique labels (first 10):", np.unique(instances)[:10])

# out_path = test_path.replace(".nii.gz", "_microsam_ais2dstack.nii.gz")
# out_nii = nib.Nifti1Image(instances.astype(np.int32), affine=nii.affine, header=nii.header)
# nib.save(out_nii, out_path)
# print("Saved:", out_path)
