In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import Path, normalize
from csbdeep.io import save_tiff_imagej_compatible

from stardist import random_label_cmap
from stardist.models import StarDist3D


np.random.seed(6)
lbl_cmap = random_label_cmap()

In [None]:
PATH_RESULTS = Path("C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/WNET OTHERS/").resolve()

PATHS = {
    "Mouse skull": (
        Path("./MouseSkull"), 
        PATH_RESULTS / "Mouse-Skull-Nuclei-CBG/X1.tif",
        PATH_RESULTS / "Mouse-Skull-Nuclei-CBG/stardist"
        ),
    "Platy-ISH" : (
        Path("./PlISH"), 
        PATH_RESULTS / "Platynereis-ISH-Nuclei-CBG/X01_cropped_downsampled.tif", 
        PATH_RESULTS / "Platynereis-ISH-CBG/stardist"),
    "Platy-Nuc" : (
        Path("./PlNuc"), 
        PATH_RESULTS / "Platynereis-Nuclei-CBG/downsmapled_cropped_dataset_hdf5_100_0.tif",
        PATH_RESULTS / "Platynereis-Nuclei-CBG/stardist"),
}

In [None]:
for key, paths in PATHS.items():
    if key == "Mouse skull": # done below with tiled predictions
        continue
    p_weights, p_images, p_results = (Path(p) for p in paths)
    p_weights = p_weights.resolve()
    p_images = p_images.resolve()
    print(f"Model: {key}")
    print(f"Loading model from {p_weights}")
    print(f"Loading images from {p_images}")
    print(f"Saving results to {p_results}")
    if not p_results.exists():
        p_results.mkdir(parents=True, exist_ok=False)

    X = str(p_images.resolve())
    X = imread(X)
    print(f"Loaded shape {X.shape} from {p_images}")

    n_channel = 1 if X.ndim == 3 else X[0].shape[-1]
    axis_norm = (0,1,2)   # normalize channels independently
    # axis_norm = (0,1,2,3) # normalize channels jointly
    if n_channel > 1:
        print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
        
    model = StarDist3D(None, name='stardist', basedir=str(p_weights / "models"))
    img = normalize(X, 1,99.8, axis=axis_norm)
    labels, details = model.predict_instances(img)
    save_tiff_imagej_compatible(
        f"{str(p_results)}/{key}_labels.tif",
        labels, axes='ZYX')

In [None]:
p_weights, p_images, p_results = PATHS["Mouse skull"]

X = imread(str(p_images.resolve()))
n_channel = 1 if X.ndim == 3 else X[0].shape[-1]
axis_norm = (0,1,2)   # normalize channels independently


# Load the model
model = StarDist3D(None, name='stardist', basedir=str(p_weights / "models"))

# Normalize the image
img = normalize(X, 1, 99.8, axis=axis_norm)

In [None]:
labels, details = model.predict_instances(img, n_tiles=(1,2,2))


In [None]:
from tifffile import imwrite
imwrite(f"{str(p_results)}/Mouse_skull_labels.tif", labels)