# Figure 5 - results on simulated data

This notebook runs the tracking algorithm on the simulated data and benchmarks the performance.

To run this notebook, you first need to run the `axon_velocity/simulations/simulation_notebooks/simulate_cell_1-2-3-4.ipynb` notebooks to generate the simulated data.

In [None]:
import numpy as np
import matplotlib.pylab as plt
import numpy as np
from pathlib import Path
from pprint import pprint
import neuroplotlib as nplt
import pandas as pd

from axon_velocity import *
from axon_velocity.models import *
from axon_velocity.evaluation import *

%matplotlib widget

### Define algorithm params

In [None]:
params = get_default_graph_velocity_params()

# change params
params['detect_threshold'] = 0.01
params['kurt_threshold'] = 0.5
params['peak_std_threshold'] = None
params['init_delay'] = 0.1
params['upsample'] = 5
params['neighbor_radius'] = 50
params["r2_threshold"] = 0.9
params["max_distance_for_edge"] = 100
params["max_distance_to_init"] = 300
params["max_peak_latency_for_splitting"] = 0.5

pprint(params)

In [None]:
plot_tracking_figures = False

In [None]:
morphology_folder = Path('..') / 'simulations' / 'neuromorpho' / 'allen_cell_types'
data_folder = Path('..') / 'simulations' / 'simulated_data' / 'allen' 

In [None]:
zspan = 0

In [None]:
morphology_files_dict = {"cell1": morphology_folder / 'H16-06-008-01-20-04_561096006_m.CNG.swc', 
                         "cell2": morphology_folder / 'H16-06-004-01-04-01_538906745_m.CNG.swc',
                         "cell3": morphology_folder / 'H16-03-006-01-04-03_563818992_m.CNG.swc', 
                         "cell4": morphology_folder / 'H17-06-006-11-08-02_606834771_m.CNG.swc'}
cell_folders_dict = {"cell1": data_folder / f'allen0_planar_{zspan}um' , 
                     "cell2": data_folder / f'allen1_planar_{zspan}um' ,
                     "cell3": data_folder / f'allen2_planar_{zspan}um' , 
                     "cell4": data_folder / f'allen3_planar_{zspan}um' }

## Cell 1

In [None]:
cell = "cell1"

In [None]:
cell_folder = cell_folders_dict[cell]
morphology_file_1 = morphology_files_dict[cell]

In [None]:
cell_path = [p for p in cell_folder.iterdir() if p.suffix == '.pkl'][0]
locs_path = [p for p in cell_folder.iterdir() if 'locations' in p.name][0]
template_path = [p for p in cell_folder.iterdir() if 'template' in p.name][0]

In [None]:
cell_1, sections_1 = load_cell(cell_path)

In [None]:
locations_1 = np.load(locs_path)
template_1 = np.load(template_path)

In [None]:
fs = 1 / cell_1.dt * 1000

In [None]:
gtr1 = GraphAxonTracking(template_1, locations_1, fs, verbose=True, **params)

In [None]:
gtr1.select_channels()

In [None]:
gtr1.build_graph()

In [None]:
gtr1.find_paths()

In [None]:
gtr1.clean_paths()

In [None]:
# fpaths_raw, axpaths_raw = plt.subplots(figsize=(7, 10))
# axpaths_raw = gtr1.plot_raw_branches(cmap="tab20", plot_bp=True, plot_neighbors=True, plot_full_template=True,
#                                     ax=axpaths_raw)
# axpaths_raw.legend(fontsize=12)

In [None]:
# plot_axon_summary(gtr1)

In [None]:
# ani = play_template_map(template_1, locations_1, skip_frames=5, log=False)
# ani.save("cell_1_log.gif", writer='imagemagick', fps=10)

In [None]:
if plot_tracking_figures:
    fchans1 = gtr1.plot_channel_selection()
    fgraph1 = gtr1.plot_graph()
    fbranch1 = gtr1.plot_branches()
    fvel1 = gtr1.plot_velocities()

In [None]:
branch_gt1 = extract_ground_truth_velocity(cell_1, sections_1, min_length=50, min_segs=5)
for i, br in enumerate(branch_gt1): 
    print(f"GT branch {i}: velocity {br['velocity']} length: {len(br['idxs'])}")

In [None]:
ev1 = evaluate_tracking_accuracy(gtr1.branches, branch_gt1, cell_1, locations_1)
print(f'Number of matched branches: {len(ev1)}')

In [None]:
cmap_branches = "tab20"
cmap_footprint = "Greys"
alpha_footprint = 0.5
alpha_marker = 0.7
legend_fs = 18

