# Lossy compression


This notebook reproduces the results of the `Lossy compression` results in the paper.

We assumes the `ephys-compression/scripts/benchmark-lossly-sim.py` and `ephys-compression/scripts/benchmark-lossly-exp.py` scripts have been run and the `../data/ephys-compression-results/results-lossy-sim` and `../data/ephys-compression-results/results-lossy-exp` are available.

In [None]:
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [None]:
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import spikeinterface as si

from utils import prettify_axes

%matplotlib inline

In [None]:
figsize_single = (10, 5)
figsize_multi_2rows = (10, 9)
figsize_multi_3rows = (10, 12)
figsize_multi_4rows = (10, 12)

In [None]:
data_folder = Path("../data/")
results_folder = Path("../results/")

save_fig = True
fig_folder = results_folder / "figures"
fig_folder.mkdir(exist_ok=True, parents=True)

results_lossy_sim_folder = data_folder / "ephys-compression-results" / "results-lossy-sim"
results_lossy_exp_folder = data_folder / "ephys-compression-results" / "results-lossy-exp"

In [None]:
res_sim = pd.read_csv(results_lossy_sim_folder / "benchmark-lossy-sim.csv", index_col=False)
res_sim_wfs = pd.read_csv(results_lossy_sim_folder / "benchmark-lossy-sim-waveforms.csv", index_col=False)
res_exp = pd.read_csv(results_lossy_exp_folder / "benchmark-lossy-exp.csv", index_col=False)

In [None]:
probe_names = {"Neuropixels1.0": "NP1",
               "Neuropixels2.0": "NP2"}
for probe, probe_name in probe_names.items():
    res_sim.loc[res_sim.query(f"probe == '{probe}'").index, "probe"] = probe_name
    res_sim_wfs.loc[res_sim_wfs.query(f"probe == '{probe}'").index, "probe"] = probe_name
    res_exp.loc[res_exp.query(f"probe == '{probe}'").index, "probe"] = probe_name
probes = ["NP1", "NP2"]

In [None]:
bit_cmap = plt.get_cmap("Purples_r")
wv_cmap = plt.get_cmap("Greens_r")

res_exp_wv = res_exp.query("strategy == 'wavpack'")
wv_order = [0] + list(np.sort(np.unique(res_exp_wv.factor))[::-1][:-1])
res_exp_bit = res_exp.query("strategy == 'bit_truncation'")
bit_order = list(np.sort(np.unique(res_exp_bit.factor)))
bit_labels = [int(b) for b in bit_order]
wv_labels = wv_order

res_sim_wv = res_sim.query("strategy == 'wavpack'")
res_sim_bit = res_sim.query("strategy == 'bit_truncation'")

### CR *versus* RMSE 

In [None]:
fig_7 = plt.figure(figsize=figsize_multi_2rows)

color = bit_cmap.name

gs = GridSpec(7, 2, hspace=0.1, wspace=0.3)
ax1_top = fig_7.add_subplot(gs[0, 0])
ax1_mid = fig_7.add_subplot(gs[1, 0])
ax1_bottom = fig_7.add_subplot(gs[2, 0])
ax12 = fig_7.add_subplot(gs[:3, 1])
ax21 = fig_7.add_subplot(gs[4:, 0])
ax22 = fig_7.add_subplot(gs[4:, 1])

axs = [ax1_top, ax1_mid, ax1_bottom, ax12, ax21, ax22]

# bit truncation
sns.pointplot(data=res_exp_bit, x="factor", y="CR", hue="probe", ax=ax1_top, errorbar="se", palette=color)
sns.pointplot(data=res_exp_bit, x="factor", y="CR", hue="probe", ax=ax1_mid, errorbar="se", palette=color)
sns.pointplot(data=res_exp_bit, x="factor", y="CR", hue="probe", ax=ax1_bottom, errorbar="se", palette=color)

top_lim = 600
mid_lims = [50, 322]
bottom_lims = [-0.1, 40]

ax1_top.set_ylim(bottom=600)
ax1_top.axhline(top_lim, ls="--", color="k", alpha=0.5)
ax1_top.set_xticks([])
ax1_top.set_xlabel("")
ax1_mid.set_ylim(mid_lims)
for lim in mid_lims:
    ax1_mid.axhline(lim, ls="--", color="k", alpha=0.5)
ax1_mid.set_xticks([])
ax1_mid.set_xlabel("")
ax1_bottom.set_ylim(bottom_lims)
ax1_bottom.axhline(bottom_lims[1], ls="--", color="k", alpha=0.5)

sns.despine(ax=ax1_bottom)
sns.despine(ax=ax1_top, bottom=True)
sns.despine(ax=ax1_mid, bottom=True)

d = .015
kwargs = dict(transform=ax1_top.transAxes, color='k', clip_on=False)
ax1_top.plot((-d, +d), (-d, +d), **kwargs)

kwargs = dict(transform=ax1_mid.transAxes, color='k', clip_on=False)
ax1_mid.plot((-d, +d), (-d, +d), **kwargs)
ax1_mid.plot((-d, +d), (1 - d, 1 + d), **kwargs)

kwargs.update(transform=ax1_bottom.transAxes)
ax1_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)

#remove one of the legend
ax1_mid.legend_.remove()
ax1_bottom.legend_.remove()

ax = ax1_bottom
ax.axhline(10, color="grey", ls="--")
ax.axhline(5, color="grey", ls="-.")
ax.text(-0.1, 10.3, "10", color="grey", fontsize=12)
ax.text(-0.1, 5.3, "5", color="grey", fontsize=12)
ax.set_xlabel("# bit")

ax1_top.set_ylabel("")
ax1_mid.set_ylabel("Compression Ratio", fontsize=15)
ax1_bottom.set_ylabel("")
ax1_bottom.set_xticklabels(bit_labels)

