# Benchmark MED strategy on experimental data


In [None]:
import spikeinterface.full as si
import probeinterface as pi

import matplotlib.pyplot as plt
import numpy as np
import sys

import pandas as pd
import seaborn as sns

import shutil

from pathlib import Path

sys.path.append("..")

from utils import prettify_axes

%matplotlib widget

In [None]:
data_folder = Path("/home/alessio/Documents/data/allen/med/")

In [None]:
job_kwargs = dict(n_jobs=10, progress_bar=True, chunk_duration="1s")

In [None]:
save_fig = True

fig_folder = Path(".") / "figures" / "med"
fig_folder.mkdir(exist_ok=True, parents=True)

# Neuropixels 1.0

In [None]:
np_version = 1

np_bin = data_folder / f"continuous_np{np_version}.dat"
np_med = data_folder / f"continuous_np{np_version}_lossy.dat"


if np_version == 1:
    np_folder = "/home/alessio/Documents/data/allen/npix-open-ephys/618382_2022-03-31_14-27-03/"
else:
    np_folder = "/home/alessio/Documents/data/allen/npix-open-ephys/595262_2022-02-21_15-18-07"



In [None]:
output_folder = Path(f"../data/med/np{np_version}")

### Load recordings

In [None]:
num_channels = 384
dtype = "int16"
fs = 30000
gain_to_uV = 0.195
offset_to_uV = 0

In [None]:
rec = si.read_binary(np_bin, sampling_frequency=fs, num_chan=num_channels, dtype=dtype, 
                     gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV)
rec_med = si.read_binary(np_lossy, sampling_frequency=fs, num_chan=num_channels, dtype=dtype, 
                           gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV)
probe = pi.read_openephys(np_folder)
rec = rec.set_probe(probe)
rec_med = rec_med.set_probe(probe)

### Visualization

In [None]:
w = si.plot_timeseries(rec, channel_ids=rec.channel_ids[30:40], color="k", 
                       show_channel_ids=True)
si.plot_timeseries(rec_med, channel_ids=rec.channel_ids[30:40], color="C0", 
                   show_channel_ids=True, ax=w.ax)

In [None]:
channel_id = 100

start_frame = int(30 * fs)
end_frame = int(32 * fs)

ts = np.arange(start_frame, end_frame) / fs

rec_f = si.bandpass_filter(rec)
rec_med_f = si.bandpass_filter(rec_med)

tr = rec.get_traces(start_frame=start_frame, end_frame=end_frame, channel_ids=[channel_id], 
                    return_scaled=True)
tr_med = rec_med.get_traces(start_frame=start_frame, end_frame=end_frame, channel_ids=[channel_id], 
                              return_scaled=True)
tr_f = rec_f.get_traces(start_frame=start_frame, end_frame=end_frame, channel_ids=[channel_id],
                        return_scaled=True)
tr_med_f = rec_med_f.get_traces(start_frame=start_frame, end_frame=end_frame, channel_ids=[channel_id],
                                return_scaled=True)

In [None]:
fig_tr, axs_tr = plt.subplots(nrows=2, figsize=(10, 7))

axs_tr[0].plot(ts, tr, color="k", alpha=0.8, label="raw")
axs_tr[0].plot(ts, tr_med, color="C0", alpha=0.8, label="med")
axs_tr[0].set_title("Raw")
axs_tr[0].set_xlabel("time (s)")
axs_tr[0].set_ylabel("V ($\mu$ V)")
axs_tr[0].legend()

axs_tr[1].plot(ts, tr_f, color="k", alpha=0.8)
axs_tr[1].plot(ts, tr_med_f, color="C0", alpha=0.8)
axs_tr[1].set_title("Filtered")
axs_tr[1].set_xlabel("time (s)")
axs_tr[1].set_ylabel("V ($\mu$ V)")

### Run spike sorting

In [None]:
sorter_list = ["kilosort2_5"]
sorter_params = {"kilosort2_5": {"n_jobs_bin": 10, "total_memory": "2G"}}

In [None]:
rec_dict = {"raw": rec, "med": rec_med}

if (output_folder / "sorting_raw").is_dir():
    print("Loading sorting outputs")
    sorting_raw = si.load_extractor(output_folder / "sorting_raw")
    sorting_med = si.load_extractor(output_folder / "sorting_med")
    sort_dict = {"raw": sorting_raw, "med": sorting_med}
