# Nuclei segmentation using StarDist

With this example, we will illustrate how StarDist can be easily trained to segment 2D nuclei images. 

Data were kindly provided by Marcelo Nöllmann, from the Centre of Biologie Structurale, Montpellier (FRANCE). 

This notebook was directly inspired from the existing notebooks created by Uwe Schmidt (https://github.com/stardist/stardist).

## I - packages installation: 
Installation of the specific packages for SytarDist: 
- csbdeep is a library of tools dedicated for DL and handling / processing the data
- stardist is the library for network
- augmend is a package for image augmentation

In [None]:
!pip install csbdeep
!pip install stardist
!pip install git+https://github.com/stardist/augmend.git

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division

## Libraries for files management and mathematical operations

import numpy as np
import time, os, sys
import numpy as np
from tqdm import tqdm # for displaying a progression bar
from glob import glob # recursive search of files or folder based on a specific pattern
# import wget
from urllib.parse import urlparse

## Libraries for displaying images and plotting graphs 

import matplotlib
matplotlib.rcParams["image.interpolation"] = None
matplotlib.rcParams['figure.dpi'] = 300
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

## Libraries for image management and manipulaption

from tifffile import imread, imwrite # handle tif file format
import skimage.io

## Specific libraries for Deep Learning and image segmentation

from csbdeep.utils import Path, normalize
from csbdeep.io import save_tiff_imagej_compatible

from augmend import Augmend, FlipRot90, Elastic, Identity, IntensityScaleShift, AdditiveNoise, Scale

from stardist import random_label_cmap, _draw_polygons, export_imagej_rois
from stardist import calculate_extents, gputools_available, Rays_GoldenSpiral
from stardist import fill_label_holes, relabel_image_stardist, random_label_cmap
from stardist.matching import matching, matching_dataset

np.random.seed(6)
lbl_cmap = random_label_cmap()
print('importing libraries done')

## I - Data importation

The data are 2d 16bit images saved in the tif format on a google drive. All images are paired :

- one raw image
- one labelled image where each nuclei are instanciated, meaning that each single nuclei is assigned a unique positive pixel value. By default the background is set to 0. 

In [None]:
from google.colab import drive
drive.mount('/content/drive')

StarDist exists in 2d or 3d. For this example, we will work with 2d images but the principle is the same when working with 3d images.
The parameters for the two options are describe below. Note that for this example, the data have already been splitted into two data sets: 
- one for the training
- one for testing

In [None]:
option = '2d'

if option == '3d':
    data_training = ...
    data_testing = ...
    dest_path = ...
    model_name = ...
    n_channel = 1
    axis_norm = (0,1,2)
    n_dim = 3
    axis_augmentation = (1,2)
else:
    data_training = "/content/drive/MyDrive/Deep_learning_formation_MRI/Doc_JB_2022/Notebooks for workshop /Data/Nuclei_segmentation/training_data"
    data_testing = "/content/drive/MyDrive/Deep_learning_formation_MRI/Doc_JB_2022/Notebooks for workshop /Data/Nuclei_segmentation/testing_data"
    dest_path = "/content/drive/MyDrive/Deep_learning_formation_MRI/Doc_JB_2022/Notebooks for workshop /Data/Nuclei_segmentation"
    model_name = 'stardist_embryos_2d_2022_03_21'
    n_channel = 1
    axis_norm = (0,1)
    n_dim = 2
    axis_augmentation = (0,1)
    IMG_HEIGHT = 256
    IMG_WIDTH = 256

# Depending on the selected option, load the StarDist model accordingly
if option == '3d':
    from stardist.models import StarDist3D, Config3D
else:
    from stardist.models import Config2D, StarDist2D, StarDistData2D

Load the data using the get_data method.

In [None]:
def get_data(file_list):
    file_list = sorted(file_list)
    im_list = list(map(imread,file_list))
    return im_list

In [None]:
# load the raw images for the training
X = get_data(glob(os.path.join(data_training, 'raw', '*.tif')))
print(f'number of training raw images found : {len(X)}')

# load the label images
Y = get_data(glob(os.path.join(data_training, 'label_class', '*.tif')))
print(f'number of training label images found : {len(Y)}')

# load the raw images for the training
X_test = get_data( glob(os.path.join(data_testing, 'raw', '*.tif')))
print(f'number of training raw images found : {len(X_test)}')

# load the label images
Y_test = get_data(glob(os.path.join(data_testing, 'label_class', '*.tif')))
print(f'number of training label images found : {len(Y_test)}')

## II - Data visualization and analysis:

Below an example of a pair of raw / labelled images is displayed. Since StarDist is performing instanciation, each single nuclei is assigned a unique positive pixel value. 

In [None]:
# select an image from the loaded data
n_im = np.random.randint(0, len(X)-1)
raw, lbl = X[n_im], Y[n_im]

if option == '3d':
    # plot the xy-MIP of the raw image and its associated label image
    plt.figure(figsize=(16,10))
    plt.subplot(121); plt.imshow(np.max(raw,axis=0),cmap='gray')
    plt.axis('off'); plt.title('MIP of raw image (XY slice)')
    plt.subplot(122); plt.imshow(np.max(lbl, axis=0),cmap=lbl_cmap)
    plt.axis('off'); plt.title('GT labels (XY slice)')

    # plot the xz-MIP of the raw image and its associated label image
    plt.figure(figsize=(16,10))
    plt.subplot(121); plt.imshow(np.max(raw,axis=1),cmap='gray')
    plt.axis('off'); plt.title('MIP of raw image (XZ slice)')
    plt.subplot(122); plt.imshow(np.max(lbl, axis=1),cmap=lbl_cmap)
    plt.axis('off'); plt.title('GT labels (XZ slice)')

else:
    # plot a raw image and its associated label
    plt.figure(figsize=(16,10))
    plt.subplot(121); plt.imshow(raw,cmap='gray')
    plt.axis('off'); plt.title('raw image')
    plt.subplot(122); plt.imshow(lbl,cmap=lbl_cmap)
    plt.axis('off'); plt.title('GT labels')

None;

Normalize the raw data. 

In [None]:
X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
# Y = [fill_label_holes(y) for y in tqdm(Y)]

X_test = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X_test)]
# Y_test = [fill_label_holes(y) for y in tqdm(Y_test)]

