# Figure 6 - application to HD-MEA datasets

In [None]:
import numpy as np
import matplotlib.pylab as plt
import MEAutility as mu
import numpy as np
from pathlib import Path
from pprint import pprint
from probeinterface import plotting
from tqdm import tqdm

%matplotlib widget

import axon_velocity as av

In [None]:
save_figs = True

fig_folder = Path("figures/") / "figure6"
fig_folder.mkdir(exist_ok=True)

In [None]:
params = av.get_default_graph_velocity_params()

# change params
params['detect_threshold'] = 0.01
params['kurt_threshold'] = 0.1
params['peak_std_threshold'] = 0.8
params['upsample'] = 5
params['neighbor_radius'] = 100
params['r2_threshold'] = 0.8

pprint(params)

In [None]:
def plot_unit_summary(gtr, probe):
    # amplitude map
    fig_amp, ax_amp = plt.subplots()
    _ = av.plot_amplitude_map(gtr.template, gtr.locations, log=True, ax=ax_amp, colorbar=False,
                              colorbar_orientation="horizontal")
    
    # latency map
    fig_peak, ax_peak = plt.subplots()  
    _ = av.plot_peak_latency_map(gtr.template, gtr.locations, gtr.fs, ax=ax_peak, colorbar=False,
                                 colorbar_orientation="horizontal")
    
    # branches
    fig_branches, ax_branches = plt.subplots()  
    _ = plotting.plot_probe(probe, ax=ax_branches, contacts_kargs={"alpha": 0.1}, probe_shape_kwargs={"alpha": 0.1})
    ax_branches.axis("off")
    ax_branches.set_title("")
    ax_branches.plot(gtr.locations[gtr.selected_channels, 0], gtr.locations[gtr.selected_channels, 1], marker=".",
                     color="k", alpha=0.1, markersize=3, ls="")
    cm = plt.get_cmap("tab20")
    for i, br in enumerate(gtr.branches):
        ax_branches.plot(gtr.locations[br["channels"], 0], gtr.locations[br["channels"], 1], 
                         marker=".", color=cm(i / len(gtr.branches)), ls="-", alpha=0.8, label=i)
    # velocities
    fig_vel, ax_vel = plt.subplots()  
    av.plot_branch_velocities(gtr.branches, legend=False, ax=ax_vel, cmap="tab20")
    ax_vel.spines['right'].set_visible(False)
    ax_vel.spines['top'].set_visible(False)
    
    fig_dict = dict(amplitude=fig_amp, latency=fig_peak, branches=fig_branches, velocity=fig_vel)
    return fig_dict

In [None]:
data_folder = Path("../data/")
mea1k_folder = data_folder / "mea1k"
dualmode_folder = data_folder / "dualmode"

## Load MEA1k data

In [None]:
load_dict = np.load(mea1k_folder / "mea1k.npz")
templates_mea1k = load_dict["templates"]
locations_mea1k = load_dict["locations"]
fs_mea1k = load_dict["fs"]

### Load or recompute axonal branches

In [None]:
if Path(mea1k_folder / "gtrs.npy").is_file():
    print("Loading existing axonal branches")
    gtrs_mea1k = np.load(mea1k_folder / "gtrs.npy", allow_pickle=True)
    gtrs_mea1k = gtrs_mea1k.item()
else:
    print("Computing and saving axonal branches")
    gtrs_mea1k = dict()
    for i in tqdm(range(len(templates_mea1k)), desc="Extracting axons"):
        template = templates_mea1k[i]
        try:
            gtr = av.compute_graph_propagation_velocity(template, locations_mea1k, fs_mea1k, 
                                                        verbose=False, **params)
            gtrs_mea1k[i] = gtr
            print(f"Found axon for unit {i}")        
        except Exception as e:
            print(f"Failed on {i}: error {e}")
    np.save(mea1k_folder / "gtrs.npy", gtrs_mea1k)

In [None]:
print(f"MEA1k: Found {len(gtrs_mea1k)} units with detectable axons out of {len(templates_mea1k)}")

In [None]:
gtrs_mea1k.keys()

In [None]:
mea1k_selected_unit_idxs = [8, 76] # 31

In [None]:
probe_mea1k = av.plotting._get_probe(locations_mea1k)

fig_mea1k, ax = plt.subplots(figsize=(10, 7))
_ = plotting.plot_probe(probe_mea1k, ax=ax, contacts_kargs={"alpha": 0.1}, probe_shape_kwargs={"alpha": 0.1})
ax.axis("off")

