### Import Modules

In [4]:
# Import libraries
from __future__ import print_function, unicode_literals, absolute_import, division
import os
import sys
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = None
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, normalize

from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist import Rays_GoldenSpiral
from stardist.matching import matching, matching_dataset
from stardist.models import Config3D, StarDist3D, StarDistData3D

np.random.seed(42)
lbl_cmap = random_label_cmap()

2022-12-02 14:37:08.929245: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-12-02 14:37:09.151122: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-12-02 14:37:09.206812: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-12-02 14:37:10.204980: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; 

### Import and Preprocess Training Data

In [None]:
# Set paths to the image files and masks
root = '/home/elyse/Documents/GitHub/LS_evaluation_tool/test_data/stardist_training'
img_dir = 'train/images'
mask_dir = 'train/masks'

X = sorted(glob(os.path.join(root, img_dir, '*.tif')))
Y = sorted(glob(os.path.join(root, mask_dir, '*.tif')))
assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))

In [None]:
# Load and normalize datasets (normalization can dramatically effect segmentation)
X = list(map(imread,X))
Y = list(map(imread,Y))

axis_norm = (0,1,2)   # normalize channels independently
X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]

In [None]:
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] 
print('number of images: %3d' % len(X))
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

### Plot Example of Image and Mask

In [None]:
# Function to shoe an XY slice of raw and mask image
def plot_img_label(img, lbl, img_title="image (XY slice)", lbl_title="label (XY slice)", z=None, **kwargs):
    if z is None:
        z = img.shape[0] // 2    
    fig, (ai,al) = plt.subplots(1,2, figsize=(12,5), gridspec_kw=dict(width_ratios=(1.25,1)))
    im = ai.imshow(img[z], cmap='gray', clim=(0,1))
    ai.set_title(img_title)    
    fig.colorbar(im, ax=ai)
    al.imshow(lbl[z], cmap=lbl_cmap)
    al.set_title(lbl_title)
    plt.tight_layout()

i = 2
img, lbl = X[i], Y[i]
img = img if img.ndim==3 else img[...,:3]
plot_img_label(img,lbl)


### Configure and Train Model

In [None]:
# Calculate anisotropy. Should be (2.0, 1.8, 1.8) for SmartSpim data
extents = calculate_extents(Y)
anisotropy = tuple(np.max(extents) / extents)
print('empirical anisotropy of labeled objects = %s' % str(anisotropy))

In [None]:
# function to get total memory for allocation
import subprocess as sp

def get_gpu_memory():
    command = "nvidia-smi --query-gpu=memory.free --format=csv"
    memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
    memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
    return memory_free_values


# Parameterize model
n_rays = 96
n_channel = 1
use_gpu = True

# Predict on subsampled grid for increased efficiency and larger field of view
grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)

# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data
rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)

conf = Config3D (
    rays             = rays,
    grid             = grid,
    anisotropy       = anisotropy,
    use_gpu          = use_gpu,
    n_channel_in     = n_channel,
    # adjust for your data below (make patch size as large as possible)
    train_patch_size = (16,32,32),
    train_batch_size = 2,
)

# parameterize gpu usage
if use_gpu:
    from csbdeep.utils.tf import limit_gpu_memory
    # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations
    gpu_mem = get_gpu_memory()
    limit_gpu_memory(0.8, total_memory = gpu_mem[0])
    # alternatively, try this:
    # limit_gpu_memory(None, allow_growth=True)



In [None]:
# create model and test parameters
model_dir = os.path.join(root, 'models')
model = StarDist3D(conf, name = 'stardist', basedir = model_dir)

median_size = calculate_extents(Y, np.median)
fov = np.array(model._axes_tile_overlap('ZYX'))
print(f"median object size:      {median_size}")
print(f"network field of view :  {fov}")
if any(median_size > fov):
    print("WARNING: median object size larger than field of view of the neural network.")

In [None]:
model.train(X_trn, Y_trn, validation_data = (X_val,Y_val), augmenter = None)

### Optimize Thresholds

In [None]:
model.optimize_thresholds(X_val, Y_val)

### Segmentation Evaluation

In [None]:
Y_val_pred = [model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)[0]
              for x in tqdm(X_val)]

In [None]:
# plot istance of image along with mask and prediction
plot_img_label(X_val[0],Y_val[0], lbl_title="label GT (XY slice)")
plot_img_label(X_val[0],Y_val_pred[0], lbl_title="label Pred (XY slice)")

In [None]:
# get metrics across range of IOU values
taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
stats = [matching_dataset(Y_val, Y_val_pred, thresh=t, show_progress=False) for t in tqdm(taus)]


