**In case of problems or questions, please first check the list of [Frequently Asked Questions (FAQ)](https://stardist.net/docs/faq.html).**

Please shutdown all other training/prediction notebooks before running this notebook (as those might occupy the GPU memory otherwise).

In [1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib
# matplotlib.rcParams["image.interpolation"] = None
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tifffile import imread, imwrite
from csbdeep.utils import Path, normalize
from csbdeep.io import save_tiff_imagej_compatible

from stardist import random_label_cmap
from stardist.models import StarDist3D

np.random.seed(6)
lbl_cmap = random_label_cmap()

# Data

We assume that data has already been downloaded in via notebook [1_data.ipynb](1_data.ipynb).  
We now load images from the sub-folder `test` that have not been used during training.

In [2]:
import pathlib as pt
path_images = pt.Path().home() / "Desktop/Code/CELLSEG_BENCHMARK/RESULTS/SPLITS"
path_weights = path_images / "STARDIST/weights"
seed = [
    "34936339",
    "34936397",
    "34936345"
    ]
splits = [
    "10",
    "20",
    "40",
    "80"
]
    
path_images = path_images / "IMAGES"
image = imread(str(path_images / "small_isotropic_visual.tif"))


# Load trained model

If you trained your own StarDist model (and optimized its thresholds) via notebook [2_training.ipynb](2_training.ipynb), then please set `demo_model = False` below.

In [3]:
weights_folders = [f"stardist_{s}_{per}" for per in splits for s in seed]

In [4]:
model_parameters = {
    "visual": {
        # "NMS": 0.3,
        "NMS": "auto",
        # "prob_thresh": 0.8
        "prob_thresh": "auto"
    },
}

## Prediction

Make sure to normalize the input image beforehand or supply a `normalizer` to the prediction function.

Calling `model.predict_instances` will
- predict object probabilities and star-convex polygon distances (see `model.predict` if you want those)
- perform non-maximum suppression (with overlap threshold `nms_thresh`) for polygons above object probability threshold `prob_thresh`.
- render all remaining polygon instances in a label image
- return the label instances image and also the details (coordinates, etc.) of all remaining polygons

In [5]:
axis_norm = (0,1,2)   # normalize channels independently
for w in weights_folders:
    seed = w.split("_")[1]
    split = w.split("_")[2]
    print(f"Predicting {seed} - {split}")
    NMS = model_parameters["visual"]["NMS"] if model_parameters["visual"]["NMS"] != "auto" else None
    model = StarDist3D(None, name=w, basedir=path_weights)
    prob_thresh = model_parameters["visual"]["prob_thresh"] if model_parameters["visual"]["prob_thresh"] != "auto" else None
    image = normalize(image, 1,99.8, axis=axis_norm)
    labels, details = model.predict_instances(image, prob_thresh=prob_thresh, nms_thresh=NMS, verbose=True)
    save_path = path_images / f"../Analysis/stardist/default/stardist_{split}_{seed}.tif"
    imwrite(str(save_path), labels)
    del model

Predicting 34936339 - 10
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.519435, nms_thresh=0.3.
predicting instances with nms_thresh = 0.3
non-maximum suppression...
NMS took 5.0266 s
keeping 572/8650 polyhedra
render polygons...
Predicting 34936397 - 10
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.528682, nms_thresh=0.5.
predicting instances with nms_thresh = 0.5
non-maximum suppression...
NMS took 6.6283 s
keeping 502/8341 polyhedra
render polygons...
Predicting 34936345 - 10
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.53442, nms_thresh=0.3.
predicting instances with nms_thresh = 0.3
non-maximum suppression...
NMS took 4.4441 s
keeping 542/8553 polyhedra
render polygons...
Predicting 34936339 - 20
Loading network weights from 'weights_best.h5'