# Weakly Supervised Learning 

This notebook will allow a user to experiment with weakly supervised learning. It will provide the user some guidance about the challenges that can be faced in the weakly supervised field.

The notebook is split into two parts. 

A first part introduces the how to use weakly supervised. At the end of this part a user should be able to qualitatively visualize the performance of the trained networks, but also obtain quantitative measurement of the performance.

A second part will dig into the performance assessment of weakly supervised learning. The user will learn how to show the increase in performance of the trained network compared to the labels it was provided with.

# Part 1 : Introduction to Weakly Supervised Learning

In [None]:
import numpy
import glob, os  
import torch
import warnings
import sys
import pickle
import skimage

from matplotlib import pyplot, patches
from sklearn.metrics import (confusion_matrix, f1_score, precision_score, recall_score,
                             precision_recall_curve)
from skimage import filters, io, draw
from scipy.spatial import distance
from collections import defaultdict

CUDA = torch.cuda.is_available()
if not CUDA: 
  print("Cuda is not available in the current notebook.")
  print("You can change this setting in Edit/Notebook settings.")

## Load data

In [None]:
# load data from public GitHub repo
# Won't work for now since the repo is private, but once it is public it will
# For now use the next cell to clone the private repo

!git clone https://github.com/FLClab/DL4HBM-2020.git
os.chdir("DL4HBM-2020")

In [None]:
# Allows to clone a private repository. I deleted user and password to avoid 
# leaking our identification. 
from getpass import getpass

# Changes directory to the main folder
os.chdir("/content")

user = getpass('GitHub user')
password = getpass('GitHub password')
os.environ['GITHUB_AUTH'] = user + ':' + password
if os.path.isdir("/content/DL4HBM-2020"):
  os.chdir("DL4HBM-2020")
  !git pull 
else:
  !git clone https://$GITHUB_AUTH@github.com/FLClab/DL4HBM-2020.git
  # Changes the directory to work on the repo
  os.chdir("DL4HBM-2020")

# Removes traces of connection
del user, password, os.environ["GITHUB_AUTH"]


## Load user-defined utils functions

In [None]:
from tqdm.auto import tqdm

