In [1]:
from stardist.models import StarDist3D
from pathlib import Path
import numpy as np
from tifffile import imread, imwrite
from csbdeep.utils import normalize
from tqdm import tqdm

In [2]:
DATA_PATH = (Path.home() / "Desktop/Code/CELLSEG_BENCHMARK/RESULTS/SUPERVISED_PERF_FIG/").resolve()
fold = 3
# pretrained_path = DATA_PATH / f"weights/fold_{fold}/cellpose/supervised_perf_fig.cellpose" 
pretrained_path = DATA_PATH / f"weights/fold_{fold}" 
assert pretrained_path.is_dir()
images_path = DATA_PATH / f"TRAINING/fold{fold}/IMAGES"
assert images_path.is_dir()
images_paths = list(images_path.glob("*tif"))

images_paths

[WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/SUPERVISED_PERF_FIG/TRAINING/fold3/IMAGES/c1image.tif'),
 WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/SUPERVISED_PERF_FIG/TRAINING/fold3/IMAGES/c2image.tif'),
 WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/SUPERVISED_PERF_FIG/TRAINING/fold3/IMAGES/c4image.tif'),
 WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/SUPERVISED_PERF_FIG/TRAINING/fold3/IMAGES/c5image.tif'),
 WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/SUPERVISED_PERF_FIG/TRAINING/fold3/IMAGES/visual.tif')]

In [3]:
test_image = imread(images_paths[0])
n_channel = 1 if test_image.ndim == 3 else test_image.shape[-1]
axis_norm = (0,1,2)   # normalize channels independently
# axis_norm = (0,1,2,3) # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))

In [4]:
save_dir = DATA_PATH / f"TRAINING/fold{fold}/stardist"
if not save_dir.is_dir():
    save_dir.mkdir(parents=True, exist_ok=True)

In [5]:
model = StarDist3D(None, name='stardist', basedir=pretrained_path)

NMS_threshold_values = np.arange(0.1, 1, 0.1)
prob_thresh_values = np.arange(0.1, 1, 0.1)
NMS_threshold_values = np.round(NMS_threshold_values, 2)
prob_thresh_values = np.round(prob_thresh_values, 2)

Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.5184, nms_thresh=0.3.


In [6]:
for image_p in tqdm(images_paths):
    im_name = image_p.stem
    print(im_name)
    # for resample_val in resample:
    for prob_t in prob_thresh_values:
        for nms in NMS_threshold_values:
            save_name = str(save_dir / f"{im_name}_stardist_masks_nms_{nms}_prob_{prob_t}.tif")
            if Path(save_name).is_file():
                continue
            img=imread(str(image_p))
            img = normalize(img, 1,99.8, axis=axis_norm)
            masks, details = model.predict_instances(img, prob_thresh=prob_t, nms_thresh=nms)
            imwrite(save_name, masks)

  0%|          | 0/5 [00:00<?, ?it/s]

c1image


 20%|██        | 1/5 [08:56<35:46, 536.67s/it]

c2image


 40%|████      | 2/5 [15:58<23:28, 469.39s/it]

c4image


 60%|██████    | 3/5 [24:41<16:27, 493.76s/it]

c5image


 80%|████████  | 4/5 [30:42<07:21, 441.29s/it]

visual


100%|██████████| 5/5 [59:35<00:00, 715.16s/it]


In [17]:
fold_1_eval = DATA_PATH / f"INFERENCE/fold1/visual.tif"
fold_1_eval = imread(fold_1_eval)
fold_2_eval = DATA_PATH / f"INFERENCE/fold2/c5image.tif"
fold_2_eval = imread(fold_2_eval)
fold_3_eval = DATA_PATH / f"INFERENCE/fold3/c3image.tif"
fold_3_eval = imread(fold_3_eval)

In [18]:
fold_1_nms = 0.3
fold_1_thresh = 0.8

fold_2_nms = 0.7
flod_2_thresh = 0.6

fold_3_nms = 0.5
fold_3_thresh = 0.7

In [21]:
threshs = [fold_1_thresh, flod_2_thresh, fold_3_thresh]
nms = [fold_1_nms, fold_2_nms, fold_3_nms]
for i, image_eval in enumerate([fold_1_eval, fold_2_eval, fold_3_eval]):
    img = normalize(image_eval, 1, 99.8, axis=axis_norm)
    masks, details = model.predict_instances(img, prob_thresh=threshs[i], nms_thresh=nms[i])
    imwrite(str(DATA_PATH / f"INFERENCE/fold{i+1}/stardist/stardist_fold{i+1}.tif"), masks)