# StarDist training for napari-mAIcrobe in Google Colab
---

This notebook is heavily inspired by the 2D demo notebook provided by the StarDist authors: 

https://github.com/stardist/stardist/blob/main/examples/2D/2_training.ipynb

This notebook is adapted to the use case of napari-mAIcrobe, and specifically to the segmentation of *S. aureus* cells in the membrane dye (Nile Red) channel.

Patch calculation assume images are 2430x2430. You might need to adapt the code if your images are of different size.
By default it performs data augmentation using random rotations and flips. You might want to adapt this to your use case.


In [None]:
!pip install csbdeep
!pip install tifffile # contains tools to operate tiff-files
!pip install stardist # contains tools to operate STARDIST.
!pip install gputools # improves STARDIST performances

In [None]:
import time
import sys
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, normalize

from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist.models import Config2D, StarDist2D

from skimage.util.shape import view_as_blocks

from itertools import product, chain

np.random.seed(12)
lbl_cmap = random_label_cmap()

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import output
    output.enable_custom_widget_manager()
    from google.colab import drive
    drive.mount('/content/drive')

def crop_and_block(img):
    blockies = view_as_blocks(img[63:-63,63:-63],block_shape=(256,256))
    croppies = [blockies[i,j,:,:] for i,j in product(range(9),repeat=2)]
    return croppies

def plot_img_label(img, lbl, img_title="image", lbl_title="label", **kwargs):
    fig, (ai,al) = plt.subplots(1,2, figsize=(12,5), gridspec_kw=dict(width_ratios=(1.25,1)))
    im = ai.imshow(img, cmap='gray', clim=(0,1))
    ai.set_title(img_title)
    fig.colorbar(im, ax=ai)
    al.imshow(lbl, cmap=lbl_cmap)
    al.set_title(lbl_title)
    plt.tight_layout()

def random_fliprot(img, mask):
    assert img.ndim >= mask.ndim
    axes = tuple(range(mask.ndim))
    perm = tuple(np.random.permutation(axes))
    img = img.transpose(perm + tuple(range(mask.ndim, img.ndim)))
    mask = mask.transpose(perm)
    for ax in axes:
        if np.random.rand() > 0.5:
            img = np.flip(img, axis=ax)
            mask = np.flip(mask, axis=ax)
    return img, mask

def random_intensity_change(img):
    img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)
    return img


def augmenter(x, y):
    """Augmentation of a single input/label image pair.
    x is an input image
    y is the corresponding ground-truth label image
    """
    x, y = random_fliprot(x, y)
    x = random_intensity_change(x)
    # add some gaussian noise
    sig = 0.02*np.random.uniform(0,1)
    x = x + sig*np.random.normal(0,1,x.shape)
    return x, y

# User input

In [None]:
#@title ## User input

#@markdown ###Load data:
Training_source = '' #@param {type:"string"}
Training_target = '' #@param {type:"string"}

#@markdown ###Model parameters:
# 32 is a good default choice (see 1_data.ipynb)
n_rays = 32 #@param {type:"integer"}

gridx = 2 #@param {type:"integer"}
gridy = 2 #@param {type:"integer"}

batch_size = 4 #@param {type:"integer"}
patch_size = 256 #@param {type:"integer"}

learning_rate = 0.0003 #@param {type:"number"}

valsplit = 0.15 #@param {type:"number"}

modelname = '' #@param {type:"string"}
modelpath = '' #@param {type:"string"}
pretrainedmodelpath = '' #@param {type:"string"}

#@markdown ###Training parameters:
n_epochs = 300 #@param {type:"integer"}
steps_per_epoch = 400 #@param {type:"integer"}


# Load data

In [None]:
X_ = sorted(glob(Training_source + '/*.tif'))
Y_ = sorted(glob(Training_target + '/*.tif'))
assert all(Path(x).name==Path(y).name for x,y in zip(X_,Y_))

X_img = list(map(imread,X_)) 
Y_img = list(map(imread,Y_)) 

print(len(X_img),X_img[0].shape)
print(len(Y_img),Y_img[0].shape)

X_split = list(map(crop_and_block,X_img))
X = list(chain(*X_split))

Y_split = list(map(crop_and_block,Y_img))
Y = list(chain(*Y_split))

print(len(X),X[0].shape,X[-1].shape)
print(len(Y),Y[0].shape,Y[-1].shape)

n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]
X = [normalize(x,1,99.8,axis=(0,1)) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]

rng = np.random.RandomState(12)
ind = rng.permutation(len(X))
n_val = max(1, int(round(valsplit * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train]
print('number of images: %3d' % len(X))
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

i = min(12, len(X)-1)
img, lbl = X[i], Y[i]
assert img.ndim in (2,3)
img = img if (img.ndim==2 or img.shape[-1]==3) else img[...,0]
plot_img_label(img,lbl)
None;

# Load model

In [None]:
# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = False and gputools_available()


conf = Config2D (
    n_rays       = n_rays,
    grid         = (gridx, gridy),
    use_gpu      = False,
    n_channel_in = n_channel,
    train_batch_size = batch_size,
    train_patch_size = (patch_size, patch_size),
    train_learning_rate = learning_rate)

if use_gpu:
    from csbdeep.utils.tf import limit_gpu_memory
    # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations
    limit_gpu_memory(None, allow_growth=True)

model = StarDist2D(conf, name=modelname, basedir=modelpath)
model.load_weights(pretrainedmodelpath)

median_size = calculate_extents(list(Y), np.median)
fov = np.array(model._axes_tile_overlap('YX'))
print(f"median object size:      {median_size}")
print(f"network field of view :  {fov}")
if any(median_size > fov):
    print("WARNING: median object size larger than field of view of the neural network.")

# Start training

In [None]:
start = time.time()
history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter, epochs=n_epochs, steps_per_epoch=steps_per_epoch)

# Displaying the time elapsed for training
dt = time.time() - start

model.optimize_thresholds(X_val, Y_val)