ax = ax21
sns.pointplot(data=res_exp_bit, x="factor", y="rmse", hue="probe", ax=ax, palette=color)
ax.axhline(1.5, color="grey", ls="--")
ax.text(-0.5, 1.7, "1.5", color="grey", fontsize=12)
ax.set_ylabel("RMSE ($\mu$V)")
ax.set_xticklabels(bit_labels)
ax.set_xlabel("# bit")

# wavpack
color = wv_cmap.name

ax = ax12
sns.pointplot(data=res_exp_wv, x="factor", y="CR", hue="probe", ax=ax, errorbar="se", 
              order=wv_order, palette=color)
ax.set_ylim(-0.1, 15)
ax.axhline(10, color="grey", ls="--")
ax.axhline(5, color="grey", ls="-.")
ax.text(-0.1, 10.3, "10", color="grey", fontsize=12)
ax.text(-0.1, 5.3, "5", color="grey", fontsize=12)
ax.set_xlabel("bps")

ax.set_ylabel("")
ax.set_xticklabels(wv_labels)

ax = ax22
sns.pointplot(data=res_exp_wv, x="factor", y="rmse", hue="probe", order=wv_order, ax=ax, palette=color)
ax.axhline(1.5, color="grey", ls="--")
ax.text(-0.5, 1.7, "1.5", color="grey", fontsize=12)
ax.set_xticklabels(wv_labels)
ax.set_xlabel("bps")
ax.set_ylabel("")


ax1_top.set_title("Bit truncation", fontsize=20)
ax12.set_title("WavPack Hybrid", fontsize=20)

prettify_axes(axs)

In [None]:
fig_s4 = plt.figure(figsize=figsize_multi_2rows)

color = bit_cmap.name

gs = GridSpec(7, 2, hspace=0.1, wspace=0.3)
ax1_top = fig_s4.add_subplot(gs[0, 0])
ax1_mid = fig_s4.add_subplot(gs[1, 0])
ax1_bottom = fig_s4.add_subplot(gs[2, 0])
ax12 = fig_s4.add_subplot(gs[:3, 1])
ax21 = fig_s4.add_subplot(gs[4:, 0])
ax22 = fig_s4.add_subplot(gs[4:, 1])

axs = [ax1_top, ax1_mid, ax1_bottom, ax12, ax21, ax22]

# bit truncation
sns.pointplot(data=res_sim_bit, x="factor", y="CR", hue="probe", ax=ax1_top, errorbar="se", palette=color)
sns.pointplot(data=res_sim_bit, x="factor", y="CR", hue="probe", ax=ax1_mid, errorbar="se", palette=color)
sns.pointplot(data=res_sim_bit, x="factor", y="CR", hue="probe", ax=ax1_bottom, errorbar="se", palette=color)

top_lim = 600
mid_lims = [50, 322]
bottom_lims = [-0.1, 33]

ax1_top.set_ylim(bottom=600)
ax1_top.axhline(top_lim, ls="-", color="k", alpha=0.2)
ax1_top.set_xticks([])
ax1_top.set_xlabel("")
ax1_mid.set_ylim(mid_lims)
for lim in mid_lims:
    ax1_mid.axhline(lim, ls="-", color="k", alpha=0.2)
ax1_mid.set_xticks([])
ax1_mid.set_xlabel("")
ax1_bottom.set_ylim(bottom_lims)
ax1_bottom.axhline(bottom_lims[1], ls="-", color="k", alpha=0.2)

sns.despine(ax=ax1_bottom)
sns.despine(ax=ax1_top, bottom=True)
sns.despine(ax=ax1_mid, bottom=True)

d = .015
kwargs = dict(transform=ax1_top.transAxes, color='k', clip_on=False)
ax1_top.plot((-d, +d), (-d, +d), **kwargs)

kwargs = dict(transform=ax1_mid.transAxes, color='k', clip_on=False)
ax1_mid.plot((-d, +d), (-d, +d), **kwargs)
ax1_mid.plot((-d, +d), (1 - d, 1 + d), **kwargs)

kwargs.update(transform=ax1_bottom.transAxes)
ax1_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)

#remove one of the legend
ax1_mid.legend_.remove()
ax1_bottom.legend_.remove()

ax = ax1_bottom
ax.axhline(10, color="grey", ls="--")
ax.axhline(5, color="grey", ls="-.")
ax.text(-0.1, 10.3, "10", color="grey", fontsize=12)
ax.text(-0.1, 5.3, "5", color="grey", fontsize=12)
ax.set_xlabel("# bit")

ax1_top.set_ylabel("")
ax1_mid.set_ylabel("Compression Ratio", fontsize=15)
ax1_bottom.set_ylabel("")
ax1_bottom.set_xticklabels(bit_labels)

ax = ax21
sns.pointplot(data=res_sim_bit, x="factor", y="rmse", hue="probe", ax=ax, palette=color)
ax.axhline(1.5, color="grey", ls="--")
ax.set_ylim(-0.1, 7)
ax.text(-0.5, 1.7, "1.5", color="grey", fontsize=12)
ax.set_ylabel("RMSE ($\mu$V)")
ax.set_xticklabels(bit_labels)
ax.set_xlabel("# bit")

# wavpack
color = wv_cmap.name

ax = ax12
sns.pointplot(data=res_sim_wv, x="factor", y="CR", hue="probe", ax=ax, errorbar="se", 
              order=wv_order, palette=color)

ax.axhline(10, color="grey", ls="--")
ax.axhline(5, color="grey", ls="-.")
ax.text(-0.1, 10.3, "10", color="grey", fontsize=12)
ax.text(-0.1, 5.3, "5", color="grey", fontsize=12)
ax.set_xlabel("bps")
ax.set_ylim(-0.1, 15)
ax.set_ylabel("")
ax.set_xticklabels(wv_labels)

