Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update inference scripts #19

Merged
merged 4 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions experiments/vision-mamba/cremi/run_cremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def run_cremi_training(args):


def _do_bd_multicut_watershed(bd):
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0)
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.25, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)
Expand All @@ -179,14 +179,8 @@ def _do_bd_multicut_watershed(bd):
# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.

# in addition, we weight the costs by the size of the corresponding edge
# for z and xy edges
z_edges = feats.compute_z_edge_mask(rag, ws_seg)
xy_edges = np.logical_not(z_edges)
edge_populations = [z_edges, xy_edges]
edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations)
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
Expand Down Expand Up @@ -268,14 +262,18 @@ def run_cremi_inference(args, device):
if args.boundaries:
bd = predictions.squeeze()

# instances = segmentation.watershed_from_components(bd, np.ones_like(bd))
instances = _do_bd_multicut_watershed(bd)
if args.multicut:
instances = _do_bd_multicut_watershed(bd)
else:
instances = segmentation.watershed_from_components(bd, np.ones_like(bd))

elif args.affinities:
affs = predictions

# instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS)
instances = _do_affs_multicut_watershed(affs[:2], OFFSETS[:2])
if args.multicut:
instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4])
else:
instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS)

elif args.distances:
fg, cdist, bdist = predictions
Expand Down Expand Up @@ -331,6 +329,8 @@ def main(args):

parser.add_argument("--force", action="store_true")

parser.add_argument("--multicut", action="store_true")

parser.add_argument("--boundaries", action="store_true")
parser.add_argument("--affinities", action="store_true")
parser.add_argument("--distances", action="store_true")
Expand Down
88 changes: 81 additions & 7 deletions experiments/vision-mamba/light_microscopy/run_lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import argparse
from glob import glob

import numpy as np
import pandas as pd
Expand All @@ -10,10 +11,13 @@

import torch_em
from torch_em.util import segmentation
from torch_em.transform.raw import standardize
from torch_em.model import get_vimunet_model
from torch_em.util.prediction import predict_with_halo
from torch_em.loss import DiceBasedDistanceLoss

import elf.segmentation.multicut as mc
import elf.segmentation.watershed as ws
import elf.segmentation.features as feats
from elf.evaluation import mean_segmentation_accuracy

from obtain_lm_datasets import get_lm_loaders
Expand Down Expand Up @@ -65,6 +69,71 @@ def run_lm_training(args):
trainer.fit(iterations=args.iterations)


def _do_bd_multicut_watershed(bd):
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)

# compute the edge costs
costs = feats.compute_boundary_features(rag, bd)[:, 0]

# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.

# in addition, we weight the costs by the size of the corresponding edge
# for z and xy edges
z_edges = feats.compute_z_edge_mask(rag, ws_seg)
xy_edges = np.logical_not(z_edges)
edge_populations = [z_edges, xy_edges]
edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
# http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf
node_labels = mc.multicut_kernighan_lin(rag, costs)

# map the results back to pixels to obtain the final segmentation
seg = feats.project_node_labels_to_pixels(rag, node_labels)

return seg


def _do_affs_multicut_watershed(affs, offsets):
# first, we have to make a single channel input map for the watershed,
# which we obtain by averaging the affinities
boundary_input = np.mean(affs, axis=0)

ws_seg, max_id = ws.distance_transform_watershed(boundary_input, threshold=0.25, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)

# compute the edge costs
# the offsets encode the pixel transition encoded by the
# individual affinity channels. Here, we only have nearest neighbor transitions
costs = feats.compute_affinity_features(rag, affs, offsets)[:, 0]

# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.
# in addition, we weight the costs by the size of the corresponding edge
edge_sizes = feats.compute_boundary_mean_and_length(rag, boundary_input)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
# http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf
node_labels = mc.multicut_kernighan_lin(rag, costs)

# map the results back to pixels to obtain the final segmentation
seg = feats.project_node_labels_to_pixels(rag, node_labels)

return seg


def run_lm_inference(args, device):
# saving the model checkpoints
save_root = os.path.join(
Expand All @@ -81,7 +150,8 @@ def run_lm_inference(args, device):
checkpoint=checkpoint
)

raise NotImplementedError
test_image_dir = os.path.join(ROOT, "livecell", "images", "livecell_test_images")
all_test_labels = glob(os.path.join(ROOT, "livecell", "annotations", "livecell_test_images", "*", "*"))

res_path = os.path.join(save_root, "results.csv")
if os.path.exists(res_path) and not args.force:
Expand All @@ -90,13 +160,17 @@ def run_lm_inference(args, device):
return

msa_list, sa50_list, sa75_list = [], [], []
for image_path, label_path in tqdm(zip(all_test_images, all_test_labels), total=len(all_test_images)):
image = imageio.imread(image_path)
for label_path in tqdm(all_test_labels):
labels = imageio.imread(label_path)
image_id = os.path.split(label_path)[-1]

predictions = predict_with_halo(
image, model, [device], block_shape=[512, 512], halo=[128, 128], disable_tqdm=True,
)
image = imageio.imread(os.path.join(test_image_dir, image_id))
image = standardize(image)

tensor_image = torch.from_numpy(image)[None, None].to(device)

predictions = model(tensor_image)
predictions = predictions.squeeze().detach().cpu().numpy()

