In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

import spikeinterface.full as si

%matplotlib notebook

In [None]:
save_fig = False

fig_folder = Path(".") / "figures"
fig_folder.mkdir(exist_ok=True)

In [None]:
res = pd.read_csv("../data/benchmark-truncation-nogt.csv", index_col=False)

In [None]:
# plot some traces
res_np1 = res.query("probe == 'Neuropixels1.0'")
res_np2 = res.query("probe == 'Neuropixels2.0'")

In [None]:
res

## Spike sorting results

### Neuropixels 1

In [None]:
fig_np1, axs_np1 = plt.subplots(ncols=2)
ax = axs_np1[0]
sns.pointplot(data=res_np1, x="trunc_bit", y="CR", ax=ax, color="C1")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax = axs_np1[1]
sns.pointplot(data=res_np1, x="trunc_bit", y="rmse", ax=ax, color="C1")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

fig_np1.suptitle("Neuropixels 1.0", fontsize=15)

In [None]:
rec_file = "/home/alessio/Documents/data/allen/npix-open-ephys/618382_2022-03-31_14-27-03/Record Node 102"
rec_exp = si.read_openephys(rec_file, stream_id="0")
rec_exp = si.split_recording(rec_exp)[0]   
print(rec_exp)  

In [None]:
channel_index = 200
nsec = 2
t_start = 30
time_range = [t_start, t_start + nsec]
timestamps = np.linspace(time_range[0], time_range[1], int(nsec * 30000))

fig_np1, axs = plt.subplots(nrows=3, sharex=True, sharey=False)

traces_exp = rec_exp.get_traces(start_frame=start_frame, end_frame=end_frame,
                                channel_ids=channel_ids, return_scaled=True)[:, 0]

axs[0].plot(timestamps, traces_exp, color="k", alpha=0.8, label="GT")

for index, row in res_np1.iterrows():
    trunc_bit = row["trunc_bit"]
    rec_zarr = si.read_zarr(row["rec_zarr_path"])
    rec_zarr = si.scale(rec_zarr, gain=row.lsb_value, dtype=rec_zarr.get_dtype())
    rec_f = si.bandpass_filter(rec_zarr)
    fs = rec_zarr.get_sampling_frequency()

    start_frame = int(time_range[0] *fs)
    end_frame = int(time_range[1] *fs)

    channel_ids = [rec_zarr.channel_ids[channel_index]]

    traces = rec_zarr.get_traces(start_frame=start_frame, end_frame=end_frame,
                                 channel_ids=channel_ids, return_scaled=True)[:, 0]
    traces_f = rec_f.get_traces(start_frame=start_frame, end_frame=end_frame,
                                channel_ids=channel_ids, return_scaled=True)[:, 0]
    axs[1].plot(timestamps, traces, color=f"C{trunc_bit}", alpha=0.8, label=trunc_bit)
    axs[2].plot(timestamps, traces_f, color=f"C{trunc_bit}", alpha=0.8, label=trunc_bit)    
axs[1].legend()
fig_np1.suptitle("Neuropixels 1.0", fontsize=15)

In [None]:
sorting_list_np1 = []
sorting_curated_list_np1 = []
sorter_names_np1 = []

for index, row in res_np1.iterrows():
    sort = si.load_extractor(row.sort_path)
    sort_curated = si.load_extractor(row.sort_curated_path)
    rec_name = f"bit{row.trunc_bit}"
    
    sorting_list_np1.append(sort)
    sorting_curated_list_np1.append(sort_curated)
    sorter_names_np1.append(rec_name)
    
    res_np1.loc[index, "num_units"] = len(sort.unit_ids)
    res_np1.loc[index, "num_units_curated"] = len(sort_curated.unit_ids)     

In [None]:
mcmp_np1 = si.compare_multiple_sorters(sorting_list_np1, sorter_names_np1, verbose=True)