ax = ax22
sns.pointplot(data=res_sim_wv, x="factor", y="rmse", hue="probe", order=wv_order, ax=ax, palette=color)
ax.axhline(1.5, color="grey", ls="--")
ax.text(-0.5, 1.7, "1.5", color="grey", fontsize=12)
ax.set_ylim(-0.1, 7)
ax.set_xticklabels(wv_labels)
ax.set_xlabel("bps")
ax.set_ylabel("")

ax1_top.set_title("Bit truncation", fontsize=20)
ax12.set_title("WavPack Hybrid", fontsize=20)

prettify_axes(axs)

In [None]:
if save_fig:
    fig_7.savefig(fig_folder / "fig7.pdf")
    fig_s4.savefig(fig_folder / "figS4.pdf")

## Spike sorting


### Simulated

In [None]:
fig_8 = plt.figure(figsize=figsize_multi_2rows) #, axs = plt.subplots(ncols=2, nrows=2, figsize=figsize_multi_2rows)
color = "Accent"

gs = fig_8.add_gridspec(9, 2, hspace=0.2, wspace=0.2)
ax11 = fig_8.add_subplot(gs[:4, 0])
ax12 = fig_8.add_subplot(gs[:4, 1])
ax21 = fig_8.add_subplot(gs[5:, 0])
ax22 = fig_8.add_subplot(gs[5:, 1])

axs = np.array([ax11, ax12, ax21, ax22]).reshape(2, 2)

for probe in probes:
    if "1" in probe:
        row = 0
    else:
        row = 1
    
    res_bit_probe = res_sim_bit.query(f"probe == '{probe}'")
    res_wv_probe = res_sim_wv.query(f"probe == '{probe}'")

    ax = axs[row, 0]
    df_perf = pd.melt(res_bit_probe, id_vars='factor', var_name='metric', value_name='value', 
                      value_vars=('accuracy', 'precision', 'recall'))
    sns.barplot(x='factor', y='value', hue='metric', data=df_perf,
                order=bit_order, ax=ax, palette=color)
    ax.set_xticklabels(bit_labels)
    ax.set_ylabel("")
    ax.set_xlabel("")
    ax.get_legend().remove()

    ax = axs[row, 1]
    df_perf = pd.melt(res_wv_probe, id_vars='factor', var_name='metric', value_name='value', 
                      value_vars=('accuracy', 'precision', 'recall'))
    sns.barplot(x='factor', y='value', hue='metric', data=df_perf,
                order=wv_order, ax=ax, palette=color)
    ax.set_xticklabels(wv_labels)
    ax.set_ylabel("")
    ax.set_xlabel("")
    ax.get_legend().remove()

axs[0, 0].set_title("Bit truncation", fontsize=18)
axs[0, 1].set_title("WavPack Hybrid", fontsize=18)
axs[1, 0].set_xlabel("# bit")
axs[1, 1].set_xlabel("bps")
axs[0, 0].set_ylabel("Avg. values")
axs[1, 0].set_ylabel("Avg. values")
axs[0, 0].legend(fontsize=15, loc=3)

prettify_axes(axs, label_fs=15)
fig_8.subplots_adjust(hspace=0.2, wspace=0.2, left=0.1)
_ = fig_8.text(0, 0.68, "NP1", transform=fig_8.transFigure, rotation=90, fontsize=20)
_ = fig_8.text(0, 0.26, "NP2", transform=fig_8.transFigure, rotation=90, fontsize=20)

In [None]:
if save_fig:
    fig_8.savefig(fig_folder / "fig8.pdf")

In [None]:
fig_9 = plt.figure(figsize=figsize_multi_2rows)
color = "Set2"

gs = fig_9.add_gridspec(9, 2, hspace=0.2, wspace=0.2)
ax11_top = fig_9.add_subplot(gs[:2, 0])
ax11_bottom = fig_9.add_subplot(gs[2:4, 0])
ax12 = fig_9.add_subplot(gs[:4, 1])
ax21_top = fig_9.add_subplot(gs[5:7, 0])
ax21_bottom = fig_9.add_subplot(gs[7:, 0])
ax22 = fig_9.add_subplot(gs[5:, 1])
ax11 = [ax11_bottom, ax11_top]
ax21 = [ax21_bottom, ax21_top]

axs = np.array([ax11, ax12, ax21, ax22]).reshape(2, 2)

ylims = {
    "NP1": [[-1, 110], 200],
    "NP2": [[-1, 110], 120],
}