In [None]:
# plot prediction metrics across IOU values
fig, (ax1,ax2) = plt.subplots(1,2, figsize=(15,5))

for m in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
    ax1.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax1.set_xlabel(r'IoU threshold $\tau$')
ax1.set_ylabel('Metric value')
ax1.grid()
ax1.legend()

for m in ('fp', 'tp', 'fn'):
    ax2.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax2.set_xlabel(r'IoU threshold $\tau$')
ax2.set_ylabel('Number #')
ax2.grid()
ax2.legend();

### Import Annotated Data

In [2]:
# Get annotated data and image file paths
import re
import os

import utils.preprocess as pp
import utils.evaluate as evaluate

from collections import defaultdict

base_dir = '/home/elyse/Documents/GitHub/LS_evaluation_tool/test_data/images/annotated'
output_root = '/home/elyse/Documents/GitHub/LS_evaluation_tool/test_data/stardist_output'

# create output directory if it doesn't already exist
if not os.path.exists(output_root):
    os.mkdir(output_root)

blocks = defaultdict(dict)

# loop through base directory and collect all data from "block_#" subfolders
for root, dirs, files in os.walk(base_dir):
    if 'block' in root: 
        for file in files:
            try:
                block_num = re.findall('[0-9]', file)[0]
            except:
                continue
            if 'signal' in file:
                blocks[block_num]['signal'] = os.path.join(root, file)
            elif 'background' in file:
                blocks[block_num]['background'] = os.path.join(root, file)
            elif 'detect' in file:
                blocks[block_num]['detect'] = pp.get_locations(os.path.join(root, file))
            else:
                blocks[block_num]['reject'] = pp.get_locations(os.path.join(root, file))

# print info on blocks        
for key in blocks.keys():
    try:
        rejected = len(blocks[key]['reject'])
    except:
        rejected = 0
    print('Pulled annotation data for block {0}: {1} cells and {2} noncells.'.format(str(key), len(blocks[key]['detect']), rejected))

Pulled annotation data for block 4: 2168 cells and 0 noncells.
Pulled annotation data for block 1: 1751 cells and 2 noncells.
Pulled annotation data for block 3: 410 cells and 5 noncells.
Pulled annotation data for block 2: 1641 cells and 2 noncells.


### Import Trained Model

In [5]:
model_dir = '/home/elyse/Documents/GitHub/LS_evaluation_tool/test_data/stardist_training/models'
model = StarDist3D(None, name= 'stardist', basedir = model_dir)

2022-12-02 14:37:27.470416: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-12-02 14:37:29.120173: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 2772 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2080 Ti, pci bus id: 0000:65:00.0, compute capability: 7.5
2022-12-02 14:37:29.120957: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 6656 MB memory:  -> device: 1, name: NVIDIA GeForce RTX 2070 SUPER, pci bus id: 0000:b3:00.0, compute capability: 7.5


Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.503434, nms_thresh=0.4.


### Run Trained Model on Annotated Data

In [5]:
import time
import skimage.io

from skimage.measure import regionprops

# Parameterize model
n_rays = 96
n_channel = 1
use_gpu = True

# Predict on subsampled grid for increased efficiency and larger field of view
grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)

# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data
rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)

# model configuratons
conf = Config3D (
    rays             = rays,
    grid             = grid,
    anisotropy       = anisotropy,
    use_gpu          = use_gpu,
    n_channel_in     = n_channel,
    # adjust for your data below (make patch size as large as possible)
    train_patch_size = (16,32,32),
    train_batch_size = 2,
)


# model parameters
chs = ['signal', 'background']
method = 'stardist'
detect_times = defaultdict(int)
axis_norm = (0,1,2)   # normalize channels independently

#preprocessing parameters
bkg_sub = True
estimator = 'SExtractorBackground'
pad = 50

start_detect = time.time()
for key in blocks.keys():
        for ch in chs:
            fname = os.path.basename(blocks[key][ch])[:-4]
            img = skimage.io.imread(blocks[key][ch])
            
            if bkg_sub:
                img = pp.astro_preprocess(img, estimator, pad = pad)
            
            # need top normalize. will send warning if you forget
            img_norm = normalize(img,1,99.8,axis=axis_norm)
            
            Y_val_pred = model.predict_instances(img_norm, n_tiles=model._guess_n_tiles(img), show_tile_progress=False)[0]

            # get centroids from masks. see: https://github.com/MouseLand/cellpose/issues/337
            centroids = []
            candidates = regionprops(Y_val_pred)
            for c in range(len(candidates)):
                location = [int(x) for x in candidates[c]['centroid']]
                centroids.append(location[::-1])
        
            if ch == 'signal':
                blocks[key]['pred_detect_' + method] = centroids
            else:
                blocks[key]['pred_reject_' + method] = centroids
            
            # save masks and flows
            output_file = os.path.join(output_root, 'block_' + key + '_mask.tif')
            
            # save masks as tif
            print(Y_val_pred.shape)
            skimage.io.imsave(output_file, Y_val_pred)

