In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy
import random
import matplotlib

from matplotlib import pyplot
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score,\
                            roc_curve, precision_recall_curve, roc_auc_score, confusion_matrix
from skimage import draw, filters, measure
from collections import defaultdict

# Imports from FLClab metrics package
from metrics.segmentation.commons import iou, dice

# Classification, segmentation and semantic segmentation

This section presents some of the metrics that can be used in cases of classification, segmentation and semantic segmentation. All of the presented metrics are available in `scikit-learn` (`sklearn`).

Note. In case of semantic segmentation the metrics can be reported in a _per class_ fashion or averaged of all classes. It's only a matter of iterating over all classes.

In [None]:
%matplotlib notebook
def show_examples(truth, prediction, instance_objects=False, cmap="gray"):
    """
    Plots the truth and prediction masks
    
    :param truth: A `numpy.ndarray` of the truth
    :param prediction: A `numpy.ndarray` of the prediction
    
    :returns : A `matplotlib.Figure`
               A `matplotlib.Axes`
    """
    fig, axes = pyplot.subplots(1, 2, figsize=(6, 3), tight_layout=True, sharex=True, sharey=True)
    
    axes[0].imshow(truth, cmap=cmap)
    if instance_objects:
        vmin, vmax = 0, prediction.max()
    else:
        vmin, vmax= 0, 1
    axes[1].imshow(prediction, cmap=cmap, vmin=vmin, vmax=vmax)
    for ax in axes:
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
    pyplot.show()
    return fig, axes

def show_tpfpfn(truth, prediction, cmap="magma"):
    """
    Shows True Positive (TP), False Positive (FP) and False Negative (FN) 
    of the predicted mask.
    
    :param truth: A `numpy.ndarray` of the truth mask
    :param prediction: A `numpy.ndarray` of the predicted mask    
    :param cmap: A `string` or `colormap` object to display the results
    
    :returns : A `matplotlib.Figure`
               A `matplotlib.Axes`
    """
    tp = numpy.logical_and(truth, prediction)
    fp = numpy.logical_and(numpy.invert(truth > 0), prediction)
    fn = numpy.logical_and(truth, numpy.invert(prediction > 0))
    mask = numpy.zeros(truth.shape)
    mask[tp] = 3
    mask[fp] = 2
    mask[fn] = 1

    fig, ax = pyplot.subplots(figsize=(3,3), tight_layout=True)
    if isinstance(cmap, str):
        cmap = pyplot.get_cmap(cmap, 4)
    im = ax.imshow(mask, cmap=cmap)
    cbar = pyplot.colorbar(im, ax=ax)
    cbar.set_ticks(numpy.linspace(0.75, 3, 4) - 0.75/2)
    cbar.set_ticklabels(["TN", "FN", "FP", "TP"])
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    
    return fig, ax

def generate_masks(is_prob_prediction=False, is_smooth=False, instance_objects=False, num_objs=20, prob_prediction=0.9, prob_truth=0.9, seed=42):
    """
    Randomly generates a truth and prediction masks
    
    :param is_prob_prediction: A `bool` if the prediction is probabilistic
    :param is_smooth: A `bool` if the prediction should be smoothed
    :param instance_objects: A `bool` if the truth and prediction are instance objects
    :param num_objs: An `int` of the number of objects to add
    :param prob_prediction: A `float` of the probability that the corresponding prediction is present
    :param prob_prediction: A `float` of the probability that the corresponding truth is present
    :param seed: An `int` of the random seed
    
    :returns : A `numpy.ndarray` of the truth mask
               A `numpy.ndarray` of the prediction mask
    """
    random.seed(seed)
    numpy.random.seed(seed)

    truth = numpy.zeros((256, 256))
    prediction =  numpy.zeros((256, 256))
    for i in range(num_objs):
        center = numpy.random.randint(truth.shape[0], size=(2,))

        if random.random() < prob_truth:
            rr, cc = draw.disk(center, radius=10, shape=truth.shape)
            if instance_objects:
                truth[rr, cc] = i + 1
            else:
                truth[rr, cc] = 1

        if random.random() < prob_prediction:
            rr, cc = draw.disk(center + numpy.random.randint(10, size=(2, )), radius=numpy.random.randint(8, 15), shape=prediction.shape)
            if is_prob_prediction and not instance_objects:
                prediction[rr, cc] = random.random()
            else:
                if instance_objects:
                    prediction[rr, cc] = i + 1
                else:
                    prediction[rr, cc] = 1
    if is_smooth and not instance_objects:
        prediction = filters.gaussian(prediction, 10)
        prediction = prediction / prediction.max()
    return truth, prediction