for probe in probes:
    if "1" in probe:
        row = 0
    else:
        row = 1
    
    res_bit_probe = res_sim_bit.query(f"probe == '{probe}'")
    res_wv_probe = res_sim_wv.query(f"probe == '{probe}'")
    
    df_count = pd.melt(res_bit_probe, id_vars='factor', var_name='Type', value_name='Units', 
             value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
    ax_split = axs[row, 0]
    ax_bottom = ax_split[0]
    ax_top = ax_split[1]
    sns.barplot(x='factor', y='Units', hue='Type', data=df_count,
                order=bit_order, ax=ax_split[0], palette=color)
    sns.barplot(x='factor', y='Units', hue='Type', data=df_count,
                order=bit_order, ax=ax_split[1], palette=color)
    top_lim = ylims[probe][1]
    bottom_lims = ylims[probe][0]
    ax_split[1].set_ylim(bottom=top_lim)
    yticks = [top_lim + 10] + list(ax_top.get_yticks())
    ax_top.set_yticks(sorted(yticks)[1:])
    ax_top.axhline(top_lim, ls="-", color="k", alpha=0.2)
    ax_top.set_xticks([])
    ax_top.set_xlabel("")
    ax_top.set_ylabel("")
    ax_bottom.set_xlabel("")
    ax_bottom.set_ylabel("")
    ax_bottom.set_ylim(bottom_lims)
    ax_bottom.axhline(bottom_lims[1], ls="-", color="k", alpha=0.2)

    sns.despine(ax=ax_bottom)
    sns.despine(ax=ax_top, bottom=True)

    d = .015
    kwargs = dict(transform=ax_top.transAxes, color='k', clip_on=False)
    ax_top.plot((-d, +d), (-d, +d), **kwargs)
    kwargs.update(transform=ax_bottom.transAxes)
    ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)

    ax_top.get_legend().remove()
    ax_bottom.get_legend().remove()
    ax_bottom.axhline(100, color="grey", ls="--", lw=1)
    ax_bottom.set_xticklabels(bit_labels)

    ax = axs[row, 1]
    df_count = pd.melt(res_wv_probe, id_vars='factor', var_name='Type', value_name='Units', 
             value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
    sns.barplot(x='factor', y='Units', hue='Type', data=df_count,
                order=wv_order, ax=ax, palette=color)
    ax.set_xticklabels(wv_labels)
    ax.legend(loc=3)
    ax.set_ylabel("")
    ax.set_xlabel("")
    ax.axhline(100, color="grey", ls="--", lw=1)
    ax.get_legend().remove()

ax11_top.set_title("Bit truncation", fontsize=18)
ax12.set_title("WavPack Hybrid", fontsize=18)
ax21_bottom.set_xlabel("# bit")
ax22.set_xlabel("bps")
ax11_top.legend(fontsize=12)

axs = np.array([ax12, ax21_bottom, ax22])

fig_9.subplots_adjust(left=0.1)
_ = fig_9.text(0.04, 0.65, "# Units", transform=fig_9.transFigure, rotation=90, fontsize=15)
_ = fig_9.text(0.04, 0.22, "# Units", transform=fig_9.transFigure, rotation=90, fontsize=15)
_ = fig_9.text(0, 0.68, "NP1", transform=fig_9.transFigure, rotation=90, fontsize=20)
_ = fig_9.text(0, 0.25, "NP2", transform=fig_9.transFigure, rotation=90, fontsize=20)

prettify_axes(axs, label_fs=15)
fig_9.subplots_adjust(hspace=0.2, wspace=0.2)

In [None]:
if save_fig:
    fig_9.savefig(fig_folder / "fig9.pdf")

### Experimental

In [None]:
fig_10 = plt.figure(figsize=figsize_multi_2rows)

color = "Set1"

gs = fig_10.add_gridspec(9, 2, hspace=0.1, wspace=0.3)
ax11_top = fig_10.add_subplot(gs[:2, 0])
ax11_bottom = fig_10.add_subplot(gs[2:4, 0])
ax12 = fig_10.add_subplot(gs[:4, 1])
ax21 = fig_10.add_subplot(gs[5:, 0])
ax22 = fig_10.add_subplot(gs[5:, 1])
ax11 = [ax11_bottom, ax11_top]

axs = np.array([ax11, ax12, ax21, ax22]).reshape(2, 2)

xlim_bit = [-0.5, 7.5]
ylim = [0, 1.5]
top_lim = 2

for probe in probes:
    if "1" in probe:
        row = 0
    else:
        row = 1
    
    res_bit_probe = res_exp_bit.query(f"probe == '{probe}'")
    res_wv_probe = res_exp_wv.query(f"probe == '{probe}'")
    
    # normalize with respect to lossless num units
    for session in np.unique(res_bit_probe.session):
        res_bit_session = res_bit_probe.query(f"session == '{session}'")
        res_bit_0 = res_bit_session.query("factor == 0")
        n_bit_good_units_0 = res_bit_0["n_curated_good_units"].values[0]
        n_bit_bad_units_0 = res_bit_0["n_curated_bad_units"].values[0]
        
        res_bit_probe.loc[res_bit_session.index, "% good units"] = \
                res_bit_session["n_curated_good_units"] / n_bit_good_units_0
        res_bit_probe.loc[res_bit_session.index, "% bad units"] = \
                res_bit_session["n_curated_bad_units"] / n_bit_bad_units_0

    for session in np.unique(res_wv_probe.session):
        res_wv_session = res_wv_probe.query(f"session == '{session}'")
        res_wv_0 = res_wv_session.query("factor == 0")
        n_wv_good_units_0 = res_wv_0["n_curated_good_units"].values[0]
        n_wv_bad_units_0 = res_wv_0["n_curated_bad_units"].values[0]
        
        res_wv_probe.loc[res_wv_session.index, "% good units"] = \
                res_wv_session["n_curated_good_units"] / n_wv_good_units_0
        res_wv_probe.loc[res_wv_session.index, "% bad units"] = \
                res_wv_session["n_curated_bad_units"] / n_wv_bad_units_0

    # bit truncation
    df_bit = pd.melt(res_bit_probe, id_vars='factor', var_name='Type', value_name='Units', 
                     value_vars=('% good units', '% bad units'))
    if probe == "NP1" and row == 0:
        ax_split = axs[row, 0]
        ax_bottom = ax_split[0]
        ax_top = ax_split[1]
        sns.barplot(x='factor', y='Units', hue='Type', data=df_bit,
                    order=bit_order, ax=ax_top, palette=color)
        sns.barplot(x='factor', y='Units', hue='Type', data=df_bit,
                    order=bit_order, ax=ax_bottom, palette=color)
        ax_top.set_ylim(bottom=top_lim)
        ax_top.axhline(top_lim, ls="-", color="k", alpha=0.2)
        ax_top.set_xticks([])
        ax_top.set_xlabel("")
        ax_top.set_ylabel("")
        ax_bottom.set_xlabel("")
        ax_bottom.set_ylabel("")
        ax_bottom.set_ylim(ylim)
        ax_bottom.axhline(ylim[1], ls="-", color="k", alpha=0.2)

        sns.despine(ax=ax_bottom)
        sns.despine(ax=ax_top, bottom=True)

        d = .015
        kwargs = dict(transform=ax_top.transAxes, color='k', clip_on=False)
        ax_top.plot((-d, +d), (-d, +d), **kwargs)
        kwargs.update(transform=ax_bottom.transAxes)
        ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)

        ax_top.get_legend().remove()
        ax_bottom.axhline(100, color="grey", ls="--", lw=1)
        ax_bottom.set_xticklabels(bit_labels)
        ax_top.set_xlim(xlim_bit)
        ax = ax_bottom
    else:
        ax = axs[row, 0]
        sns.barplot(x='factor', y='Units', hue='Type', data=df_bit,
                    order=bit_order, ax=ax, palette=color)
        ax.set_xticklabels(bit_labels)

    ax.get_legend().remove()
    ax.set_ylabel("")
    ax.set_xlabel("")
    ax.set_xlim(xlim_bit)
    ax.set_ylim(ylim)
    n_good_lossless = 1
    ax.axhline(n_good_lossless, color="grey", ls="--", lw=2)
    
    ax = axs[row, 1]
    df_wv = pd.melt(res_wv_probe, id_vars='factor', var_name='Type', value_name='Units', 
                    value_vars=('% good units', '% bad units'))
    sns.barplot(x='factor', y='Units', hue='Type', data=df_wv,
                order=wv_order, ax=ax, palette=color)
    ax.set_ylabel("")
    ax.set_xlabel("")
    ax.set_ylim(ylim)
    ax.axhline(n_good_lossless, color="grey", ls="--")
    ax.get_legend().remove()

               
