forked from constantinpape/torch-em
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CREMI experiments using ViMUNet (#12)
CREMI experiments for benchmarking instance segmentation using UNet, UNETR and ViMUNet
- Loading branch information
Showing
5 changed files
with
698 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,338 @@ | ||
import os | ||
import argparse | ||
import numpy as np | ||
import pandas as pd | ||
from glob import glob | ||
from tqdm import tqdm | ||
|
||
import imageio.v3 as imageio | ||
|
||
import torch | ||
|
||
import torch_em | ||
from torch_em.util import segmentation | ||
from torch_em.data import MinInstanceSampler | ||
from torch_em.model import get_vimunet_model | ||
from torch_em.data.datasets import get_cremi_loader | ||
from torch_em.util.prediction import predict_with_halo | ||
from torch_em.loss import DiceLoss, LossWrapper, ApplyAndRemoveMask, DiceBasedDistanceLoss | ||
|
||
|
||
import elf.segmentation.multicut as mc | ||
import elf.segmentation.watershed as ws | ||
import elf.segmentation.features as feats | ||
from elf.evaluation import mean_segmentation_accuracy | ||
|
||
|
||
ROOT = "/scratch/usr/nimanwai" | ||
CREMI_TEST_ROOT = "/scratch/projects/nim00007/sam/data/cremi/slices_original" | ||
|
||
OFFSETS = [ | ||
[-1, 0], [0, -1], | ||
[-3, 0], [0, -3], | ||
[-9, 0], [0, -9], | ||
[-27, 0], [0, -27] | ||
] | ||
|
||
|
||
def get_loaders(args, patch_shape=(1, 512, 512)): | ||
if args.distances: | ||
label_trafo = torch_em.transform.label.PerObjectDistanceTransform( | ||
distances=True, boundary_distances=True, directed_distances=False, foreground=True | ||
) | ||
else: | ||
label_trafo = None | ||
|
||
train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]} | ||
val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]} | ||
|
||
sampler = MinInstanceSampler() | ||
|
||
train_loader = get_cremi_loader( | ||
path=args.input, | ||
patch_shape=patch_shape, | ||
batch_size=2, | ||
label_transform=label_trafo, | ||
rois=train_rois, | ||
sampler=sampler, | ||
ndim=2, | ||
label_dtype=torch.float32, | ||
defect_augmentation_kwargs=None, | ||
boundaries=args.boundaries, | ||
offsets=OFFSETS if args.affinities else None, | ||
num_workers=16, | ||
) | ||
val_loader = get_cremi_loader( | ||
path=args.input, | ||
patch_shape=patch_shape, | ||
batch_size=1, | ||
label_transform=label_trafo, | ||
rois=val_rois, | ||
sampler=sampler, | ||
ndim=2, | ||
label_dtype=torch.float32, | ||
defect_augmentation_kwargs=None, | ||
boundaries=args.boundaries, | ||
offsets=OFFSETS if args.affinities else None, | ||
num_workers=16, | ||
) | ||
|
||
return train_loader, val_loader | ||
|
||
|
||
def get_output_channels(args): | ||
if args.boundaries: | ||
output_channels = 1 | ||
elif args.distances: | ||
output_channels = 3 | ||
elif args.affinities: | ||
output_channels = len(OFFSETS) | ||
|
||
return output_channels | ||
|
||
|
||
def get_loss_function(args): | ||
if args.affinities: | ||
loss = LossWrapper( | ||
loss=DiceLoss(), | ||
transform=ApplyAndRemoveMask(masking_method="multiply") | ||
) | ||
elif args.distances: | ||
loss = DiceBasedDistanceLoss(mask_distances_in_bg=True) | ||
|
||
else: | ||
loss = DiceLoss() | ||
|
||
return loss | ||
|
||
|
||
def get_save_root(args): | ||
# experiment_type | ||
if args.boundaries: | ||
experiment_type = "boundaries" | ||
elif args.affinities: | ||
experiment_type = "affinities" | ||
elif args.distances: | ||
experiment_type = "distances" | ||
else: | ||
raise ValueError | ||
|
||
model_name = args.model_type | ||
|
||
# saving the model checkpoints | ||
save_root = os.path.join( | ||
args.save_root, "pretrained" if args.pretrained else "scratch", experiment_type, model_name | ||
) | ||
return save_root | ||
|
||
|
||
def run_cremi_training(args): | ||
# the dataloaders for cremi dataset | ||
train_loader, val_loader = get_loaders(args) | ||
|
||
if args.pretrained: | ||
checkpoint = "/scratch/usr/nimanwai/models/Vim-tiny/vim_tiny_73p1.pth" | ||
else: | ||
checkpoint = None | ||
|
||
output_channels = get_output_channels(args) | ||
|
||
# the vision-mamba + decoder (UNet-based) model | ||
model = get_vimunet_model( | ||
out_channels=output_channels, | ||
model_type=args.model_type, | ||
checkpoint=checkpoint, | ||
with_cls_token=True | ||
) | ||
|
||
save_root = get_save_root(args) | ||
|
||
# loss function | ||
loss = get_loss_function(args) | ||
|
||
# trainer for the segmentation task | ||
trainer = torch_em.default_segmentation_trainer( | ||
name="cremi-vimunet", | ||
model=model, | ||
train_loader=train_loader, | ||
val_loader=val_loader, | ||
learning_rate=1e-4, | ||
loss=loss, | ||
metric=loss, | ||
log_image_interval=50, | ||
save_root=save_root, | ||
compile_model=False | ||
) | ||
trainer.fit(iterations=args.iterations) | ||
|
||
|
||
def _do_bd_multicut_watershed(bd): | ||
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0) | ||
|
||
# compute the region adjacency graph | ||
rag = feats.compute_rag(ws_seg) | ||
|
||
# compute the edge costs | ||
costs = feats.compute_boundary_features(rag, bd)[:, 0] | ||
|
||
# transform the edge costs from [0, 1] to [-inf, inf], which is | ||
# necessary for the multicut. This is done by intepreting the values | ||
# as probabilities for an edge being 'true' and then taking the negative log-likelihood. | ||
|
||
# in addition, we weight the costs by the size of the corresponding edge | ||
# for z and xy edges | ||
z_edges = feats.compute_z_edge_mask(rag, ws_seg) | ||
xy_edges = np.logical_not(z_edges) | ||
edge_populations = [z_edges, xy_edges] | ||
edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1] | ||
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations) | ||
|
||
# run the multicut partitioning, here, we use the kernighan lin | ||
# heuristics to solve the problem, introduced in | ||
# http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf | ||
node_labels = mc.multicut_kernighan_lin(rag, costs) | ||
|
||
# map the results back to pixels to obtain the final segmentation | ||
seg = feats.project_node_labels_to_pixels(rag, node_labels) | ||
|
||
return seg | ||
|
||
|
||
def _do_affs_multicut_watershed(affs, offsets): | ||
# first, we have to make a single channel input map for the watershed, | ||
# which we obtain by averaging the affinities | ||
boundary_input = np.mean(affs, axis=0) | ||
|
||
ws_seg, max_id = ws.distance_transform_watershed(boundary_input, threshold=0.25, sigma_seeds=2.0) | ||
|
||
# compute the region adjacency graph | ||
rag = feats.compute_rag(ws_seg) | ||
|
||
# compute the edge costs | ||
# the offsets encode the pixel transition encoded by the | ||
# individual affinity channels. Here, we only have nearest neighbor transitions | ||
costs = feats.compute_affinity_features(rag, affs, offsets)[:, 0] | ||
|
||
# transform the edge costs from [0, 1] to [-inf, inf], which is | ||
# necessary for the multicut. This is done by intepreting the values | ||
# as probabilities for an edge being 'true' and then taking the negative log-likelihood. | ||
# in addition, we weight the costs by the size of the corresponding edge | ||
edge_sizes = feats.compute_boundary_mean_and_length(rag, boundary_input)[:, 1] | ||
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes) | ||
|
||
# run the multicut partitioning, here, we use the kernighan lin | ||
# heuristics to solve the problem, introduced in | ||
# http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf | ||
node_labels = mc.multicut_kernighan_lin(rag, costs) | ||
|
||
# map the results back to pixels to obtain the final segmentation | ||
seg = feats.project_node_labels_to_pixels(rag, node_labels) | ||
|
||
return seg | ||
|
||
|
||
def run_cremi_inference(args, device): | ||
output_channels = get_output_channels(args) | ||
|
||
save_root = get_save_root(args) | ||
|
||
checkpoint = os.path.join(save_root, "checkpoints", "cremi-vimunet", "best.pt") | ||
|
||
# the vision-mamba + decoder (UNet-based) model | ||
model = get_vimunet_model( | ||
out_channels=output_channels, | ||
model_type=args.model_type, | ||
with_cls_token=True, | ||
checkpoint=checkpoint | ||
) | ||
|
||
all_test_images = glob(os.path.join(CREMI_TEST_ROOT, "raw", "cremi_test_*.tif")) | ||
all_test_labels = glob(os.path.join(CREMI_TEST_ROOT, "labels", "cremi_test_*.tif")) | ||
|
||
res_path = os.path.join(save_root, "results.csv") | ||
if os.path.exists(res_path) and not args.force: | ||
print(pd.read_csv(res_path)) | ||
print(f"The result is saved at {res_path}") | ||
return | ||
|
||
msa_list, sa50_list, sa75_list = [], [], [] | ||
for image_path, label_path in tqdm(zip(all_test_images, all_test_labels), total=len(all_test_images)): | ||
image = imageio.imread(image_path) | ||
labels = imageio.imread(label_path) | ||
|
||
predictions = predict_with_halo( | ||
image, model, [device], block_shape=[512, 512], halo=[128, 128], disable_tqdm=True, | ||
) | ||
|
||
if args.boundaries: | ||
bd = predictions.squeeze() | ||
|
||
# instances = segmentation.watershed_from_components(bd, np.ones_like(bd)) | ||
instances = _do_bd_multicut_watershed(bd) | ||
|
||
elif args.affinities: | ||
affs = predictions | ||
|
||
# instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS) | ||
instances = _do_affs_multicut_watershed(affs[:2], OFFSETS[:2]) | ||
|
||
elif args.distances: | ||
fg, cdist, bdist = predictions | ||
instances = segmentation.watershed_from_center_and_boundary_distances( | ||
cdist, bdist, fg, min_size=50, | ||
center_distance_threshold=0.5, | ||
boundary_distance_threshold=0.6, | ||
distance_smoothing=1.0 | ||
) | ||
|
||
msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True) | ||
msa_list.append(msa) | ||
sa50_list.append(sa_acc[0]) | ||
sa75_list.append(sa_acc[5]) | ||
|
||
res = { | ||
"LiveCELL": "Metrics", | ||
"mSA": np.mean(msa_list), | ||
"SA50": np.mean(sa50_list), | ||
"SA75": np.mean(sa75_list) | ||
} | ||
df = pd.DataFrame.from_dict([res]) | ||
df.to_csv(res_path) | ||
print(df) | ||
print(f"The result is saved at {res_path}") | ||
|
||
|
||
def main(args): | ||
assert (args.boundaries + args.affinities + args.distances) == 1 | ||
|
||
print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU") | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
if args.train: | ||
run_cremi_training(args) | ||
|
||
if args.predict: | ||
run_cremi_inference(args, device) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-i", "--input", type=str, default=os.path.join(ROOT, "data", "cremi")) | ||
parser.add_argument("--iterations", type=int, default=int(1e5)) | ||
parser.add_argument("-s", "--save_root", type=str, default=os.path.join(ROOT, "experiments", "vimunet")) | ||
parser.add_argument("-m", "--model_type", type=str, default="vim_t") | ||
|
||
parser.add_argument("--pretrained", action="store_true") | ||
|
||
parser.add_argument("--train", action="store_true") | ||
parser.add_argument("--predict", action="store_true") | ||
|
||
parser.add_argument("--force", action="store_true") | ||
|
||
parser.add_argument("--boundaries", action="store_true") | ||
parser.add_argument("--affinities", action="store_true") | ||
parser.add_argument("--distances", action="store_true") | ||
|
||
args = parser.parse_args() | ||
main(args) |
File renamed without changes.
Oops, something went wrong.