# Benchmark lossy strategies on GT MEArec data

In this notebook we analyze how lossy compression affects downstream analysis, inscluding spike sorting. 

We use two different strategies:

- Bit truncation
- WavPack hybrid mode

The analysis focuses on:

* compression performance
* influence on spike sorting results
* influence on waveform shapes

This notebook assumes the `ephys-compression/scripts/benchmark-lossy-gt.py` has been run and the `ephys-compression/data/results/benchmark-lossy-gt.csv` and `ephys-compression/data/results/benchmark-lossy-gt-wfs.csv` are available. Moreover, ground-truth recordings for NP1 and NP2 needs to be present in the `ephys-compression/data/mearec` folder (using the `ephys-compression/notebooks/generate-gt-neuropixels-recordings.ipynb` notebook).

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import sys
from pathlib import Path

import spikeinterface.full as si

sys.path.append("..")

from audio_numcodecs import WavPackCodec
from utils import prettify_axes

%matplotlib notebook

In [None]:
save_fig = True

fig_folder = Path(".") / "figures" / "lossy"
fig_folder.mkdir(exist_ok=True, parents=True)

In [None]:
data_folder = Path("../data")

In [None]:
res = pd.read_csv(data_folder / "results" / "benchmark-lossy-gt.csv", index_col=False)
res_wfs = pd.read_csv(data_folder / "results" / "benchmark-lossy-gt-wfs.csv", index_col=False)

In [None]:
job_kwargs = {"n_jobs": 10, "chunk_duration": "1s", "progress_bar": True}

In [None]:
# # plot some traces
res_np1 = res.query("probe == 'Neuropixels1.0'")
res_np2 = res.query("probe == 'Neuropixels2.0'")

In [None]:
res.head()

In [None]:
res_wv = res.query("strategy == 'wavpack'")
wv_order = [0] + list(np.sort(np.unique(res_wv.factor))[::-1][:-1])
res_bit = res.query("strategy == 'bit_truncation'")
bit_order = list(np.sort(np.unique(res_bit.factor)))

Compute dataframe with relative errors

In [None]:
template_metrics = si.get_template_metric_names()
df_errors = None
for bit in bit_order[1:]:
    strategy = "bit_truncation"
    new_e_df = res_wfs[["probe", "unit_id", "distance"]].copy()
    new_e_df.loc[:, "strategy"] = [strategy] * len(new_e_df)
    new_e_df.loc[:, "factor"] = [bit] * len(new_e_df)

    for metric in template_metrics:
        metric_gt = f"{metric}_gt"
        metric_tested = f"{metric}_{strategy}_{bit}"
        error = np.abs(res_wfs[metric_tested] - res_wfs[metric_gt]) / np.abs(res_wfs[metric_gt])
        new_e_df.loc[:, f"err_{metric}"] = error
        
    if df_errors is None:
        df_errors = new_e_df
    else:
        df_errors = pd.concat([df_errors, new_e_df])
        
for wv in wv_order[1:]:
    strategy = "wavpack"
    new_e_df = res_wfs[["probe", "unit_id", "distance"]].copy()
    new_e_df.loc[:, "strategy"] = [strategy] * len(new_e_df)
    new_e_df.loc[:, "factor"] = [wv] * len(new_e_df)

    for metric in template_metrics:
        metric_gt = f"{metric}_gt"
        metric_tested = f"{metric}_{strategy}_{wv}"
        error = np.abs(res_wfs[metric_tested] - res_wfs[metric_gt]) / np.abs(res_wfs[metric_gt])
        new_e_df.loc[:, f"err_{metric}"] = error
        
    if df_errors is None:
        df_errors = new_e_df
    else:
        df_errors = pd.concat([df_errors, new_e_df])

In [None]:
bit_cmap = plt.get_cmap("tab10")
wv_cmap = plt.get_cmap("tab10")

bit_colors = {}
for b, bit in enumerate(bit_order):
    bit_colors[bit] = bit_cmap(b / len(bit_order))
wv_colors = {}
for w, wv in enumerate(wv_order):
    wv_colors[wv] = wv_cmap(w / len(wv_order))

In [None]:
probes = np.unique(res.probe)

