# 3DL_NuCount (Training notebook)

Author: Fabrice Daian

Inspired from original Stardist3D example notebook: https://github.com/stardist/stardist/blob/master/examples/3D/2_training.ipynb

#### Imports

In [18]:
import sys
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = None
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import normalize

from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist import Rays_GoldenSpiral
from stardist.models import Config3D, StarDist3D

import tifffile
import os

np.random.seed(42)
lbl_cmap = random_label_cmap()

ipykernel_launcher.py (4): Support for setting an rcParam that expects a str value to a non-str value is deprecated since 3.5 and support will be removed two minor releases later.


#### Parameters

In [19]:
# Image and label paths
lbl_path            = "./dataset/labels/"
img_name            = "./dataset/0.tif"

# Normalization
axis_norm           = (0,1,2)   # normalize channels independently

# Stardist training
train_patch_size    = (48,96,96)
train_batch_size    = 32

# Models
model_name       = "3dl_nucount"
model_basedir    = "models"
epochs           = 100





#### Read and prepare Training Dataset

In [20]:
# Read labels (manual annotation)
y=[]
for filename in os.listdir(lbl_path):
    y.append(tifffile.imread(os.path.join(lbl_path,filename)))
y=np.array(y)
y=np.sum(y,axis=0)

# Read images
x=tifffile.imread(img_name)
x= np.float32(x)

# Create 4 subvolumes of size 63x128x128 from the original 63x256x256 volume
X=[]
X.append(x[:,:128,:128])
X.append(x[:,:128:,128:])
X.append(x[:,128:,:128])
X.append(x[:,128:,128:])
X=np.array(X)

Y=[]
Y.append(y[:,:128,:128])
Y.append(y[:,:128:,128:])
Y.append(y[:,128:,:128])
Y.append(y[:,128:,128:])
Y=np.array(Y)

# Sanity Check
print(X.shape, Y.shape)


(4, 63, 128, 128) (4, 63, 128, 128)


#### Dataset Normalization

In [None]:
# Stardist Image Normalization procedure

n_channel = 1 if X[0].ndim == 3 else X[0].shape[-1]


if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 3 in axis_norm else 'independently'))
    sys.stdout.flush()

X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]

#### Dataset Split

In [None]:
# Stardist Train/Test split procedure

assert len(X) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] 
print('number of images: %3d' % len(X))
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

#### Data Augmentation strategy

In [None]:
# Stardist Data Augmentation procedure 


def random_fliprot(img, mask, axis=None): 
    if axis is None:
        axis = tuple(range(mask.ndim))
    axis = tuple(axis)
            
    assert img.ndim>=mask.ndim
    perm = tuple(np.random.permutation(axis))
    transpose_axis = np.arange(mask.ndim)
    for a, p in zip(axis, perm):
        transpose_axis[a] = p
    transpose_axis = tuple(transpose_axis)
    img = img.transpose(transpose_axis + tuple(range(mask.ndim, img.ndim))) 
    mask = mask.transpose(transpose_axis) 
    for ax in axis: 
        if np.random.rand() > 0.5:
            img = np.flip(img, axis=ax)
            mask = np.flip(mask, axis=ax)
    return img, mask 

def random_intensity_change(img):
    img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)
    return img

def augmenter(x, y):
    """Augmentation of a single input/label image pair.
    x is an input image
    y is the corresponding ground-truth label image
    """
    # Note that we only use fliprots along axis=(1,2), i.e. the yx axis 
    # as 3D microscopy acquisitions are usually not axially symmetric
    x, y = random_fliprot(x, y, axis=(1,2))
    x = random_intensity_change(x)
    return x, y

In [None]:
# Show random slice from augmented the volume

img, lbl = X[0],Y[0]
plot_img_label(img, lbl)
for _ in range(3):
    img_aug, lbl_aug = augmenter(img,lbl)
    plot_img_label(img_aug, lbl_aug, img_title="image augmented (XY slice)", lbl_title="label augmented (XY slice)")

#### Stardist 3D model configuration

In [None]:
# Stardist Hyper-parameter settings

# Empirical anisotropy
extents = calculate_extents(Y)
anisotropy = tuple(np.max(extents) / extents)
print('empirical anisotropy of labeled objects = %s' % str(anisotropy))

# Rays
n_rays = 96

# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = False and gputools_available()

# Predict on subsampled grid for increased efficiency and larger field of view
grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)

# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data
rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)

# Create the Stardist configuration
conf = Config3D (
    rays             = rays,
    grid             = grid,
    anisotropy       = anisotropy,
    use_gpu          = use_gpu,
    n_channel_in     = n_channel,
    train_patch_size = train_patch_size,
    train_batch_size = train_batch_size,
)
print(conf)
vars(conf)

#### Stardist3D Model creation

In [None]:
model = StarDist3D(conf, name=model_name, basedir=model_basedir)

#### Stardist3D Model training

In [None]:
model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), epochs=epochs)

#### Threshold optimization for model inference

In [None]:
# Binarization threshold inference on validation set

model.optimize_thresholds(X_val, Y_val)