# Benchmark truncation on GT MEArec data

In this notebook we analyze how bit truncation affects downstream analysis, insclusing spike sorting. 

The analysis focuses on:

* assessing if truncating bits from NP1 and NP2 data affects extracted waveforms/templates shapes and features (in this case, GT spyking activity is used)
* assessing if spike sorting results are degraded by truncating bits

This notebook assumes the `scripts/benchmark-truncation-gt.py` has been run and the `data/benchmark-truncation-gt.csv` is available.

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

import spikeinterface.full as si

%matplotlib notebook

In [None]:
save_fig = False

fig_folder = Path(".") / "figures"
fig_folder.mkdir(exist_ok=True)

mearec_folder = Path("mearec/")

In [None]:
data_folder = Path("../data")

In [None]:
res = pd.read_csv(data_folder / "benchmark-truncation-gt.csv", index_col=False)

In [None]:
res

In [None]:
job_kwargs = {"n_jobs": 20, "chunk_duration": "1s", "progress_bar": True}

In [None]:
# # plot some traces
res_np1 = res.query("probe == 'Neuropixels1.0'")
res_np2 = res.query("probe == 'Neuropixels2.0'")

### Neuropixels 1

In [None]:
np_version = 1
max_truncation = 6
radius_um = 50
probe_name = "Neuropixels1.0"

In [None]:
fig_np1, axs_np1 = plt.subplots(ncols=2)
ax = axs_np1[0]
sns.pointplot(data=res_np1, x="trunc_bit", y="CR", ax=ax, color="C1")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax = axs_np1[1]
sns.pointplot(data=res_np1, x="trunc_bit", y="rmse", ax=ax, color="C1")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

fig_np1.suptitle("Neuropixels 1.0", fontsize=15)

In [None]:
mearec_file = mearec_folder / f"np{np_version}_mearec_dist-corr.h5"

rec_gt, sort_gt = si.read_mearec(mearec_file)

we_gt = si.extract_waveforms(rec_gt, sort_gt, folder=mearec_folder / f"wf_{mearec_file.stem}",
                             load_if_exists=True, **job_kwargs)

In [None]:
si.plot_timeseries(rec_gt, channel_ids=rec_gt.channel_ids[::10], show_channel_ids=True)

In [None]:
channel_index = 350
nsec = 2
t_start = 30
time_range = [t_start, t_start + nsec]
timestamps = np.linspace(time_range[0], time_range[1], int(nsec * rec_gt.get_sampling_frequency()))

fig_np1, axs = plt.subplots(nrows=3, sharex=True, sharey=False)

traces_gt = rec_gt.get_traces(start_frame=start_frame, end_frame=end_frame,
                              channel_ids=channel_ids, return_scaled=True)[:, 0]

axs[0].plot(timestamps, traces_gt, color="k", alpha=0.8, label="GT")

for index, row in res_np1.iterrows():
    trunc_bit = row["trunc_bit"]
    rec_zarr = si.read_zarr(row["rec_zarr_path"])
    rec_zarr = si.scale(rec_zarr, gain=row.lsb_value, dtype=rec_zarr.get_dtype())
    rec_f = si.bandpass_filter(rec_zarr)
    fs = rec_zarr.get_sampling_frequency()

    start_frame = int(time_range[0] *fs)
    end_frame = int(time_range[1] *fs)

    channel_ids = [rec_zarr.channel_ids[channel_index]]

    traces = rec_zarr.get_traces(start_frame=start_frame, end_frame=end_frame,
                                 channel_ids=channel_ids, return_scaled=True)[:, 0]
    traces_f = rec_f.get_traces(start_frame=start_frame, end_frame=end_frame,
                                channel_ids=channel_ids, return_scaled=True)[:, 0]
    axs[1].plot(timestamps, traces, color=f"C{trunc_bit}", alpha=0.8, label=trunc_bit)
    axs[2].plot(timestamps, traces_f, color=f"C{trunc_bit}", alpha=0.8, label=trunc_bit)    