else:
    working_folder=output_folder / "working"
    if output_folder.is_dir():
        shutil.rmtree(working_folder)
    print(f"Running spike sorting with {sorter_list}")
    sortings = si.run_sorters(sorter_list=sorter_list, recording_dict_or_list=rec_dict, 
                              working_folder=working_folder, sorter_params=sorter_params,
                              verbose=False, mode_if_folder_exists="keep")

    # dump sortings
    sort_dict = {}
    for name, sorting in sortings.items():
        rec_name, _ = name
        sort = sorting.save(folder=output_folder / f"sorting_{rec_name}")
        sort_dict[rec_name] = sort
    # clean up
    shutil.rmtree(working_folder)

print(sort_dict)

In [None]:
sorting_raw = sort_dict['raw']
sorting_med = sort_dict['med']

In [None]:
sorting_raw.unit_ids

In [None]:
# only keep good ones:
selected_units = sorting_raw.unit_ids[sorting_raw.get_property('KSLabel')=="good"]
sorting_good = sorting_raw.select_units(unit_ids=selected_units)
sorting_good

### Waveforms and waveform features

Here we extract waveforms using the "good" sorting output from the original raw data:

In [None]:
we_raw = si.extract_waveforms(rec_f, sorting_good, output_folder / "wf_raw", 
                              load_if_exists=True, ms_after=5, **job_kwargs)
we_med = si.extract_waveforms(rec_med_f, sorting_good, output_folder / "wf_med", 
                              load_if_exists=True, ms_after=5, **job_kwargs)

In [None]:
unit_idxs = [0, 1, 2, 3, 4, 5, 6]

In [None]:
fig_u, ax_u = plt.subplots(nrows=len(unit_idxs), figsize=(7, 15))

for i, unit_idx in enumerate(unit_idxs):
    ax = ax_u[i]
    unit_id = we_raw.sorting.unit_ids[unit_idx]
    si.plot_unit_templates(we_raw, unit_ids=[unit_id], unit_colors={unit_id: "k"}, axes=[ax], radius_um=100)
    ax.get_lines()[-1].set_label("raw")

    si.plot_unit_templates(we_med, unit_ids=[unit_id], unit_colors={unit_id: "C0"}, 
                           axes=[ax], radius_um=100)
    ax.get_lines()[-1].set_label(f"med")
    if i == len(unit_idxs) // 2:
        ax.legend(bbox_to_anchor=(1.2, 0.2))

fig_u.subplots_adjust(hspace=0.5, right=0.8)

In [None]:
template_metrics = si.get_template_metric_names()

In [None]:
sparsity_dict = dict(method="radius", radius_um=30)

# compute template metrics
df_tm = None

tm = si.calculate_template_metrics(we_raw, upsample=10,
                                   sparsity_dict=sparsity_dict)
tm_med = si.calculate_template_metrics(we_med, upsample=10,
                                       sparsity_dict=sparsity_dict)

if sparsity_dict is None:
    tm["unit_id"] = tm.index
    tm["rank"] = ["0"] * len(tm)
else:
    tm["unit_id"] = tm.index.to_frame()["unit_id"].values
    tm["channel_id"] = tm.index.to_frame()["channel_id"].values

    # add channel rank
    for unit_id in np.unique(tm.unit_id):
        if isinstance(unit_id, str):
            tm_unit = tm.query(f"unit_id == '{unit_id}'")
        else:
            tm_unit = tm.query(f"unit_id == {unit_id}")
        
        i = 0
        for index, row in tm_unit.iterrows():
            tm.at[index, "rank"] = str(i)
            i += 1

for metric in template_metrics:
    tm[f"{metric}_med"] = tm_med[metric]

In [None]:
ranks = np.unique(tm["rank"])

In [None]:
fig_m, axs_m = plt.subplots(nrows=len(template_metrics), ncols=len(ranks), figsize=(10, 12))

for i, metric in enumerate(template_metrics):
    for j, rank in enumerate(ranks):
        sns.scatterplot(data=tm.query(f"rank == '{rank}'"), x=metric, y=f"{metric}_med", ax=axs_m[i, j])
        axs_m[i, j].set_yticks([])
        axs_m[i, j].set_xticks([])
        axs_m[i, j].set_xlabel("")
        axs_m[i, j].set_ylabel("")
        axs_m[i, j].axis("equal")
        if i == 0:
            axs_m[i, j].set_title(f"Rank {rank}")
        if i == len(template_metrics) - 1:
            axs_m[i, j].set_xlabel(f"(raw)")
        if j == 0:
            axs_m[i, j].set_ylabel(f"{metric}\n(med)")
        
