Skip to content

Commit

Permalink
Add CREMI experiments using ViMUNet (#12)
Browse files Browse the repository at this point in the history
CREMI experiments for benchmarking instance segmentation using UNet, UNETR and ViMUNet
  • Loading branch information
anwai98 committed Mar 30, 2024
1 parent 55fba22 commit 73f7021
Show file tree
Hide file tree
Showing 5 changed files with 698 additions and 83 deletions.
338 changes: 338 additions & 0 deletions experiments/vision-mamba/cremi/run_cremi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
import os
import argparse
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm

import imageio.v3 as imageio

import torch

import torch_em
from torch_em.util import segmentation
from torch_em.data import MinInstanceSampler
from torch_em.model import get_vimunet_model
from torch_em.data.datasets import get_cremi_loader
from torch_em.util.prediction import predict_with_halo
from torch_em.loss import DiceLoss, LossWrapper, ApplyAndRemoveMask, DiceBasedDistanceLoss


import elf.segmentation.multicut as mc
import elf.segmentation.watershed as ws
import elf.segmentation.features as feats
from elf.evaluation import mean_segmentation_accuracy


ROOT = "/scratch/usr/nimanwai"
CREMI_TEST_ROOT = "/scratch/projects/nim00007/sam/data/cremi/slices_original"

OFFSETS = [
[-1, 0], [0, -1],
[-3, 0], [0, -3],
[-9, 0], [0, -9],
[-27, 0], [0, -27]
]


def get_loaders(args, patch_shape=(1, 512, 512)):
if args.distances:
label_trafo = torch_em.transform.label.PerObjectDistanceTransform(
distances=True, boundary_distances=True, directed_distances=False, foreground=True
)
else:
label_trafo = None

train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]}
val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]}

sampler = MinInstanceSampler()

train_loader = get_cremi_loader(
path=args.input,
patch_shape=patch_shape,
batch_size=2,
label_transform=label_trafo,
rois=train_rois,
sampler=sampler,
ndim=2,
label_dtype=torch.float32,
defect_augmentation_kwargs=None,
boundaries=args.boundaries,
offsets=OFFSETS if args.affinities else None,
num_workers=16,
)
val_loader = get_cremi_loader(
path=args.input,
patch_shape=patch_shape,
batch_size=1,
label_transform=label_trafo,
rois=val_rois,
sampler=sampler,
ndim=2,
label_dtype=torch.float32,
defect_augmentation_kwargs=None,
boundaries=args.boundaries,
offsets=OFFSETS if args.affinities else None,
num_workers=16,
)

return train_loader, val_loader


def get_output_channels(args):
if args.boundaries:
output_channels = 1
elif args.distances:
output_channels = 3
elif args.affinities:
output_channels = len(OFFSETS)

return output_channels


def get_loss_function(args):
if args.affinities:
loss = LossWrapper(
loss=DiceLoss(),
transform=ApplyAndRemoveMask(masking_method="multiply")
)
elif args.distances:
loss = DiceBasedDistanceLoss(mask_distances_in_bg=True)

else:
loss = DiceLoss()

return loss


def get_save_root(args):
# experiment_type
if args.boundaries:
experiment_type = "boundaries"
elif args.affinities:
experiment_type = "affinities"
elif args.distances:
experiment_type = "distances"
else:
raise ValueError

model_name = args.model_type

# saving the model checkpoints
save_root = os.path.join(
args.save_root, "pretrained" if args.pretrained else "scratch", experiment_type, model_name
)
return save_root


def run_cremi_training(args):
# the dataloaders for cremi dataset
train_loader, val_loader = get_loaders(args)

if args.pretrained:
checkpoint = "/scratch/usr/nimanwai/models/Vim-tiny/vim_tiny_73p1.pth"
else:
checkpoint = None

output_channels = get_output_channels(args)

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=output_channels,
model_type=args.model_type,
checkpoint=checkpoint,
with_cls_token=True
)

save_root = get_save_root(args)

# loss function
loss = get_loss_function(args)

# trainer for the segmentation task
trainer = torch_em.default_segmentation_trainer(
name="cremi-vimunet",
model=model,
train_loader=train_loader,
val_loader=val_loader,
learning_rate=1e-4,
loss=loss,
metric=loss,
log_image_interval=50,
save_root=save_root,
compile_model=False
)
trainer.fit(iterations=args.iterations)


def _do_bd_multicut_watershed(bd):
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)

# compute the edge costs
costs = feats.compute_boundary_features(rag, bd)[:, 0]

# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.