axs[1].legend()
fig_np1.suptitle("Neuropixels 1.0", fontsize=15)

In [None]:
unit_idxs = [0, 1, 2, 3, 4, 5, 6]

In [None]:
fig_u, ax_u = plt.subplots(nrows=len(unit_idxs), figsize=(7, 15))

for i, unit_idx in enumerate(unit_idxs):
    ax = ax_u[i]
    unit_id = we_gt.sorting.unit_ids[unit_idx]
    si.plot_unit_templates(we_gt, unit_ids=[unit_id], unit_colors={unit_id: "k"}, axes=[ax], radius_um=100)
    ax.get_lines()[-1].set_label("GT")

    for index, row in res.iterrows():
        if str(np_version) in row["probe"]:
            trunc_bit = row['trunc_bit']
            if trunc_bit <= max_truncation:
                we_bit = si.WaveformExtractor.load_from_folder(row["we_path"])

                si.plot_unit_templates(we_bit, unit_ids=[unit_id], unit_colors={unit_id: f"C{row['trunc_bit']}"}, 
                                       axes=[ax], radius_um=100)
                ax.get_lines()[-1].set_label(f"bit{row['trunc_bit']}")
    if i == len(unit_idxs) // 2:
        ax.legend(bbox_to_anchor=(1.2, 0.2))

fig_u.subplots_adjust(hspace=0.5, right=0.8)

In [None]:
sparsity_dict = dict(method="radius", radius_um=radius_um)

# compute template metrics
df_tm = None
for index, row in res.iterrows():
    if str(np_version) in row["probe"]:
        trunc_bit = row['trunc_bit']
        if trunc_bit <= max_truncation:
            print(f"Calculating template metrics for trunc bit: {trunc_bit}")
            we_bit = si.WaveformExtractor.load_from_folder(row["we_path"])
            tm = si.calculate_template_metrics(we_bit, upsample=10,
                                               sparsity_dict=sparsity_dict)
            tm["trunc_bit"] = [trunc_bit] * len(tm)
            if sparsity_dict is None:
                tm["unit_id"] = tm.index
                tm["channel_idx"] = ["0"] * len(tm)
            else:
                tm["unit_id"] = tm.index.to_frame()["unit_id"].values
                tm["channel_id"] = tm.index.to_frame()["channel_id"].values
                
                # add channel rank
                for unit_id in np.unique(tm.unit_id):
                    tm_unit = tm.query(f"unit_id == '{unit_id}'")
                    i = 0
                    for index, row in tm_unit.iterrows():
                        tm.at[index, "channel_idx"] = str(i)
                        i += 1
                
            if df_tm is None:
                df_tm = tm
            else:
                df_tm = pd.concat([df_tm, tm], ignore_index=True)


In [None]:
n_unit_to_plot = 10

if sparsity_dict is None:
    style = "unit_id"
else:
    style = "channel_id"

unit_ids = we_gt.sorting.unit_ids

units = unit_ids[np.random.permutation(len(unit_ids))[:n_unit_to_plot]]

for rank in np.unique(df_tm.channel_idx):
    fig, ax = plt.subplots(figsize=(10, 7))

    df_rank = df_tm.query(f"channel_idx == '{rank}'")
    df_units = df_rank.query(f"unit_id in {list(units)}")
    sns.scatterplot(data=df_units, x="peak_to_valley", 
                    y="half_width", hue="unit_id",
                    size="trunc_bit", palette="tab20", ax=ax)
    ax.legend(ncol=3)
    
    fig.suptitle(f"Rank channel {rank}", fontsize=15)

In [None]:
features = ["peak_to_valley", "peak_trough_ratio", "half_width", "repolarization_slope", "recovery_slope"]