i = 0
i_sel = 0
cmap = "tab20"
cm = plt.get_cmap(cmap)
for i, gtr in gtrs_mea1k.items():
    
    if i in mea1k_selected_unit_idxs:
        color = f"C{i_sel}"
        lw = 3
        alpha = 1
        zorder = 10
        i_sel += 1
    else:
        color = cm(i / len(gtrs_mea1k))
        lw = 1
        alpha = 1
        zorder = 1
    if len(gtr.branches) > 0:
        ax.plot(gtr.locations[gtr.init_channel, 0], gtr.locations[gtr.init_channel, 1], 
                marker="o", markersize=5, color=color, alpha=alpha, zorder=zorder)
        for b_i, br in enumerate(gtr.branches):
            if b_i == 0:
                ax.plot(gtr.locations[br["channels"], 0], gtr.locations[br["channels"], 1], marker="", color=color,
                        lw=lw, alpha=alpha, zorder=zorder, label=i)
            else:
                ax.plot(gtr.locations[br["channels"], 0], gtr.locations[br["channels"], 1], marker="", color=color,
                        lw=lw, alpha=alpha, zorder=zorder)

# ax.legend(ncol=10)
ax.plot([0, 500], [1900, 1900], color="k", marker="|")
ax.text(100, 1920, "500$\mu$m", color="k")
ax.set_title("")

In [None]:
figs_mea1k = []
for i in mea1k_selected_unit_idxs:
    fig_dict = plot_unit_summary(gtrs_mea1k[i], probe_mea1k)
    figs_mea1k.append(fig_dict)

In [None]:
if save_figs:
    for i, fig_dict in enumerate(figs_mea1k):
        for fig_name, fig in fig_dict.items():
            fig.savefig(fig_folder / f"mea1k_neuron{i+1}_{fig_name}.png", dpi=600)
    fig_mea1k.savefig(fig_folder / f"mea1k.png", dpi=600)

## Load DualMode data

In [None]:
load_dict = np.load(dualmode_folder / "dualmode.npz")
templates_dualmode = load_dict["templates"]
locations_dualmode = load_dict["locations"]
fs_dualmode = load_dict["fs"]

In [None]:
params['upsample'] = 10 # to get ~ 100kHz

In [None]:
if Path(dualmode_folder / "gtrs.npy").is_file():
    print("Loading existing axonal branches")
    gtrs_dualmode = np.load(dualmode_folder / "gtrs.npy", allow_pickle=True)
else:
    print("Computing and saving axonal branches")
    gtrs_dualmode = dict()
    for i in tqdm(range(len(templates_dualmode)), desc="Extracting axons"):
        template = templates_dualmode[i]
        try:
            gtr = av.compute_graph_propagation_velocity(template, locations_dualmode, fs_dualmode, 
                                                        verbose=False, **params)
            gtrs_dualmode[i] = gtr
            print(f"Found axon for unit {i}")        
        except Exception as e:
            print(f"Failed on {i}: error {e}")
    np.save(dualmode_folder / "gtrs.npy", gtrs_dualmode)

In [None]:
print(f"DualMode: Found {len(gtrs_dualmode)} units with detectable axons out of {len(templates_dualmode)}")

In [None]:
dualmode_selected_unit_idxs = [20, 45]

In [None]:
probe_dualmode = av.plotting._get_probe(locations_dualmode)

fig_dualmode, ax = plt.subplots(figsize=(10, 7))
_ = plotting.plot_probe(probe_dualmode, ax=ax, contacts_kargs={"alpha": 0.1}, probe_shape_kwargs={"alpha": 0.1})
ax.axis("off")

i = 0
i_sel = 0
cmap = "tab20"
cm = plt.get_cmap(cmap)
for i, gtr in gtrs_dualmode.items():
    
    if i in dualmode_selected_unit_idxs:
        color = f"C{i_sel}"
        lw = 3
        alpha = 1
        zorder = 10
        i_sel += 1
    else:
        color = cm(i / len(gtrs_dualmode))
        lw = 1
        alpha = 1
        zorder = 1

    if len(gtr.branches) > 0:
        ax.plot(gtr.locations[gtr.init_channel, 0], gtr.locations[gtr.init_channel, 1], 
                marker="o", markersize=5, color=color, alpha=alpha, zorder=zorder)
        for b_i, br in enumerate(gtr.branches):
            if b_i == 0:
                ax.plot(gtr.locations[br["channels"], 0], gtr.locations[br["channels"], 1], marker="", color=color,
                        lw=lw, alpha=alpha, zorder=zorder, label=i)
            else:
                ax.plot(gtr.locations[br["channels"], 0], gtr.locations[br["channels"], 1], marker="", color=color,
                        lw=lw, alpha=alpha, zorder=zorder)

# ax.legend(ncol=10)
ax.plot([0, 500], [1700, 1700], color="k", marker="|")
ax.text(100, 1720, "500$\mu$m", color="k")
ax.set_title("")

In [None]:
figs_dualmode = []
for i in dualmode_selected_unit_idxs:
    fig_dict = plot_unit_summary(gtrs_dualmode[i], probe_dualmode)
    figs_dualmode.append(fig_dict)

In [None]:
if save_figs:
    for i, fig_dict in enumerate(figs_dualmode):
        for fig_name, fig in fig_dict.items():
            fig.savefig(fig_folder / f"dualmode_neuron{i+1}_{fig_name}.png", dpi=600)
    fig_dualmode.savefig(fig_folder / f"dualmode.png", dpi=600)