In [None]:
mearec_file_np1 = res_np1.iloc[0]["rec_gt"]
mearec_file_np2 = res_np2.iloc[0]["rec_gt"]

rec_gt1, sort_gt1 = si.read_mearec(mearec_file_np1)
rec_gt2, sort_gt2 = si.read_mearec(mearec_file_np2)

gt_dict = {"Neuropixels1.0": {"rec": rec_gt1, "sort": sort_gt1},
           "Neuropixels2.0": {"rec": rec_gt2, "sort": sort_gt2}}

fs = rec_gt1.get_sampling_frequency()

# Bit truncation

In [None]:
res_bit.head()

In [None]:
fig_bit_cr_rmse, axs_bit_cr_rmse = plt.subplots(ncols=2, nrows=1, figsize=(15, 6))

# bit truncation
ax = axs_bit_cr_rmse[0]
sns.pointplot(data=res_bit, x="factor", y="CR", hue="probe", ax=ax)
ax.set_ylim(0, 100)
ax.axhline(10, color="grey", ls="--")
ax.axhline(5, color="grey", ls="-.")
ax.text(-0.5, 10.2, "10", color="grey", fontsize=12)
ax.text(-0.5, 5.2, "5", color="grey", fontsize=12)

ax = axs_bit_cr_rmse[1]
sns.pointplot(data=res_bit, x="factor", y="rmse", hue="probe", ax=ax)
ax.set_ylim(-0.5, 7)
ax.axhline(1.5, color="grey", ls="--")
ax.text(-0.5, 1.55, "1.5", color="grey", fontsize=12)

prettify_axes(axs_bit_cr_rmse)

fig_bit_cr_rmse.suptitle("Bit truncation", fontsize=20)
fig_bit_cr_rmse.subplots_adjust(hspace=0.3)

In [None]:
if save_fig:
    fig_bit_cr_rmse.savefig(fig_folder / "bit_cr_rmse.pdf")

### Check traces

In [None]:
fig_bit_traces, axs = plt.subplots(nrows=2, ncols=2, sharex=False, sharey=False, figsize=(15, 10))

alpha = 0.7
channel_ids1 = ["351"]
channel_ids2 = ["313"]
nsec = 0.003
t_start1 = 30.044
t_start2 = 30.02

time_range1 = [t_start1, t_start1 + nsec]
time_range2 = [t_start2, t_start2 + nsec]

lw_gt = 3

rec_gt1_f = si.bandpass_filter(rec_gt1)
rec_gt2_f = si.bandpass_filter(rec_gt2)

time_range = time_range1
timestamps = np.linspace(time_range[0], time_range[1], int(nsec * rec_gt1.get_sampling_frequency()))
start_frame = int(time_range[0] * fs)
end_frame = int(time_range[1] * fs)
traces_gt1 = rec_gt1.get_traces(start_frame=start_frame, end_frame=end_frame,
                                channel_ids=channel_ids1, return_scaled=True)[:, 0]

traces_gt1_f = rec_gt1_f.get_traces(start_frame=start_frame, end_frame=end_frame,
                                    channel_ids=channel_ids1, return_scaled=True)[:, 0]
axs[0, 0].plot(timestamps, traces_gt1, color="k", alpha=alpha, lw=lw_gt, label="GT")
axs[1, 0].plot(timestamps, traces_gt1_f, color="k", alpha=alpha, lw=lw_gt, label="GT")

time_range = time_range2
timestamps = np.linspace(time_range[0], time_range[1], int(nsec * rec_gt1.get_sampling_frequency()))
start_frame = int(time_range[0] * fs)
end_frame = int(time_range[1] * fs)
traces_gt2 = rec_gt2.get_traces(start_frame=start_frame, end_frame=end_frame,
                                channel_ids=channel_ids2, return_scaled=True)[:, 0]

traces_gt2_f = rec_gt2_f.get_traces(start_frame=start_frame, end_frame=end_frame,
                                    channel_ids=channel_ids2, return_scaled=True)[:, 0]
axs[0, 1].plot(timestamps, traces_gt2, color="k", alpha=alpha, lw=lw_gt, label="GT")
axs[1, 1].plot(timestamps, traces_gt2_f, color="k", alpha=alpha, lw=lw_gt, label="GT")

