# 3DL_NuCount (Inference notebook)

Author: Fabrice Daian

Inspired from original Stardist3D example notebook: https://github.com/stardist/stardist/blob/master/examples/3D/3_prediction.ipynb

#### Imports

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
from tqdm.notebook import tqdm as tqdm
import tifffile
from csbdeep.utils import normalize

from stardist.models import StarDist3D

np.random.seed(42)

#### Parameters

In [None]:
#models
model_name       = "3dl_nucount"
model_basedir    = "models"

#images
image_to_predict = "./images/22-07-13 Mef2 twi GFP M1BP RNAi fort 40x Wstack4 zoom 0.7.tif"
image_result     = "./results/result.tif"

#normalization
axis_norm        = (0,1,2) # normalize channels independently

#tiling 
patchsize        = 784
imsize           = 2048  # imsize must be divisible by 32 to fit into the network



#### Model Loading
(Check _nucount_trainin.ipynb_ for training workflow)

In [None]:
model = StarDist3D(None, name=model_name,basedir=model_basedir)

#### Read one 3D volume

In [None]:
A = tifffile.imread(image_to_predict)

#### Tiling and normalization

In [None]:
l = list(range(0,imsize,patchsize))
c = list(range(0,imsize,patchsize))

X=[]
for i in l:
    for j in c:
        X.append(A[:,i:i+patchsize,j:j+patchsize])
        
# Normalization
X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]                   

#### Inference by tiles

In [None]:

p=[]
for i in tqdm(range(len(X)), desc="Processing", leave=False):
# for i in range(len(X)):
#     print(i)
    labels, details = model.predict_instances(X[i])
    p.append(labels)

    

#### Volume reconstruction from tile

In [None]:
# result image initialization
R=np.ones(A.shape)*-1

l = list(range(0,imsize,patchsize))
c = list(range(0,imsize,patchsize))

k=0
curmax=0
for i in l:
    for j in c:
        
        R[:,i:i+patchsize,j:j+patchsize]=p[k]+(curmax+1)
        tmp = np.copy(R[:,i:i+patchsize,j:j+patchsize])
        tmp[tmp==(curmax+1)]=0
        R[:,i:i+patchsize,j:j+patchsize]=np.copy(tmp)
        
        curmax=np.max(p[k])
        k=k+1



#### Counting

In [None]:
print("Count nuclei:", np.max(R))

#### Save result

In [None]:
tifffile.imwrite(image_result,R)