In [1]:
import os
from pathlib import Path

import numpy as np
import nibabel as nib
from skimage.measure import label as connected_components

import torch
from torch.utils.data import Dataset, DataLoader

import scipy.ndimage as ndi

import micro_sam.training as sam_training
from micro_sam.automatic_segmentation import (
    get_predictor_and_segmenter,
    automatic_instance_segmentation,
)

# Root directory with your finetune patches
ROOT = Path("/midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches")
print("ROOT exists:", ROOT.exists())


  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


ROOT exists: True


In [2]:
def compute_distance_to_center(inst_mask: np.ndarray) -> np.ndarray:
    """
    For each instance ID > 0, compute distance to its center of mass.
    Normalize to [0,1] per instance and invert so that center=1, edges=0.
    Outside objects: 0.
    """
    H, W = inst_mask.shape
    out = np.zeros((H, W), dtype=np.float32)

    instance_ids = np.unique(inst_mask)
    instance_ids = instance_ids[instance_ids != 0]

    yy, xx = np.indices((H, W), dtype=np.float32)

    for iid in instance_ids:
        mask = (inst_mask == iid)
        if not np.any(mask):
            continue

        cy, cx = ndi.center_of_mass(mask.astype(np.float32))
        dist = np.sqrt((yy - cy) ** 2 + (xx - cx) ** 2)

        maxd = dist[mask].max() + 1e-6
        dist_norm = dist / maxd  # 0..1 inside object
        out[mask] = 1.0 - dist_norm[mask]  # 1 at center, 0 at furthest

    return out


def compute_distance_to_boundary(inst_mask: np.ndarray) -> np.ndarray:
    """
    For each instance ID > 0, compute distance transform inside the object.
    Normalize to [0,1] per instance. Outside objects: 0.
    """
    H, W = inst_mask.shape
    out = np.zeros((H, W), dtype=np.float32)

    instance_ids = np.unique(inst_mask)
    instance_ids = instance_ids[instance_ids != 0]

    for iid in instance_ids:
        mask = (inst_mask == iid)
        if not np.any(mask):
            continue

        dist = ndi.distance_transform_edt(mask)
        maxd = dist[mask].max() + 1e-6
        dist_norm = dist / maxd  # 0..1 inside object
        out[mask] = dist_norm[mask]

    return out


In [3]:
class Selma2DSliceDataset(Dataset):
    def __init__(self, raw_label_pairs):
        """
        raw_label_pairs: list of (raw_path, label_path)
        """
        self.pairs = raw_label_pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        raw_path, label_path = self.pairs[idx]

        # Load volumes
        raw_vol = nib.load(str(raw_path)).get_fdata().astype(np.float32)
        lab_vol = nib.load(str(label_path)).get_fdata().astype(np.float32)

        assert raw_vol.shape == lab_vol.shape, f"shape mismatch: {raw_vol.shape} vs {lab_vol.shape}"
        Z = raw_vol.shape[0]

        # --- find a slice that contains at least one object ---
        for _ in range(20):
            z = np.random.randint(0, Z)
            lab_slice = lab_vol[z]
            if lab_slice.max() > 0:
                break
        else:
            # fallback: deterministic scan
            found = False
            for z in range(Z):
                lab_slice = lab_vol[z]
                if lab_slice.max() > 0:
                    found = True
                    break
            if not found:
                # no instances in this volume at all; pick a different pair
                return self.__getitem__((idx + 1) % len(self.pairs))

        img_slice = raw_vol[z]
        lab_slice = lab_vol[z]

        # --- instance mask (connected components) ---
        bin_mask = (lab_slice > 0).astype(np.uint8)
        inst_mask = connected_components(bin_mask, connectivity=1).astype(np.int32)

        if inst_mask.max() == 0:
            # no instance after CC, rare corner case
            return self.__getitem__((idx + 1) % len(self.pairs))

        # --- robust image normalization to [0,255] for SAM ---
        p1 = np.percentile(img_slice, 1)
        p99 = np.percentile(img_slice, 99)
        if p99 > p1:
            img_slice_norm = (img_slice - p1) / (p99 - p1)
        else:
            img_slice_norm = np.zeros_like(img_slice, dtype=np.float32)

        img_slice_norm = np.clip(img_slice_norm, 0.0, 1.0)
        img_slice_uint = (img_slice_norm * 255.0).astype(np.float32)

        # --- label channels ---
        # channel 0: instance ids
        instance_channel = inst_mask.astype(np.float32)

        # foreground mask
        fg = (inst_mask > 0).astype(np.float32)

        # distances
        dist_center = compute_distance_to_center(inst_mask)
        dist_boundary = compute_distance_to_boundary(inst_mask)

        # mask distances and clamp to [0,1]
        dist_center = np.clip(dist_center * fg, 0.0, 1.0)
        dist_boundary = np.clip(dist_boundary * fg, 0.0, 1.0)

        # assemble y: (4, H, W)
        # IMPORTANT: channel order to match microSAM expectations:
        #   y[0] = instance ids (integer-valued)
        #   y[1] = foreground (0..1)
        #   y[2] = center distance (0..1)
        #   y[3] = boundary distance (0..1)
        y_np = np.stack(
            [instance_channel, fg, dist_center, dist_boundary],
            axis=0
        ).astype(np.float32)

        # to torch
        x = torch.from_numpy(img_slice_uint[None, ...])  # (1,H,W), float32 in [0,255]
        y = torch.from_numpy(y_np)                        # (4,H,W)

        return x, y


