# StarDist Training Notebook demo, Feb 2024
## This notebook will walk you through using the demo model, and how the training works including:
* Data Augmentation
* Train/Test split
* Number of rays
* Metrics
* Tau


In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import os
import sys
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = None
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, normalize, download_and_extract_zip_file


from stardist import fill_label_holes, relabel_image_stardist, random_label_cmap, calculate_extents, gputools_available
from stardist.matching import matching, matching_dataset
from stardist.models import Config2D, StarDist2D, StarDistData2D

np.random.seed(42)
lbl_cmap = random_label_cmap()

In [None]:
#Check to see if GPU is enabled
gputools_available()

# Data

This code block downloads the sample data from the internet, specifically the github repository from the authors

### When formatting your data for training, pay careful attention to how the data is formatted in the sample case. You will want to name your images in the same manner.

### Note! Labels that are made by humans are typically called the "ground truth"

<div class="alert alert-block alert-info">
Training data (for input `X` with associated label masks `Y`) can be provided via lists of numpy arrays, where each image can have a different size. Alternatively, a single numpy array can also be used if all images have the same size.  
Input images can either be two-dimensional (single-channel) or three-dimensional (multi-channel) arrays, where the channel axis comes last. Label images need to be integer-valued.
</div>

In [None]:
download_and_extract_zip_file(
    url       = 'https://github.com/stardist/stardist/releases/download/0.1.0/dsb2018.zip',
    targetdir = 'data',
    verbose   = 1,
)

In [None]:
X = sorted(glob('data/dsb2018/train/images/*.tif'))
Y = sorted(glob('data/dsb2018/train/masks/*.tif'))
assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))

In [None]:
print('The first three label images are: ',*X[0:3], sep='\n')
print('')
print('The first three label images are: ',*Y[0:3], sep='\n')

In [None]:
X = list(map(imread,X))
Y = list(map(imread,Y))
n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]

Normalize images and fill small label holes.

In [None]:
axis_norm = (0,1)   # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
    sys.stdout.flush()

X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]

In [None]:

assert len(X) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] 
print('number of images: %3d' % len(X))
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

### Training data consists of pairs of input image and label instances.

In [None]:

def plot_img_label(img, lbl, img_title="image", lbl_title="label", **kwargs):
    fig, (ai,al) = plt.subplots(1,2, figsize=(12,5), gridspec_kw=dict(width_ratios=(1.25,1)))
    im = ai.imshow(img, cmap='gray', clim=(0,1))
    ai.set_title(img_title)    
    fig.colorbar(im, ax=ai)
    
    al.imshow(lbl, cmap=lbl_cmap)
    al.set_title(lbl_title)
    plt.tight_layout()

In [None]:
i = min(9, len(X)-1)
img, lbl = X[i], Y[i]
assert img.ndim in (2,3)
img = img if (img.ndim==2 or img.shape[-1]==3) else img[...,0]
plot_img_label(img,lbl)
None;

# Configuration

A `StarDist2D` model is specified via a `Config2D` object.

In [None]:
print(Config2D.__doc__)

# Example image


# See an image with it's ground truth labels

In [None]:
##Make up list of images and their labels

i = min(4, len(X)-1)
img, lbl = X[i], fill_label_holes(Y[i])
assert img.ndim in (2,3)
img = img if img.ndim==2 else img[...,:3]
# assumed axes ordering of img and lbl is: YX(C)

In [None]:
#plot the image

plt.figure(figsize=(16,10))
plt.subplot(121); plt.imshow(img,cmap='gray');   plt.axis('off'); plt.title('Raw image')
plt.subplot(122); plt.imshow(lbl,cmap=lbl_cmap); plt.axis('off'); plt.title('GT labels')
None;

# Fitting ground-truth labels with star-convex polygons

## Rays
### StarDist uses rays from a central point to map out the perimeter of the object.  The more rays coming out, the closer to actual the perimeter will be.  There is a compromise to be made though, the more rays you have, the slower the training will be

#### This will show us if the objects can be fit by star-convex polygons

In [None]:
## Generate images with different number of rays from 2^1 to 2^8

n_rays = [2**i for i in range(2,8)]
scores = []
for r in tqdm(n_rays):
    Y_reconstructed = [relabel_image_stardist(lbl, n_rays=r) for lbl in Y]
    mean_iou = matching_dataset(Y, Y_reconstructed, thresh=0, show_progress=False).mean_true_score
    scores.append(mean_iou)

## Intersection over union
 The traditional measure of the how well a label fits to a ground truth is by measuring the intersection over union. 
 ![image.png](attachment:image.png)

## Here we can see how well we reconstruct the ground truth labels with varying amount of rays.  Obviously four is too few rays, but as we increase the number of rays after a certain point, we get diminishing returns

In [None]:
fig, ax = plt.subplots(2,3, figsize=(16,11))
for a,r in zip(ax.flat,n_rays):
    a.imshow(relabel_image_stardist(lbl, n_rays=r), cmap=lbl_cmap)
    a.set_title('Reconstructed (%d rays)' % r)
    a.axis('off')