ax11_top.set_title("Bit truncation", fontsize=18)
ax12.set_title("WavPack Hybrid", fontsize=18)
ax21.set_xlabel("# bit")
ax22.set_xlabel("bps")
ax11_top.legend(fontsize=15, loc=2)

axs = np.array([ax12, ax21, ax22])

prettify_axes(axs, label_fs=20)

fig_10.subplots_adjust(left=0.1)
_ = fig_10.text(0.04, 0.63, "% Units", transform=fig_10.transFigure, rotation=90, fontsize=15)
_ = fig_10.text(0.04, 0.2, "% Units", transform=fig_10.transFigure, rotation=90, fontsize=15)
_ = fig_10.text(0, 0.65, "NP1", transform=fig_10.transFigure, rotation=90, fontsize=20)
_ = fig_10.text(0, 0.22, "NP2", transform=fig_10.transFigure, rotation=90, fontsize=20)

In [None]:
if save_fig:
    fig_10.savefig(fig_folder / "fig10.pdf")

## Waveform features

Compute dataframe with relative errors

In [None]:
selected_distances = [0, 60]
metrics = ["peak_to_valley", "half_width", "peak_trough_ratio"]

In [None]:
template_metrics = metrics

res_wfs_dist = res_sim_wfs.query(f"distance in {selected_distances}")
df_errors = None
for bit in bit_order[1:]:
    strategy = "bit_truncation"
    new_e_df = res_wfs_dist[["probe", "unit_id", "distance"]].copy()
    new_e_df.loc[:, "strategy"] = [strategy] * len(new_e_df)
    new_e_df.loc[:, "factor"] = [bit] * len(new_e_df)

    for metric in template_metrics:
        metric_gt = f"{metric}_gt"
        metric_tested = f"{metric}_{strategy}_{int(bit)}"
        error = np.abs(res_sim_wfs[metric_tested] - res_sim_wfs[metric_gt]) / np.abs(res_sim_wfs[metric_gt])
        new_e_df.loc[:, f"err_{metric}"] = error
        
    if df_errors is None:
        df_errors = new_e_df
    else:
        df_errors = pd.concat([df_errors, new_e_df])
        
for wv in wv_order[1:]:
    strategy = "wavpack"
    new_e_df = res_wfs_dist[["probe", "unit_id", "distance"]].copy()
    new_e_df.loc[:, "strategy"] = [strategy] * len(new_e_df)
    new_e_df.loc[:, "factor"] = [wv] * len(new_e_df)

    for metric in template_metrics:
        metric_gt = f"{metric}_gt"
        metric_tested = f"{metric}_{strategy}_{wv}"
        error = np.abs(res_sim_wfs[metric_tested] - res_sim_wfs[metric_gt]) / np.abs(res_sim_wfs[metric_gt])
        new_e_df.loc[:, f"err_{metric}"] = error
        
    if df_errors is None:
        df_errors = new_e_df
    else:
        df_errors = pd.concat([df_errors, new_e_df])

In [None]:
fig_11 = plt.figure(figsize=figsize_multi_3rows)
multiplier = 4
gs = fig_11.add_gridspec((multiplier + 1) * 4, len(template_metrics), hspace=0.5, wspace=0.2, left=0.13)

axs = []
for p_i in range(4):
    ax_cols = []
    for t_i in range(len(template_metrics)):
        if p_i % 2 == 0:
            shift = multiplier // 2
            ax1 = fig_11.add_subplot(gs[multiplier * p_i + p_i:multiplier*p_i + shift + p_i, t_i])
            ax2 = fig_11.add_subplot(gs[multiplier * p_i + shift + p_i:
                                        multiplier * p_i + multiplier + p_i, t_i])
            ax = [ax1, ax2]
        else:
            ax = fig_11.add_subplot(gs[multiplier * p_i + p_i:multiplier * p_i + multiplier + p_i, t_i])
        ax_cols.append(ax)
    axs.append(ax_cols)

ylims = [(-0.01, 0.2), 0.2]