fg, cdist, bdist = predictions
instances = segmentation.watershed_from_center_and_boundary_distances(
Expand Down
78 changes: 75 additions & 3 deletions experiments/vision-mamba/livecell/run_livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from torch_em.data.datasets import get_livecell_loader
from torch_em.loss import DiceLoss, LossWrapper, ApplyAndRemoveMask, DiceBasedDistanceLoss

import elf.segmentation.multicut as mc
import elf.segmentation.watershed as ws
import elf.segmentation.features as feats
from elf.evaluation import mean_segmentation_accuracy


Expand Down Expand Up @@ -155,6 +158,65 @@ def run_livecell_training(args):
trainer.fit(iterations=int(args.iterations))


def _do_bd_multicut_watershed(bd):
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.25, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)

# compute the edge costs
costs = feats.compute_boundary_features(rag, bd)[:, 0]

# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.
edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
# http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf
node_labels = mc.multicut_kernighan_lin(rag, costs)

# map the results back to pixels to obtain the final segmentation
seg = feats.project_node_labels_to_pixels(rag, node_labels)

return seg


def _do_affs_multicut_watershed(affs, offsets):
# first, we have to make a single channel input map for the watershed,
# which we obtain by averaging the affinities
boundary_input = np.mean(affs, axis=0)

ws_seg, max_id = ws.distance_transform_watershed(boundary_input, threshold=0.25, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)

# compute the edge costs
# the offsets encode the pixel transition encoded by the
# individual affinity channels. Here, we only have nearest neighbor transitions
costs = feats.compute_affinity_features(rag, affs, offsets)[:, 0]

# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.
# in addition, we weight the costs by the size of the corresponding edge
edge_sizes = feats.compute_boundary_mean_and_length(rag, boundary_input)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
# http://xilinx.asia/_hdl/4/eda.ee.ucla.edu/EE201A-04Spring/kl.pdf
node_labels = mc.multicut_kernighan_lin(rag, costs)

# map the results back to pixels to obtain the final segmentation
seg = feats.project_node_labels_to_pixels(rag, node_labels)

return seg


def run_livecell_inference(args, device):
output_channels = get_output_channels(args)

Expand All @@ -174,7 +236,7 @@ def run_livecell_inference(args, device):
all_test_labels = glob(os.path.join(ROOT, "data", "livecell", "annotations", "livecell_test_images", "*", "*"))

res_path = os.path.join(save_root, "results.csv")
if os.path.exists(res_path):
if os.path.exists(res_path) and not args.force:
print(pd.read_csv(res_path))
print(f"The result is saved at {res_path}")
return
Expand All @@ -194,11 +256,19 @@ def run_livecell_inference(args, device):

if args.boundaries:
fg, bd = predictions
instances = segmentation.watershed_from_components(bd, fg)

if args.multicut:
instances = _do_bd_multicut_watershed(bd)
else:
instances = segmentation.watershed_from_components(bd, fg)

elif args.affinities:
fg, affs = predictions[0], predictions[1:]
instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS)

if args.multicut:
instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4])
else:
instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS)

elif args.distances:
fg, cdist, bdist = predictions
Expand Down Expand Up @@ -251,6 +321,8 @@ def main(args):

parser.add_argument("--force", action="store_true")

parser.add_argument("--multicut", action="store_true")

parser.add_argument("--train", action="store_true")
parser.add_argument("--predict", action="store_true")

Expand Down
24 changes: 12 additions & 12 deletions experiments/vision-transformer/unetr/cremi/run_cremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def run_cremi_unetr_training(args, device):


def _do_bd_multicut_watershed(bd):
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.5, sigma_seeds=2.0)
ws_seg, max_id = ws.distance_transform_watershed(bd, threshold=0.25, sigma_seeds=2.0)

# compute the region adjacency graph
rag = feats.compute_rag(ws_seg)
Expand All @@ -198,14 +198,8 @@ def _do_bd_multicut_watershed(bd):
# transform the edge costs from [0, 1] to [-inf, inf], which is
# necessary for the multicut. This is done by intepreting the values
# as probabilities for an edge being 'true' and then taking the negative log-likelihood.

# in addition, we weight the costs by the size of the corresponding edge
# for z and xy edges
z_edges = feats.compute_z_edge_mask(rag, ws_seg)
xy_edges = np.logical_not(z_edges)
edge_populations = [z_edges, xy_edges]
edge_sizes = feats.compute_boundary_mean_and_length(rag, bd)[:, 1]
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes, edge_populations=edge_populations)
costs = mc.transform_probabilities_to_costs(costs, edge_sizes=edge_sizes)

# run the multicut partitioning, here, we use the kernighan lin
# heuristics to solve the problem, introduced in
Expand Down Expand Up @@ -289,14 +283,18 @@ def run_cremi_unetr_inference(args, device):
if args.boundaries:
bd = predictions.squeeze()

# instances = segmentation.watershed_from_components(bd, np.ones_like(bd))
instances = _do_bd_multicut_watershed(bd)
if args.multicut:
instances = _do_bd_multicut_watershed(bd)
else:
instances = segmentation.watershed_from_components(bd, np.ones_like(bd))

elif args.affinities:
affs = predictions

# instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS)
instances = _do_affs_multicut_watershed(affs[:2], OFFSETS[:2])
if args.multicut:
instances = _do_affs_multicut_watershed(affs[:4], OFFSETS[:4])
else:
instances = segmentation.mutex_watershed_segmentation(np.ones_like(labels), affs, offsets=OFFSETS)

elif args.distances:
fg, cdist, bdist = predictions
Expand Down Expand Up @@ -348,6 +346,8 @@ def main(args):

parser.add_argument("--force", action="store_true")

parser.add_argument("--multicut", action="store_true")

parser.add_argument("--train", action="store_true")
parser.add_argument("--predict", action="store_true")

Expand Down
Loading
Loading