In [4]:
pairs = []

for class_dir in ROOT.iterdir():
    if not class_dir.is_dir():
        continue

    # use ch0 as the image (for vessels: patch_*_ch0.nii.gz)
    for raw_path in sorted(class_dir.glob("*_ch0.nii.gz")):
        label_path = raw_path.with_name(raw_path.name.replace(".nii.gz", "_label.nii.gz"))
        if label_path.exists():
            pairs.append((raw_path, label_path))

print("Found", len(pairs), "raw/label volumes")


Found 68 raw/label volumes


In [5]:
# simple random split
rng = np.random.default_rng(42)
perm = rng.permutation(len(pairs))
pairs = [pairs[i] for i in perm]

n_train = int(0.9 * len(pairs))
train_pairs = pairs[:n_train]
val_pairs = pairs[n_train:]

train_dataset = Selma2DSliceDataset(train_pairs)
val_dataset = Selma2DSliceDataset(val_pairs)

batch_size = 1  # microSAM typically uses small batch sizes

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

# microSAM's JointSamTrainer expects these attributes
train_loader.shuffle = True
val_loader.shuffle = False

print("Train samples:", len(train_dataset), "Val samples:", len(val_dataset))


Train samples: 61 Val samples: 7


In [6]:
n_epochs = 10
n_objects_per_batch = 5
train_instance_segmentation = True
model_type = "vit_b"  # must match when loading later

checkpoint_name = "selma3d_microsam_ais"
root_dir = os.getcwd()

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

sam_training.train_sam(
    name=checkpoint_name,
    save_root=os.path.join(root_dir, "models"),
    model_type=model_type,
    train_loader=train_loader,
    val_loader=val_loader,
    n_epochs=n_epochs,
    n_objects_per_batch=n_objects_per_batch,
    with_segmentation_decoder=train_instance_segmentation,
    device=device,
)


Using device: cuda


Verifying labels in 'train' dataloader: 100%|██████████| 50/50 [00:00<00:00, 54.26it/s]
Verifying labels in 'val' dataloader:  14%|█▍        | 7/50 [00:00<00:01, 38.15it/s]


Start fitting for 610 iterations /  10 epochs
with 61 iterations per epoch
Training with mixed precision


Epoch 10: average [s/it]: 0.441431, current metric: 0.431691, best metric: 0.258571: 100%|█████████▉| 609/610 [05:23<00:00,  1.88it/s]

Finished training after 10 epochs / 610 iterations.
The best epoch is number 7.
Training took 326.2011938095093 seconds (= 00:05:26 hours)





In [7]:
checkpoint_name = "selma3d_microsam_ais"
best_checkpoint = os.path.join(
    root_dir,
    "models",
    "checkpoints",
    checkpoint_name,
    "best.pt",
)
print("Best checkpoint:", best_checkpoint, "exists:", os.path.exists(best_checkpoint))