class MetricCalculator:
  """
  Implements a MetricCalculator class to facilite the calculation of metric 
  between the targets and the predictions
  """
  def __init__(self, targets, predictions, foregrounds=None):
    """
    Instantiates the `MetricCalculator` class

    :param targets: A `numpy.ndarray` of the targets 
    :param targets: A `numpy.ndarray` of the predictions 
    :param targets: A `numpy.ndarray` of the foregrounds
    """
    # Assign member variables 
    self.targets = targets
    self.predictions = predictions
    if isinstance(foregrounds, type(None)):
      self.foregrounds = [None] * len(self.targets)
    else:
      self.foregrounds = foregrounds

  def get(self, metric_names, **kwargs):
    """
    Implements a get method to get the metric score between the targets
    and the predictions

    :param metrics: A `list` of the metrics to compute

    :returns : A `list` of scores 
    """
    try:
      scorers = [getattr(self, f"_{metric_name}") for metric_name in metric_names]
    except AttributeError:
      warnings.warn(f"The chosen method `{metric_name}` does not exist.\nExiting...", category=UserWarning)
      return
    all_scores = []
    for t, p, f in zip(tqdm(self.targets, leave=False), self.predictions, self.foregrounds):
      p = numpy.argmax(p, axis=0) if p.shape[0] > 1 else p.squeeze()
      if not isinstance(f, type(None)):
        t = t[f.astype(bool)]
        p = p[f.astype(bool)]
      if (not numpy.any(t)) and ((not numpy.any(p)) or ("precision_recall_curve" in metric_names)):
        continue  
      scores = [scorer(t, p, **kwargs) for scorer in scorers]
      all_scores.append(scores)
    if "precision_recall_curve" in metric_names:
      return all_scores
    return numpy.array(all_scores).T

  def _confusion_matrix(self, target, prediction, normalized=False):
    """
    Computes the confusion matrix between the target and the prediction

    :param target: A `numpy.ndarray` with shape [H, W] of the target
    :param prediction: A `numpy.ndarray` with shape [H, W] of the prediction

    :returns : A 2x2 `numpy.ndarray` of the confusion matrix
    """
    if target.ndim > 1:
      target, prediction = target.ravel(), prediction.ravel()
    truth, prediction = target.astype(bool), prediction.astype(bool)
    cm = confusion_matrix(target, prediction)
    if normalized:
      cm = cm / (cm.sum(axis=1)[:, numpy.newaxis] + 1e-12)
    return cm

  def _iou(self, target, prediction):
    """
    Computes the intersection over union

    :param target: A `numpy.ndarray` with shape [H, W] of the target
    :param prediction: A `numpy.ndarray` with shape [H, W] of the prediction

    :returns : Intersection over union
    """
    target, prediction = target.astype(bool), prediction.astype(bool)
    intersection = (target * prediction).sum()
    union = (target + prediction).sum() + 1e-12
    return intersection / union

  def _dice(self, target, prediction):
    """
    Computes the dice similarity coefficient

    :param target: A `numpy.ndarray` with shape [H, W] of the target
    :param prediction: A `numpy.ndarray` with shape [H, W] of the prediction

    :returns : Dice similarity coefficient
    """
    target, prediction = target.astype(bool), prediction.astype(bool)
    intersection = (target * prediction).sum()
    return 2 * intersection / (target.sum() + prediction.sum() + 1e-12)

  def _f1_score(self, target, prediction):
    """
    Computes the f1 score between the target and the prediction

    :param target: A `numpy.ndarray` with shape [H, W] of the target
    :param prediction: A `numpy.ndarray` with shape [H, W] of the prediction

    :returns : F1-score
    """
    if target.ndim > 1:
      target, prediction = target.ravel(), prediction.ravel()
    target, prediction = target.astype(bool), prediction.astype(bool)
    return self._dice(target, prediction)

  def _precision(self, target, prediction):
    """
    Computes the precision between the target and the prediction

    :param target: A `numpy.ndarray` with shape [H, W] of the target
    :param prediction: A `numpy.ndarray` with shape [H, W] of the prediction

    :returns : Precision
    """
    if target.ndim > 1:
      target, prediction = target.ravel(), prediction.ravel()    
    target, prediction = target.astype(bool), prediction.astype(bool)
    return precision_score(target, prediction, zero_division=0)

  def _recall(self, target, prediction):
    """
    Computes the recall between the target and the prediction

    :param target: A `numpy.ndarray` with shape [H, W] of the target
    :param prediction: A `numpy.ndarray` with shape [H, W] of the prediction

    :returns : Recall
    """
    if target.ndim > 1:
      target, prediction = target.ravel(), prediction.ravel()
    target, prediction = target.astype(bool), prediction.astype(bool)
    return recall_score(target, prediction, zero_division=0) 

  def _accuracy(self, target, prediction):
    """
    Computes the accuracy between the target and the prediction

    :param target: A `numpy.ndarray` with shape [H, W] of the target
    :param prediction: A `numpy.ndarray` with shape [H, W] of the prediction

    :returns : accuracy
    """
    if target.ndim > 1:
      target, prediction = target.ravel(), prediction.ravel()
    target, prediction = target.astype(bool), prediction.astype(bool)
    return (target == prediction).sum() / len(target)

  def _precision_recall_curve(self, target, prediction):
    """
    Computes the precision recall curve between the target and the prediction 

    :param target: A `numpy.ndarray` with shape [H, W] of the target
    :param prediction: A `numpy.ndarray` with shape [H, W] of the prediction

    :returns : A `numpy.ndarray` of precision
    :returns : A `numpy.ndarray` of recall
    :returns : A `numpy.ndarray` of thresholds
    """   
    if target.ndim > 1:
      target, prediction = target.ravel(), prediction.ravel()
    target = target.astype(bool)
    return precision_recall_curve(target, prediction)

def sigmoid(x):
  """
  Calculates the sigmoid transform of the input array 

  :param x: A `numpy.ndarray`

  :returns : A `numpy.ndarray` of the transformed array
  """
  x = numpy.clip(x, -10, 10) # Avoids overflow
  return 1 / (1 + numpy.exp(-x))

