## Evaluation of the VNC length method

This notebook compares the calculated VNC length values with annotated data, measured manually. It therefore requires annotated data, which should be placed under the directory `annotated`, within the experiment directory. 

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
from pathlib import Path

import matplotlib.pyplot as plt
from tifffile import imread

from snazzy_processing import centerline_errors, find_hatching, grid_search, utils, vnc_length

experiment_name = '20240611'
root_dir = Path.cwd().parent
project_dir = root_dir.joinpath('data', experiment_name)
annotated_dir = project_dir.joinpath('annotated')
img_dir = project_dir.joinpath('embs')

if not annotated_dir.exists():
    print('Cannot evaluate the measurements for the current experiment.\nThe evaluations require manually measured data.')
    assert False

Sometimes the embryo names of the annotated data don't match the names generated by pasnascope.
In these cases, a look up table (LUT) can be used to associate pasnascope sliced movies to annotated files.
The LUT can also be used to ignore embryos if needed, since only the embryos in the LUT will be used.
The look up table is a dictionary, where the keys are the embryo numbers of the individual movies and the values are  the numbers used to identify the embryos in annotated data.

Use the image generated from the `process-raw-data` notebook to inspect the numbers used by `pasnascope` and build the LUT.

In [None]:
# name lookup table in cases the annotated embs and the pasnascope embs
# have different numbering
# can be removed if all emb names match
# relates {img_file: annotated_file}

# experiment 20240611
name_LUT = {
    1: 1,
    2: 2,
    3: 3,
    4: 4,
    5: 5,
    6: 6,
    9: 9,
    10: 10,
    12: 11,
    15: 12,
    18: 15,
    19: 16,
    20: 17,
    21: 18,
    22: 19,
    23: 20,
}

Calculates the VNC lengths using pasnascope.
These values will be compared against annotated data in the next cells.

In [None]:
measurements = {}
hatching_points = {}
annotated = {}
interval = 50


def get_hatching_points(embryos):
    return {e.stem: find_hatching.find_hatching_point(e) for e in embryos}


def measure(
    embryos, hatching_points, interval=20, thres_rel=0.6, min_dist=5, outlier_thres=0.09
):
    measurements = {emb.stem: [] for emb in embryos}
    for emb in embryos:
        print(emb.name)
        hp = hatching_points[emb.stem]
        img = imread(emb, key=range(0, hp, interval))
        vnc_len = vnc_length.measure_VNC_centerline(
            img, thres_rel=thres_rel, min_dist=min_dist, outlier_thres=outlier_thres
        )
        measurements[emb.stem] = vnc_len
    return measurements


def get_annotated_data(embryos):
    annotated = {e.stem: [] for e in embryos}

    for emb in embryos:
        hp = hatching_points[emb.stem]
        csv_end = hp // interval
        ann = mapping[emb.name]
        ann_path = annotated_dir.joinpath(ann)
        ann_data = vnc_length.get_length_from_csv(ann_path, columns=[1], end=csv_end)
        annotated[emb.stem] = ann_data
    return annotated


embryo_files = sorted(img_dir.glob("*ch2.tif"), key=utils.emb_number)
annotated_files = sorted(annotated_dir.glob("*ch2.csv"), key=utils.emb_number)
mapping = centerline_errors.get_matching_embryos(
    embryo_files, annotated_files, name_LUT
)

embryos = [Path(img_dir).joinpath(e) for e in mapping.keys()]

hatching_points = get_hatching_points(embryos)
annotated = get_annotated_data(embryos)
measurements = measure(
    embryos, hatching_points, interval=interval, thres_rel=0.2, min_dist=7
)

Compare center line estimation against annotated data for `n` embryos, starting at index `start`.

In [None]:
start = 0
n = 3

embryo_names = list(measurements.keys())[start : start + n]

fig, ax = plt.subplots(len(embryo_names))
ax = ax.ravel()
fig.canvas.header_visible = False
fig.canvas.resizable = False
fig.suptitle("Centerline estimation")
for i, emb_name in enumerate(embryo_names):
    l = min(len(measurements[emb_name]), len(annotated[emb_name]))
    x = list(range(0, l * interval, interval))
    ax[i].plot(x, measurements[emb_name][:l], color="r", label="calculated")
    ax[i].plot(x, annotated[emb_name][:l], color="g", label="annotated")
    ax[i].set_title(emb_name)