In [None]:
fig1, ax = plt.subplots()
fig1.set_size_inches((10, 10))

morphology_file = morphology_file_1
cell_model = cell_1
evaluation = ev1
locations = locations_1

nplt.plot_neuron(morphology=str(morphology_file), plane='xy', alpha=0.1, ax=ax, position=cell_model.somapos,
                 exclude_sections=['axon'])
nplt.plot_neuron(morphology=str(morphology_file), plane='xy', alpha=0.1, ax=ax, position=cell_model.somapos,
                 exclude_sections=['soma', 'apic', 'basal'], color='g')
plot_amplitude_map(template_1, locations_1, log=True, alpha=alpha_footprint, ax=ax, cmap=cmap_footprint)

cm = plt.get_cmap(cmap_branches)

for i, ev in enumerate(evaluation):  
    ax_idxs_list = ev['axon_idxs']
    channels = ev['channels']
    locs = locations[channels]
    color = cm(i / len(evaluation))
    
    for i_idx, ax_idxs in enumerate(ax_idxs_list):
        if i_idx == 0:
            ax.plot(np.mean(cell_model.x[ax_idxs], 1), 
                    np.mean(cell_model.y[ax_idxs], 1), ls='-', lw=3, color=color, 
                    alpha=1, label=f"branch {i}")
        else:
            ax.plot(np.mean(cell_model.x[ax_idxs], 1), np.mean(cell_model.y[ax_idxs], 1), 
                    ls='-', lw=3, color=color, alpha=1)
    ax.plot(locs[:, 0], locs[:, 1], ls='', marker='.', markeredgecolor="k", 
            color=color, markersize=10, alpha=alpha_marker)

ax.legend(ncol=4, fontsize=legend_fs, loc=9)
ax.axis('equal')
ax.axis('off')

In [None]:
vel_gt = []
vel_est = []
vel_abs = []
vel_rel = []
errors = []
branch_ids = []
models = []

for ie, ev in enumerate(ev1):
    branch_ids.append(ie)
    gtv = int(np.round(ev['velocity_ground_truth']))
    estv = int(np.round(ev['velocity_estimated']))
    vel_gt.append(gtv)
    vel_est.append(estv)
    vel_abs.append(np.abs(gtv - estv))
    vel_rel.append(np.round(np.abs(ev['velocity_ground_truth'] - ev['velocity_estimated']) / 
                   ev['velocity_ground_truth'] * 100, 1))
    if ie == 0:
        models.append("Cell 1")
    else:
        models.append("")
    errors.append(f"{np.round(ev['mean_error'], 1)} $\pm$ {np.round(ev['std_error'], 1)}")


df1 = pd.DataFrame(data={"model ID": models, "branch ID": branch_ids, "velocity GT": vel_gt, "velocity est.": vel_est, 
                         "abs. vel. error": vel_abs, "rel. vel. error": vel_rel,"tracking error": errors})
df1

## Cell 2

In [None]:
cell = "cell2"

In [None]:
cell_folder = cell_folders_dict[cell]
morphology_file_2 = morphology_files_dict[cell]

In [None]:
cell_path = [p for p in cell_folder.iterdir() if p.suffix == '.pkl'][0]
locs_path = [p for p in cell_folder.iterdir() if 'locations' in p.name][0]
template_path = [p for p in cell_folder.iterdir() if 'template' in p.name][0]

In [None]:
cell_2, sections_2 = load_cell(cell_path)

In [None]:
locations_2 = np.load(locs_path)
template_2 = np.load(template_path)

In [None]:
fs = 1 / cell_2.dt * 1000

In [None]:
gtr2 = GraphAxonTracking(template_2, locations_2, fs, verbose=True, **params)

In [None]:
gtr2.select_channels()

In [None]:
gtr2.build_graph()

In [None]:
gtr2.find_paths()

In [None]:
gtr2.clean_paths()

In [None]:
# _ = plot_axon_summary(gtr2)

In [None]:
if plot_tracking_figures:
    fchans2 = gtr2.plot_channel_selection()
    fgraph2 = gtr2.plot_graph()
    fbranch2 = gtr2.plot_branches()
    fvel2 = gtr2.plot_velocities()

In [None]:
branch_gt2 = extract_ground_truth_velocity(cell_2, sections_2, min_length=50, min_segs=5)
for i, br in enumerate(branch_gt2): 
    print(f"GT branch {i}: velocity {br['velocity']} length: {len(br['idxs'])}")

