# Figure 3 - channel selection

This notebook shows the four different selection filters available on a cell model from the DualMode dataset.
The dataset (`dualmode.npz`) needs to be downloaded from Zenodo (https://doi.org/10.5281/zenodo.4896745) and placed in the `axon_velocity/data/dualmode` folder.

In [None]:
import numpy as np
import matplotlib.pylab as plt
import MEAutility as mu
from scipy.signal import resample_poly
from scipy.stats import kurtosis, linregress
from matplotlib import gridspec
from scipy import io
import numpy as np
import networkx as nx
from pathlib import Path
from pprint import pprint
import sys
import os
import matplotlib as mpl

#%matplotlib widget
%matplotlib inline

In [None]:
import axon_velocity as av

In [None]:
save_fig = True
fig_folder =  Path('figures') / "figure3"
fig_folder.mkdir(exist_ok=True, parents=True)

### Define algorithm params

In [None]:
params = av.get_default_graph_velocity_params()

# change params
params['detect_threshold'] = 0.02
params['kurt_threshold'] = 0.3
params['peak_std_threshold'] = 0.8
params['init_delay'] = 0.2
params['upsample'] = 10

pprint(params)

In [None]:
dualmode_folder = Path('..') / 'data' / 'dualmode'
load_dict = np.load(dualmode_folder / "dualmode.npz")
templates = load_dict["templates"]
locations = load_dict["locations"]
fs = load_dict["fs"]

In [None]:
selected_unit = 20

In [None]:
template = templates[selected_unit]

In [None]:
gtr = av.GraphAxonTracking(template, locations, fs, verbose=True, **params)

In [None]:
gtr.select_channels()

In [None]:
fig_amp = plt.figure(figsize=(10, 5))
ax_amp = fig_amp.add_subplot(111)
ax_amp = av.plot_amplitude_map(template, locations, log=True, ax=ax_amp)
ax_amp.set_title(f"Amplitude", fontsize=20)

fig_peaks = plt.figure(figsize=(10, 5))
ax_peaks = fig_peaks.add_subplot(111)
ax_peaks = av.plot_peak_latency_map(template, locations, fs=fs, log=True, ax=ax_peaks)
ax_peaks.set_title(f"Peak latency", fontsize=20)

In [None]:
fig_detect = plt.figure(figsize=(10, 5))
ax_detect = fig_detect.add_subplot(111)
ax_detect.set_title(f"Selected after detection threshold: {gtr0._detect_threshold}", fontsize=20)
channeld_detection = np.array(list(gtr0._selected_channels_detect))
plt.plot(gtr0.locations[:, 0], gtr0.locations[:, 1], marker=".", color="grey", ls="", alpha=0.2)
plt.plot(gtr0.locations[channeld_detection, 0], 
         gtr0.locations[channeld_detection, 1], marker=".", color="k", ls="", alpha=0.5)
ax_detect.axis("off")

In [None]:
fig_kurt = plt.figure(figsize=(10, 5))
ax_kurt = fig_kurt.add_subplot(111)
ax_kurt.set_title(f"Selected after kurtosis threshold: {gtr0._kurt_threshold}", fontsize=20)
channeld_kurt = np.array(list(gtr0._selected_channels_kurt))
ax_kurt.plot(gtr0.locations[:, 0], gtr0.locations[:, 1], marker=".", color="grey", ls="", alpha=0.2)
ax_kurt.plot(gtr0.locations[channeld_kurt, 0], 
         gtr0.locations[channeld_kurt, 1], marker=".", color="k", ls="", alpha=0.5)
ax_kurt.axis("off")

In [None]:
fig_peak = plt.figure(figsize=(10, 5))
ax_peak = fig_peak.add_subplot(111)
ax_peak.set_title(f"Selected after peak std threshold: {gtr0._peak_std_threhsold} ms", fontsize=20)
channeld_peak = np.array(list(gtr0._selected_channels_peakstd))
ax_peak.plot(gtr0.locations[:, 0], gtr0.locations[:, 1], marker=".", color="grey", ls="", alpha=0.2)
ax_peak.plot(gtr0.locations[channeld_peak, 0], 
         gtr0.locations[channeld_peak, 1], marker=".", color="k", ls="", alpha=0.5)
ax_peak.axis("off")

In [None]:
fig_init = plt.figure(figsize=(10, 5))
ax_init = fig_init.add_subplot(111)
ax_init.set_title(f"Selected after init delay threshold: {gtr0._init_delay} ms", fontsize=20)
channeld_init = np.array(list(gtr0._selected_channels_init))
ax_init.plot(gtr0.locations[:, 0], gtr0.locations[:, 1], marker=".", color="grey", ls="", alpha=0.2)
ax_init.plot(gtr0.locations[channeld_init, 0], 
         gtr0.locations[channeld_init, 1], marker=".", color="k", ls="", alpha=0.5)
ax_init.axis("off")

In [None]:
fig_all = plt.figure(figsize=(10, 5))
ax_all = fig_all.add_subplot(111)
ax_all.set_title(f"Selected after all thresholds", fontsize=20)
channeld_all = gtr0.selected_channels
ax_all.plot(gtr0.locations[:, 0], gtr0.locations[:, 1], marker=".", color="grey", ls="", alpha=0.2)
ax_all.plot(gtr0.locations[channeld_all, 0], 
         gtr0.locations[channeld_all, 1], marker=".", color="k", ls="", alpha=0.5)
ax_all.axis("off")

In [None]:
# save figures
if save_fig:
    fig_amp.savefig(fig_folder / 'panelA_amp-map.png', dpi=600)
    fig_peak.savefig(fig_folder / 'panelA_peak-map.png', dpi=600)

    fig_detect.savefig(fig_folder / 'panelB_channel-amp.png', dpi=600)
    fig_kurt.savefig(fig_folder / 'panelC_channel-kurt.png', dpi=600)
    fig_peak.savefig(fig_folder / 'panelD_init-peak-std.png', dpi=600)
    fig_init.savefig(fig_folder / 'panelE_init-init-peak.png', dpi=600)
    fig_all.savefig(fig_folder / 'panelF_all.png', dpi=600)