Skip to content

Commit

Permalink
Update inference scripts (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Apr 5, 2024
1 parent 0e009ef commit 4ad364f
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 45 deletions.
24 changes: 12 additions & 12 deletions experiments/vision-mamba/cremi/run_cremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def run_cremi_training(args):


def _do_bd_multicut_watershed(bd):
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0)
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.25, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)
Expand All @@ -179,14 +179,8 @@ def _do_bd_multicut_watershed(bd):
# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.

# in addition, we weight the costs by the size of the corresponding edge
# for z and xy edges
z_edges = feats.compute_z_edge_mask(rag, ws_seg)
xy_edges = np.logical_not(z_edges)
edge_populations = [z_edges, xy_edges]
edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations)
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
Expand Down Expand Up @@ -268,14 +262,18 @@ def run_cremi_inference(args, device):
if args.boundaries:
bd = predictions.squeeze()

# instances = segmentation.watershed_from_components(bd, np.ones_like(bd))
instances = _do_bd_multicut_watershed(bd)
if args.multicut:
instances = _do_bd_multicut_watershed(bd)
else:
instances = segmentation.watershed_from_components(bd, np.ones_like(bd))

elif args.affinities:
affs = predictions

# instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS)
instances = _do_affs_multicut_watershed(affs[:2], OFFSETS[:2])
if args.multicut:
instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4])
else:
instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS)

elif args.distances:
fg, cdist, bdist = predictions
Expand Down Expand Up @@ -331,6 +329,8 @@ def main(args):

parser.add_argument("--force", action="store_true")

parser.add_argument("--multicut", action="store_true")

parser.add_argument("--boundaries", action="store_true")
parser.add_argument("--affinities", action="store_true")
parser.add_argument("--distances", action="store_true")
Expand Down
88 changes: 81 additions & 7 deletions experiments/vision-mamba/light_microscopy/run_lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import argparse
from glob import glob

import numpy as np
import pandas as pd
Expand All @@ -10,10 +11,13 @@

import torch_em
from torch_em.util import segmentation
from torch_em.transform.raw import standardize
from torch_em.model import get_vimunet_model
from torch_em.util.prediction import predict_with_halo
from torch_em.loss import DiceBasedDistanceLoss

import elf.segmentation.multicut as mc
import elf.segmentation.watershed as ws
import elf.segmentation.features as feats
from elf.evaluation import mean_segmentation_accuracy

from obtain_lm_datasets import get_lm_loaders
Expand Down Expand Up @@ -65,6 +69,71 @@ def run_lm_training(args):
trainer.fit(iterations=args.iterations)


def _do_bd_multicut_watershed(bd):
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)

# compute the edge costs
costs = feats.compute_boundary_features(rag, bd)[:, 0]

# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.

# in addition, we weight the costs by the size of the corresponding edge
# for z and xy edges
z_edges = feats.compute_z_edge_mask(rag, ws_seg)
xy_edges = np.logical_not(z_edges)
edge_populations = [z_edges, xy_edges]
edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
# http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf
node_labels = mc.multicut_kernighan_lin(rag, costs)

# map the results back to pixels to obtain the final segmentation
seg = feats.project_node_labels_to_pixels(rag, node_labels)

return seg


def _do_affs_multicut_watershed(affs, offsets):
# first, we have to make a single channel input map for the watershed,
# which we obtain by averaging the affinities
boundary_input = np.mean(affs, axis=0)

ws_seg, max_id = ws.distance_transform_watershed(boundary_input, threshold=0.25, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)

# compute the edge costs
# the offsets encode the pixel transition encoded by the
# individual affinity channels. Here, we only have nearest neighbor transitions
costs = feats.compute_affinity_features(rag, affs, offsets)[:, 0]

# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.
# in addition, we weight the costs by the size of the corresponding edge
edge_sizes = feats.compute_boundary_mean_and_length(rag, boundary_input)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
# http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf
node_labels = mc.multicut_kernighan_lin(rag, costs)

# map the results back to pixels to obtain the final segmentation
seg = feats.project_node_labels_to_pixels(rag, node_labels)

return seg


def run_lm_inference(args, device):
# saving the model checkpoints
save_root = os.path.join(
Expand All @@ -81,7 +150,8 @@ def run_lm_inference(args, device):
checkpoint=checkpoint
)

raise NotImplementedError
test_image_dir = os.path.join(ROOT, "livecell", "images", "livecell_test_images")
all_test_labels = glob(os.path.join(ROOT, "livecell", "annotations", "livecell_test_images", "*", "*"))

res_path = os.path.join(save_root, "results.csv")
if os.path.exists(res_path) and not args.force:
Expand All @@ -90,13 +160,17 @@ def run_lm_inference(args, device):
return

msa_list, sa50_list, sa75_list = [], [], []
for image_path, label_path in tqdm(zip(all_test_images, all_test_labels), total=len(all_test_images)):
image = imageio.imread(image_path)
for label_path in tqdm(all_test_labels):
labels = imageio.imread(label_path)
image_id = os.path.split(label_path)[-1]

predictions = predict_with_halo(
image, model, [device], block_shape=[512, 512], halo=[128, 128], disable_tqdm=True,
)
image = imageio.imread(os.path.join(test_image_dir, image_id))
image = standardize(image)

