**In case of problems or questions, please first check the list of [Frequently Asked Questions (FAQ)](https://stardist.net/docs/faq.html).**

Please shutdown all other training/prediction notebooks before running this notebook (as those might occupy the GPU memory otherwise).

In [1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib
# matplotlib.rcParams["image.interpolation"] = None
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tifffile import imread, imwrite
from csbdeep.utils import Path, normalize
from csbdeep.io import save_tiff_imagej_compatible

from stardist import random_label_cmap
from stardist.models import StarDist3D

np.random.seed(6)
lbl_cmap = random_label_cmap()

# Data

We assume that data has already been downloaded in via notebook [1_data.ipynb](1_data.ipynb).  
We now load images from the sub-folder `test` that have not been used during training.

In [2]:
import pathlib as pt
path_images = pt.Path().home() / "Desktop/Code/CELLSEG_BENCHMARK/RESULTS/SPLITS"
path_weights = path_images / "STARDIST/weights"
weights_folders = [
    path_weights / "c1_5",
    path_weights / "c1-4_v",
    path_weights / "c1245_v",
]
splits_folders = [
    "10-90",
    # "20-80",
    # "60-40",
    # "80-20",
]
path_images = path_images / "IMAGES"
images = [
    imread(str(path_images / "small_isotropic_visual.tif")),
    imread(str(path_images / "c5image.tif")),
    imread(str(path_images / "c3image.tif"))
]

# Load trained model

If you trained your own StarDist model (and optimized its thresholds) via notebook [2_training.ipynb](2_training.ipynb), then please set `demo_model = False` below.

In [3]:
# Create a model for each dataset and split
models = {}
for w in weights_folders:
    splits_models = {}
    for s in splits_folders:
        split_dict = None
        splits_models[s] = split_dict
    models[w.name] = splits_models

In [4]:
model_parameters = {
    "c1_5": {
        "NMS": 0.3,
        "prob_thresh": 0.8
    },
    "c1-4_v": {
        "NMS": "auto",
        "prob_thresh": "auto"
    },
    "c1245_v": {
        "NMS": 0.5,
        "prob_thresh": 0.7
    }
}

## Prediction

Make sure to normalize the input image beforehand or supply a `normalizer` to the prediction function.

Calling `model.predict_instances` will
- predict object probabilities and star-convex polygon distances (see `model.predict` if you want those)
- perform non-maximum suppression (with overlap threshold `nms_thresh`) for polygons above object probability threshold `prob_thresh`.
- render all remaining polygon instances in a label image
- return the label instances image and also the details (coordinates, etc.) of all remaining polygons

In [5]:
images[0].dtype

dtype('int32')

In [6]:
axis_norm = (0,1,2)   # normalize channels independently
for i, (dataset, split_dict) in enumerate(models.items()):
    for split, _ in split_dict.items():
        print(f"Predicting {dataset} - {split}")
        img = images[0] if "c1_5" in dataset else images[1] if "c1-4_v" in dataset else images[2]
        NMS = model_parameters[dataset]["NMS"] if model_parameters[dataset]["NMS"] != "auto" else None
        model = StarDist3D(None, name=split, basedir=weights_folders[i])
        prob_thresh = model_parameters[dataset]["prob_thresh"] if model_parameters[dataset]["prob_thresh"] != "auto" else None
        img = normalize(img, 1,99.8, axis=axis_norm)
        labels, details = model.predict_instances(img, prob_thresh=prob_thresh, nms_thresh=NMS, verbose=True)
        save_path = path_images / f"../Analysis/{dataset}/sd/stardist_{split.replace('-', '')}_labels.tif"
        imwrite(str(save_path), labels)
        del model

Predicting c1_5 - 10-90
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.476826, nms_thresh=0.4.
predicting instances with nms_thresh = 0.3
non-maximum suppression...
NMS took 0.1346 s
keeping 197/274 polyhedra
render polygons...
Predicting c1-4_v - 10-90
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.165246, nms_thresh=0.3.
predicting instances with nms_thresh = 0.3
non-maximum suppression...
NMS took 0.5974 s
keeping 271/1566 polyhedra
render polygons...
Predicting c1245_v - 10-90
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.219767, nms_thresh=0.3.
predicting instances with nms_thresh = 0.5
non-maximum suppression...
NMS took 0.0020 s
keeping 0/0 polyhedra
render polygons...