In [None]:
# compare with bit trunc 0
for index, row in df_tm.iterrows():
    unit_id = row["unit_id"]
    channel_id = row["channel_id"]
    
    df_ref = df_tm.query(f"unit_id == '{unit_id}' and channel_id == '{channel_id}' and trunc_bit == 0")
        
    if len(df_ref) == 1:
        ref_series = df_ref.iloc[0]
        for feat in features:
            df_tm.at[index, f"err_{feat}"] = abs((row[feat] - ref_series[feat]) / ref_series[feat])
    else:
        for feat in features:
            df_tm.at[index, f"err_{feat}"] = np.nan

In [None]:
for feat in features:
    fig_err, ax_err = plt.subplots(ncols=2, figsize=(10, 7))
    
    sns.boxplot(data=df_tm.query("channel_idx == '0'"), x="trunc_bit", y=f"err_{feat}", showfliers=False, 
                ax=ax_err[0])
    ax_err[0].set_title("Best channel")
    
    sns.boxplot(data=df_tm, hue="trunc_bit", y=f"err_{feat}", x="channel_idx", showfliers=False,
                ax=ax_err[1])
    ax_err[1].set_title("All best channels")

    
    fig_err.suptitle(feat, fontsize=15)

In [None]:
df_tm.iloc[np.argmax(df_tm.err_recovery_slope)]

In [None]:
unit_id = df_tm.iloc[np.argmax(df_tm.err_recovery_slope)].unit_id

In [None]:
fig_u, ax_u = plt.subplots(figsize=(7, 15))


ax = ax_u
si.plot_unit_templates(we_gt, unit_ids=[unit_id], unit_colors={unit_id: "k"}, axes=[ax], radius_um=100)
ax.get_lines()[-1].set_label("GT")

for index, row in res.iterrows():
    if str(np_version) in row["probe"]:
        trunc_bit = row['trunc_bit']
        if trunc_bit <= max_truncation:
            we_bit = si.WaveformExtractor.load_from_folder(row["we_path"])

            si.plot_unit_templates(we_bit, unit_ids=[unit_id], unit_colors={unit_id: f"C{row['trunc_bit']}"}, 
                                   axes=[ax], radius_um=100)
            ax.get_lines()[-1].set_label(f"bit{row['trunc_bit']}")
ax.legend(bbox_to_anchor=(1.1, 0.2))

fig_u.subplots_adjust(hspace=0.5, right=0.8)

### Plot study results

In [None]:
sorter = "kilosort2_5"
trunc_folder = data_folder / "tmp_compression_bit_gt" / "trunc_GT"
study_folder = trunc_folder / f"study_{probe_name}_{sorter}"

study = si.GroundTruthStudy(study_folder)

In [None]:
study.run_comparisons(exhaustive_gt=True, verbose=True)

In [None]:
len(sort_gt.unit_ids)

In [None]:
dfs = study.aggregate_dataframes()

In [None]:
order = [f"bit{i}" for i in range(max_truncation + 1)]

In [None]:
fig_perf, axs_perf = plt.subplots(nrows=3, figsize=(7, 12))

sns.swarmplot(data=dfs["perf_by_unit"], x="rec_name", y="accuracy", order=order,
              palette="Greens", ax=axs_perf[0])
axs_perf[0].set_title("Accuracy", fontsize=15)

sns.swarmplot(data=dfs["perf_by_unit"], x="rec_name", y="precision", order=order,
              palette="Blues", ax=axs_perf[1])
axs_perf[1].set_title("Precision", fontsize=15)

sns.swarmplot(data=dfs["perf_by_unit"], x="rec_name", y="recall", order=order,
              palette="Reds", ax=axs_perf[2])
axs_perf[2].set_title("Recall", fontsize=15)


fig_perf.subplots_adjust(hspace=0.3)

In [None]:
fig, ax = plt.subplots()
p = ax.get_position()
p.x1 = 0.85
ax.set_position(p)
sns.set_palette(sns.color_palette("Set1"))
df = pd.melt(dfs['perf_by_unit'], id_vars='rec_name', var_name='Metric', value_name='Score', 
             value_vars=('accuracy','precision', 'recall'))