# in addition, we weight the costs by the size of the corresponding edge
# for z and xy edges
z_edges = feats.compute_z_edge_mask(rag, ws_seg)
xy_edges = np.logical_not(z_edges)
edge_populations = [z_edges, xy_edges]
edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
# http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf
node_labels = mc.multicut_kernighan_lin(rag, costs)

# map the results back to pixels to obtain the final segmentation
seg = feats.project_node_labels_to_pixels(rag, node_labels)

return seg


def _do_affs_multicut_watershed(affs, offsets):
# first, we have to make a single channel input map for the watershed,
# which we obtain by averaging the affinities
boundary_input = np.mean(affs, axis=0)

ws_seg, max_id = ws.distance_transform_watershed(boundary_input, threshold=0.25, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)

# compute the edge costs
# the offsets encode the pixel transition encoded by the
# individual affinity channels. Here, we only have nearest neighbor transitions
costs = feats.compute_affinity_features(rag, affs, offsets)[:, 0]

# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.
# in addition, we weight the costs by the size of the corresponding edge
edge_sizes = feats.compute_boundary_mean_and_length(rag, boundary_input)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
# http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf
node_labels = mc.multicut_kernighan_lin(rag, costs)

# map the results back to pixels to obtain the final segmentation
seg = feats.project_node_labels_to_pixels(rag, node_labels)

return seg


def run_cremi_inference(args, device):
output_channels = get_output_channels(args)

save_root = get_save_root(args)

checkpoint = os.path.join(save_root, "checkpoints", "cremi-vimunet", "best.pt")

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=output_channels,
model_type=args.model_type,
with_cls_token=True,
checkpoint=checkpoint
)

all_test_images = glob(os.path.join(CREMI_TEST_ROOT, "raw", "cremi_test_*.tif"))
all_test_labels = glob(os.path.join(CREMI_TEST_ROOT, "labels", "cremi_test_*.tif"))

res_path = os.path.join(save_root, "results.csv")
if os.path.exists(res_path) and not args.force:
print(pd.read_csv(res_path))
print(f"The result is saved at {res_path}")
return

msa_list, sa50_list, sa75_list = [], [], []
for image_path, label_path in tqdm(zip(all_test_images, all_test_labels), total=len(all_test_images)):
image = imageio.imread(image_path)
labels = imageio.imread(label_path)

predictions = predict_with_halo(
image, model, [device], block_shape=[512, 512], halo=[128, 128], disable_tqdm=True,
)

if args.boundaries:
bd = predictions.squeeze()

# instances = segmentation.watershed_from_components(bd, np.ones_like(bd))
instances = _do_bd_multicut_watershed(bd)

elif args.affinities:
affs = predictions

# instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS)
instances = _do_affs_multicut_watershed(affs[:2], OFFSETS[:2])

elif args.distances:
fg, cdist, bdist = predictions
instances = segmentation.watershed_from_center_and_boundary_distances(
cdist, bdist, fg, min_size=50,
center_distance_threshold=0.5,
boundary_distance_threshold=0.6,
distance_smoothing=1.0
)

msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True)
msa_list.append(msa)
sa50_list.append(sa_acc[0])
sa75_list.append(sa_acc[5])

res = {
"LiveCELL": "Metrics",
"mSA": np.mean(msa_list),
"SA50": np.mean(sa50_list),
"SA75": np.mean(sa75_list)
}
df = pd.DataFrame.from_dict([res])
df.to_csv(res_path)
print(df)
print(f"The result is saved at {res_path}")


def main(args):
assert (args.boundaries + args.affinities + args.distances) == 1

print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU")

device = "cuda" if torch.cuda.is_available() else "cpu"

if args.train:
run_cremi_training(args)

if args.predict:
run_cremi_inference(args, device)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", type=str, default=os.path.join(ROOT, "data", "cremi"))
parser.add_argument("--iterations", type=int, default=int(1e5))
parser.add_argument("-s", "--save_root", type=str, default=os.path.join(ROOT, "experiments", "vimunet"))
parser.add_argument("-m", "--model_type", type=str, default="vim_t")

parser.add_argument("--pretrained", action="store_true")

parser.add_argument("--train", action="store_true")
parser.add_argument("--predict", action="store_true")

parser.add_argument("--force", action="store_true")

parser.add_argument("--boundaries", action="store_true")
parser.add_argument("--affinities", action="store_true")
parser.add_argument("--distances", action="store_true")

args = parser.parse_args()
main(args)
File renamed without changes.
Loading

0 comments on commit 73f7021

Please sign in to comment.