def plot_learning_curves(model_path, title=None):
  """
  Plots the learning curves of the loaded model

  :param model_path: The path of the loaded model 
  """
  stats = pickle.load(open(os.path.join(model_path, "statsCkpt_490.pkl"), "rb"))
  fig, ax = pyplot.subplots()
  for condition in ["train", "test"]:
    mean, std = map(numpy.array, (stats[f"{condition}Mean"], stats[f"{condition}Std"]))
    epochs = numpy.arange(len(mean))
    ax.plot(epochs, mean, label="validation" if condition == "test" else condition)
    if condition == "test":
      ax.axvline(x=numpy.argmin(mean), color="black", linestyle="dashed")
    ax.fill_between(epochs, mean - std, mean + std)
  ax.legend()
  ax.set(
      xlabel="Epochs", ylabel="Cross Entropy Loss",
      title=title
  )
  pyplot.show()

def show_random(*args, samples=5):
  """
  Randomly sample from the given the input arrays 

  :param *args: `numpy.ndarray` with shape [B, C, H, W]
  :param samples: The number of samples to sample from each arrays  
  """
  samples = numpy.random.choice(len(args[0]), size=min(samples, len(args[0])), replace=False)
  for sample in samples:
    fig, axes = pyplot.subplots(1, len(args), figsize=(8, 3))
    for ax, ary in zip(axes.ravel(), args):
      ary = numpy.squeeze(ary) # removes the empty channel from the image
      if ary.ndim == 3:
        ax.imshow(ary[sample], vmax=0.3 * ary[sample].max(), cmap="gray")
      else:
        ax.imshow(numpy.argmax(ary[sample], axis=0), cmap="gray")
        # ax.imshow(sigmoid(ary[sample, 1]), vmin=0, vmax=1, cmap="gray")
    
    for ax in axes:
      ax.axis("off")
  pyplot.show()
    
def plot_scores(data, show_points=True, avail_metrics=["iou"], ylim=(0,1),
                show_lines=False, rotation=None):
  """
  Plots the scores from the given data 

  :param data: A `dict` of scores with the following architecture 
               {condition : {metric : []}}

  :returns : A `matplotlib.Figure` instance 
  :returns : A `matplotlib.Axes` instance 
  """
  # Creates the matplotlib figure and define constants
  fig, ax = pyplot.subplots()
  num_metrics = len(avail_metrics)
  width = 1 / (num_metrics + 1)
  cmap = pyplot.cm.get_cmap("tab10")

  # Plots the data 
  conditions, metrics = [], []
  for i, (condition, scores) in enumerate(data.items()):
    conditions.append(condition)
    j = 0
    for (metric, values) in scores.items():
      if metric not in avail_metrics: continue
      if i == 0: metrics.append(metric)
      # ax.bar(i + j * width, numpy.mean(values), yerr=numpy.std(values), 
      #        align="edge", width=width, color=cmap(j), alpha=0.8)
      ax.bar(i + j * width, numpy.median(values), yerr=numpy.diff(numpy.quantile(values, [0.5, 0.75])), 
             align="edge", width=width, color=cmap(j), alpha=0.8)

      # Add the scatter points to the graph
      if show_points:
        xs = numpy.random.normal(loc=i + j * width + width / 2, scale=width/25, size=len(values))
        ax.scatter(xs, values, color="black", alpha=0.3, s=10)

      j += 1
    
  if show_lines:
    for metric in avail_metrics:
      delta = metrics.index(metric)
      for points in zip(*[data[condition][metric] for condition in conditions]):
        ax.plot(numpy.arange(len(points)) + delta * width + width / 2, points,
                color="black", alpha=0.3)
  
  # Sets the axes 
  ax.set(
      ylim = ylim,
      xticks = numpy.arange(len(data)) + width * num_metrics / 2,
      xticklabels = conditions,
      ylabel = "Scores"
  )
  if not isinstance(rotation, type(None)):
    ax.set_xticklabels(conditions, rotation=rotation)

  # Sets the legend of the graph
  ax.legend(handles=[patches.Patch(color=cmap(i), label=metric) for i, metric in enumerate(metrics)])

  return fig, ax