sns.swarmplot(data=df, x='rec_name', y='Score', hue='Metric', dodge=True,
              order=order, ax=ax)
ax.set_xticklabels(order, rotation=30, ha='center')
ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.5)
ax.set_xlabel(None);
ax.set_ylabel('Score');

In [None]:
fig, ax = plt.subplots()
p = ax.get_position()
p.x1=0.85
ax.set_position(p)
df = pd.melt(dfs['count_units'], id_vars='rec_name', var_name='Type', value_name='Units', 
             value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
sns.set_palette(sns.color_palette("Set1"))
sns.barplot(x='rec_name', y='Units', hue='Type', data=df,
            order=order, ax=ax)
ax.set_xticklabels(order, rotation=30, ha='right')
ax.axhline(len(sort_gt.unit_ids), color="grey", ls="--")
ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.1)

In [None]:
dsets = {}
for rec_name in study.rec_names:
    for sort_name in study.sorter_names:
        sorting = study.get_sorting(sort_name, rec_name)
        if sorting is not None:
            dsets[f"{rec_name}_{sort_name}"] = study.get_sorting(sort_name, rec_name)

In [None]:
mcmp = si.compare_multiple_sorters(sorting_list=list(dsets.values()), 
                                   name_list=list(dsets.keys()), 
                                   verbose=True)

In [None]:
si.plot_multicomp_agreement(mcmp)

fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(10, 7))

axes = axs.ravel()[:len(study.rec_names)]
for ax in axs.ravel()[len(study.rec_names):]:
    ax.axis("off")

si.plot_multicomp_agreement_by_sorter(mcmp, axes=axes)

### Neuropixels 2

In [None]:
np_version = 2
max_truncation = 6
radius_um = 30
probe_name = "Neuropixels2.0"

In [None]:
fig_np2, axs_np2 = plt.subplots(ncols=2)
ax = axs_np2[0]
sns.pointplot(data=res_np2, x="trunc_bit", y="CR", ax=ax, color="C0")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax = axs_np2[1]
sns.pointplot(data=res_np2, x="trunc_bit", y="rmse", ax=ax, color="C0")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

fig_np2.suptitle("Neuropixels 2.0", fontsize=15)

In [None]:
mearec_file = mearec_folder / f"np{np_version}_mearec_dist-corr.h5"

rec_gt, sort_gt = si.read_mearec(mearec_file)

we_gt = si.extract_waveforms(rec_gt, sort_gt, folder=mearec_folder / f"wf_{mearec_file.stem}",
                             load_if_exists=True, **job_kwargs)

In [None]:
unit_idxs = [0, 1, 2, 3, 4, 5, 6]

fig_u, ax_u = plt.subplots(nrows=len(unit_idxs), figsize=(7, 15))

for i, unit_idx in enumerate(unit_idxs):
    ax = ax_u[i]
    unit_id = we_gt.sorting.unit_ids[unit_idx]
    si.plot_unit_templates(we_gt, unit_ids=[unit_id], unit_colors={unit_id: "k"}, axes=[ax], radius_um=100)
    ax.get_lines()[-1].set_label("GT")

    for index, row in res.iterrows():
        if str(np_version) in row["probe"]:
            trunc_bit = row['trunc_bit']
            if trunc_bit <= max_truncation:
                we_bit = si.WaveformExtractor.load_from_folder(row["we_path"])

                si.plot_unit_templates(we_bit, unit_ids=[unit_id], unit_colors={unit_id: f"C{row['trunc_bit']}"}, 
                                       axes=[ax], radius_um=100)
                ax.get_lines()[-1].set_label(f"bit{row['trunc_bit']}")
    if i == len(unit_idxs) // 2:
        ax.legend(bbox_to_anchor=(1.2, 0.2))

fig_u.subplots_adjust(hspace=0.5, right=0.8)

In [None]:
sparsity_dict = dict(method="radius", radius_um=radius_um)