plt.tight_layout();

In [None]:
plt.figure(figsize=(8,5))
plt.plot(n_rays, scores, 'o-')
plt.xlabel('Number of rays for star-convex polygon')
plt.ylabel('Reconstruction score (mean intersection over union)')
plt.title("Accuracy of ground truth reconstruction (should be > 0.8 for a reasonable number of rays)")
None;

# Do the Training!
### Below is the different variables we can use to do some training. 

* Alter the number of rays based on how your images look
* If you have memory issues, change the grid size (but default is generally ok)
* Change use_gpu to True if you want to use the GPU for training (you do)
* There are a ton of other variables you can use.  One thing that is often done is transfer learning, where you can use a pre-built model and add data to it to increase it's accuracy


In [None]:
# 32 is a good default choice (see 1_data.ipynb)
n_rays = 32  #number of rays

# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = True and gputools_available()

# Predict on subsampled grid for increased efficiency and larger field of view
grid = (2,2)

conf = Config2D (
    n_rays       = n_rays,
    grid         = grid,
    use_gpu      = use_gpu,
    n_channel_in = n_channel,
)
print(conf)
vars(conf)

In [None]:
if use_gpu:
    from csbdeep.utils.tf import limit_gpu_memory
    # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations
    #limit_gpu_memory(0.8)
    # alternatively, try this:
    limit_gpu_memory(None, allow_growth=True)

**Note:** The trained `StarDist2D` model will *not* predict completed shapes for partially visible objects at the image boundary if `train_shape_completion=False` (which is the default option).

In [None]:
model = StarDist2D(conf, name='stardist_demo', basedir='models')

Check if the neural network has a large enough field of view to see up to the boundary of most objects.

In [None]:
median_size = calculate_extents(list(Y), np.median)
fov = np.array(model._axes_tile_overlap('YX'))
print(f"median object size:      {median_size}")
print(f"network field of view :  {fov}")
if any(median_size > fov):
    print("WARNING: median object size larger than field of view of the neural network.")

# Data Augmentation
You can define a function/callable that applies augmentation to each batch of the data generator.  
We here use an `augmenter` that applies random rotations, flips, and intensity changes, which are typically sensible for (2D) microscopy images (but you can disable augmentation by setting `augmenter = None`).

In [None]:
def random_fliprot(img, mask): 
    assert img.ndim >= mask.ndim
    axes = tuple(range(mask.ndim))
    perm = tuple(np.random.permutation(axes))
    img = img.transpose(perm + tuple(range(mask.ndim, img.ndim))) 
    mask = mask.transpose(perm) 
    for ax in axes: 
        if np.random.rand() > 0.5:
            img = np.flip(img, axis=ax)
            mask = np.flip(mask, axis=ax)
    return img, mask 

def random_intensity_change(img):
    img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)
    return img


def augmenter(x, y):
    """Augmentation of a single input/label image pair.
    x is an input image
    y is the corresponding ground-truth label image
    """
    x, y = random_fliprot(x, y)
    x = random_intensity_change(x)
    # add some gaussian noise
    sig = 0.02*np.random.uniform(0,1)
    x = x + sig*np.random.normal(0,1,x.shape)
    return x, y

In [None]:
# plot some augmented examples
img, lbl = X[0],Y[0]
plot_img_label(img, lbl)
for _ in range(3):
    img_aug, lbl_aug = augmenter(img,lbl)
    plot_img_label(img_aug, lbl_aug, img_title="image augmented", lbl_title="label augmented")

# Callbacks

Callbacks are functions that you can specify to run during the training, to get a feel for what is going on in the training.  Here, I wrote a Callback that every tenth epoch runs the model "as is" on a validation image, and generates a label image, and saves it to an image path.

There are actually two images generated, a label image, and an image with the labels overlaid on the original image.



In [None]:
from tensorflow.keras.callbacks import Callback


class EvaluateModelCallback(Callback):
    def __init__(self, validation_data, im_path):
        super(EvaluateModelCallback, self).__init__()
        self.validation_data = validation_data
        self.im_path = im_path

    def on_epoch_end(self, epoch, logs=None):
        if epoch % 10 == 0:
            # Assuming you have a function to evaluate your model on validation data
            Y_val_pred = model.predict_instances(validation_data, n_tiles=model._guess_n_tiles(validation_data), show_tile_progress=False)[0]
            
            # Print or log the evaluation results
            print(f"\nEvaluation results at epoch {epoch}:")
            plt.imshow(validation_data, cmap='gray')
            plt.imshow(Y_val_pred, cmap=lbl_cmap, alpha=0.35)
            #im_path= '/nemo/stp/lm/working/fallest/StarDist_Course_Jan_2024/output_images/gray/'
            filename_tiff = im_path+"Labels_epoch_"+str(epoch)+"_"+model.name+".tiff"
            filename_png = im_path+"Overlay_epoch_"+str(epoch)+"_"+model.name+".png"
            print(filename_tiff)
            plt.imsave(filename_tiff, Y_val_pred, cmap=lbl_cmap)
            plt.savefig(filename_png)
            