fig_m.subplots_adjust(wspace=0.5, hspace=0.3)

### Spike sorting comparison

In [None]:
print(sorting_raw, sorting_med)

In [None]:
sortings_med_good = sortings_med.select_units(unit_ids=sortings_med.unit_ids[sortings_med.get_property('KSLabel')=="good"])
sortings_med_good

In [None]:
cmp = si.compare_sorter_to_ground_truth(sortings_good, sortings_med)

In [None]:
fig_p, axs_p = plt.subplots(ncols=3)
sns.swarmplot(y=cmp.get_performance()["accuracy"], ax=axs_p[0], color="g")
sns.swarmplot(y=cmp.get_performance()["precision"], ax=axs_p[1], color="b")
sns.swarmplot(y=cmp.get_performance()["recall"], ax=axs_p[2], color="r")

In [None]:
good_detection_thr = 0.95

In [None]:
well_detected_fraction = np.round(len(cmp.get_well_detected_units(good_detection_thr)) / len(sortings_good.unit_ids),2)
print(f"Fraction of well detected units from MED: {well_detected_fraction}")

# Neuropixels 2.0

In [None]:
np_version = 2

np_bin = data_folder / f"continuous_np{np_version}.dat"
np_med = data_folder / f"continuous_np{np_version}_lossy.dat"


if np_version == 1:
    np_folder = "/home/alessio/Documents/data/allen/npix-open-ephys/618382_2022-03-31_14-27-03/"
else:
    np_folder = "/home/alessio/Documents/data/allen/npix-open-ephys/595262_2022-02-21_15-18-07"



In [None]:
output_folder = Path(f"../data/med/np{np_version}")

### Load recordings

In [None]:
num_channels = 384
dtype = "int16"
fs = 30000
gain_to_uV = 0.195
offset_to_uV = 0

In [None]:
rec = si.read_binary(np_bin, sampling_frequency=fs, num_chan=num_channels, dtype=dtype, 
                     gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV)
rec_med = si.read_binary(np_med, sampling_frequency=fs, num_chan=num_channels, dtype=dtype, 
                         gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV)
probe = pi.read_openephys(np_folder)
rec = rec.set_probe(probe)
rec_med = rec_med.set_probe(probe)

### Visualization

In [None]:
w = si.plot_timeseries(rec, channel_ids=rec.channel_ids[30:40], color="k", 
                       show_channel_ids=True)
si.plot_timeseries(rec_med, channel_ids=rec.channel_ids[30:40], color="C0", 
                   show_channel_ids=True, ax=w.ax)

In [None]:
channel_id = 190

start_frame = int(30 * fs)
end_frame = int(32 * fs)

ts = np.arange(start_frame, end_frame) / fs

rec_f = si.bandpass_filter(rec)
rec_med_f = si.bandpass_filter(rec_med)

tr = rec.get_traces(start_frame=start_frame, end_frame=end_frame, channel_ids=[channel_id], 
                    return_scaled=True)
tr_med = rec_med.get_traces(start_frame=start_frame, end_frame=end_frame, channel_ids=[channel_id], 
                              return_scaled=True)
tr_f = rec_f.get_traces(start_frame=start_frame, end_frame=end_frame, channel_ids=[channel_id],
                        return_scaled=True)
tr_med_f = rec_med_f.get_traces(start_frame=start_frame, end_frame=end_frame, channel_ids=[channel_id],
                                return_scaled=True)

In [None]:
fig_tr, axs_tr = plt.subplots(nrows=2, figsize=(10, 7))

axs_tr[0].plot(ts, tr, color="k", alpha=0.8, label="raw")
axs_tr[0].plot(ts, tr_med, color="C0", alpha=0.8, label="med")
axs_tr[0].set_title("Raw")
# axs_tr[0].set_xlabel("time (s)")
axs_tr[0].set_ylabel("V ($\mu$ V)")
axs_tr[0].legend()

axs_tr[1].plot(ts, tr_f, color="k", alpha=0.8)
axs_tr[1].plot(ts, tr_med_f, color="C0", alpha=0.8)
axs_tr[1].set_title("Filtered")
axs_tr[1].set_xlabel("time (s)")
axs_tr[1].set_ylabel("V ($\mu$ V)")