def plot_cumfreq(data, avail_metrics=["iou"]):
  """
  Plots the scores from the given data 

  :param data: A `dict` of scores with the following architecture 
               {condition : {metric : []}}

  :returns : A `matplotlib.Figure` instance 
  :returns : A `matplotlib.Axes` instance 
  """
  # Creates the matplotlib figure and define constants
  fig, ax = pyplot.subplots()
  cmap = pyplot.cm.get_cmap("tab10")

  # Plots the data 
  conditions, metrics = [], []
  linestyles = ["-", "--", "-."]
  for i, (condition, scores) in enumerate(data.items()):
    conditions.append(condition)

    j = 0
    for (metric, values) in scores.items():
      if metric not in avail_metrics: continue
      if i == 0: metrics.append(metric)

      hist, bins = numpy.histogram(values)
      cumsum = numpy.cumsum(hist) / hist.sum()
      ax.plot(bins[:-1], cumsum, color=cmap(j), linestyle=linestyles[i])

      j += 1

  # Sets the axes 
  ax.set(
      ylim = (0,1),
      ylabel = "Cumulative frequency"
  )

  # Sets the legend of the graph
  ax.legend(handles=[patches.Patch(color=cmap(i), label=metric) for i, metric in enumerate(metrics)])

  return fig, ax

def plot_cm(cm):
  """
  Plots a confusion matrix 

  :param cm: A `numpy.ndarray` of the confusion matrix 

  :returns : A `matplotlib.Figure` instance 
  :returns : A `matplotlib.Axes` instance 
  """
  fig, ax = pyplot.subplots()
  ax.imshow(cm, cmap="Blues")
  for j in range(cm.shape[0]):
    for i in range(cm.shape[1]):
      ax.text(j, i, "{:0.4f}".format(cm[j, i]), horizontalalignment="center", verticalalignment="center")
  ax.set(
      xticks=numpy.arange(cm.shape[1]), xticklabels=["Background", "Structure"],
      yticks=numpy.arange(cm.shape[1]), yticklabels=["Background", "Structure"],
  )
  return fig, ax

def show_average_cm(data):
  """
  Plots the average confusion matrix from data 

  :param data: A `dict` of scores with the following architecture 
               {condition : {metric : []}}
  
  returns: A `list` of (fig, ax) tuple
  """
  conditions, metrics = [], []
  output = []
  for i, (condition, scores) in enumerate(data.items()):
    conditions.append(condition)
    for j, (metric, values) in enumerate(scores.items()):
      if i == 0: metrics.append(metric)
      if metric != "confusion_matrix": continue
      cms = numpy.array(values)
      cm = numpy.sum(cms, axis=0)
      cm = cm / (cm.sum(axis=0, keepdims=True) + 1e-12)
      fig, ax = plot_cm(cm)
      output.append((fig, ax))
  return output

def get_foreground(ary):
  """
  Retreives the foreground from an image 

  :param ary: A 2D or 3D `numpy.ndarray` of the array(s) to detect the foreground 

  :returns : A 2D or 3D `numpy.ndarray` of the detected foreground(s)
  """
  if ary.ndim == 3:
    return numpy.array([get_foreground(x) for x in ary])
  filtered = filters.gaussian(ary, sigma=5)
  threshold = filters.threshold_triangle(filtered)
  return (filtered >= threshold).astype(int)


## Load network and infer on testing data

In [None]:
import loader 
from network import UNet

minmax = numpy.load(os.path.join("raw_data", "minmax.npy"))
networks_infos = {
    name : {
        "data_path" : os.path.join("raw_data", f"data_{name}.npz"),
        "model_path" : os.path.join("trained-networks", name)
    }
    for name in ["polygonal_bbox", "bbox"]
}

In [None]:
# Plots the learning curve of the loaded model
for name, network_infos in networks_infos.items():
  plot_learning_curves(network_infos["model_path"], title=name)

In [None]:
# Selects the desired network from the networks informations and loads it 
network_infos = networks_infos["polygonal_bbox"]

# Loads the data from the cloned folder 
data = numpy.load(network_infos["data_path"])

# Loads the image from data and the indices
images, targets = data["images"], data["labels"]
train_idx, valid_idx, test_idx = loader.get_idx(data)

# Creation of the model
model = UNet(in_channels=1, out_channels=2)
model.load_model(network_infos["model_path"], cuda=CUDA)

In [None]:
# Infer the network on the testing dataset
for (X, y, pred, idx) in model.predict(images, targets, idx=test_idx, cuda=CUDA, minmax=minmax):
  show_random(X, y, pred, samples=5)
  pyplot.show()

## Quantitative assessment