for probe in probes:
    if "1" in probe:
        row = 0
    else:
        row = 2
    for i, metric in enumerate(template_metrics):
        df_errors_bit = df_errors.query("strategy == 'bit_truncation'")
        color = bit_cmap.name

        # bit truncation
        axs_bit = axs[row][i]
        if isinstance(axs_bit, list):
            ax_top = axs_bit[0]
            ax_bottom = axs_bit[1]
            sns.boxenplot(data=df_errors_bit.query(f"probe == '{probe}'"), 
                          x="factor", y=f"err_{metric}", hue="distance",
                          order=bit_order[1:], showfliers=False, ax=ax_top,
                          palette=color)
            sns.boxenplot(data=df_errors_bit.query(f"probe == '{probe}'"), 
                          x="factor", y=f"err_{metric}", hue="distance",
                          order=bit_order[1:], showfliers=False, ax=ax_bottom,
                          palette=color)
            ax_top.set_ylim(bottom=ylims[1])
            ax_top.axhline(ylims[1], ls="-", color="k", alpha=0.2)
            ax_top.set_xticks([])
            ax_top.set_xlabel("")
            ax_top.set_ylabel("")
            ax_bottom.set_xlabel("")
            ax_bottom.set_ylabel("")
            ax_bottom.set_ylim(ylims[0])
            ax_bottom.axhline(ylims[0][1], ls="-", color="k", alpha=0.2)

            sns.despine(ax=ax_bottom)
            sns.despine(ax=ax_top, bottom=True)

            d = .015
            kwargs = dict(transform=ax_top.transAxes, color='k', clip_on=False)
            ax_top.plot((-d, +d), (-d, +d), **kwargs)
            kwargs.update(transform=ax_bottom.transAxes)
            ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
            
            ax_top.get_legend().remove()
            ax_bottom.axhline(100, color="grey", ls="--", lw=1)
            ax_bottom.set_xticklabels(bit_labels[1:])
            ax = ax_bottom
        else:
            ax = axs_bit
            sns.boxenplot(data=df_errors_bit.query(f"probe == '{probe}'"), 
                          x="factor", y=f"err_{metric}", hue="distance",
                          order=bit_order[1:], showfliers=False, ax=ax_bottom,
                          palette=color)
            
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_xticklabels(bit_labels[1:])
        ax.set_xlabel("# bit")
        ax.set_title("")
        ax.axhline(0.1, color="grey", ls="--")
        ax.text(0, 0.11, "10%", fontsize=10, color="grey")
        ax.set_ylim(ylims[0])
        ax.get_legend().remove()
        

        df_errors_wv = df_errors.query("strategy == 'wavpack'")
        color = wv_cmap.name
        ax = axs[row + 1][i]
        sns.boxenplot(data=df_errors_wv.query(f"probe == '{probe}'"), 
                      x="factor", y=f"err_{metric}", hue="distance",
                      order=wv_order[1:], showfliers=False, ax=ax,
                      palette=color)
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_xticklabels(wv_labels[1:])
        ax.set_xlabel("bps")
        ax.set_title("")
        ax.axhline(0.1, color="grey", ls="--")
        ax.text(0, 0.11, "10%", fontsize=10, color="grey")
        ax.set_ylim(ylims[0])
        ax.get_legend().remove()

# add legends
ax_legends = [axs[0][0][0], axs[1][0]]
for ax_leg in ax_legends:
    handles, labels = ax_leg.get_legend_handles_labels()
    ax_leg.legend(handles, ["0$\mu$m", "60$\mu$m"], loc=2, fontsize=12)

axs[0][0][0].set_title("Peak-to-valley", fontsize=18)
axs[0][1][0].set_title("Half-width", fontsize=18)
axs[0][2][0].set_title("Peak-trough ratio", fontsize=18)
        
_ = fig_11.text(0.07, 0.76, "Bit truncation", transform=fig_11.transFigure, rotation=90, fontsize=12)
_ = fig_11.text(0.07, 0.56, "WavPack Hybrid", transform=fig_11.transFigure, rotation=90, fontsize=12)
_ = fig_11.text(0.07, 0.37, "Bit truncation", transform=fig_11.transFigure, rotation=90, fontsize=12)
_ = fig_11.text(0.07, 0.17, "WavPack Hybrid", transform=fig_11.transFigure, rotation=90, fontsize=12)

_ = fig_11.text(0.035, 0.64, "Relative Errors", transform=fig_11.transFigure, rotation=90, fontsize=15)
_ = fig_11.text(0.035, 0.26, "Relative Errors", transform=fig_11.transFigure, rotation=90, fontsize=15)

_ = fig_11.text(0, 0.68, "NP1", transform=fig_11.transFigure, rotation=90, fontsize=20)
_ = fig_11.text(0, 0.30, "NP2", transform=fig_11.transFigure, rotation=90, fontsize=20)

axs_all = []
for ax_row in axs:
    for ax in ax_row:
        if isinstance(ax, list):
            [axs_all.append(ax_i) for ax_i in ax]
        else: 
            axs_all.append(ax)
prettify_axes(axs_all, label_fs=12)


In [None]:
if save_fig:
    fig_11.savefig(fig_folder / "fig11.pdf")

## Plot templates

In [None]:
import spikeinterface.widgets as sw

# load GT waveforms
gt_dict = {}
bit_dict = {}
wv_dict = {}
dsets = ["NP1", "NP2"]
for dset in dsets:
    wfs_folder = results_lossy_sim_folder / f"gt-{dset}" / "waveforms"
    we = si.load_waveforms(wfs_folder, with_recording=False)
    gt_dict[dset] = we
    bit_dict[dset] = {}
    wv_dict[dset] = {}

    for bit in bit_order[1:]:
        wfs_folder = results_lossy_sim_folder / f"waveforms-{dset}-bit_truncation" / f"wf_lossy_bit_truncation_{int(bit)}"
        we = si.load_waveforms(wfs_folder, with_recording=False)
        bit_dict[dset][int(bit)] = we
    for bps in wv_order[1:]:
        wfs_folder = results_lossy_sim_folder / f"waveforms-{dset}-wavpack" / f"wf_lossy_wavpack_{bps}"
        we = si.load_waveforms(wfs_folder, with_recording=False)
        wv_dict[dset][bps] = we

Plot template in the center, and single channels on subplots (main channel, suburb channel)

In [None]:
n_bits = len(bit_order) + 2
n_bps = len(wv_order) + 2

In [None]:
# unit ids
unit_ids = dict(
    NP1="#10",
    NP2="#30"
)
dist = 60
radius_um = 100