# compute template metrics
df_tm = None
for index, row in res.iterrows():
    if str(np_version) in row["probe"]:
        trunc_bit = row['trunc_bit']
        if trunc_bit <= max_truncation:
            print(f"Calculating template metrics for trunc bit: {trunc_bit}")
            we_bit = si.WaveformExtractor.load_from_folder(row["we_path"])
            tm = si.calculate_template_metrics(we_bit, upsample=10,
                                               sparsity_dict=sparsity_dict)
            tm["trunc_bit"] = [trunc_bit] * len(tm)
            if sparsity_dict is None:
                tm["unit_id"] = tm.index
                tm["channel_idx"] = ["0"] * len(tm)
            else:
                tm["unit_id"] = tm.index.to_frame()["unit_id"].values
                tm["channel_id"] = tm.index.to_frame()["channel_id"].values
                
                # add channel rank
                for unit_id in np.unique(tm.unit_id):
                    tm_unit = tm.query(f"unit_id == '{unit_id}'")
                    i = 0
                    for index, row in tm_unit.iterrows():
                        tm.at[index, "channel_idx"] = str(i)
                        i += 1
                
            if df_tm is None:
                df_tm = tm
            else:
                df_tm = pd.concat([df_tm, tm], ignore_index=True)

In [None]:
n_unit_to_plot = 10

if sparsity_dict is None:
    style = "unit_id"
else:
    style = "channel_id"

unit_ids = we_gt.sorting.unit_ids

units = unit_ids[np.random.permutation(len(unit_ids))[:n_unit_to_plot]]

for rank in np.unique(df_tm.channel_idx):
    fig, ax = plt.subplots(figsize=(10, 7))

    df_rank = df_tm.query(f"channel_idx == '{rank}'")
    df_units = df_rank.query(f"unit_id in {list(units)}")
    sns.scatterplot(data=df_units, x="peak_to_valley", 
                    y="half_width", hue="unit_id",
                    size="trunc_bit", palette="tab20", ax=ax)
    ax.legend(ncol=3)
    
    fig.suptitle(f"Rank channel {rank}", fontsize=15)

In [None]:
features = ["peak_to_valley", "peak_trough_ratio", "half_width", "repolarization_slope", "recovery_slope"]

# compare with bit trunc 0
for index, row in df_tm.iterrows():
    unit_id = row["unit_id"]
    channel_id = row["channel_id"]
    
    df_ref = df_tm.query(f"unit_id == '{unit_id}' and channel_id == '{channel_id}' and trunc_bit == 0")
        
    if len(df_ref) == 1:
        ref_series = df_ref.iloc[0]
        for feat in features:
            df_tm.at[index, f"err_{feat}"] = abs((row[feat] - ref_series[feat]) / ref_series[feat])
    else:
        for feat in features:
            df_tm.at[index, f"err_{feat}"] = np.nan

In [None]:
for feat in features:
    fig_err, ax_err = plt.subplots(ncols=2, figsize=(10, 7))
    
    sns.boxplot(data=df_tm.query("channel_idx == '0'"), x="trunc_bit", y=f"err_{feat}", showfliers=False, 
                ax=ax_err[0])
    ax_err[0].set_title("Best channel")
    
    sns.boxplot(data=df_tm, hue="trunc_bit", y=f"err_{feat}", x="channel_idx", showfliers=False,
                ax=ax_err[1])
    ax_err[1].set_title("All best channels")

    
    fig_err.suptitle(feat, fontsize=15)

In [None]:
print(df_tm.iloc[np.argmax(df_tm.err_recovery_slope)])

unit_id = df_tm.iloc[np.argmax(df_tm.err_recovery_slope)].unit_id

In [None]:
fig_u, ax_u = plt.subplots(figsize=(7, 15))

ax = ax_u
si.plot_unit_templates(we_gt, unit_ids=[unit_id], unit_colors={unit_id: "k"}, axes=[ax], radius_um=100)
ax.get_lines()[-1].set_label("GT")