def plot_cm(cm, title=None, normalized=True, labels=[0,1], cmap="magma"):
    """
    Plots the confusion matrix on a `matplotlib.Axes`

    :param cm: A 2D `numpy.ndarray`
    :param normalized: Wheter to normalize the confusion matrix
    """
    tmp = cm.copy()
    if normalized:
        cm = cm / cm.sum(axis=1)[:, numpy.newaxis]
    fig, ax = pyplot.subplots(tight_layout=True, figsize=(3, 3))
    im = ax.imshow(cm, cmap=cmap, vmin=0, vmax=1)
    cb = fig.colorbar(im, ax=ax)
    for j in range(cm.shape[1]):
        for i in range(cm.shape[0]):
            val = cm[j, i]
            counts = int(tmp[j, i])
            ax.text(i, j, "{:d}".format(counts), horizontalalignment="center",
                    verticalalignment="center", fontdict={"fontsize":10, "color":"black" if val < 0.5 else "white"})       
    ax.set(
        title=title,
        xticks=numpy.arange(len(labels)),
        yticks=numpy.arange(len(labels)),
        xticklabels= labels,
        yticklabels=labels,
        ylabel="True label",
        xlabel="Predicted label"
    )        
    return fig, ax

In [None]:
truth, prediction = generate_masks()
fig, axes = show_examples(truth, prediction)

In [None]:
fig, ax = show_tpfpfn(truth, prediction)
pyplot.show()

Calculation of common metrics such as: _accuracy_, _precision_, and _recall_. Another way to report the performance of the model is to calculate the Intersection Over Union (IOU) (also called Jaccard index). This metric is commonly used when assessing the segmentation performance of a model. The Dice metric (also called Sørensen–Dice coefficient) is also reported in some cases but is very similar to IOU.

In case of multiclass classification, the confusion matrix is a nice way to visually see the performance of the model.

In [None]:
for name, metric_func in zip(
    ["Accuracy", "Precision", "Recall", "IOU", "Dice"], 
    [accuracy_score, precision_score, recall_score, iou, dice]
):
    print(name, metric_func(truth.ravel(), prediction.ravel()))
    
cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
    name="space",
    colors=["#343434", "#ffc949"]
)
cm = confusion_matrix(truth.ravel(), prediction.ravel())
fig, ax = plot_cm(cm, cmap="Blues", labels=["Bckg", "Fg"])

In [None]:
truth, prediction = generate_masks(is_prob_prediction=True, is_smooth=False)
fig, axes = show_examples(truth, prediction)

fpr, tpr, thresholds = roc_curve(truth.ravel(), prediction.ravel())
fig, ax = pyplot.subplots(figsize=(3,3), tight_layout=True)
ax.plot(fpr, tpr, label="ROC-curve")
ax.plot(fpr, thresholds, label="Threshold")
ax.set(
    xlabel="FPR", ylabel="TPR",
    ylim=(0, 1), xlim=(0, 1)
)
pyplot.legend()
pyplot.show()

precision, recall, thresholds = precision_recall_curve(truth.ravel(), prediction.ravel())
fig, ax = pyplot.subplots(figsize=(3,3), tight_layout=True)
ax.plot(recall, precision, label="PR-curve")
ax.plot(recall[1:], thresholds, label="Threshold")
ax.set(
    xlabel="Recall", ylabel="Precision",
    ylim=(0, 1), xlim=(0, 1)    
)
pyplot.legend()
pyplot.show()


## Caveats

Accuracy metric is very dependant on the number of negative pixel in the image even if the segmentation is kept constant. 