Compute the anisotropy of all the labelled objects. This is important for StarDist since the segmentation is based on the reconstruction of star-convex polygons defined by a specific set of rays. For isotropy objects, the number of rays can be low.  

In [None]:
# compute the anisotropy of the label objects
extents = calculate_extents(Y)
anisotropy = tuple(np.max(extents) / extents)
print('Empirical anisotropy of labeled objects = %s' % str(anisotropy))

# compute the min dimensions of the training set
im_dim = np.zeros((len(X), n_dim))
for n,im in enumerate(X):
    im_dim[n,:] = im.shape
print(f'The smallest dimensions of the training set are : {np.min(im_dim, axis=0)}')

# compute the median size of the nuclei
median_size = calculate_extents(Y, np.median)
print(f"Median object size:      {median_size}")

Below is an illustration of the effect of the number of rays on the reconstruction of the nuclei.

In [None]:
# defines the number of rays to test for the reconstruction
n_rays = [2**i for i in range(2,8)]

# select an image from the loaded data
n_im = np.random.randint(0, len(X)-1)
raw, lbl = X[n_im], Y[n_im]

# plot the reconstruction
ig, ax = plt.subplots(2,3, figsize=(16,11))
for a,r in zip(ax.flat,n_rays):
    a.imshow(relabel_image_stardist(lbl, n_rays=r), cmap=lbl_cmap)
    a.set_title('Reconstructed (%d rays)' % r)
    a.axis('off')
plt.tight_layout();

## III- Defining the parameters for StarDist

Below the parameters for the network are defined. The number of rays as well the number of epochs, the batch size, the architecture of the network,  etc.

In [None]:
# 96 is a good default choice for anisotropic 3D data (see 1_data.ipynb) - 32 is usually fine for anisotropic 2D data
if option == '3d':
    n_rays = 96
    train_patch_size = (28,128,128)
else:
    n_rays = 32
    train_patch_size = (128,128)

# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = True and gputools_available()

# Predict on subsampled grid for increased efficiency and larger field of view
grid = tuple(1 if a > 1.5 else 4 for a in anisotropy)

if option == '3d':
    
    # Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data
    rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)
    conf = Config3D (
        rays             = rays,
        grid             = grid,
        anisotropy       = anisotropy,
        use_gpu          = use_gpu,
        n_channel_in     = 1,
        # adjust for your data below (make patch size as large as possible)
        train_patch_size = train_patch_size,
        train_batch_size = 2,
        backbone = 'unet',
        train_epochs = 1,
        train_steps_per_epoch = 100
    )