prettify_axes(axs_tr)
fig_tr.subplots_adjust(hspace=0.3)

In [None]:
if save_fig:
    fig_tr.savefig(fig_folder / f"med_traces_np{np_version}.pdf")

### Run spike sorting

In [None]:
sorter_list = ["kilosort2_5"]
sorter_params = {"kilosort2_5": {"n_jobs_bin": 10, "total_memory": "2G"}}

In [None]:
rec_dict = {"raw": rec, "med": rec_med}

if (output_folder / "sorting_raw").is_dir():
    sorting_raw = si.load_extractor(output_folder / "sorting_raw")
    sorting_med = si.load_extractor(output_folder / "sorting_med")
    sort_dict = {"raw": sorting_raw, "med": sorting_med}
else:
    working_folder=output_folder / "working"
    if output_folder.is_dir():
        shutil.rmtree(working_folder)
    print(f"Running spike sorting with {sorter_list}")
    sortings = si.run_sorters(sorter_list=sorter_list, recording_dict_or_list=rec_dict, 
                              working_folder=working_folder, sorter_params=sorter_params,
                              verbose=False, mode_if_folder_exists="keep")

    # dump sortings
    sort_dict = {}
    for name, sorting in sortings.items():
        rec_name, _ = name
        sort = sorting.save(folder=output_folder / f"sorting_{rec_name}")
        sort_dict[rec_name] = sort

    # clean up
    shutil.rmtree(working_folder)

print(sort_dict)

In [None]:
sorting_raw = sort_dict['raw']
sorting_med = sort_dict['med']

In [None]:
# only keep good ones:
selected_units = sorting_raw.unit_ids[sorting_raw.get_property('KSLabel')=="good"]
sorting_good = sorting_raw.select_units(unit_ids=selected_units)
sorting_good

### Waveforms and waveform features

Here we extract waveforms using the "good" sorting output from the original raw data:

In [None]:
we_raw = si.extract_waveforms(rec_f, sorting_good, output_folder / "wf_raw", 
                              load_if_exists=True, ms_after=5, **job_kwargs)
we_med = si.extract_waveforms(rec_med_f, sorting_good, output_folder / "wf_med", 
                              load_if_exists=True, ms_after=5, **job_kwargs)

we_dict = {"raw": we_raw, "med": we_med}

In [None]:
unit_idxs = [0, 1, 2, 3, 4, 5, 6, 7]