In [None]:
#Set the image we want to use for the callback, and the path we want to save the images to
validation_data = X_val[0]
im_path= '/nemo/stp/lm/working/fallest/StarDist_Course_Jan_2024/output_images/gray/'

if not os.path.exists(im_path):  #make the directory if it doesn't exist
    os.makedirs(im_path)



evaluate_callback = EvaluateModelCallback(validation_data, im_path) #set up the callback with the image and the path

### Show the image we're validating on

In [None]:
plt.imshow(validation_data)

# Training

In [None]:
quick_demo = False
#if false!
if quick_demo==False:
    epochs=2000
    steps=20

if quick_demo:
    print (
        "NOTE: This is only for a quick demonstration!\n"
        "      Please set the variable 'quick_demo = False' for proper (long) training.",
        file=sys.stderr, flush=True
    )
    model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,
                epochs=2, steps_per_epoch=10)

    print("====> Stopping training and loading previously trained demo model from disk.", file=sys.stderr, flush=True)
    model = StarDist2D.from_pretrained('2D_demo')
else:
    model.prepare_for_training() # open up the model for adding callbacks
    model.callbacks.append(evaluate_callback) #add our callback
    model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter, epochs=epochs, steps_per_epoch=steps)
None;

# Metrics:
The code for the metrics can be found in : https://github.com/stardist/stardist/blob/master/stardist/models/base.py

In the above computations, you'll see the same metric twice, i.e. loss and val_loss.  In this case, loss is the loss computed on the training set, and val_loss is the loss computed on the validation set.

Some of the metrics that will help you understand what's going on are:

* loss - Measure of how well the model is performing, the error between the predicted values and ground truth. We want this value to be low

* prob_kld - Kullback-Leibler Divergence. This measures the divergence between two probability distributions, in this case the predicted probability distribution and the expected probability distributions. This should be low

* iou_metric - The intersection over union.  The closer to 1 this is, the better

* mae - Mean absolute error, difference between predicted and ground truth. This should decrease over time.

* mse - Mean squared error of difference between predicted and ground truth. This should decrease over time.



In [None]:
if quick_demo:
    model.optimize_thresholds(X_val[:2], Y_val[:2])
else:
    model.optimize_thresholds(X_val, Y_val)

# Evaluation and Detection Performance
Besides the losses and metrics during training, we can also quantitatively evaluate the actual detection/segmentation performance on the validation data by considering objects in the ground truth to be correctly matched if there are predicted objects with overlap (here [intersection over union (IoU)](https://en.wikipedia.org/wiki/Jaccard_index)) beyond a chosen IoU threshold $\tau$.


The corresponding matching statistics (average overlap, accuracy, recall, precision, etc.) are typically of greater practical relevance than the losses/metrics computed during training (but harder to formulate as a loss function). 

### Tau: Tau is the IoU threshold.  

### The value of $\tau$ can be between 0 (even slightly overlapping objects count as correctly predicted) and 1 (only pixel-perfectly overlapping objects count) and which $\tau$ to use depends on the needed segmentation precision/application.

Please see `help(matching)` for definitions of the abbreviations used in the evaluation below and see the Wikipedia page on [Sensitivity and specificity](https://en.wikipedia.org/wiki/Sensitivity_and_specificity) for further details.

In [None]:
#help(matching)

### First predict the labels for all validation images:

In [None]:
Y_val_pred = [model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)[0]
              for x in tqdm(X_val)]

Plot a GT/prediction example 

In [None]:
plot_img_label(X_val[0],Y_val[0], lbl_title="label GT")
plot_img_label(X_val[0],Y_val_pred[0], lbl_title="label Pred")

### Choose several IoU thresholds $\tau$ that might be of interest and for each compute matching statistics for the validation data.

In [None]:
taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
stats = [matching_dataset(Y_val, Y_val_pred, thresh=t, show_progress=False) for t in tqdm(taus)]

## Example: Print all available matching statistics for $\tau=0.5$

In [None]:
stats[taus.index(0.5)]

### Plot the matching statistics and the number of true/false positives/negatives as a function of the IoU threshold $\tau$. 

* False positives are the number of objects that are identified that aren't there in reality.
* False negatives are the objects that are missed by the algorithm
* True positives are the ones that are correct


### As we increase tau, we tend to get many more false negatives and less true positives, as we are requiring the matches to be more and more perfect

In [None]:
fig, (ax1,ax2) = plt.subplots(1,2, figsize=(15,5))

for m in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
    ax1.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax1.set_xlabel(r'IoU threshold $\tau$')
ax1.set_ylabel('Metric value')
ax1.grid()
ax1.legend()

for m in ('fp', 'tp', 'fn'):
    ax2.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax2.set_xlabel(r'IoU threshold $\tau$')
ax2.set_ylabel('Number #')
ax2.grid()
ax2.legend();