In [None]:
# Infer the network on the testing dataset and compute metrics 
to_compute_metrics = ["dice", "iou", "confusion_matrix", "precision", "recall", "accuracy"]

# Predict all images in the testing dataset
complete_scores = {
    condition : {
        name : [] for name in to_compute_metrics
    } for condition in ["bbox"]
}
for (X, y, pred, idx) in model.predict(images, targets, idx=test_idx, cuda=CUDA, minmax=minmax):

  foregrounds = None
  metric_calculator = MetricCalculator(y, pred, foregrounds=foregrounds)
  scores = metric_calculator.get(to_compute_metrics)
  for metric, score in zip(to_compute_metrics, scores):
    complete_scores["bbox"][metric].extend(score)

In [None]:
# Choose which metrics to show 
keep_keys = ["iou", "dice", "precision", "recall", "accuracy"]

fig, ax = plot_scores(complete_scores, show_points=True, avail_metrics=keep_keys, ylim=(0, 1))
fig, ax = plot_cumfreq(complete_scores, avail_metrics=keep_keys)
figs_axes = show_average_cm(complete_scores)
pyplot.show()

In [None]:
# Impact of foreground detection on metrics
complete_scores["bbox + foreground"] = {
    name : [] for name in to_compute_metrics
}
for (X, y, pred, idx) in model.predict(images, targets, idx=test_idx, cuda=CUDA, minmax=minmax):

  foregrounds = get_foreground(X.squeeze())
  metric_calculator = MetricCalculator(y, pred, foregrounds=foregrounds)
  scores = metric_calculator.get(to_compute_metrics)
  for metric, score in zip(to_compute_metrics, scores):
    complete_scores["bbox + foreground"][metric].extend(score)

In [None]:
# Choose which metrics to show 
keep_keys = ["iou", "dice", "precision", "recall", "accuracy"]

fig, ax = plot_scores(complete_scores, show_points=True, avail_metrics=keep_keys, ylim=(0, 1))
fig, ax = plot_cumfreq(complete_scores, avail_metrics=keep_keys)
figs_axes = show_average_cm(complete_scores)
pyplot.show()

# Part 2 : Increased Performance & Metrics

## Compare prediction with manual labels

In [None]:
# Loads the network model and parameters
network_infos = networks_infos["polygonal_bbox"]
data = numpy.load(network_infos["data_path"])
targets_unprecised = data["labels"]
images, targets = data["images"], data["labels"]
train_idx, valid_idx, test_idx = loader.get_idx(data)
model = UNet(in_channels=1, out_channels=2)
model.load_model(network_infos["model_path"], cuda=CUDA)

# Predict all images in the testing dataset
to_compute_metrics = ["dice", "iou", "confusion_matrix", "precision", "recall", "accuracy"]
manual_comparison = {
    condition : {
        name : [] for name in to_compute_metrics
    } for condition in ["bbox - manual", "prediction - manual", "prediction - bbox"]
}

# Extracts the name of the precise labels
_man = glob.glob(os.path.join("testing", "*_man.tif"))
available_manuals = [int(os.path.basename(man_name).split("_")[0]) for man_name in _man]
for (X, y, pred, idx) in model.predict(images, targets, idx=test_idx, cuda=CUDA, minmax=minmax):
  keep = [tidx in available_manuals for tidx in test_idx[idx]]
  if any(keep):
    X, y, pred, idx = X[keep], y[keep], pred[keep], idx[keep]
    manual = numpy.stack([io.imread(os.path.join("testing", "{}_man.tif".format(test_idx[i]))) for i in idx], axis=0)
    
    # Shows some examples
    show_random(X, y, pred, manual, samples=2)
    pyplot.show()

    y = y[:, numpy.newaxis, ...]
    metric_calculator = MetricCalculator(manual, y)
    scores = metric_calculator.get(to_compute_metrics)
    for metric, score in zip(to_compute_metrics, scores):
      manual_comparison["bbox - manual"][metric].extend(score)

    metric_calculator = MetricCalculator(manual, pred)
    scores = metric_calculator.get(to_compute_metrics)
    for metric, score in zip(to_compute_metrics, scores):
      manual_comparison["prediction - manual"][metric].extend(score)

    y = y.squeeze()
    metric_calculator = MetricCalculator(y, pred)
    scores = metric_calculator.get(to_compute_metrics)
    for metric, score in zip(to_compute_metrics, scores):
      manual_comparison["prediction - bbox"][metric].extend(score)