for bit in bit_order:
    if bit > 0: # skip lossless
        for probe in probes:
            row = res_bit.query(f"factor == {bit} and probe == '{probe}'").iloc[0]
            factor = row["factor"]
            strategy = row["strategy"]
            rec_zarr = si.read_zarr(row["rec_zarr_path"])
            rec_f = si.bandpass_filter(rec_zarr)
            if "1" in probe:
                ax_idx = 0
                channel_ids = channel_ids1
                time_range = time_range1
            else:
                ax_idx = 1
                channel_ids = channel_ids2
                time_range = time_range2
            timestamps = np.linspace(time_range[0], time_range[1], int(nsec * rec_gt1.get_sampling_frequency()))
            start_frame = int(time_range[0] * fs)
            end_frame = int(time_range[1] * fs)

            traces = rec_zarr.get_traces(start_frame=start_frame, end_frame=end_frame,
                                         channel_ids=channel_ids, return_scaled=True)[:, 0]
            traces_f = rec_f.get_traces(start_frame=start_frame, end_frame=end_frame,
                                        channel_ids=channel_ids, return_scaled=True)[:, 0]
            
            axs[0, ax_idx].plot(timestamps, traces, color=bit_colors[factor], alpha=alpha, label=f"bit{factor}")
            axs[1, ax_idx].plot(timestamps, traces_f, color=bit_colors[factor], alpha=alpha, label=f"bit{factor}") 
        

axs[0, 0].set_title("Neuropixels 1.0\nRaw")
axs[0, 1].set_title("Neuropixels 2.0\nRaw")
axs[1, 0].set_title("Filtered")
axs[1, 1].set_title("Filtered")
axs[0, 0].legend(ncol=4)
axs[0, 1].legend(ncol=4)

axs[0, 0].set_ylabel("V ($\mu$V)")
axs[1, 0].set_ylabel("V ($\mu$V)")
axs[1, 0].set_xlabel("time (s)")
axs[1, 1].set_xlabel("time (s)")

prettify_axes(axs)
fig_bit_traces.suptitle("Bit truncation - traces", fontsize=15)

In [None]:
if save_fig:
    fig_bit_traces.savefig(fig_folder / "bit_traces.pdf")

In [None]:
np1_lsb_uV = res_np1.iloc[0].lsb_value * 0.195
np2_lsb_uV = res_np2.iloc[0].lsb_value * 0.195

In [None]:
for bit in bit_order:
    print("NP1", np1_lsb_uV * 2**bit, "NP2", np2_lsb_uV * 2**bit)

In [None]:
print(res_bit.query("probe == 'Neuropixels2.0'").sort_values("factor").to_latex(columns=["factor", "CR", "rmse"], index=False))

### Spike sorting 

In [None]:
fig_bit_ss, axs_bit_ss = plt.subplots(ncols=2, nrows=2, figsize=(15, 6))

bit_labels = [int(b) for b in bit_order]