for index, row in res.iterrows():
    if str(np_version) in row["probe"]:
        trunc_bit = row['trunc_bit']
        if trunc_bit <= max_truncation:
            we_bit = si.WaveformExtractor.load_from_folder(row["we_path"])

            si.plot_unit_templates(we_bit, unit_ids=[unit_id], unit_colors={unit_id: f"C{row['trunc_bit']}"}, 
                                   axes=[ax], radius_um=100)
            ax.get_lines()[-1].set_label(f"bit{row['trunc_bit']}")
ax.legend(bbox_to_anchor=(1.1, 0.2))

fig_u.subplots_adjust(hspace=0.5, right=0.8)

### Plot study results

In [None]:
study_folder

In [None]:
sorter = "kilosort2_5"
trunc_folder = data_folder / "tmp_compression_bit_gt" / "trunc_GT"
study_folder = trunc_folder / f"study_{probe_name}_{sorter}"

study = si.GroundTruthStudy(study_folder)
study.run_comparisons(exhaustive_gt=True, verbose=True)

In [None]:
print(len(sort_gt.unit_ids))

In [None]:
dfs = study.aggregate_dataframes()

order = [f"bit{i}" for i in range(max_truncation + 1)]

In [None]:
fig_perf, axs_perf = plt.subplots(nrows=3, figsize=(7, 12))

sns.swarmplot(data=dfs["perf_by_unit"], x="rec_name", y="accuracy", order=order,
              palette="Greens", ax=axs_perf[0])
axs_perf[0].set_title("Accuracy", fontsize=15)

sns.swarmplot(data=dfs["perf_by_unit"], x="rec_name", y="precision", order=order,
              palette="Blues", ax=axs_perf[1])
axs_perf[1].set_title("Precision", fontsize=15)

sns.swarmplot(data=dfs["perf_by_unit"], x="rec_name", y="recall", order=order,
              palette="Reds", ax=axs_perf[2])
axs_perf[2].set_title("Recall", fontsize=15)

fig_perf.subplots_adjust(hspace=0.3)


In [None]:
fig, ax = plt.subplots()
p = ax.get_position()
p.x1 = 0.85
ax.set_position(p)
sns.set_palette(sns.color_palette("Set1"))
df = pd.melt(dfs['perf_by_unit'], id_vars='rec_name', var_name='Metric', value_name='Score', 
             value_vars=('accuracy','precision', 'recall'))
sns.swarmplot(data=df, x='rec_name', y='Score', hue='Metric', dodge=True,
              order=order, ax=ax)
ax.set_xticklabels(order, rotation=30, ha='center')
ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.5)
ax.set_xlabel(None);
ax.set_ylabel('Score');

In [None]:
fig, ax = plt.subplots()
p = ax.get_position()
p.x1=0.85
ax.set_position(p)
df = pd.melt(dfs['count_units'], id_vars='rec_name', var_name='Type', value_name='Units', 
             value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
sns.set_palette(sns.color_palette("Set1"))
sns.barplot(x='rec_name', y='Units', hue='Type', data=df,
            order=order, ax=ax)
ax.set_xticklabels(order, rotation=30, ha='right')
ax.axhline(len(sort_gt.unit_ids), color="grey", ls="--")
ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.1)

In [None]:
dsets = {}
for rec_name in study.rec_names:
    for sort_name in study.sorter_names:
        sorting = study.get_sorting(sort_name, rec_name)
        if sorting is not None:
            dsets[f"{rec_name}_{sort_name}"] = study.get_sorting(sort_name, rec_name)

In [None]:
mcmp = si.compare_multiple_sorters(sorting_list=list(dsets.values()), 
                                   name_list=list(dsets.keys()), 
                                   verbose=True)

In [None]:
si.plot_multicomp_agreement(mcmp)

fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(10, 7))

axes = axs.ravel()[:len(study.rec_names)]
for ax in axs.ravel()[len(study.rec_names):]:
    ax.axis("off")

si.plot_multicomp_agreement_by_sorter(mcmp, axes=axes)