In [None]:
fig_u, ax_u = plt.subplots(nrows=2, ncols=len(unit_idxs) // 2, figsize=(15, 10))

ax_u = ax_u.flatten()
for i, unit_idx in enumerate(unit_idxs):
    ax = ax_u[i]
    unit_id = we_raw.sorting.unit_ids[unit_idx]
    si.plot_unit_templates(we_raw, unit_ids=[unit_id], unit_colors={unit_id: "k"}, axes=[ax], radius_um=100,
                           lw=3)
    ax.get_lines()[-1].set_label("raw")

    si.plot_unit_templates(we_med, unit_ids=[unit_id], unit_colors={unit_id: "C0"}, 
                           axes=[ax], radius_um=100)
    ax.get_lines()[-1].set_label(f"med")
    if i == len(unit_idxs) // 2:
        ax.legend(bbox_to_anchor=(1.2, 0.2))

fig_u.subplots_adjust(hspace=0.5, right=0.8)
prettify_axes(ax_u)

In [None]:
if save_fig:
    fig_u.savefig(fig_folder / f"med_templates_np{np_version}.pdf")

In [None]:
template_metrics = si.get_template_metric_names()

In [None]:
# find channels for each "GT" unit
target_distances = [0, 30, 60, 90]

extremum_channels = si.get_template_extremum_channel(we_raw)
rec_locs = rec.get_channel_locations()

sparsity = {}
for unit, main_ch in extremum_channels.items():
    channels_by_dist[unit] = []
    main_ch_idx = rec.id_to_index(main_ch)
    
    # compute distances
    main_loc = rec_locs[main_ch_idx]
    distances = np.array([np.linalg.norm(loc - main_loc) for loc in rec_locs])
    distances_sort_idxs = np.argsort(distances)
    distances_sorted = distances[distances_sort_idxs]
    dist_idxs = np.searchsorted(distances_sorted, target_distances)
    selected_channel_idxs = distances_sort_idxs[dist_idxs]
    sparsity[unit] = rec.channel_ids[selected_channel_idxs]

In [None]:
# compute template metrics
df_tm = None
for we_name, we in we_dict.items():
    print(f"Calculating template metrics for {we_name}")
    tm = si.calculate_template_metrics(we, upsample=10,
                                       sparsity=sparsity)
    tm["name"] = [we_name] * len(tm)
    if sparsity is None:
        tm["unit_id"] = tm.index
        tm["distance"] = [0] * len(tm)
    else:
        tm["unit_id"] = tm.index.to_frame()["unit_id"].values
        tm["channel_id"] = tm.index.to_frame()["channel_id"].values

        # add channel rank
        for unit_id in np.unique(tm.unit_id):
            if isinstance(unit_id, str):
                tm_unit = tm.query(f"unit_id == '{unit_id}'")
            else:
                tm_unit = tm.query(f"unit_id == {unit_id}")
                
            loc_main = rec.get_channel_locations(channel_ids=[extremum_channels[unit_id]])[0]
            for index, row in tm_unit.iterrows():
                loc = rec.get_channel_locations(channel_ids=[row["channel_id"]])[0]
                distance = np.linalg.norm(loc - loc_main)
                tm.at[index, "distance"] = distance

    if we_name == "raw":
        df_tm = tm
    else:
        for metric in template_metrics:
            df_tm[f"{metric}_med"] = tm[metric]


In [None]:
distances = np.unique(df_tm["distance"])
print(distances)

In [None]:
fig_m, axs_m = plt.subplots(nrows=len(template_metrics), ncols=len(distances), figsize=(15, 10))

for i, metric in enumerate(template_metrics):
    for j, dist in enumerate(distances):
        tm_dist = df_tm.query(f"distance == {dist}")
        sns.scatterplot(data=tm_dist, x=metric, y=f"{metric}_med", 
                        color=f"C{j}",
                        ax=axs_m[i, j])
        axs_m[i, j].set_yticks([])
        axs_m[i, j].set_xticks([])
        axs_m[i, j].set_xlabel("")
        axs_m[i, j].set_ylabel("")
        lims = [np.min(tm_dist[metric]) - 0.2 * np.ptp(tm_dist[metric]), 
                np.max(tm_dist[metric]) + 0.2 * np.ptp(tm_dist[metric])]
        axs_m[i, j].plot(lims, lims, color="grey", alpha=0.7, ls="--")
        axs_m[i, j].axis("equal")
        if i == 0:
            axs_m[i, j].set_title(f"Dist: {int(dist)} $\mu$m")
        if i == len(template_metrics) - 1:
            axs_m[i, j].set_xlabel(f"(raw)")
        if j == 0:
            axs_m[i, j].set_ylabel(f"{metric}\n(med)")

prettify_axes(axs_m, label_fs=11)
fig_m.subplots_adjust(wspace=0.1, hspace=0.3)

In [None]:
if save_fig:
    fig_m.savefig(fig_folder / f"med_features_np{np_version}.pdf")

### Spike sorting comparison

In [None]:
print(sorting_raw, sorting_med)

In [None]:
selected_units = sorting_med.unit_ids[sorting_med.get_property('KSLabel')=="good"]
sorting_med_good = sorting_med.select_units(unit_ids=selected_units)
sorting_med_good

In [None]:
print(sorting_good, sorting_med_good)

In [None]:
cmp = si.compare_sorter_to_ground_truth(sorting_good, sorting_med_good)

In [None]:
fig_p, axs_p = plt.subplots(ncols=3, figsize=(15, 6))
sns.swarmplot(y=cmp.get_performance()["accuracy"], ax=axs_p[0], color="g")
sns.swarmplot(y=cmp.get_performance()["precision"], ax=axs_p[1], color="b")
sns.swarmplot(y=cmp.get_performance()["recall"], ax=axs_p[2], color="r")
prettify_axes(axs_p, label_fs=18)
fig_p.subplots_adjust(wspace=0.3)

In [None]:
if save_fig:
    fig_p.savefig(fig_folder / f"med_ss_np{np_version}.pdf")

In [None]:
good_detection_thr = 0.9

In [None]:
well_detected_fraction = np.round(len(cmp.get_well_detected_units(good_detection_thr)) / len(sorting_good.unit_ids),2)
print(f"Fraction of well detected units from MED: {well_detected_fraction}")

## Conclusion

In its current form, MED is too lossy and it strongly affects downstream analysis, both in terms of waveform shapes and in terms of the performance of spike sorting.