ax[0].legend()
plt.tight_layout()

Center line estimation for a single embryo.
Oftenly the parameters that work well for most of the embryos don't perform the same way for a specific embryo. 
This happens because CLE accuracy is influenced by the embryo configuration.
The next cell is used to optimize parameters for a single embryo.

In [None]:
i = 13
embryo = list(measurements.keys())[i]
embryo_file = next(e for e in embryos if e.stem == embryo)

calc_measuremnts = measure(
    [embryo_file],
    hatching_points,
    thres_rel=0.5,
    min_dist=5,
    interval=interval,
    outlier_thres=0.09,
)
calc = calc_measuremnts[embryo]

fig, ax = plt.subplots()
fig.canvas.header_visible = False
fig.canvas.resizable = False
fig.suptitle(f"VNC length estimation - {embryo}")

l = min(len(calc), len(annotated[embryo]))
x = list(range(0, l * interval, interval))
ax.plot(x, calc[:l], color="r", label="calculated")
ax.plot(x, annotated[embryo][:l], color="g", label="annotated")
ax.legend()
ax.set_title(embryo)

Visualize how much variability in the centerline estimation we have, since RANSAC in inherently non-deterministic.

In [None]:
i = 6
num_replicates = 3
interval = 20

embryo = embryo_files[i]
hp = hatching_points[embryo.stem]

img = imread(embryo, key=range(0, hp, interval))

replicates = []
for i in range(num_replicates):
    replicates.append(vnc_length.measure_VNC_centerline(img, min_dist=5, thres_rel=0.7))

fig, ax = plt.subplots()
fig.canvas.header_visible = False
fig.canvas.resizable = False
fig.suptitle(f"VNC length estimation for {embryo.stem} (n={num_replicates})")

x = list(range(0, len(replicates[0]) * interval, interval))
for l in replicates:
    ax.plot(x, l)

plt.tight_layout()

Tests all combinations of the passed parameters to find the best performance for a given experiment.

In [None]:
interval = 50
results_dir = project_dir.parents[1].joinpath("results", experiment_name)
results_dir.mkdir(exist_ok=True)
results_path = results_dir.joinpath("centerline_errors_emb")
exp_dir = project_dir

thres_rels = [0, 0.2, 0.4, 0.6]
min_dists = [1, 3, 5, 7, 9]
num_samples = 5

try:
    grid_search.parse_grid_search_output(results_path)
except FileNotFoundError:
    grid = grid_search.search(
        thres_rels,
        min_dists,
        embryos,
        annotated,
        hatching_points,
        interval,
        num_samples,
    )
    grid_search.write_grid(
        grid, results_path, thres_rels, "thres_rel", min_dists, "min_dist"
    )
    grid_search.parse_grid_search_output(results_path)

Calculates the average error for all embryos of an experiment that have annotated data.
The error is defined as the average of the absolute relative error.

In [None]:
errors = centerline_errors.evaluate_CLE_global(measurements, annotated)

x = range(len(errors.keys()))
y = [e[0] for e in errors.values()]

fig, ax = plt.subplots()
ax.plot(x, y, "b.")
ax.set_xticks([])
ax.set_ylabel("Abs relative error")
fig.canvas.header_visible = False
fig.suptitle(f"Error (compared to annotated data) for exp {project_dir.stem}")

plt.tight_layout()

In [None]:
errors = centerline_errors.evaluate_CLE_global(measurements, annotated)

errors_by_frame = {emb.stem: [] for emb in embryos}
for emb in embryos:
    measured = measurements[emb.stem]
    ann = annotated[emb.stem]
    errors_by_frame[emb.stem] = centerline_errors.point_wise_err(measured, ann)

x_labels = [emb_name[:-4] for emb_name in errors_by_frame.keys()]
x = range(len(x_labels))

fig, ax = plt.subplots()
for i, errs in enumerate(errors_by_frame.values()):
    avg_err = sum(errs) / len(errs)
    ax.scatter([i] * len(errs), errs, s=12, linewidths=0)
    ax.plot(i, avg_err, marker="_", color="k", markersize=5)
ax.set_xticks(x, x_labels, rotation="vertical")
ax.set_ylabel("Relative error")
fig.canvas.header_visible = False
fig.suptitle(f"Error (compared to annotated data) for exp {project_dir.stem}")

ax.xaxis.grid(color="0.9")
ax.set_axisbelow(True)
plt.tight_layout()