In [None]:
metrics = {
    name : func
    for name, func in zip(["Accuracy", "Precision", "Recall"], 
    [accuracy_score, precision_score, recall_score])
}
scores = defaultdict(list)
img_sizes = [32, 64, 128, 256, 512]
for img_size in img_sizes:
    truth = numpy.zeros((img_size, img_size))
    truth[:25, :25] = 1
    prediction = numpy.zeros((img_size, img_size))
    prediction[:20, :28] = 1
    
    for name, func in metrics.items():
        scores[name].append(func(truth.ravel(), prediction.ravel()))
    
fig, ax = pyplot.subplots(figsize=(3,3), tight_layout=True)
for key, value in scores.items():
    ax.plot(img_sizes, value, label=key)
ax.set(
    ylabel="Score", xlabel="Image Size"
)
pyplot.legend()
pyplot.show()

The IOU metric can be sensitive to the size of the objects in the field of view. 

In [None]:
from skimage import transform

scores = []
radii = [8, 16, 32, 64]
for radius in radii:
    
    truth = numpy.zeros((256, 256))
    prediction =  numpy.zeros((256, 256))
    
    rr, cc = draw.disk((128, 128), radius=radius, shape=truth.shape)
    truth[rr, cc] = 1
    rr, cc = draw.disk((128, 128), radius=radius - 1, shape=prediction.shape)
    prediction[rr, cc] = 1    
    
    fig, ax = show_tpfpfn(truth, prediction)
    pyplot.show()

    scores.append(iou(truth, prediction).item())

fig, ax = pyplot.subplots(figsize=(3, 3), tight_layout=True)
ax.plot(radii, scores)
ax.set(
    ylabel="IOU", xlabel="Object Size",
    xticks=radii
)
pyplot.show()


# Instance segmentation or object detection

The general goal of this section is to associate truth and predicted objects in some fashion using some metric.

In the instance segmentation the goal is to associate the truth objects with the predicted objects. The instance segmentation is mostly used and seen in cases of cell body segmentation, where all cell bodies in the image should be specific instances of a cell. The metrics that can be used to perform the association can be the `f1-score`, `IOU`, or Bounding Boxes (BBOX).

Object detection is quite similar with instance segmentation in the sense that you would like to detect instances of objects. In the object detection task, you may wish to associate centroids of the truth and predicted objects if they are within some distance.


## Instance Segmentation

In [None]:
label_cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
    name="nice-prism",
    colors=["#000000", "#5F4690","#1D6996","#38A6A5","#0F8554","#73AF48","#EDAD08","#E17C05","#CC503E","#94346E"]
)

truth, prediction = generate_masks(instance_objects=True, num_objs=25, prob_prediction=0.75, prob_truth=0.75)
fig, ax = show_examples(truth, prediction, instance_objects=True, cmap=label_cmap)
pyplot.show()

In [None]:
from metrics.detection import IOUDetectionError

scorer = IOUDetectionError(truth.astype(int), prediction.astype(int))

fig, ax = pyplot.subplots(figsize=(3,3), tight_layout=True)
im = ax.imshow(scorer.cost_matrix, cmap="Blues")
cbar = pyplot.colorbar(im, ax=ax)
cbar.set_label("IOU")
ax.set(
    ylabel="Truth Objects", xlabel="Predicted Objects"
)
pyplot.show()

In [None]:
# The scorer object allows the user to extract common metrics such as F1-score and 
# Average Precision as a function of the IOU
thresholds = numpy.linspace(0., 1.0, num=10)
f1_score, thresholds = scorer.get_f1_score(thresholds)

fig, ax = pyplot.subplots(figsize=(3,3), tight_layout=True)
ax.plot(thresholds, f1_score)
ax.set(
    ylim=(0, 1), xlim=(0, 1),
    ylabel="F1-score", xlabel="IOU"
)
pyplot.show()

thresholds = numpy.linspace(0., 1.0, num=10)
average_precision, thresholds = scorer.get_average_precision(thresholds)

fig, ax = pyplot.subplots(figsize=(3,3), tight_layout=True)
ax.plot(thresholds, average_precision)
ax.set(
    ylim=(0, 1), xlim=(0, 1),
    ylabel="Average Precision", xlabel="IOU"
)
pyplot.show()