gt_lw = 3

fig_S5_panels = {}

for dset in dsets:
    fig_S5_panels[dset] = {}
    unit_id = unit_ids[dset]
    we_gt = gt_dict[dset]
    extremum_channels_inds = si.get_template_extremum_channel(we_gt, outputs="index")
    main_channel_ind = extremum_channels_inds[unit_id]
    locs = we_gt.get_channel_locations()
    distances = np.linalg.norm(locs - locs[main_channel_ind], axis=1)
    order = np.argsort(distances)
    idx = np.where(distances[order] > dist)[0][0]
    last_channel_ind = order[idx]
    sparsity = si.compute_sparsity(we_gt, method="radius", radius_um=radius_um)

    w_circles = sw.plot_unit_templates(we_gt, sparsity=sparsity, unit_ids=[unit_id], same_axis=True, 
                                       unit_colors={unit_id : "k"})
    w = sw.plot_unit_templates(we_gt, sparsity=sparsity, unit_ids=[unit_id], same_axis=True, 
                               unit_colors={unit_id : "k"})
    w.legend.remove()
    w.ax.set_title("")
    w.ax.axis("off")
    w_circles.ax.plot(*locs[main_channel_ind], "o", markersize=10, markerfacecolor=None)
    w_circles.ax.plot(*locs[last_channel_ind], "o", markersize=10, markerfacecolor=None)
    
    fig_bit_main, ax_bit_main = plt.subplots(ncols=1, figsize=figsize_multi_2rows)
    fig_bit_last, ax_bit_last = plt.subplots(ncols=1, figsize=figsize_multi_2rows)

    
    ts = np.arange(-we.nbefore, we.nafter) / we.sampling_frequency * 1000
    template_main_gt = we_gt.get_template(unit_id)[:, main_channel_ind]
    ax_bit_main.plot(ts, template_main_gt, color="k", label="GT", lw=gt_lw)
    template_last_gt = we_gt.get_template(unit_id)[:, last_channel_ind]
    ax_bit_last.plot(ts, template_last_gt, color="k", label="GT", lw=gt_lw)
    for i, (bit, we) in enumerate(bit_dict[dset].items()):
        color = bit_cmap(i / n_bits)
        template = we.get_template(unit_id)
        template_main = template[:, main_channel_ind]
        ax_bit_main.plot(ts, template_main, color=color, label=int(bit))
        template_last = template[:, last_channel_ind]
        ax_bit_last.plot(ts, template_last, color=color, label=int(bit))

    # legend
    ax_bit_main.legend(ncol=4, fontsize=12)
    ax_bit_last.legend(ncol=4, fontsize=12)

    # scale bar
    ax_bit_main.plot([-3, -2], [-50, -50], color="k", lw=3)
    ax_bit_main.plot([-3, -3], [-50, -40], color="k", lw=3)
    ax_bit_main.text(-2.7, -60, "1ms", fontsize=12)
    ax_bit_main.text(-2.9, -45, "10$\mu$V", fontsize=12)

    ax_bit_last.plot([-3, -2], [-20, -20], color="k", lw=3)
    ax_bit_last.plot([-3, -3], [-20, -10], color="k", lw=3)
    ax_bit_last.text(-2.7, -19, "1ms", fontsize=12)
    ax_bit_last.text(-2.9, -15, "10$\mu$V", fontsize=12)
    
    ax_bit_main.axis("off")
    ax_bit_last.axis("off")  

    
    fig_wv_main, ax_wv_main = plt.subplots(ncols=1, figsize=figsize_multi_2rows)
    fig_wv_last, ax_wv_last = plt.subplots(ncols=1, figsize=figsize_multi_2rows)

    n_bps = len(wv_dict[dset]) + 2
    ts = np.arange(-we.nbefore, we.nafter) / we.sampling_frequency * 1000
    template_main_gt = we_gt.get_template(unit_id)[:, main_channel_ind]
    ax_wv_main.plot(ts, template_main_gt, color="k", label="GT", lw=gt_lw)
    template_last_gt = we_gt.get_template(unit_id)[:, last_channel_ind]
    ax_wv_last.plot(ts, template_last_gt, color="k", label="GT", lw=gt_lw)
    for i, (bps, we) in enumerate(wv_dict[dset].items()):
        color = wv_cmap(i / n_bps)
        template = we.get_template(unit_id)
        template_main = template[:, main_channel_ind]
        ax_wv_main.plot(ts, template_main, color=color, label=bps)
        template_last = template[:, last_channel_ind]
        ax_wv_last.plot(ts, template_last, color=color, label=bps)
    
    # legend
    ax_wv_main.legend(ncol=4, fontsize=12)
    ax_wv_last.legend(ncol=4, fontsize=12)

    # scale bar
    ax_wv_main.plot([-3, -2], [-50, -50], color="k", lw=3)
    ax_wv_main.plot([-3, -3], [-50, -40], color="k", lw=3)
    ax_wv_main.text(-2.7, -60, "1ms", fontsize=12)
    ax_wv_main.text(-2.9, -45, "10$\mu$V", fontsize=12)

    ax_wv_last.plot([-3, -2], [-20, -20], color="k", lw=3)
    ax_wv_last.plot([-3, -3], [-20, -10], color="k", lw=3)
    ax_wv_last.text(-2.7, -19, "1ms", fontsize=12)
    ax_wv_last.text(-2.9, -15, "10$\mu$V", fontsize=12)

    ax_wv_main.axis("off")
    ax_wv_last.axis("off")
    
    fig_S5_panels[dset]["gt"] = w.figure
    fig_S5_panels[dset]["bit_main"] = fig_bit_main
    fig_S5_panels[dset]["bit_lat"] = fig_bit_last
    fig_S5_panels[dset]["wv_main"] = fig_wv_main
    fig_S5_panels[dset]["wv_last"] = fig_wv_last