In [None]:
mcmp_np1_curated = si.compare_multiple_sorters(sorting_curated_list_np1, sorter_names_np1, verbose=True)

In [None]:
fig_num, axs = plt.subplots(nrows=2, figsize=(7, 10))

sns.barplot(data=res_np1, x="trunc_bit", y="num_units", ax=axs[0],
            palette="Reds")
sns.barplot(data=res_np1, x="trunc_bit", y="num_units_curated", ax=axs[1],
            palette="Blues")
axs[0].set_title("NP1 - RAW")
axs[1].set_title("NP1 - CURATED")
fig_num.subplots_adjust(hspace=0.3)

for agr_count in range(2, 8):
    sort_agreement = mcmp_np1.get_agreement_sorting(minimum_agreement_count=agr_count)
    sort_agreement_curated = mcmp_np1_curated.get_agreement_sorting(minimum_agreement_count=agr_count)
    axs[0].axhline(len(sort_agreement.unit_ids), label=f"Agr. {agr_count}", color=f"C{agr_count}")
    axs[1].axhline(len(sort_agreement.unit_ids), label=f"Agr. {agr_count}", color=f"C{agr_count}")

    print(f"Agreement {agr_count}: \n\tRAW {len(sort_agreement.unit_ids)}\n\tCURATED  {len(sort_agreement_curated.unit_ids)}")
axs[0].legend()
axs[1].legend()

In [None]:
w = si.plot_multicomp_agreement(mcmp_np1)
wc = si.plot_multicomp_agreement(mcmp_np1_curated)

w.ax.set_title("NP1 RAW")
wc.ax.set_title("NP1 CURATED")

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(10, 7))

axes = axs.ravel()[:7]
axs.ravel()[-1].axis("off")

si.plot_multicomp_agreement_by_sorter(mcmp_np1, axes=axes)

fig.suptitle("NP1 RAW")

figc, axs = plt.subplots(nrows=2, ncols=4, figsize=(10, 7))

axes = axs.ravel()[:7]
axs.ravel()[-1].axis("off")

si.plot_multicomp_agreement_by_sorter(mcmp_np1_curated, axes=axes)

figc.suptitle("NP1 CURATED")

### Neuropixels 2

In [None]:
rec_file = "/home/alessio/Documents/data/allen/npix-open-ephys/595262_2022-02-21_15-18-07/Record Node 102"
rec_exp = si.read_openephys(rec_file, stream_id="0")
rec_exp = si.split_recording(rec_exp)[0]   
print(rec_exp)  

In [None]:
fig_np2, axs_np2 = plt.subplots(ncols=2)
ax = axs_np2[0]
sns.pointplot(data=res_np2, x="trunc_bit", y="CR", ax=ax, color="C0")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax = axs_np2[1]
sns.pointplot(data=res_np2, x="trunc_bit", y="rmse", ax=ax, color="C0")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

fig_np2.suptitle("Neuropixels 2.0", fontsize=15)

In [None]:
channel_index = 150

fig_np2, axs = plt.subplots(nrows=4, sharex=True, sharey=False)

traces_exp = rec_exp.get_traces(start_frame=start_frame, end_frame=end_frame,
                                channel_ids=channel_ids, return_scaled=True)[:, 0]
rec_exp_f = si.bandpass_filter(rec_exp)
traces_exp_f = rec_exp_f.get_traces(start_frame=start_frame, end_frame=end_frame,
                                    channel_ids=channel_ids, return_scaled=True)[:, 0]
axs[0].plot(timestamps, traces_exp, color="k", alpha=0.8, label="GT")
axs[1].plot(timestamps, traces_exp_f, color="k", alpha=0.8, label="GT_f")