detect_times[method] = time.time() - start_detect

2022-12-01 15:48:20.168930: E tensorflow/stream_executor/cuda/cuda_dnn.cc:389] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2022-12-01 15:48:20.169013: E tensorflow/stream_executor/cuda/cuda_dnn.cc:398] Possibly insufficient driver version: 515.86.1
2022-12-01 15:48:20.169066: W tensorflow/core/framework/op_kernel.cc:1780] OP_REQUIRES failed at conv_ops_3d.cc:509 : UNIMPLEMENTED: DNN library is not found.


UnimplementedError: Graph execution error:

Detected at node 'model/conv3d/Conv3D' defined at (most recent call last):
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 712, in start
      self.io_loop.start()
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 390, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2914, in run_cell
      result = self._run_cell(
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2960, in _run_cell
      return runner(coro)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
      coro.send(None)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3185, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3377, in run_ast_nodes
      if (await self.run_code(code, result,  async_=asy)):
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3457, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_16269/705827369.py", line 29, in <module>
      Y_val_pred = model.predict_instances(img_norm, n_tiles=model._guess_n_tiles(img), show_tile_progress=False)[0]
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/stardist/models/base.py", line 775, in predict_instances
      for r in self._predict_instances_generator(*args, **kwargs):
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/stardist/models/base.py", line 727, in _predict_instances_generator
      for res in self._predict_sparse_generator(img, axes=axes, normalizer=normalizer, n_tiles=n_tiles,
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/stardist/models/base.py", line 549, in _predict_sparse_generator
      tile_generator, output_shape, create_empty_output = tiling_setup()
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/stardist/models/base.py", line 405, in tiling_setup
      axes_net_tile_overlaps = self._axes_tile_overlap(axes_net)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/stardist/models/base.py", line 1084, in _axes_tile_overlap
      self._tile_overlap = self._compute_receptive_field()
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/stardist/models/base.py", line 1069, in _compute_receptive_field
      y  = self.keras_model.predict(x)[0][0,...,0]
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/engine/training.py", line 2253, in predict
      tmp_batch_outputs = self.predict_function(iterator)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/engine/training.py", line 2041, in predict_function
      return step_function(self, iterator)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/engine/training.py", line 2027, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/engine/training.py", line 2015, in run_step
      outputs = model.predict_step(data)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/engine/training.py", line 1983, in predict_step
      return self(x, training=False)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/engine/training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/engine/functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/engine/functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/layers/convolutional/base_conv.py", line 283, in call
      outputs = self.convolution_op(inputs, self.kernel)
    File "/home/elyse/miniconda3/envs/segment_compare/lib/python3.8/site-packages/keras/layers/convolutional/base_conv.py", line 255, in convolution_op
      return tf.nn.convolution(
Node: 'model/conv3d/Conv3D'
DNN library is not found.
	 [[{{node model/conv3d/Conv3D}}]] [Op:__inference_predict_function_867]

### Stardist Outputs and Profile

In [None]:
# get time to run each method
for method in ['stardist']:
    print('Using ' + method + ' method for 2D to 3D')
    for key in blocks.keys():
        print('Classified {0} cells and {1} noncells for annocation block {2}.'.format(len(blocks[key]['pred_detect_' + method]), 
                                                                                       len(blocks[key]['pred_reject_' + method]), 
                                                                                       str(key)))

    print('Detection and Classification via ' + method + ' took {0} seconds.\n'.format(detect_times[method]))

### Plot Performance

In [None]:
# get metrics on model performance and plot results
import matplotlib.pyplot as plt
%matplotlib inline

methods = ['stardist']
max_dist = 10
trained = False
get_keys = ['4']

performance = evaluate.get_performance(blocks, methods, max_dist, trained, get_keys)
evaluate.plot_performance(performance, methods)

In [None]:
# Classification Matrix Plots
F1_scores = performance[2::3, :-1]
opt_vals = np.unravel_index(np.argmax(F1_scores, axis=None), F1_scores.shape)
metric_df = evaluate.annot_pred_overlap(blocks, opt_vals[0] + 1, methods[opt_vals[1]])

evaluate.plot_cm(metric_df)