From f89dd01661f2caa49d551cb8bd047baa6c83a9ac Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 5 Apr 2024 00:00:06 +0200 Subject: [PATCH 1/4] Update LIVECell inference --- .../vision-mamba/livecell/run_livecell.py | 76 ++++++++++++++++++- .../livecell/run_livecell_for_vimunet.py | 19 +++-- 2 files changed, 82 insertions(+), 13 deletions(-) diff --git a/experiments/vision-mamba/livecell/run_livecell.py b/experiments/vision-mamba/livecell/run_livecell.py index b2fc01d2..fac4c348 100644 --- a/experiments/vision-mamba/livecell/run_livecell.py +++ b/experiments/vision-mamba/livecell/run_livecell.py @@ -16,6 +16,9 @@ from torch_em.data.datasets import get_livecell_loader from torch_em.loss import DiceLoss, LossWrapper, ApplyAndRemoveMask, DiceBasedDistanceLoss +import elf.segmentation.multicut as mc +import elf.segmentation.watershed as ws +import elf.segmentation.features as feats from elf.evaluation import mean_segmentation_accuracy @@ -155,6 +158,71 @@ def run_livecell_training(args): trainer.fit(iterations=int(args.iterations)) +def _do_bd_multicut_watershed(bd): + ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0) + + # compute the region adjacency graph + rag = feats.compute_rag(ws_seg) + + # compute the edge costs + costs = feats.compute_boundary_features(rag, bd)[:, 0] + + # transform the edge costs from [0, 1] to [-inf, inf], which is + # necessary for the multicut. This is done by intepreting the values + # as probabilities for an edge being 'true' and then taking the negative log-likelihood. + + # in addition, we weight the costs by the size of the corresponding edge + # for z and xy edges + z_edges = feats.compute_z_edge_mask(rag, ws_seg) + xy_edges = np.logical_not(z_edges) + edge_populations = [z_edges, xy_edges] + edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1] + costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations) + + # run the multicut partitioning, here, we use the kernighan lin + # heuristics to solve the problem, introduced in + # http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf + node_labels = mc.multicut_kernighan_lin(rag, costs) + + # map the results back to pixels to obtain the final segmentation + seg = feats.project_node_labels_to_pixels(rag, node_labels) + + return seg + + +def _do_affs_multicut_watershed(affs, offsets): + # first, we have to make a single channel input map for the watershed, + # which we obtain by averaging the affinities + boundary_input = np.mean(affs, axis=0) + + ws_seg, max_id = ws.distance_transform_watershed(boundary_input, threshold=0.25, sigma_seeds=2.0) + + # compute the region adjacency graph + rag = feats.compute_rag(ws_seg) + + # compute the edge costs + # the offsets encode the pixel transition encoded by the + # individual affinity channels. Here, we only have nearest neighbor transitions + costs = feats.compute_affinity_features(rag, affs, offsets)[:, 0] + + # transform the edge costs from [0, 1] to [-inf, inf], which is + # necessary for the multicut. This is done by intepreting the values + # as probabilities for an edge being 'true' and then taking the negative log-likelihood. + # in addition, we weight the costs by the size of the corresponding edge + edge_sizes = feats.compute_boundary_mean_and_length(rag, boundary_input)[:, 1] + costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes) + + # run the multicut partitioning, here, we use the kernighan lin + # heuristics to solve the problem, introduced in + # http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf + node_labels = mc.multicut_kernighan_lin(rag, costs) + + # map the results back to pixels to obtain the final segmentation + seg = feats.project_node_labels_to_pixels(rag, node_labels) + + return seg + + def run_livecell_inference(args, device): output_channels = get_output_channels(args) @@ -174,7 +242,7 @@ def run_livecell_inference(args, device): all_test_labels = glob(os.path.join(ROOT, "data", "livecell", "annotations", "livecell_test_images", "*", "*")) res_path = os.path.join(save_root, "results.csv") - if os.path.exists(res_path): + if os.path.exists(res_path) and not args.force: print(pd.read_csv(res_path)) print(f"The result is saved at {res_path}") return @@ -194,11 +262,13 @@ def run_livecell_inference(args, device): if args.boundaries: fg, bd = predictions - instances = segmentation.watershed_from_components(bd, fg) + # instances = segmentation.watershed_from_components(bd, fg) + instances = _do_bd_multicut_watershed(bd) elif args.affinities: fg, affs = predictions[0], predictions[1:] - instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS) + # instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS) + instances = _do_affs_multicut_watershed(affs[:2], OFFSETS[:2]) elif args.distances: fg, cdist, bdist = predictions diff --git a/experiments/vision-transformer/unetr/livecell/run_livecell_for_vimunet.py b/experiments/vision-transformer/unetr/livecell/run_livecell_for_vimunet.py index c5c1f860..61040afb 100644 --- a/experiments/vision-transformer/unetr/livecell/run_livecell_for_vimunet.py +++ b/experiments/vision-transformer/unetr/livecell/run_livecell_for_vimunet.py @@ -12,9 +12,9 @@ import torch_em from torch_em.util import segmentation from torch_em.model import UNETR, UNet2d -from torch_em.data import MinInstanceSampler +from torch_em.transform.raw import standardize from torch_em.data.datasets import get_livecell_loader -from torch_em.util.prediction import predict_with_halo +from torch_em.util.prediction import predict_with_padding from torch_em.loss import DiceLoss, LossWrapper, ApplyAndRemoveMask, DiceBasedDistanceLoss import elf.segmentation.multicut as mc @@ -245,7 +245,6 @@ def _do_affs_multicut_watershed(affs, offsets): def run_livecell_unetr_inference(args, device): - raise NotImplementedError save_root = get_save_root(args) checkpoint = os.path.join( @@ -266,7 +265,7 @@ def run_livecell_unetr_inference(args, device): all_test_labels = glob(os.path.join(ROOT, "data", "livecell", "annotations", "livecell_test_images", "*", "*")) res_path = os.path.join(save_root, "results.csv") - if os.path.exists(res_path): + if os.path.exists(res_path) and not args.force: print(pd.read_csv(res_path)) print(f"The result is saved at {res_path}") return @@ -279,18 +278,18 @@ def run_livecell_unetr_inference(args, device): image = imageio.imread(os.path.join(test_image_dir, image_id)) image = standardize(image) - tensor_image = torch.from_numpy(image)[None, None].to(device) - - predictions = model(tensor_image) - predictions = predictions.squeeze().detach().cpu().numpy() + predictions = predict_with_padding(model, image, min_divisible=(16, 16), device=device) + predictions = predictions.squeeze() if args.boundaries: fg, bd = predictions - instances = segmentation.watershed_from_components(bd, fg) + # instances = segmentation.watershed_from_components(bd, fg) + instances = _do_bd_multicut_watershed(bd) elif args.affinities: fg, affs = predictions[0], predictions[1:] - instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS) + # instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS) + instances = _do_affs_multicut_watershed(affs[:2], OFFSETS[:2]) elif args.distances: fg, cdist, bdist = predictions From 82c18975a786a00fd2118dd04961eec36d312949 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 5 Apr 2024 08:49:19 +0200 Subject: [PATCH 2/4] Add multicut-based segmentation for livecell --- .../vision-mamba/light_microscopy/run_lm.py | 88 +++++++++++++++++-- .../unetr/cremi/run_cremi.py | 2 +- .../livecell/run_livecell_for_vimunet.py | 2 +- 3 files changed, 83 insertions(+), 9 deletions(-) diff --git a/experiments/vision-mamba/light_microscopy/run_lm.py b/experiments/vision-mamba/light_microscopy/run_lm.py index fe0832c6..79823a42 100644 --- a/experiments/vision-mamba/light_microscopy/run_lm.py +++ b/experiments/vision-mamba/light_microscopy/run_lm.py @@ -1,5 +1,6 @@ import os import argparse +from glob import glob import numpy as np import pandas as pd @@ -10,10 +11,13 @@ import torch_em from torch_em.util import segmentation +from torch_em.transform.raw import standardize from torch_em.model import get_vimunet_model -from torch_em.util.prediction import predict_with_halo from torch_em.loss import DiceBasedDistanceLoss +import elf.segmentation.multicut as mc +import elf.segmentation.watershed as ws +import elf.segmentation.features as feats from elf.evaluation import mean_segmentation_accuracy from obtain_lm_datasets import get_lm_loaders @@ -65,6 +69,71 @@ def run_lm_training(args): trainer.fit(iterations=args.iterations) +def _do_bd_multicut_watershed(bd): + ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0) + + # compute the region adjacency graph + rag = feats.compute_rag(ws_seg) + + # compute the edge costs + costs = feats.compute_boundary_features(rag, bd)[:, 0] + + # transform the edge costs from [0, 1] to [-inf, inf], which is + # necessary for the multicut. This is done by intepreting the values + # as probabilities for an edge being 'true' and then taking the negative log-likelihood. + + # in addition, we weight the costs by the size of the corresponding edge + # for z and xy edges + z_edges = feats.compute_z_edge_mask(rag, ws_seg) + xy_edges = np.logical_not(z_edges) + edge_populations = [z_edges, xy_edges] + edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1] + costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations) + + # run the multicut partitioning, here, we use the kernighan lin + # heuristics to solve the problem, introduced in + # http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf + node_labels = mc.multicut_kernighan_lin(rag, costs) + + # map the results back to pixels to obtain the final segmentation + seg = feats.project_node_labels_to_pixels(rag, node_labels) + + return seg + + +def _do_affs_multicut_watershed(affs, offsets): + # first, we have to make a single channel input map for the watershed, + # which we obtain by averaging the affinities + boundary_input = np.mean(affs, axis=0) + + ws_seg, max_id = ws.distance_transform_watershed(boundary_input, threshold=0.25, sigma_seeds=2.0) + + # compute the region adjacency graph + rag = feats.compute_rag(ws_seg) + + # compute the edge costs + # the offsets encode the pixel transition encoded by the + # individual affinity channels. Here, we only have nearest neighbor transitions + costs = feats.compute_affinity_features(rag, affs, offsets)[:, 0] + + # transform the edge costs from [0, 1] to [-inf, inf], which is + # necessary for the multicut. This is done by intepreting the values + # as probabilities for an edge being 'true' and then taking the negative log-likelihood. + # in addition, we weight the costs by the size of the corresponding edge + edge_sizes = feats.compute_boundary_mean_and_length(rag, boundary_input)[:, 1] + costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes) + + # run the multicut partitioning, here, we use the kernighan lin + # heuristics to solve the problem, introduced in + # http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf + node_labels = mc.multicut_kernighan_lin(rag, costs) + + # map the results back to pixels to obtain the final segmentation + seg = feats.project_node_labels_to_pixels(rag, node_labels) + + return seg + + def run_lm_inference(args, device): # saving the model checkpoints save_root = os.path.join( @@ -81,7 +150,8 @@ def run_lm_inference(args, device): checkpoint=checkpoint ) - raise NotImplementedError + test_image_dir = os.path.join(ROOT, "livecell", "images", "livecell_test_images") + all_test_labels = glob(os.path.join(ROOT, "livecell", "annotations", "livecell_test_images", "*", "*")) res_path = os.path.join(save_root, "results.csv") if os.path.exists(res_path) and not args.force: @@ -90,13 +160,17 @@ def run_lm_inference(args, device): return msa_list, sa50_list, sa75_list = [], [], [] - for image_path, label_path in tqdm(zip(all_test_images, all_test_labels), total=len(all_test_images)): - image = imageio.imread(image_path) + for label_path in tqdm(all_test_labels): labels = imageio.imread(label_path) + image_id = os.path.split(label_path)[-1] - predictions = predict_with_halo( - image, model, [device], block_shape=[512, 512], halo=[128, 128], disable_tqdm=True, - ) + image = imageio.imread(os.path.join(test_image_dir, image_id)) + image = standardize(image) + + tensor_image = torch.from_numpy(image)[None, None].to(device) + + predictions = model(tensor_image) + predictions = predictions.squeeze().detach().cpu().numpy() fg, cdist, bdist = predictions instances = segmentation.watershed_from_center_and_boundary_distances( diff --git a/experiments/vision-transformer/unetr/cremi/run_cremi.py b/experiments/vision-transformer/unetr/cremi/run_cremi.py index 8dcb1b66..ea33e59a 100644 --- a/experiments/vision-transformer/unetr/cremi/run_cremi.py +++ b/experiments/vision-transformer/unetr/cremi/run_cremi.py @@ -296,7 +296,7 @@ def run_cremi_unetr_inference(args, device): affs = predictions # instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS) - instances = _do_affs_multicut_watershed(affs[:2], OFFSETS[:2]) + instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4]) elif args.distances: fg, cdist, bdist = predictions diff --git a/experiments/vision-transformer/unetr/livecell/run_livecell_for_vimunet.py b/experiments/vision-transformer/unetr/livecell/run_livecell_for_vimunet.py index 61040afb..ed272177 100644 --- a/experiments/vision-transformer/unetr/livecell/run_livecell_for_vimunet.py +++ b/experiments/vision-transformer/unetr/livecell/run_livecell_for_vimunet.py @@ -289,7 +289,7 @@ def run_livecell_unetr_inference(args, device): elif args.affinities: fg, affs = predictions[0], predictions[1:] # instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS) - instances = _do_affs_multicut_watershed(affs[:2], OFFSETS[:2]) + instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4]) elif args.distances: fg, cdist, bdist = predictions From 30967ed52513de4f7d0cd9121c20094f2364f03a Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 5 Apr 2024 20:06:21 +0200 Subject: [PATCH 3/4] Update multicut parameters --- experiments/vision-mamba/cremi/run_cremi.py | 24 +++++++++---------- .../unetr/cremi/run_cremi.py | 24 +++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/experiments/vision-mamba/cremi/run_cremi.py b/experiments/vision-mamba/cremi/run_cremi.py index 52edc569..28db5374 100644 --- a/experiments/vision-mamba/cremi/run_cremi.py +++ b/experiments/vision-mamba/cremi/run_cremi.py @@ -168,7 +168,7 @@ def run_cremi_training(args): def _do_bd_multicut_watershed(bd): - ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0) + ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.25, sigma_seeds=2.0) # compute the region adjacency graph rag = feats.compute_rag(ws_seg) @@ -179,14 +179,8 @@ def _do_bd_multicut_watershed(bd): # transform the edge costs from [0, 1] to [-inf, inf], which is # necessary for the multicut. This is done by intepreting the values # as probabilities for an edge being 'true' and then taking the negative log-likelihood. - - # in addition, we weight the costs by the size of the corresponding edge - # for z and xy edges - z_edges = feats.compute_z_edge_mask(rag, ws_seg) - xy_edges = np.logical_not(z_edges) - edge_populations = [z_edges, xy_edges] edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1] - costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations) + costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes) # run the multicut partitioning, here, we use the kernighan lin # heuristics to solve the problem, introduced in @@ -268,14 +262,18 @@ def run_cremi_inference(args, device): if args.boundaries: bd = predictions.squeeze() - # instances = segmentation.watershed_from_components(bd, np.ones_like(bd)) - instances = _do_bd_multicut_watershed(bd) + if args.multicut: + instances = _do_bd_multicut_watershed(bd) + else: + instances = segmentation.watershed_from_components(bd, np.ones_like(bd)) elif args.affinities: affs = predictions - # instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS) - instances = _do_affs_multicut_watershed(affs[:2], OFFSETS[:2]) + if args.multicut: + instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4]) + else: + instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS) elif args.distances: fg, cdist, bdist = predictions @@ -331,6 +329,8 @@ def main(args): parser.add_argument("--force", action="store_true") + parser.add_argument("--multicut", action="store_true") + parser.add_argument("--boundaries", action="store_true") parser.add_argument("--affinities", action="store_true") parser.add_argument("--distances", action="store_true") diff --git a/experiments/vision-transformer/unetr/cremi/run_cremi.py b/experiments/vision-transformer/unetr/cremi/run_cremi.py index ea33e59a..dfb238fc 100644 --- a/experiments/vision-transformer/unetr/cremi/run_cremi.py +++ b/experiments/vision-transformer/unetr/cremi/run_cremi.py @@ -187,7 +187,7 @@ def run_cremi_unetr_training(args, device): def _do_bd_multicut_watershed(bd): - ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0) + ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.25, sigma_seeds=2.0) # compute the region adjacency graph rag = feats.compute_rag(ws_seg) @@ -198,14 +198,8 @@ def _do_bd_multicut_watershed(bd): # transform the edge costs from [0, 1] to [-inf, inf], which is # necessary for the multicut. This is done by intepreting the values # as probabilities for an edge being 'true' and then taking the negative log-likelihood. - - # in addition, we weight the costs by the size of the corresponding edge - # for z and xy edges - z_edges = feats.compute_z_edge_mask(rag, ws_seg) - xy_edges = np.logical_not(z_edges) - edge_populations = [z_edges, xy_edges] edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1] - costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations) + costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes) # run the multicut partitioning, here, we use the kernighan lin # heuristics to solve the problem, introduced in @@ -289,14 +283,18 @@ def run_cremi_unetr_inference(args, device): if args.boundaries: bd = predictions.squeeze() - # instances = segmentation.watershed_from_components(bd, np.ones_like(bd)) - instances = _do_bd_multicut_watershed(bd) + if args.multicut: + instances = _do_bd_multicut_watershed(bd) + else: + instances = segmentation.watershed_from_components(bd, np.ones_like(bd)) elif args.affinities: affs = predictions - # instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS) - instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4]) + if args.multicut: + instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4]) + else: + instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS) elif args.distances: fg, cdist, bdist = predictions @@ -348,6 +346,8 @@ def main(args): parser.add_argument("--force", action="store_true") + parser.add_argument("--multicut", action="store_true") + parser.add_argument("--train", action="store_true") parser.add_argument("--predict", action="store_true") From f8a8c9c30a93debbdd5a4dc30d5239993d09c272 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 5 Apr 2024 20:49:47 +0200 Subject: [PATCH 4/4] Update livecell inference --- .../vision-mamba/livecell/run_livecell.py | 26 ++++++++++--------- .../livecell/run_livecell_for_vimunet.py | 18 +++++++++---- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/experiments/vision-mamba/livecell/run_livecell.py b/experiments/vision-mamba/livecell/run_livecell.py index fac4c348..adb0e98c 100644 --- a/experiments/vision-mamba/livecell/run_livecell.py +++ b/experiments/vision-mamba/livecell/run_livecell.py @@ -159,7 +159,7 @@ def run_livecell_training(args): def _do_bd_multicut_watershed(bd): - ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0) + ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.25, sigma_seeds=2.0) # compute the region adjacency graph rag = feats.compute_rag(ws_seg) @@ -170,14 +170,8 @@ def _do_bd_multicut_watershed(bd): # transform the edge costs from [0, 1] to [-inf, inf], which is # necessary for the multicut. This is done by intepreting the values # as probabilities for an edge being 'true' and then taking the negative log-likelihood. - - # in addition, we weight the costs by the size of the corresponding edge - # for z and xy edges - z_edges = feats.compute_z_edge_mask(rag, ws_seg) - xy_edges = np.logical_not(z_edges) - edge_populations = [z_edges, xy_edges] edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1] - costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations) + costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes) # run the multicut partitioning, here, we use the kernighan lin # heuristics to solve the problem, introduced in @@ -262,13 +256,19 @@ def run_livecell_inference(args, device): if args.boundaries: fg, bd = predictions - # instances = segmentation.watershed_from_components(bd, fg) - instances = _do_bd_multicut_watershed(bd) + + if args.multicut: + instances = _do_bd_multicut_watershed(bd) + else: + instances = segmentation.watershed_from_components(bd, fg) elif args.affinities: fg, affs = predictions[0], predictions[1:] - # instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS) - instances = _do_affs_multicut_watershed(affs[:2], OFFSETS[:2]) + + if args.multicut: + instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4]) + else: + instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS) elif args.distances: fg, cdist, bdist = predictions @@ -321,6 +321,8 @@ def main(args): parser.add_argument("--force", action="store_true") + parser.add_argument("--multicut", action="store_true") + parser.add_argument("--train", action="store_true") parser.add_argument("--predict", action="store_true") diff --git a/experiments/vision-transformer/unetr/livecell/run_livecell_for_vimunet.py b/experiments/vision-transformer/unetr/livecell/run_livecell_for_vimunet.py index ed272177..305a02da 100644 --- a/experiments/vision-transformer/unetr/livecell/run_livecell_for_vimunet.py +++ b/experiments/vision-transformer/unetr/livecell/run_livecell_for_vimunet.py @@ -180,7 +180,7 @@ def run_livecell_unetr_training(args, device): def _do_bd_multicut_watershed(bd): - ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0) + ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.25, sigma_seeds=2.0) # compute the region adjacency graph rag = feats.compute_rag(ws_seg) @@ -283,13 +283,19 @@ def run_livecell_unetr_inference(args, device): if args.boundaries: fg, bd = predictions - # instances = segmentation.watershed_from_components(bd, fg) - instances = _do_bd_multicut_watershed(bd) + + if args.multicut: + instances = _do_bd_multicut_watershed(bd) + else: + instances = segmentation.watershed_from_components(bd, fg) elif args.affinities: fg, affs = predictions[0], predictions[1:] - # instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS) - instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4]) + + if args.multicut: + instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4]) + else: + instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS) elif args.distances: fg, cdist, bdist = predictions @@ -341,6 +347,8 @@ def main(args): parser.add_argument("--force", action="store_true") + parser.add_argument("--multicut", action="store_true") + parser.add_argument("--train", action="store_true") parser.add_argument("--predict", action="store_true")