In [None]:
fig, ax = plot_scores(manual_comparison, avail_metrics=["iou"], show_lines=False)
pyplot.show()

## Crafting a representative metric

The crafted metric for the F-actin dataset on axons uses the spatial frequency information in the extracted masks of an expert and the prediction. 

The F-actin periodical lattice has a known distance of 190 nm between each rings. This implies that a precise segmentation, for instance a manual segmentation, of the structure should be increased in power frequency. 

The idea of the metric is that the power frequency of a manual segmentation and precise segmentation should be similar. Hence, there should be a small difference. A very different segmentation of a network should result in a bigger difference.

<div>
<img src="https://drive.google.com/uc?id=1oGHSapjPGtC3bMfJVdYu_jZxevR1zZdy" width="500"/>
</div>


In [None]:
def fft_ratio(image, truth, predicted, foreground):
  """
  Computes the Fourier ratio between the pixels in the ground truth mask
  and the pixels in the predicted mask

  :param image: A 2D `numpy.ndarray` of the image
  :param truth: A 2D `numpy.ndarray` of the ground truth mask
  :param predicted: A 2D `numpy.ndarray` of the predicted mask
  :param foreground: A 2D `numpy.ndarray` of the detected foreground

  :returns : The absolute difference of the variation
  """
  truth, predicted, foreground = map(lambda ary : ary.astype(bool), (truth, predicted, foreground))
  if (not numpy.any(truth * foreground)) and (not numpy.any(predicted * foreground)):
    return 0
  ary_1, ary_2 = image * truth * foreground, image * predicted * foreground

  fft_orig, freq = fft(image, numpy.ones(image.shape))
  fft_truth, freq = fft(ary_1.reshape(image.shape), numpy.ones(image.shape))
  fft_pred, freq = fft(ary_2.reshape(image.shape), numpy.ones(image.shape))

  # Creates a mesh grid of angles
  yy, xx = numpy.meshgrid(*freq)
  xx[(xx == 0.)] += 1 # Avoids 0 division
  atan = numpy.arctan(yy / xx) * 180 / numpy.pi + 90
  atan[xx >= 0] += 180

  forig, ftruth, fpred = [], [], []
  angles, wavelengths = numpy.arange(0, 360, 10), numpy.arange(170, 200, 10)
  for angle in angles:
    for wavelength in wavelengths:
      z = (xx**2 + yy**2 <= (1 / wavelength)**2) & (xx**2 + yy**2 >= (1 / (wavelength + 10))**2) & \
          (atan >= angle) & (atan <= angle + 10)
      forig.append(fft_orig[z].sum())
      ftruth.append(fft_truth[z].sum())
      fpred.append(fft_pred[z].sum())

  forig, ftruth, fpred = map(lambda l : numpy.array(l).reshape(len(angles), len(wavelengths)), (forig, ftruth, fpred))

  ratio_truth = (ftruth.sum() - forig.sum()) / fft_orig.sum()
  ratio_pred = (fpred.sum() - forig.sum()) / fft_orig.sum()
  return abs(ratio_truth - ratio_pred)

def fft(ary, foreground, timestamp=20):
  """
  Computes the FFT of an image

  :param ary: A 2D `numpy.ndarray`
  :param foreground: A 2D `numpy.ndarray` of the detected foreground
  
  :returns : Fourier transform and the frequency axis
  """
  x = numpy.fft.fft2(ary * foreground)
  return numpy.abs(x), tuple(numpy.fft.fftshift(numpy.fft.fftfreq(shape, d=timestamp)) for shape in ary.shape)


In [None]:
# Loads the network architecture in memory
network_infos = networks_infos["polygonal_bbox"]
data = numpy.load(network_infos["data_path"])
targets_unprecised = data["labels"]
images, targets = data["images"], data["labels"]
train_idx, valid_idx, test_idx = loader.get_idx(data)
model = UNet(in_channels=1, out_channels=2)
model.load_model(network_infos["model_path"], cuda=CUDA)

# Predict all images in the testing dataset
to_compute_metrics = ["fft"]
manual_comparison = {
    condition : {
        name : [] for name in to_compute_metrics
    } for condition in ["prediction - manual", "prediction - bbox"]
}