In [None]:
ev2 = evaluate_tracking_accuracy(gtr2.branches, branch_gt2, cell_2, locations_2, max_median_dist_for_match=20)
print(f'Number of matched branches: {len(ev2)}')

In [None]:
fig2, ax = plt.subplots()
fig2.set_size_inches((10, 10))

morphology_file = morphology_file_2
cell_model = cell_2
evaluation = ev2
locations = locations_2

nplt.plot_neuron(morphology=str(morphology_file), plane='xy', alpha=0.1, ax=ax, position=cell_model.somapos,
                 exclude_sections=['axon'])
nplt.plot_neuron(morphology=str(morphology_file), plane='xy', alpha=0.1, ax=ax, position=cell_model.somapos,
                 exclude_sections=['soma', 'apic', 'basal'], color='k')
nplt.plot_neuron(morphology=str(morphology_file), plane='xy', alpha=0.1, ax=ax, position=cell_model.somapos,
                 exclude_sections=['soma', 'apic', 'basal'], color='g')
plot_amplitude_map(template_2, locations_2, log=True, alpha=alpha_footprint, ax=ax, cmap=cmap_footprint)

cm = plt.get_cmap(cmap_branches)

for i, ev in enumerate(evaluation):  
    ax_idxs_list = ev['axon_idxs']
    channels = ev['channels']
    locs = locations[channels]
    color = cm(i / len(evaluation))
    for i_idx, ax_idxs in enumerate(ax_idxs_list):
        if i_idx == 0:
            ax.plot(np.mean(cell_model.x[ax_idxs], 1), 
                    np.mean(cell_model.y[ax_idxs], 1), ls='-', lw=3, color=color, 
                    alpha=1, label=f"branch {i}")
        else:
            ax.plot(np.mean(cell_model.x[ax_idxs], 1), np.mean(cell_model.y[ax_idxs], 1), 
                    ls='-', lw=3, color=color, alpha=1)
    ax.plot(locs[:, 0], locs[:, 1], ls='', marker='.', markeredgecolor="k", 
            color=color, markersize=10, alpha=alpha_marker)

ax.legend(ncol=4, fontsize=legend_fs, loc=9)
ax.axis('equal')
ax.axis('off')

In [None]:
vel_gt = []
vel_est = []
vel_abs = []
vel_rel = []
errors = []
branch_ids = []
models = []

for ie, ev in enumerate(ev2):
    branch_ids.append(ie)
    gtv = int(np.round(ev['velocity_ground_truth']))
    estv = int(np.round(ev['velocity_estimated']))
    vel_gt.append(gtv)
    vel_est.append(estv)
    vel_abs.append(np.abs(gtv - estv))
    vel_rel.append(np.round(np.abs(ev['velocity_ground_truth'] - ev['velocity_estimated']) / 
                   ev['velocity_ground_truth'] * 100, 1))
    if ie == 0:
        models.append("Cell 2")
    else:
        models.append("")
    errors.append(f"{np.round(ev['mean_error'], 1)} $\pm$ {np.round(ev['std_error'], 1)}")


df2 = pd.DataFrame(data={"model ID": models, "branch ID": branch_ids, "velocity GT": vel_gt, "velocity est.": vel_est, 
                         "abs. vel. error": vel_abs, "rel. vel. error": vel_rel,"tracking error": errors})
df2

## Cell 3

In [None]:
cell = "cell3"

In [None]:
cell_folder = cell_folders_dict[cell]
morphology_file_3 = morphology_files_dict[cell]

In [None]:
cell_path = [p for p in cell_folder.iterdir() if p.suffix == '.pkl'][0]
locs_path = [p for p in cell_folder.iterdir() if 'locations' in p.name][0]
template_path = [p for p in cell_folder.iterdir() if 'template' in p.name][0]

In [None]:
cell_3, sections_3 = load_cell(cell_path)

In [None]:
locations_3 = np.load(locs_path)
template_3 = np.load(template_path)

In [None]:
fs = 1 / cell_3.dt * 1000

In [None]:
gtr3 = GraphAxonTracking(template_3, locations_3, fs, verbose=True, **params)

In [None]:
gtr3.select_channels()

In [None]:
gtr3.build_graph()

In [None]:
gtr3.find_paths()

In [None]:
gtr3.clean_paths()

In [None]:
if plot_tracking_figures:
    fchans3 = gtr3.plot_channel_selection()
    fgraph3 = gtr3.plot_graph()
    fbranch3 = gtr3.plot_branches()
    fvel3 = gtr3.plot_velocities()

In [None]:
# plot_axon_summary(gtr3)