In [None]:
if save_fig:
    for fig_dset in fig_S5_panels:
        for fig_name, fig in fig_S5_panels[fig_dset].items():
            fig.savefig(fig_folder / f"figS5-{fig_dset}-{fig_name}.pdf")

In [None]:
# PLOT distributions with different colors
fig_S6_panels = {}

strategy = "bit_truncation"
for probe in probes:
    df_wfs_probe = res_sim_wfs.query(f"probe == '{probe}'")
    fig_m, axs_m = plt.subplots(nrows=len(template_metrics), ncols=len(selected_distances),
                                figsize=figsize_multi_2rows)
    for b, bit in enumerate(bit_order[1:]):
        color = bit_cmap(b / n_bits)
        for i, metric in enumerate(template_metrics):
            for j, dist in enumerate(selected_distances):
                tm_dist = df_wfs_probe.query(f"distance == {dist}")
                gt_metric_name = f"{metric}_gt"
                sns.scatterplot(data=tm_dist, x=gt_metric_name, y=f"{metric}_{strategy}_{int(bit)}", 
                                ax=axs_m[i, j], color=color, label=int(bit))
                zorder = len(bit_order[1:]) - b
                plt.setp(axs_m[i, j].collections[-1], zorder=zorder)
                axs_m[i, j].set_yticks([])
                axs_m[i, j].set_xticks([])
                axs_m[i, j].set_xlabel("")
                if j == len(selected_distances) - 1:
                    axs_m[i, j].set_ylabel("")
                lims = [np.min(tm_dist[gt_metric_name]) - 0.2 * np.ptp(tm_dist[gt_metric_name]), 
                        np.max(tm_dist[gt_metric_name]) + 0.2 * np.ptp(tm_dist[gt_metric_name])]
                axs_m[i, j].plot(lims, lims, color="grey", alpha=0.7, ls="--")
                if i == 0:
                    axs_m[i, j].set_title(f"Distance: {int(dist)} $\mu$m")
                if i == len(template_metrics) - 1:
                    axs_m[i, j].set_xlabel("GT metric values")                       
                ext = lims[1] - lims[0]
                axs_m[i, j].set_xlim(lims[0] - 0.5*ext, lims[1] + 0.5*ext)
                axs_m[i, j].set_ylim(lims[0] - ext, lims[1] + ext)
                axs_m[i, j].legend().remove()
    axs_m[0, 0].set_ylabel("Peak-to-valley")
    axs_m[1, 0].set_ylabel("Half-width")
    axs_m[2, 0].set_ylabel("Peak-trough ratio")
    axs_m[0, 1].set_ylabel("")
    axs_m[1, 1].set_ylabel("")
    axs_m[2, 1].set_ylabel("")
    axs_m[0, 0].legend()
    prettify_axes(axs_m, label_fs=12)

    fig_m.suptitle(f"{probe} - Bit truncation", fontsize=20)
    fig_m.subplots_adjust(wspace=0.1, hspace=0.3)
    fig_S6_panels[f"fig_S6_features_{probe}_{strategy}"] = fig_m
    
strategy = "wavpack"
for probe in probes:
    df_wfs_probe = res_sim_wfs.query(f"probe == '{probe}'")
    fig_m, axs_m = plt.subplots(nrows=len(template_metrics), ncols=len(selected_distances),
                                figsize=figsize_multi_2rows)
    for b, bps in enumerate(wv_order[1:]):
        color = wv_cmap(b / n_bits)
        for i, metric in enumerate(template_metrics):
            for j, dist in enumerate(selected_distances):
                tm_dist = df_wfs_probe.query(f"distance == {dist}")
                gt_metric_name = f"{metric}_gt"
                sns.scatterplot(data=tm_dist, x=gt_metric_name, y=f"{metric}_{strategy}_{bps}", 
                                ax=axs_m[i, j], color=color, label=bps)
                zorder = len(wv_order[1:]) - b
                plt.setp(axs_m[i, j].collections[-1], zorder=zorder)
                axs_m[i, j].set_yticks([])
                axs_m[i, j].set_xticks([])
                axs_m[i, j].set_xlabel("")
                if j == len(selected_distances) - 1:
                    axs_m[i, j].set_ylabel("")
                lims = [np.min(tm_dist[gt_metric_name]) - 0.2 * np.ptp(tm_dist[gt_metric_name]), 
                        np.max(tm_dist[gt_metric_name]) + 0.2 * np.ptp(tm_dist[gt_metric_name])]
                axs_m[i, j].plot(lims, lims, color="grey", alpha=0.7, ls="--")
                if i == 0:
                    axs_m[i, j].set_title(f"Distance: {int(dist)} $\mu$m")
                if i == len(template_metrics) - 1:
                    axs_m[i, j].set_xlabel("GT metric values")                    
                ext = lims[1] - lims[0]
                axs_m[i, j].set_xlim(lims[0] - 0.5*ext, lims[1] + 0.5*ext)
                axs_m[i, j].set_ylim(lims[0] - ext, lims[1] + ext)
                axs_m[i, j].legend().remove()
                
    axs_m[0, 0].set_ylabel("Peak-to-valley")
    axs_m[1, 0].set_ylabel("Half-width")
    axs_m[2, 0].set_ylabel("Peak-trough ratio")
    axs_m[0, 1].set_ylabel("")
    axs_m[1, 1].set_ylabel("")
    axs_m[2, 1].set_ylabel("")
    axs_m[0, 0].legend()
    prettify_axes(axs_m, label_fs=12)

    fig_m.suptitle(f"{probe} - Wavpack Hybrid", fontsize=20)
    fig_m.subplots_adjust(wspace=0.1, hspace=0.3)
    fig_S6_panels[f"fig_S6_features_{probe}_{strategy}"] = fig_m

In [None]:
if save_fig:
    for fig_name, fig in fig_S6_panels.items():
        fig.savefig(fig_folder / f"figS6-{fig_name}.pdf")