_man = glob.glob(os.path.join("testing", "*_man.tif"))
available_manuals = [int(os.path.basename(man_name).split("_")[0]) for man_name in _man]
for (X, y, pred, idx) in model.predict(images, targets, idx=test_idx, cuda=CUDA, minmax=minmax):
  keep = [tidx in available_manuals for tidx in test_idx[idx]]
  if any(keep):
    X, y, pred, idx = X[keep], y[keep], pred[keep], idx[keep]
    manual = numpy.stack([io.imread(os.path.join("testing", "{}_man.tif".format(test_idx[i]))) for i in idx], axis=0)
    
    foregrounds = get_foreground(X.squeeze())

    show_random(X, y, pred, manual, samples=2)
    pyplot.show()

    y = y[:, numpy.newaxis, ...]
    
    for xx, yy, ppred, fforeground in zip(X, y, pred, foregrounds):
      ratio = fft_ratio(xx.squeeze(), yy.squeeze(), numpy.argmax(ppred, axis=0), fforeground)
      manual_comparison["prediction - bbox"]["fft"].append(ratio)

    for xx, yy, ppred, fforeground in zip(X, manual, pred, foregrounds):
      ratio = fft_ratio(xx.squeeze(), yy.squeeze(), numpy.argmax(ppred, axis=0), fforeground)
      manual_comparison["prediction - manual"]["fft"].append(ratio)


In [None]:
fig, ax = plot_scores(manual_comparison, avail_metrics=["fft"], ylim=None)
pyplot.show()

## Provide a different threshold



In [None]:
# Selects the desired network from the networks informations and loads it 
network_infos = networks_infos["bbox"]

# Loads the data from the cloned folder 
data = numpy.load(network_infos["data_path"])

# Loads the image from data and the indices
images, targets = data["images"], data["labels"]
train_idx, valid_idx, test_idx = loader.get_idx(data)

# Creation of the model
model = UNet(in_channels=1, out_channels=2)
model.load_model(network_infos["model_path"], cuda=CUDA)

In [None]:
# Infer the network on the testing dataset and compute metrics 
to_compute_metrics = ["precision_recall_curve"]

# Load precise labels from dataset using polygonal_bbox
precise_targets = numpy.load(networks_infos["polygonal_bbox"]["data_path"])["labels"]

# Predict all images in the validation dataset
all_scores = {}
for (X, y, pred, idx) in model.predict(images, targets, idx=valid_idx, cuda=CUDA, minmax=minmax):

  precise = precise_targets[valid_idx[idx]]

  foregrounds = get_foreground(X.squeeze())
  pred_probas = sigmoid(pred)[:, 1]
  pred_probas = pred_probas[:, numpy.newaxis]
  metric_calculator = MetricCalculator(precise, pred_probas, foregrounds=foregrounds)
  scores = metric_calculator.get(to_compute_metrics)
  
  for i, score in zip(idx, scores):
    all_scores[i] = score


In [None]:
def get_threshold(scores, num_valid_images=10, show_dist=False):
  """
  Computes the optimal threshold from the validation images. We simply use the 
  points from the PR-Curve that is the nearest to a precision of 1 and a recall 
  of 1.

  :param scores: A `dict` where each key is referencing an index in the valid_idx 
  :param num_valid_image: (Optional) The number of images to keep from valid.
                          None results in all images
  :param show_dist: (Optional) Wheter to show the distribution of points
  
  :returns : A `list` of all optimal thresholds 
  """
  all_thresholds = []

  # Results in all images 
  if isinstance(num_valid_images, type(None)): 
    num_valid_images = len(scores) + 1

  for image_scores in [values for key, values in scores.items() if key < num_valid_images]:
    for metric_score in image_scores:
      precision, recall, thresholds = metric_score
      pr = numpy.stack([precision, recall]).T
      distances = distance.cdist(pr, [[1, 1]])
      threshold = thresholds[distances.argmin()]
      all_thresholds.append(threshold)

  if show_dist:  
    fig, ax= pyplot.subplots()
    ax.boxplot([all_thresholds])
    ax.scatter(numpy.random.normal(loc=1, scale=0.02, size=len(all_thresholds)),
              all_thresholds)
    ax.set(
        ylim=(0, 1), ylabel="Threshold"
    )
    pyplot.show()

  return all_thresholds