In [None]:
branch_gt3 = extract_ground_truth_velocity(cell_3, sections_3)
for i, br in enumerate(branch_gt3): 
    print(f"GT branch {i}: velocity {br['velocity']} length: {len(br['idxs'])}")

In [None]:
ev3 = evaluate_tracking_accuracy(gtr3.branches, branch_gt3, cell_3, locations_3)
print(f'Number of matched branches: {len(ev3)}')

In [None]:
fig3, ax = plt.subplots()
fig3.set_size_inches((10, 10))

morphology_file = morphology_file_3
cell_model = cell_3
evaluation = ev3
locations = locations_3

nplt.plot_neuron(morphology=str(morphology_file), plane='xy', alpha=0.1, ax=ax, position=cell_model.somapos,
                 exclude_sections=['axon'])
nplt.plot_neuron(morphology=str(morphology_file), plane='xy', alpha=0.1, ax=ax, position=cell_model.somapos,
                 exclude_sections=['soma', 'apic', 'basal'], color='k')
nplt.plot_neuron(morphology=str(morphology_file), plane='xy', alpha=0.1, ax=ax, position=cell_model.somapos,
                 exclude_sections=['soma', 'apic', 'basal'], color='g')
plot_amplitude_map(template_3, locations_3, log=True, alpha=alpha_footprint, ax=ax, cmap=cmap_footprint)

cm = plt.get_cmap(cmap_branches)


for i, ev in enumerate(evaluation):  
    ax_idxs_list = ev['axon_idxs']
    channels = ev['channels']
    locs = locations[channels]
    color = cm(i / len(evaluation))
    for i_idx, ax_idxs in enumerate(ax_idxs_list):
        if i_idx == 0:
            ax.plot(np.mean(cell_model.x[ax_idxs], 1), 
                    np.mean(cell_model.y[ax_idxs], 1), ls='-', lw=3, color=color, 
                    alpha=1, label=f"branch {i}")
        else:
            ax.plot(np.mean(cell_model.x[ax_idxs], 1), np.mean(cell_model.y[ax_idxs], 1), 
                    ls='-', lw=3, color=color, alpha=1)
    ax.plot(locs[:, 0], locs[:, 1], ls='', marker='.', markeredgecolor="k", 
            color=color, markersize=10, alpha=alpha_marker)
ax.legend(ncol=4, fontsize=legend_fs, loc=9)
ax.axis('equal')
ax.axis('off')

In [None]:
vel_gt = []
vel_est = []
vel_abs = []
vel_rel = []
errors = []
branch_ids = []
models = []

for ie, ev in enumerate(ev3):
    branch_ids.append(ie)
    gtv = int(np.round(ev['velocity_ground_truth']))
    estv = int(np.round(ev['velocity_estimated']))
    vel_gt.append(gtv)
    vel_est.append(estv)
    vel_abs.append(np.abs(gtv - estv))
    vel_rel.append(np.round(np.abs(ev['velocity_ground_truth'] - ev['velocity_estimated']) / 
                   ev['velocity_ground_truth'] * 100, 1))
    if ie == 0:
        models.append("Cell 3")
    else:
        models.append("")
    errors.append(f"{np.round(ev['mean_error'], 1)} $\pm$ {np.round(ev['std_error'], 1)}")

df3 = pd.DataFrame(data={"model ID": models, "branch ID": branch_ids, "velocity GT": vel_gt, "velocity est.": vel_est, 
                         "abs. vel. error": vel_abs, "rel. vel. error": vel_rel,"tracking error": errors})
df3

## Cell 4

In [None]:
cell = "cell4"

In [None]:
cell_folder = cell_folders_dict[cell]
morphology_file_4 = morphology_files_dict[cell]

In [None]:
cell_path = [p for p in cell_folder.iterdir() if p.suffix == '.pkl'][0]
locs_path = [p for p in cell_folder.iterdir() if 'locations' in p.name][0]
template_path = [p for p in cell_folder.iterdir() if 'template' in p.name][0]

In [None]:
cell_4, sections_4 = load_cell(cell_path)

In [None]:
locations_4 = np.load(locs_path)
template_4 = np.load(template_path)

In [None]:
fs = 1 / cell_4.dt * 1000

In [None]:
gtr4 = GraphAxonTracking(template_4, locations_4, fs, verbose=True, **params)

In [None]:
gtr4.select_channels()

In [None]:
gtr4.build_graph()

In [None]:
gtr4.find_paths()

In [None]:
gtr4.clean_paths()