for index, row in res_np2.iterrows():
    trunc_bit = row["trunc_bit"]
    rec_zarr = si.read_zarr(row["rec_zarr_path"])
    rec_zarr = si.scale(rec_zarr, gain=row.lsb_value, dtype=rec_zarr.get_dtype())
    rec_f = si.bandpass_filter(rec_zarr)
    fs = rec_zarr.get_sampling_frequency()

    start_frame = int(time_range[0] *fs)
    end_frame = int(time_range[1] *fs)

    channel_ids = [rec_zarr.channel_ids[channel_index]]

    traces = rec_zarr.get_traces(start_frame=start_frame, end_frame=end_frame,
                                 channel_ids=channel_ids, return_scaled=True)[:, 0]
    traces_f = rec_f.get_traces(start_frame=start_frame, end_frame=end_frame,
                                channel_ids=channel_ids, return_scaled=True)[:, 0]
    axs[2].plot(timestamps, traces, color=f"C{trunc_bit}", alpha=0.8, label=trunc_bit)
    axs[3].plot(timestamps, traces_f, color=f"C{trunc_bit}", alpha=0.8, label=trunc_bit)    
axs[2].legend()
fig_np2.suptitle("Neuropixels 2.0", fontsize=15)

In [None]:
sorting_list_np2 = []
sorting_curated_list_np2 = []
sorter_names_np2 = []

for index, row in res_np2.iterrows():
    sort = si.load_extractor(row.sort_path)
    sort_curated = si.load_extractor(row.sort_curated_path)
    rec_name = f"bit{row.trunc_bit}"
    
    sorting_list_np2.append(sort)
    sorting_curated_list_np2.append(sort_curated)
    sorter_names_np2.append(rec_name)
    
    res_np2.loc[index, "num_units"] = len(sort.unit_ids)
    res_np2.loc[index, "num_units_curated"] = len(sort_curated.unit_ids)     

In [None]:
mcmp_np2 = si.compare_multiple_sorters(sorting_list_np2, sorter_names_np2, verbose=True)

In [None]:
mcmp_np2_curated = si.compare_multiple_sorters(sorting_curated_list_np2, sorter_names_np2, verbose=True)

In [None]:
fig_num, axs = plt.subplots(nrows=2, figsize=(7, 10))

sns.barplot(data=res_np2, x="trunc_bit", y="num_units", ax=axs[0],
            palette="Reds")
sns.barplot(data=res_np2, x="trunc_bit", y="num_units_curated", ax=axs[1],
            palette="Blues")
axs[0].set_title("NP2 - RAW")
axs[1].set_title("NP2 - CURATED")
fig_num.subplots_adjust(hspace=0.3)

for agr_count in range(2, 8):
    sort_agreement = mcmp_np2.get_agreement_sorting(minimum_agreement_count=agr_count)
    sort_agreement_curated = mcmp_np2_curated.get_agreement_sorting(minimum_agreement_count=agr_count)
    axs[0].axhline(len(sort_agreement.unit_ids), label=f"Agr. {agr_count}", color=f"C{agr_count}")
    axs[1].axhline(len(sort_agreement.unit_ids), label=f"Agr. {agr_count}", color=f"C{agr_count}")

    print(f"Agreement {agr_count}: \n\tRAW {len(sort_agreement.unit_ids)}\n\tCURATED  {len(sort_agreement_curated.unit_ids)}")
axs[0].legend()
axs[1].legend()

In [None]:
w = si.plot_multicomp_agreement(mcmp_np2)
wc = si.plot_multicomp_agreement(mcmp_np2_curated)

w.ax.set_title("NP2 RAW")
wc.ax.set_title("NP2 CURATED")

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(10, 7))

axes = axs.ravel()[:7]
axs.ravel()[-1].axis("off")

si.plot_multicomp_agreement_by_sorter(mcmp_np2, axes=axes)

fig.suptitle("NP2 RAW")

figc, axs = plt.subplots(nrows=2, ncols=4, figsize=(10, 7))

axes = axs.ravel()[:7]
axs.ravel()[-1].axis("off")

si.plot_multicomp_agreement_by_sorter(mcmp_np2_curated, axes=axes)

figc.suptitle("NP2 CURATED")