**In case of problems or questions, please first check the list of [Frequently Asked Questions (FAQ)](https://stardist.net/docs/faq.html).**

Please shutdown all other training/prediction notebooks before running this notebook (as those might occupy the GPU memory otherwise).

In [1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib
# matplotlib.rcParams["image.interpolation"] = None
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tifffile import imread, imwrite
from csbdeep.utils import Path, normalize
from csbdeep.io import save_tiff_imagej_compatible

from stardist import random_label_cmap
from stardist.models import StarDist3D

np.random.seed(6)
lbl_cmap = random_label_cmap()

# Data

We assume that data has already been downloaded in via notebook [1_data.ipynb](1_data.ipynb).  
We now load images from the sub-folder `test` that have not been used during training.

In [2]:
import pathlib as pt
path_images = pt.Path().home() / "Desktop/Code/CELLSEG_BENCHMARK/RESULTS/SPLITS"
path_weights = path_images / "STARDIST/weights/tuned"
dataset = [
    "1_c15",
    "2_c1_c4_visual",
    "3_c1245_visual"
]
gt_dict = {
    "1_c15" : "visual",
    "2_c1_c4_visual" : "c5",
    "3_c1245_visual" : "c3"
}
save_dict = {
    "1_c15" : "c1_5",
    "2_c1_c4_visual" : "c1-4_v",
    "3_c1245_visual" : "c1245_v"
}
splits = [
    "10",
    "20",
    "60",
    "80"
]
    
path_images = path_images / "IMAGES"
image_dict = {
    "visual" : imread(str(path_images / "small_isotropic_visual.tif")),
    "c5" : imread(str(path_images / "c5image.tif")),
    "c3" : imread(str(path_images / "c3image.tif")),
}


# Load trained model

If you trained your own StarDist model (and optimized its thresholds) via notebook [2_training.ipynb](2_training.ipynb), then please set `demo_model = False` below.

In [3]:
weights_folders = [f"stardist_{dat}_{per}" for per in splits for dat in dataset]

In [4]:
model_parameters = {
    "visual": {
        # "NMS": 0.3,
        "NMS": "auto",
        # "prob_thresh": 0.8
        "prob_thresh": "auto"
    },
    "c5": {
        "NMS": "auto",
        "prob_thresh": "auto"
    },
    "c3": {
        "NMS": "auto",
        "prob_thresh": "auto"
    }
}

## Prediction

Make sure to normalize the input image beforehand or supply a `normalizer` to the prediction function.

Calling `model.predict_instances` will
- predict object probabilities and star-convex polygon distances (see `model.predict` if you want those)
- perform non-maximum suppression (with overlap threshold `nms_thresh`) for polygons above object probability threshold `prob_thresh`.
- render all remaining polygon instances in a label image
- return the label instances image and also the details (coordinates, etc.) of all remaining polygons

In [5]:
axis_norm = (0,1,2)   # normalize channels independently
for w in weights_folders:
    dataset = w[9:-3]
    split = w[-2:]
    print(f"Predicting {dataset} - {split}")
    gt = gt_dict[dataset]
    NMS = model_parameters[gt]["NMS"] if model_parameters[gt]["NMS"] != "auto" else None
    model = StarDist3D(None, name=w, basedir=path_weights)
    prob_thresh = model_parameters[gt]["prob_thresh"] if model_parameters[gt]["prob_thresh"] != "auto" else None
    image = image_dict[gt]
    image = normalize(image, 1,99.8, axis=axis_norm)
    labels, details = model.predict_instances(image, prob_thresh=prob_thresh, nms_thresh=NMS, verbose=True)
    save_path = path_images / f"../Analysis/dataset_splits/{save_dict[dataset]}/sd/tuned/stardist_{dataset}_{split}.tif"
    print("Saving to", save_path)
    if save_path.is_file():
        print("File already exists, skipping")
        continue
    imwrite(str(save_path), labels)
    del model

Predicting 1_c15 - 10
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.47128, nms_thresh=0.3.
predicting instances with nms_thresh = 0.3
non-maximum suppression...
NMS took 5.2696 s
keeping 465/15648 polyhedra
render polygons...
Saving to C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\RESULTS\SPLITS\IMAGES\..\Analysis\dataset_splits\c1_5\sd\tuned\stardist_1_c15_10.tif
Predicting 2_c1_c4_visual - 10
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.490714, nms_thresh=0.3.
predicting instances with nms_thresh = 0.3
non-maximum suppression...
NMS took 1.5372 s
keeping 237/6704 polyhedra
render polygons...
Saving to C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\RESULTS\SPLITS\IMAGES\..\Analysis\dataset_splits\c1-4_v\sd\tuned\stardist_2_c1_c4_visual_10.tif
Predicting 3_c1245_visual - 10
Loading network weights from 'weights_best.h5'.
Loadi