In [None]:
if plot_tracking_figures:
    fchans3 = gtr4.plot_channel_selection()
    fgraph3 = gtr4.plot_graph()
    fbranch3 = gtr4.plot_branches()
    fvel3 = gtr4.plot_velocities()

In [None]:
# plot_axon_summary(gtr4)

In [None]:
branch_gt4 = extract_ground_truth_velocity(cell_4, sections_4)
for i, br in enumerate(branch_gt4): 
    print(f"GT branch {i}: velocity {br['velocity']} length: {len(br['idxs'])}")

In [None]:
ev4 = evaluate_tracking_accuracy(gtr4.branches, branch_gt4, cell_4, locations_4, max_median_dist_for_match=30)
print(f'Number of matched branches: {len(ev4)}')

In [None]:
fig4, ax = plt.subplots()
fig4.set_size_inches((10, 10))

morphology_file = morphology_file_4
cell_model = cell_4
evaluation = ev4
locations = locations_4

nplt.plot_neuron(morphology=str(morphology_file), plane='xy', alpha=0.1, ax=ax, position=cell_model.somapos,
                 exclude_sections=['axon'])
nplt.plot_neuron(morphology=str(morphology_file), plane='xy', alpha=0.1, ax=ax, position=cell_model.somapos,
                 exclude_sections=['soma', 'apic', 'basal'], color='g')
# nplt.plot_neuron(morphology=str(morphology_file), plane='xy', alpha=0.1, ax=ax, position=cell_model.somapos,
#                  exclude_sections=['soma', 'apic', 'basal'], color='g')
plot_amplitude_map(template_4, locations_4, log=True, alpha=alpha_footprint, ax=ax, cmap=cmap_footprint,
                   )

cm = plt.get_cmap(cmap_branches)

for i, ev in enumerate(evaluation):  
    ax_idxs_list = ev['axon_idxs']
    channels = ev['channels']
    locs = locations[channels]
    color = cm(i / len(evaluation))
    for i_idx, ax_idxs in enumerate(ax_idxs_list):
        if i_idx == 0:
            ax.plot(np.mean(cell_model.x[ax_idxs], 1), 
                    np.mean(cell_model.y[ax_idxs], 1), ls='-', lw=3, color=color, 
                    alpha=1, label=f"branch {i}")
        else:
            ax.plot(np.mean(cell_model.x[ax_idxs], 1), np.mean(cell_model.y[ax_idxs], 1), 
                    ls='-', lw=3, color=color, alpha=1)
    ax.plot(locs[:, 0], locs[:, 1], ls='', marker='.', markeredgecolor="k", 
            color=color, markersize=10, alpha=alpha_marker)

ax.legend(ncol=4, fontsize=legend_fs, loc=9)
ax.axis('equal')
ax.axis('off')

In [None]:
vel_gt = []
vel_est = []
vel_abs = []
vel_rel = []
errors = []
branch_ids = []
models = []

for ie, ev in enumerate(ev4):
    branch_ids.append(ie)
    gtv = int(np.round(ev['velocity_ground_truth']))
    estv = int(np.round(ev['velocity_estimated']))
    vel_gt.append(gtv)
    vel_est.append(estv)
    vel_abs.append(np.abs(gtv - estv))
    vel_rel.append(np.round(np.abs(ev['velocity_ground_truth'] - ev['velocity_estimated']) / 
                   ev['velocity_ground_truth'] * 100, 1))
    if ie == 0:
        models.append("Cell 4")
    else:
        models.append("")
    errors.append(f"{np.round(ev['mean_error'], 1)} $\pm$ {np.round(ev['std_error'], 1)}")


df4 = pd.DataFrame(data={"model ID": models, "branch ID": branch_ids, "velocity GT": vel_gt, "velocity est.": vel_est, 
                         "abs. vel. error": vel_abs, "rel. vel. error": vel_rel,"tracking error": errors})
df4

## Combined results

In [None]:
model_names = ['Cell 1', 'Cell 2', 'Cell 3', 'Cell 4']

In [None]:
df = pd.concat([df1, df2, df3, df4])
df.reset_index()

In [None]:
print(df.to_latex(index=False))

In [None]:
save_figures = True

In [None]:
figures = [fig1, fig2, fig3, fig4]
fig_folder = Path('figures') / "figure5"
fig_folder.mkdir(exist_ok=True, parents=True)

if save_figures:
    for f, m in zip(figures, model_names):
        f.savefig(fig_folder / f"{m}_branches.png", dpi=600)
        f.savefig(fig_folder / f"{m}_branches.pdf")