thresholds = get_threshold(all_scores, num_valid_images=25, show_dist=True)
threshold = numpy.median(thresholds)
print(threshold)

In [None]:
# Load precise labels from dataset 
precise_targets = numpy.load(networks_infos["polygonal_bbox"]["data_path"])["labels"]

# Predict all images in the testing dataset
to_compute_metrics = ["dice", "iou", "confusion_matrix", "precision", "recall", "accuracy"]
naive_threshold_comparison = {
    condition : {
        name : [] for name in to_compute_metrics
    } for condition in ["prediction - polygonal bbox", "prediction - bbox", "threshold prediction - polygonal bbox", "threshold prediction - bbox"]
}

for (X, y, pred, idx) in model.predict(images, targets, idx=test_idx, cuda=CUDA, minmax=minmax):

  precise = precise_targets[test_idx[idx]]

  foregrounds = get_foreground(X.squeeze())
  pred_probas = sigmoid(pred)[:, 1]
  pred_probas = pred_probas[:, numpy.newaxis] >= threshold

  show_random(X, y, pred, pred_probas, samples=5)

  metric_calculator = MetricCalculator(precise, pred_probas, foregrounds=foregrounds)
  scores = metric_calculator.get(to_compute_metrics)
  for metric, score in zip(to_compute_metrics, scores):
      naive_threshold_comparison["threshold prediction - polygonal bbox"][metric].extend(score)

  metric_calculator = MetricCalculator(precise, pred, foregrounds=foregrounds)
  scores = metric_calculator.get(to_compute_metrics)
  for metric, score in zip(to_compute_metrics, scores):
      naive_threshold_comparison["prediction - polygonal bbox"][metric].extend(score)    

  metric_calculator = MetricCalculator(y, pred, foregrounds=foregrounds)
  scores = metric_calculator.get(to_compute_metrics)
  for metric, score in zip(to_compute_metrics, scores):
      naive_threshold_comparison["prediction - bbox"][metric].extend(score)   

  metric_calculator = MetricCalculator(y, pred_probas, foregrounds=foregrounds)
  scores = metric_calculator.get(to_compute_metrics)
  for metric, score in zip(to_compute_metrics, scores):
      naive_threshold_comparison["threshold prediction - bbox"][metric].extend(score)           

In [None]:
fig, ax = plot_scores(naive_threshold_comparison, avail_metrics=["precision"], rotation=15)

## Intersection over union problem

In [None]:
# Creation of the figure
fig, (ax1, ax2, ax3, ax4) = pyplot.subplots(1, 4)

# A plain circle centered at (10, 10) with radii of 5 is our ground truth
ground_truth = numpy.zeros((20, 20), dtype=numpy.bool)
rr, cc = skimage.draw.circle(10, 10, 5)
ground_truth[rr, cc] = True
ax1.imshow(ground_truth, cmap="gray")
ax1.set_title("Ground Truth")

# Set a prediction centered at (10, 10) with radii of 4
pred_smaller = numpy.zeros((20, 20), dtype=numpy.bool)
rr, cc = skimage.draw.circle(10, 10, 4)
pred_smaller[rr, cc] = True
ax2.imshow(pred_smaller, cmap="gray")
dice = MetricCalculator([None], [None])._iou(ground_truth, pred_smaller)
ax2.set_title(f"IOU of {dice:0.2f}")

# Set a prediction centered at (10, 10) with radii of 6
pred_smaller = numpy.zeros((20, 20), dtype=numpy.bool)
rr, cc = skimage.draw.circle(10, 10, 6)
pred_smaller[rr, cc] = True
ax3.imshow(pred_smaller, cmap="gray")
dice = MetricCalculator([None], [None])._iou(ground_truth, pred_smaller)
ax3.set_title(f"IOU of {dice:0.2f}")

# Set a prediction centered one pixel away from ground truth
pred_offset = numpy.zeros((20, 20), dtype=numpy.bool)
rr, cc = skimage.draw.circle(10, 11, 5)
pred_offset[rr, cc] = True
dice = MetricCalculator([None], [None])._iou(ground_truth, pred_offset)
ax4.set_title(f"IOU of {dice:0.2f}")
ax4.imshow(pred_offset, cmap="gray")

pyplot.show()