for probe in probes:
    if "1" in probe:
        col = 0
    else:
        col = 1
    
    res_probe = res_bit.query(f"probe == '{probe}'")
    
    ax = axs_bit_ss[0, col]
    df_perf = pd.melt(res_probe, id_vars='factor', var_name='metric', value_name='value', 
                      value_vars=('avg_accuracy', 'avg_precision', 'avg_recall'))
    sns.barplot(x='factor', y='value', hue='metric', data=df_perf,
                order=bit_order, ax=ax, palette=sns.color_palette("Set2"))
    ax.set_xticklabels(bit_labels, rotation=30, ha='right')
    ax.legend(loc=3)
    ax.set_ylabel("")
    ax.set_xlabel("")
    
    ax = axs_bit_ss[1, col]
    df_count = pd.melt(res_probe, id_vars='factor', var_name='Type', value_name='Units', 
             value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
    sns.barplot(x='factor', y='Units', hue='Type', data=df_count,
                order=bit_order, ax=ax, palette=sns.color_palette("Set2"))
    ax.set_xticklabels(bit_labels, rotation=30, ha='right')
    ax.legend(loc=2)
    ax.set_ylabel("# units")
    ax.set_xlabel("")
    ax.set_ylim(-1, 500)
    ax.axhline(100, color="grey", ls="--")

axs_bit_ss[0, 0].set_title("Neuropixels 1.0", fontsize=18)
axs_bit_ss[0, 1].set_title("Neuropixels 2.0", fontsize=18)
axs_bit_ss[1, 0].set_xlabel("# bit")
axs_bit_ss[1, 1].set_xlabel("# bit")
axs_bit_ss[0, 0].set_ylabel("avg. value")
axs_bit_ss[1, 0].set_ylabel("avg. value")


prettify_axes(axs_bit_ss)

fig_bit_ss.suptitle("Spike sorting performance", fontsize=20)
fig_bit_ss.subplots_adjust(hspace=0.3)

In [None]:
if save_fig:
    fig_bit_ss.savefig(fig_folder / "bit_ss.pdf")

### Waveforms

In [None]:
fig_bit_feat, ax_bit_feat = plt.subplots(ncols=len(template_metrics), nrows=2, figsize=(15, 10))

df_errors_bit = df_errors.query("strategy == 'bit_truncation'")

for probe in probes:
    if "1" in probe:
        row = 0
    else:
        row = 1
    for i, metric in enumerate(template_metrics):
        ax = ax_bit_feat[row, i]
        sns.boxenplot(data=df_errors_bit.query(f"probe == '{probe}'"), 
                      x="factor", y=f"err_{metric}", hue="distance",
                      order=bit_order[1:], showfliers=False, ax=ax,
                      palette=sns.color_palette("tab10"))
        ax.set_xlabel("")
        ax.set_ylabel("")
        if i == 0:
            ax.set_ylabel(f"{probe}\n values")
        if row == 1:
            ax.set_xlabel("# bit")
        else:
            ax.set_title(metric, fontsize=15)
        ax.axhline(0.1, color="grey", ls="--")
        ax.text(0, 0.11, "10%", fontsize=10, color="grey")
        ax.set_ylim(-0.05, 0.5)
prettify_axes(ax_bit_feat, label_fs=15)
fig_bit_feat.suptitle("Waveform features", fontsize=20)
fig_bit_feat.subplots_adjust(hspace=0.3, wspace=0.2)

In [None]:
if save_fig:
    fig_bit_feat.savefig(fig_folder / "bit_feature_errors.pdf")

In [None]:
distances = np.unique(res_wfs["distance"])
strategy = "bit_truncation"
figs_bit_features = {}
for probe in probes:
    df_wfs_probe = res_wfs.query(f"probe == '{probe}'")
    for bit in bit_order[1:]:
        fig_m, axs_m = plt.subplots(nrows=len(template_metrics), ncols=len(distances), figsize=(15, 10))

        for i, metric in enumerate(template_metrics):
            for j, dist in enumerate(distances):
                tm_dist = df_wfs_probe.query(f"distance == {dist}")
                gt_metric_name = f"{metric}_gt"
                sns.scatterplot(data=tm_dist, x=gt_metric_name, y=f"{metric}_{strategy}_{bit}", 
                                color=f"C{j}",
                                ax=axs_m[i, j])
                axs_m[i, j].set_yticks([])
                axs_m[i, j].set_xticks([])
                axs_m[i, j].set_xlabel("")
                axs_m[i, j].set_ylabel("")
                lims = [np.min(tm_dist[gt_metric_name]) - 0.2 * np.ptp(tm_dist[gt_metric_name]), 
                        np.max(tm_dist[gt_metric_name]) + 0.2 * np.ptp(tm_dist[gt_metric_name])]
                axs_m[i, j].plot(lims, lims, color="grey", alpha=0.7, ls="--")
                axs_m[i, j].axis("equal")
                if i == 0:
                    axs_m[i, j].set_title(f"Dist: {int(dist)} $\mu$m")
                if i == len(template_metrics) - 1:
                    axs_m[i, j].set_xlabel(f"(gt)")
                if j == 0:
                    axs_m[i, j].set_ylabel(f"{metric}\n({bit})")

        prettify_axes(axs_m, label_fs=11)
        fig_m.suptitle(f"{probe} - {strategy}-{bit}")
        fig_m.subplots_adjust(wspace=0.1, hspace=0.3)
        figs_bit_features[f"{probe}_{strategy}_{bit}"] = fig_m

In [None]:
if save_fig:
    fig_features_folder = fig_folder / "features_bit"
    fig_features_folder.mkdir(exist_ok=True)
    for fig_name, fig in figs_bit_features.items():
        fig.savefig(fig_features_folder / f"{fig_name}.pdf")

# WavPack Hybrid

In [None]:
res_wv.head()

In [None]:
fig_wv_cr_rmse, axs_wv_cr_rmse = plt.subplots(ncols=2, nrows=1, figsize=(15, 6))

# bit truncation
ax = axs_wv_cr_rmse[0]
sns.pointplot(data=res_wv, x="factor", y="CR", hue="probe", ax=ax, order=wv_order)
ax.set_ylim(0, 11)
ax.axhline(10, color="grey", ls="--")
ax.axhline(5, color="grey", ls="-.")
ax.text(-0.5, 10.2, "10", color="grey", fontsize=12)
ax.text(-0.5, 5.2, "5", color="grey", fontsize=12)

ax = axs_wv_cr_rmse[1]
sns.pointplot(data=res_wv, x="factor", y="rmse", hue="probe", ax=ax, order=wv_order)
ax.set_ylim(-0.5, 7)
ax.axhline(1.5, color="grey", ls="--")
ax.text(-0.5, 1.55, "1.5", color="grey", fontsize=12)

prettify_axes(axs_wv_cr_rmse)

fig_wv_cr_rmse.suptitle("WavPack Hybrid", fontsize=20)
fig_wv_cr_rmse.subplots_adjust(hspace=0.3)

In [None]:
if save_fig:
    fig_wv_cr_rmse.savefig(fig_folder / "wv_cr_rmse.pdf")

### Check traces

In [None]:
fig_wv_traces, axs = plt.subplots(nrows=2, ncols=2, sharex=False, sharey=False, figsize=(15, 10))

alpha = 0.8
channel_ids1 = ["351"]
channel_ids2 = ["313"]
nsec = 0.003
t_start1 = 30.044
t_start2 = 30.02

time_range1 = [t_start1, t_start1 + nsec]
time_range2 = [t_start2, t_start2 + nsec]

lw_gt = 3

rec_gt1_f = si.bandpass_filter(rec_gt1)
rec_gt2_f = si.bandpass_filter(rec_gt2)

time_range = time_range1
timestamps = np.linspace(time_range[0], time_range[1], int(nsec * rec_gt1.get_sampling_frequency()))
start_frame = int(time_range[0] * fs)
end_frame = int(time_range[1] * fs)
traces_gt1 = rec_gt1.get_traces(start_frame=start_frame, end_frame=end_frame,
                                channel_ids=channel_ids1, return_scaled=True)[:, 0]

traces_gt1_f = rec_gt1_f.get_traces(start_frame=start_frame, end_frame=end_frame,
                                    channel_ids=channel_ids1, return_scaled=True)[:, 0]
axs[0, 0].plot(timestamps, traces_gt1, color="k", alpha=alpha, lw=lw_gt, label="GT")
axs[1, 0].plot(timestamps, traces_gt1_f, color="k", alpha=alpha, lw=lw_gt, label="GT")

time_range = time_range2
timestamps = np.linspace(time_range[0], time_range[1], int(nsec * rec_gt1.get_sampling_frequency()))
start_frame = int(time_range[0] * fs)
end_frame = int(time_range[1] * fs)
traces_gt2 = rec_gt2.get_traces(start_frame=start_frame, end_frame=end_frame,
                                channel_ids=channel_ids2, return_scaled=True)[:, 0]

traces_gt2_f = rec_gt2_f.get_traces(start_frame=start_frame, end_frame=end_frame,
                                    channel_ids=channel_ids2, return_scaled=True)[:, 0]
axs[0, 1].plot(timestamps, traces_gt2, color="k", alpha=alpha, lw=lw_gt, label="GT")
axs[1, 1].plot(timestamps, traces_gt2_f, color="k", alpha=alpha, lw=lw_gt, label="GT")

for wv in wv_order:
    if wv > 0: # skip lossless
        for probe in probes:
            row = res_wv.query(f"factor == {wv} and probe == '{probe}'").iloc[0]
            factor = row["factor"]
            strategy = row["strategy"]
            rec_zarr = si.read_zarr(row["rec_zarr_path"])
            rec_f = si.bandpass_filter(rec_zarr)
            if "1" in probe:
                ax_idx = 0
                channel_ids = channel_ids1
                time_range = time_range1
            else:
                ax_idx = 1
                channel_ids = channel_ids2
                time_range = time_range2
            timestamps = np.linspace(time_range[0], time_range[1], int(nsec * rec_gt1.get_sampling_frequency()))
            start_frame = int(time_range[0] * fs)
            end_frame = int(time_range[1] * fs)

            traces = rec_zarr.get_traces(start_frame=start_frame, end_frame=end_frame,
                                         channel_ids=channel_ids, return_scaled=True)[:, 0]
            traces_f = rec_f.get_traces(start_frame=start_frame, end_frame=end_frame,
                                        channel_ids=channel_ids, return_scaled=True)[:, 0]
            
            axs[0, ax_idx].plot(timestamps, traces, color=wv_colors[factor], alpha=alpha, label=f"wv{factor}")
            axs[1, ax_idx].plot(timestamps, traces_f, color=wv_colors[factor], alpha=alpha, label=f"wv{factor}") 
        

axs[0, 0].set_title("Neuropixels 1.0\nRaw")
axs[0, 1].set_title("Neuropixels 2.0\nRaw")
axs[1, 0].set_title("Filtered")
axs[1, 1].set_title("Filtered")
axs[0, 0].legend(ncol=4)
axs[0, 1].legend(ncol=4)

axs[0, 0].set_ylabel("V ($\mu$V)")
axs[1, 0].set_ylabel("V ($\mu$V)")
axs[1, 0].set_xlabel("time (s)")
axs[1, 1].set_xlabel("time (s)")

prettify_axes(axs)
fig_wv_traces.suptitle("WavPack Hybrid - traces", fontsize=15)

In [None]:
if save_fig:
    fig_wv_traces.savefig(fig_folder / "wv_traces.pdf")

In [None]:
print(res_wv.query("probe == 'Neuropixels1.0'").sort_values("factor", ascending=False).to_latex(columns=["probe", "factor", "CR", "rmse"], index=False))

print(res_wv.query("probe == 'Neuropixels2.0'").sort_values("factor", ascending=False).to_latex(columns=["probe", "factor", "CR", "rmse"], index=False))

In [None]:
fig_wv_ss, axs_wv_ss = plt.subplots(ncols=2, nrows=2, figsize=(15, 6))

wv_labels = wv_order

for probe in probes:
    if "1" in probe:
        col = 0
    else:
        col = 1
    
    res_probe = res_wv.query(f"probe == '{probe}'")
    
    ax = axs_wv_ss[0, col]
    df_perf = pd.melt(res_probe, id_vars='factor', var_name='metric', value_name='value', 
                      value_vars=('avg_accuracy', 'avg_precision', 'avg_recall'))
    sns.barplot(x='factor', y='value', hue='metric', data=df_perf,
                order=wv_order, ax=ax, palette=sns.color_palette("Set2"))
    ax.set_xticklabels(wv_labels, rotation=30, ha='right')
    ax.legend(loc=3)
    ax.set_ylabel("")
    ax.set_xlabel("")
    
    ax = axs_wv_ss[1, col]
    df_count = pd.melt(res_probe, id_vars='factor', var_name='Type', value_name='Units', 
                       value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
    sns.barplot(x='factor', y='Units', hue='Type', data=df_count,
                order=wv_order, ax=ax, palette=sns.color_palette("Set1"))
    ax.set_xticklabels(wv_labels, rotation=30, ha='right')
    ax.legend(loc=2, ncol=2)
    ax.set_ylabel("# units")
    ax.set_xlabel("")
    ax.set_ylim(-1, 150)
    ax.axhline(100, color="grey", ls="--")

axs_wv_ss[0, 0].set_title("Neuropixels 1.0", fontsize=18)
axs_wv_ss[0, 1].set_title("Neuropixels 2.0", fontsize=18)
axs_wv_ss[1, 0].set_xlabel("hybrid factor")
axs_wv_ss[1, 1].set_xlabel("hybrid factor")
axs_wv_ss[0, 0].set_ylabel("avg. value")
axs_wv_ss[1, 0].set_ylabel("avg. value")


prettify_axes(axs_wv_ss)

fig_wv_ss.suptitle("Spike sorting performance", fontsize=20)
fig_wv_ss.subplots_adjust(hspace=0.3)

In [None]:
if save_fig:
    fig_wv_ss.savefig(fig_folder / "wv_ss.pdf")

In [None]:
fig_wv_feat, ax_wv_feat = plt.subplots(ncols=len(template_metrics), nrows=2, figsize=(15, 10))

df_errors_wv = df_errors.query("strategy == 'wavpack'")

for probe in probes:
    if "1" in probe:
        row = 0
    else:
        row = 1
    for i, metric in enumerate(template_metrics):
        ax = ax_wv_feat[row, i]
        sns.boxenplot(data=df_errors_wv.query(f"probe == '{probe}'"), 
                      x="factor", y=f"err_{metric}", hue="distance",
                      order=wv_order[1:], showfliers=False, ax=ax)
        ax.set_xlabel("")
        ax.set_ylabel("")
        if i == 0:
            ax.set_ylabel(f"{probe}\n values")
        if row == 1:
            ax.set_xlabel("hybrid factor")
        else:
            ax.set_title(metric, fontsize=15)
        ax.axhline(0.1, color="grey", ls="--")
        ax.text(0, 0.11, "10%", fontsize=10, color="grey")
        ax.set_ylim(-0.05, 0.5)
prettify_axes(ax_wv_feat, label_fs=15)
fig_wv_feat.suptitle("Waveform features", fontsize=20)
fig_wv_feat.subplots_adjust(hspace=0.3, wspace=0.2)

In [None]:
if save_fig:
    fig_wv_feat.savefig(fig_folder / "wv_feature_errors.pdf")

In [None]:
distances = np.unique(res_wfs["distance"])
strategy = "wavpack"
figs_wv_features = {}
for probe in probes:
    df_wfs_probe = res_wfs.query(f"probe == '{probe}'")
    for wv in wv_order[1:]:
        fig_m, axs_m = plt.subplots(nrows=len(template_metrics), ncols=len(distances), figsize=(15, 10))

        for i, metric in enumerate(template_metrics):
            for j, dist in enumerate(distances):
                tm_dist = df_wfs_probe.query(f"distance == {dist}")
                gt_metric_name = f"{metric}_gt"
                sns.scatterplot(data=tm_dist, x=gt_metric_name, y=f"{metric}_{strategy}_{wv}", 
                                color=f"C{j}",
                                ax=axs_m[i, j])
                axs_m[i, j].set_yticks([])
                axs_m[i, j].set_xticks([])
                axs_m[i, j].set_xlabel("")
                axs_m[i, j].set_ylabel("")
                lims = [np.min(tm_dist[gt_metric_name]) - 0.2 * np.ptp(tm_dist[gt_metric_name]), 
                        np.max(tm_dist[gt_metric_name]) + 0.2 * np.ptp(tm_dist[gt_metric_name])]
                axs_m[i, j].plot(lims, lims, color="grey", alpha=0.7, ls="--")
                axs_m[i, j].axis("equal")
                if i == 0:
                    axs_m[i, j].set_title(f"Dist: {int(dist)} $\mu$m")
                if i == len(template_metrics) - 1:
                    axs_m[i, j].set_xlabel(f"(gt)")
                if j == 0:
                    axs_m[i, j].set_ylabel(f"{metric}\n({wv})")

        prettify_axes(axs_m, label_fs=11)
        fig_m.suptitle(f"{probe} - {strategy}-{wv}")
        fig_m.subplots_adjust(wspace=0.1, hspace=0.3)
        figs_wv_features[f"{probe}_{strategy}_{wv}"] = fig_m

In [None]:
if save_fig:
    fig_features_folder = fig_folder / "features_wv"
    fig_features_folder.mkdir(exist_ok=True)
    for fig_name, fig in figs_wv_features.items():
        fig.savefig(fig_features_folder / f"{fig_name}.pdf")