else:
    conf = Config2D (
        n_rays             = n_rays,
        grid             = grid,
        use_gpu          = use_gpu,
        n_channel_in     = 1,
        # adjust for your data below (make patch size as large as possible)
        train_patch_size = train_patch_size,
        train_batch_size = 2,
        backbone = 'unet', #'resnet'
        train_epochs = 50,
        train_steps_per_epoch = 50
    )
vars(conf)

Compile the model according to the parameters indicated above :

In [None]:
# create the folder if it does not exist yet
if not os.path.exists(dest_path):
    os.mkdir(dest_path)
    print(f"folder {dest_path} created!")

# based on the parameters defined above, create the StarDist model
model_folder = os.path.join(dest_path, model_name)
if option == '3d':
    model = StarDist3D(conf, name=model_name, basedir=model_folder)
else:
    model = StarDist2D(conf, name=model_name, basedir=model_folder)

Compute the median size of the nuclei and check if the neural network has a large enough field of view to see up to the boundary of most objects.


In [None]:
median_size = calculate_extents(Y, np.median)
if option == '3d':
    fov = np.array(model._axes_tile_overlap('ZYX'))
else:
    fov = np.array(model._axes_tile_overlap('YX'))
    
print(f"median object size:      {median_size}")
print(f"network field of view :  {fov}")
if any(median_size > fov):
    print("WARNING: median object size larger than field of view of the neural network.")

Add data augmentation.

In [None]:
elastic_kwargs = dict(axis=axis_augmentation, amount=10, use_gpu=model.config.use_gpu)
#scale_kwargs = dict(axis=1, amount=1.5, use_gpu=model.config.use_gpu)
aug = Augmend()
#aug.add([Scale(order=0,**scale_kwargs),Scale(order=0,**scale_kwargs)], probability=0.25)
aug.add([FlipRot90(axis=axis_augmentation),FlipRot90(axis=axis_augmentation)])
aug.add([Elastic(order=0,**elastic_kwargs),Elastic(order=0,**elastic_kwargs)], probability=0.5)
aug.add([IntensityScaleShift(scale=(.6,2),shift=(-.2,.2)),Identity()])
aug.add([AdditiveNoise(sigma=(0.05,0.05)),Identity()], probability=0.25)

def augmenter(x,y):
    return aug([x,y])

Launch the training : 

In [None]:
model.train(X, Y, validation_data=(X_test,Y_test), augmenter=augmenter)

## IV- Reconstruction and visualization of the results

Finalization of the reconstruction and display the results.

In [None]:
model.optimize_thresholds(X_test, Y_test)

In [None]:
# using the newly trained model, calculate the segmentation prediction using the test images
Y_test_pred = [model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)[0]
              for x in tqdm(X_test)]

In [None]:
def plot_img_label(img, lbl, img_title="image (XY slice)", lbl_title="label (XY slice)", **kwargs):
    
    fig, (ai,al) = plt.subplots(1,2, figsize=(12,5), gridspec_kw=dict(width_ratios=(1.25,1)))
    if len(img.shape) == 3:
        im = ai.imshow(np.max(img, axis=0), cmap='gray', clim=(0,1))
    else:
        im = ai.imshow(img, cmap='gray', clim=(0,1))
    ai.set_title(img_title)    
    fig.colorbar(im, ax=ai)
    if len(img.shape) == 3:
        al.imshow(np.max(lbl, axis=0), cmap=lbl_cmap)
    else:
        al.imshow(lbl, cmap=lbl_cmap)
    al.set_title(lbl_title)
    plt.tight_layout()

n_image = np.random.randint(0, len(X_test)-1)
plot_img_label(X_test[n_image],Y_test[n_image], lbl_title="label GT (XY slice)")
plot_img_label(X_test[n_image],Y_test_pred[n_image], lbl_title="label Pred (XY slice)")

In [None]:
taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
stats = [matching_dataset(Y_test, Y_test_pred, thresh=t, show_progress=False) for t in tqdm(taus)]

fig, (ax1,ax2) = plt.subplots(1,2, figsize=(15,5))

for m in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
    ax1.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax1.set_xlabel(r'IoU threshold $\tau$')
ax1.set_ylabel('Metric value')
ax1.grid()
ax1.legend()

for m in ('fp', 'tp', 'fn'):
    ax2.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax2.set_xlabel(r'IoU threshold $\tau$')
ax2.set_ylabel('Number #')
ax2.grid()
ax2.legend();