An approach that was described by Caicedo _et al_ [1] is to report the number of merges, splits, extra objects and missed objects as a function of the IOU. This can be done using `IOUDetectionError`.

[1] Caicedo, J. C. et al. Evaluation of Deep Learning Strategies for Nucleus Segmentation in Fluorescence Images. Cytometry Part A 95, 952–965 (2019).

In [None]:
thresholds = numpy.linspace(0, 1, 10)
missed_objs, _ = scorer.get_missed_objects(threshold=thresholds)
extra_objs, _ = scorer.get_extra_objects(threshold=thresholds)
split_objs, _ = scorer.get_split_objects(threshold=thresholds)
merged_objs, _ = scorer.get_merged_objects(threshold=thresholds)

fig, ax = pyplot.subplots(figsize=(3,3), tight_layout=True)
for name, objs in zip(["Missed", "Extra", "Split", "Merged"], [missed_objs, extra_objs, split_objs, merged_objs]):
    # Objs is a list of list of regionprops
    objs = [len(o) for o in objs]
    ax.plot(thresholds, objs, label=name)
ax.set(
    ylabel="Num Objects", xlabel="IOU"
)
pyplot.legend()
pyplot.show()

In [None]:
# IOUDetectionError has the possibility of showing the performance of the model 
# extracting merges, splits, extra objects and missed objects
fig, axes = scorer.show(threshold=0.10, cmap=label_cmap)
pyplot.show()

In the Cell Tracking Challenge, the metric that are used to compute the performance of the model are `SEG` and `DET` [1]. This metric is quite challenging to use and requires a specific folder architecture. A tool was created to facilitate the use of this particular metric. 

**This metric will not use with Windows. If you want to test or make it work on Windows, please share!**

<br>
<i>
[1] Ulman, V. et al. An objective comparison of cell-tracking algorithms. Nature Methods 14, 1141–1152 (2017).
</i>

In [None]:
from metrics.aogm import CTCMeasure

with CTCMeasure([truth], [prediction]) as ctc_measure:
    print("SEG:", ctc_measure.get_seg())
    print("DET:", ctc_measure.get_det())

## Object Detection

In [None]:
from metrics.detection import CentroidDetectionError

truth_centroids = numpy.array([rprop.centroid for rprop in measure.regionprops(truth.astype(int))])
prediction_centroids = numpy.array([rprop.centroid for rprop in measure.regionprops(prediction.astype(int))])

scorer = CentroidDetectionError(truth_centroids, prediction_centroids, threshold=25)
display(scorer.get_score_summary())

In [None]:
truth_couple, pred_couple = scorer.get_coupled()
false_positives = scorer.get_false_positives()
false_negatives = scorer.get_false_negatives()

fig, ax = pyplot.subplots(figsize=(3,3), tight_layout=True)
ax.scatter(
    truth_centroids[:, 1], truth_centroids[:, 0], marker="+", label="Truth"
)
ax.scatter(
    prediction_centroids[:, 1], prediction_centroids[:, 0], marker="+", label="Prediction"
)
for tidx, pidx in zip(truth_couple, pred_couple):
    ax.plot(
        [truth_centroids[tidx, 1], prediction_centroids[pidx, 1]], 
        [truth_centroids[tidx, 0], prediction_centroids[pidx, 0]],
        color="black"
    )
ax.legend()
pyplot.show()

In [None]:
# CentroidDetectionError can also work in >2D (actually works in ND)
# Example where centroids are events from 2D+t movie, e.g. calcium movies
truth_centroids = numpy.random.rand(128, 3) * 512
prediction_centroids = truth_centroids + numpy.random.normal(loc=0, scale=5, size=truth_centroids.shape)
# Centroids are subsampled to simulate missing and extra objects
choices = numpy.random.choice(len(truth_centroids), size=125, replace=False)
truth_centroids = truth_centroids[choices]
choices = numpy.random.choice(len(prediction_centroids), size=115, replace=False)
prediction_centroids = prediction_centroids[choices]

scorer = CentroidDetectionError(truth_centroids, prediction_centroids, threshold=10)
display(scorer.get_score_summary())