tensor_image = torch.from_numpy(image)[None, None].to(device)

predictions = model(tensor_image)
predictions = predictions.squeeze().detach().cpu().numpy()

fg, cdist, bdist = predictions
instances = segmentation.watershed_from_center_and_boundary_distances(
Expand Down
78 changes: 75 additions & 3 deletions experiments/vision-mamba/livecell/run_livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from torch_em.data.datasets import get_livecell_loader
from torch_em.loss import DiceLoss, LossWrapper, ApplyAndRemoveMask, DiceBasedDistanceLoss

import elf.segmentation.multicut as mc
import elf.segmentation.watershed as ws
import elf.segmentation.features as feats
from elf.evaluation import mean_segmentation_accuracy


Expand Down Expand Up @@ -155,6 +158,65 @@ def run_livecell_training(args):
trainer.fit(iterations=int(args.iterations))


def _do_bd_multicut_watershed(bd):
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.25, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)

# compute the edge costs
costs = feats.compute_boundary_features(rag, bd)[:, 0]

# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.
edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
# http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf
node_labels = mc.multicut_kernighan_lin(rag, costs)

# map the results back to pixels to obtain the final segmentation
seg = feats.project_node_labels_to_pixels(rag, node_labels)

return seg


def _do_affs_multicut_watershed(affs, offsets):
# first, we have to make a single channel input map for the watershed,
# which we obtain by averaging the affinities
boundary_input = np.mean(affs, axis=0)

ws_seg, max_id = ws.distance_transform_watershed(boundary_input, threshold=0.25, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)

# compute the edge costs
# the offsets encode the pixel transition encoded by the
# individual affinity channels. Here, we only have nearest neighbor transitions
costs = feats.compute_affinity_features(rag, affs, offsets)[:, 0]

# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.
# in addition, we weight the costs by the size of the corresponding edge
edge_sizes = feats.compute_boundary_mean_and_length(rag, boundary_input)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
# http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf
node_labels = mc.multicut_kernighan_lin(rag, costs)

# map the results back to pixels to obtain the final segmentation
seg = feats.project_node_labels_to_pixels(rag, node_labels)

return seg


def run_livecell_inference(args, device):
output_channels = get_output_channels(args)

Expand All @@ -174,7 +236,7 @@ def run_livecell_inference(args, device):
all_test_labels = glob(os.path.join(ROOT, "data", "livecell", "annotations", "livecell_test_images", "*", "*"))

res_path = os.path.join(save_root, "results.csv")
if os.path.exists(res_path):
if os.path.exists(res_path) and not args.force:
print(pd.read_csv(res_path))
print(f"The result is saved at {res_path}")
return
Expand All @@ -194,11 +256,19 @@ def run_livecell_inference(args, device):

if args.boundaries:
fg, bd = predictions
instances = segmentation.watershed_from_components(bd, fg)

if args.multicut:
instances = _do_bd_multicut_watershed(bd)
else:
instances = segmentation.watershed_from_components(bd, fg)

elif args.affinities:
fg, affs = predictions[0], predictions[1:]
instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS)

if args.multicut:
instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4])
else:
instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS)

elif args.distances:
fg, cdist, bdist = predictions
Expand Down Expand Up @@ -251,6 +321,8 @@ def main(args):

parser.add_argument("--force", action="store_true")

parser.add_argument("--multicut", action="store_true")

parser.add_argument("--train", action="store_true")
parser.add_argument("--predict", action="store_true")

Expand Down
24 changes: 12 additions & 12 deletions experiments/vision-transformer/unetr/cremi/run_cremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def run_cremi_unetr_training(args, device):


def _do_bd_multicut_watershed(bd):
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0)
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.25, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)
Expand All @@ -198,14 +198,8 @@ def _do_bd_multicut_watershed(bd):
# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.

# in addition, we weight the costs by the size of the corresponding edge
# for z and xy edges
z_edges = feats.compute_z_edge_mask(rag, ws_seg)
xy_edges = np.logical_not(z_edges)
edge_populations = [z_edges, xy_edges]
edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations)
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
Expand Down Expand Up @@ -289,14 +283,18 @@ def run_cremi_unetr_inference(args, device):
if args.boundaries:
bd = predictions.squeeze()

# instances = segmentation.watershed_from_components(bd, np.ones_like(bd))
instances = _do_bd_multicut_watershed(bd)
if args.multicut:
instances = _do_bd_multicut_watershed(bd)
else:
instances = segmentation.watershed_from_components(bd, np.ones_like(bd))

elif args.affinities:
affs = predictions

# instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS)
instances = _do_affs_multicut_watershed(affs[:2], OFFSETS[:2])
if args.multicut:
instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4])
else:
instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS)

elif args.distances:
fg, cdist, bdist = predictions
Expand Down Expand Up @@ -348,6 +346,8 @@ def main(args):

parser.add_argument("--force", action="store_true")

parser.add_argument("--multicut", action="store_true")

parser.add_argument("--train", action="store_true")
parser.add_argument("--predict", action="store_true")

Expand Down
Loading

0 comments on commit 4ad364f

Please sign in to comment.