predictor, segmenter = get_predictor_and_segmenter(
    model_type=model_type,
    checkpoint=best_checkpoint,
    device=device,
    is_tiled=False,  # we will handle tiling for large images ourselves
)


Best checkpoint: /home/ads4015/ssl_project/models/checkpoints/selma3d_microsam_ais/best.pt exists: True


In [8]:
def segment_slice_2d(predictor, segmenter, img2d, tile_shape=None, halo=None, verbose=False):
    """Run AIS on a single 2D slice."""
    img2d = img2d.astype(np.float32)

    # normalize to [0,255] similarly to training
    p1 = np.percentile(img2d, 1)
    p99 = np.percentile(img2d, 99)
    if p99 > p1:
        img2d = (img2d - p1) / (p99 - p1)
    else:
        img2d = np.zeros_like(img2d)
    img2d = np.clip(img2d, 0.0, 1.0)
    img2d = (img2d * 255.0).astype(np.float32)

    instances = automatic_instance_segmentation(
        predictor=predictor,
        segmenter=segmenter,
        input_path=img2d,
        ndim=2,
        tile_shape=tile_shape,
        halo=halo,
        verbose=verbose,
    )

    return instances


In [9]:
def segment_volume_slices(predictor, segmenter, vol, tile_shape=None, halo=None, verbose=False):
    """
    vol: numpy array (Z, Y, X)
    returns: instances (Z, Y, X) with per-slice instance labels.
    """
    Z, Y, X = vol.shape
    instances_vol = np.zeros((Z, Y, X), dtype=np.int32)

    for z in range(Z):
        if verbose:
            print(f"Segmenting slice {z+1}/{Z}...")
        img2d = vol[z]
        seg2d = segment_slice_2d(
            predictor,
            segmenter,
            img2d,
            tile_shape=tile_shape,
            halo=halo,
            verbose=False,
        )
        instances_vol[z] = seg2d.astype(np.int32)

    return instances_vol


In [10]:
test_path = str(
    "/midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches/vessels_patches/patch_001_vol021_ch1.nii.gz"
)

nii = nib.load(test_path)
vol = nii.get_fdata().astype(np.float32)  # (Z, Y, X)

instances = segment_volume_slices(
    predictor,
    segmenter,
    vol,
    tile_shape=None,  # no tiling needed for 96^3 patches
    halo=None,
    verbose=True,
)

print("Instances shape:", instances.shape)
print("Unique labels (first 10):", np.unique(instances)[:10])

out_path = test_path.replace(".nii.gz", "_microsam_ais2dstack.nii.gz")
out_nii = nib.Nifti1Image(instances.astype(np.int32), affine=nii.affine, header=nii.header)
nib.save(out_nii, out_path)
print("Saved:", out_path)


Segmenting slice 1/96...
Segmenting slice 2/96...
Segmenting slice 3/96...
Segmenting slice 4/96...
Segmenting slice 5/96...
Segmenting slice 6/96...
Segmenting slice 7/96...
Segmenting slice 8/96...
Segmenting slice 9/96...
Segmenting slice 10/96...
Segmenting slice 11/96...
Segmenting slice 12/96...
Segmenting slice 13/96...
Segmenting slice 14/96...
Segmenting slice 15/96...
Segmenting slice 16/96...
Segmenting slice 17/96...
Segmenting slice 18/96...
Segmenting slice 19/96...
Segmenting slice 20/96...
Segmenting slice 21/96...
Segmenting slice 22/96...
Segmenting slice 23/96...
Segmenting slice 24/96...
Segmenting slice 25/96...
Segmenting slice 26/96...
Segmenting slice 27/96...
Segmenting slice 28/96...
Segmenting slice 29/96...
Segmenting slice 30/96...
Segmenting slice 31/96...
Segmenting slice 32/96...
Segmenting slice 33/96...
Segmenting slice 34/96...
Segmenting slice 35/96...
Segmenting slice 36/96...
Segmenting slice 37/96...
Segmenting slice 38/96...
Segmenting slice 39/9