## 02 Stardist Training

The stardist model is used to create instance masks for a subsequent watershed. The main drawback of stardist is that one doesn't get accurate cell borders due to the limited number of polygons (depending on computational expenses). Therefore, this model 'only' operates with 32 angles. For details on how to install stardist, please check out their [Github](https://github.com/mpicbg-csbd/stardist).

In [None]:
import glob
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

import tqdm
import datetime
import skimage

import utils.dirtools
import utils.data_provider
import utils.augmentation
import stardist
from stardist.models import Config2D

lbl_cmap = stardist.random_label_cmap()

### Image import

Because Stardist doesn't allow for runtime reading of images, we have to import them here. The `stardist_importer` import the images and normalizes them for the network to use.

In [None]:
root = '.data/train_val'

In [None]:
# Import paths
X = sorted(glob.glob(f'{root}/images/*.tif'))
Y = sorted(glob.glob(f'{root}/masks/*.tif'))

In [None]:
# Train / valid split
x_train, x_valid, y_train, y_valid = utils.dirtools.train_valid_split(x_list=X, y_list=Y, valid_split=0.2)

# Import images – stardist doesnt allow for runtime reading
x_train, y_train = utils.data_provider.stardist_importer(x_train, y_train)
x_valid, y_valid = utils.data_provider.stardist_importer(x_valid, y_valid)

In [None]:
# Sanity check
ix = np.random.randint(0, len(X)-1)

fig, ax = plt.subplots(1, 2, figsize=(10, 4))
ax[0].imshow(skimage.io.imread(X[ix]))
ax[0].set_title(f'Original Image – #{ix}')
ax[1].imshow(skimage.io.imread(Y[ix]), cmap=lbl_cmap)
ax[1].set_title('Ground Truth')
plt.show()

### Configure model

The hyperparameters and augmentations can be changed below. We decided on the number of rays as 32 due to stardists reccomendations as can be seen [here](https://nbviewer.jupyter.org/github/mpicbg-csbd/stardist/blob/master/examples/2D/1_data.ipynb). If the starnet model passes some more tests, one could probably decrease the number of rays to 16 or 8 as only the centroid location is actually used.

In [None]:
print(Config2D.__doc__)

In [None]:
# Hyperparameters
conf = Config2D (
    train_epochs = 4, # 400,
    train_steps_per_epoch = 10, # 100,
    n_rays = 32,
    grid = (2, 2),
    use_gpu = False, #and gputools_available(),
    unet_n_depth = 3,
    n_channel_in = 1 if x_train[0].ndim==2 else x_train[0].shape[-1],
    train_patch_size = (256, 256),
)

# ImageDataGenerator
data_gen_args = dict(horizontal_flip=True,
                     vertical_flip=True,
                     rotation_range=90,
                     zoom_range=0.5,
                     shear_range=0.5,
                     width_shift_range=0.5,
                     height_shift_range=0.5,
                     fill_mode='reflect',
                     data_format='channels_last')

#vars(conf)

In [None]:
model_name = f"{datetime.date.today().strftime('%Y%m%d')}_Star"
model = stardist.models.StarDist2D(conf, name=model_name, basedir='./models/')

In [None]:
# Check filed of view size
median_size = stardist.calculate_extents(list(y_train), np.median)
fov = np.array(model._axes_tile_overlap('YX'))
print('Median object size > FOV') if any(median_size > fov) else print('')

### Training

Connect to [http://localhost:6006/](http://localhost:6006/) after activating tensorboard.

    $ tensorboard --logdir=.


In [None]:
model.train(x_train, y_train,
            validation_data=(x_valid, y_valid),
            augmenter=utils.augmentation.StarAugment(**data_gen_args).augment)

### Threshold optimization

While the default values for the probability and non-maximum suppression thresholds already yield good results in many cases, we still recommend to adapt the thresholds to your data. The optimized threshold values are saved to disk and will be automatically loaded with the model.

In [None]:
model.optimize_thresholds(x_valid, y_valid)