# Benchmark lossy strategies on experimental data

In this notebook we analyze how lossy compression affects experimental datasets.

We use two different strategies:

- Bit truncation
- WavPack hybrid mode

The analysis focuses:

* compression performance
* influence on spike sorting results

This notebook assumes the `ephys-compression/scripts/benchmark-lossy-nogt.py` has been run and the `ephys-compression/data/results/benchmark-lossy-nogt.csv` is available.

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import sys
from pathlib import Path

import spikeinterface.full as si

sys.path.append("..")

from audio_numcodecs import WavPackCodec
from utils import prettify_axes

%matplotlib notebook

In [None]:
save_fig = True

fig_folder = Path(".") / "figures" / "lossy-nogt"
fig_folder.mkdir(exist_ok=True, parents=True)

In [None]:
data_folder = Path("../data")

In [None]:
res = pd.read_csv(data_folder / "results" / "benchmark-lossy-nogt.csv", index_col=False)

In [None]:
job_kwargs = {"n_jobs": 10, "chunk_duration": "1s", "progress_bar": True}

In [None]:
# # plot some traces
res_np1 = res.query("probe == 'Neuropixels1.0'")
res_np2 = res.query("probe == 'Neuropixels2.0'")

In [None]:
res.head()

In [None]:
res_wv = res.query("strategy == 'wavpack'")
wv_order = [0] + list(np.sort(np.unique(res_wv.factor))[::-1][:-1])
res_bit = res.query("strategy == 'bit_truncation'")
bit_order = list(np.sort(np.unique(res_bit.factor)))

In [None]:
bit_cmap = plt.get_cmap("tab10")
wv_cmap = plt.get_cmap("tab10")

bit_colors = {}
for b, bit in enumerate(bit_order):
    bit_colors[bit] = bit_cmap(b / len(bit_order))
bit_colors[0] = "k"
wv_colors = {}
for w, wv in enumerate(wv_order):
    wv_colors[wv] = wv_cmap(w / len(wv_order))
wv_colors[0] = "k"

In [None]:
probes = np.unique(res.probe)

# CR - RMSE

In [None]:
fig_cr_rmse, axs_cr_rmse = plt.subplots(ncols=2, nrows=2, figsize=(15, 10))

# bit truncation
ax = axs_cr_rmse[0, 0]
sns.pointplot(data=res_bit, x="factor", y="CR", hue="probe", ax=ax,
              order=bit_order)
ax.set_ylim(0, 50)
ax.axhline(10, color="grey", ls="--")
ax.axhline(5, color="grey", ls="-.")
ax.text(-0.5, 10.2, "10", color="grey", fontsize=12)
ax.text(-0.5, 5.2, "5", color="grey", fontsize=12)
ax.set_ylabel("Bit truncation\nCR")

ax = axs_cr_rmse[0, 1]
sns.pointplot(data=res_bit, x="factor", y="rmse", hue="probe", ax=ax,
              order=bit_order)
ax.set_ylim(-0.5, 30)
ax.axhline(1.5, color="grey", ls="--")
ax.text(-0.5, 1.55, "1.5", color="grey", fontsize=12)

ax = axs_cr_rmse[1, 0]
sns.pointplot(data=res_wv, x="factor", y="CR", hue="probe", ax=ax,
              order=wv_order)
ax.set_ylim(0, 15)
ax.axhline(10, color="grey", ls="--")
ax.axhline(5, color="grey", ls="-.")
ax.text(-0.5, 10.2, "10", color="grey", fontsize=12)
ax.text(-0.5, 5.2, "5", color="grey", fontsize=12)
ax.set_ylabel("WavPack hybrid\nCR")

ax = axs_cr_rmse[1, 1]
sns.pointplot(data=res_wv, x="factor", y="rmse", hue="probe", ax=ax,
              order=wv_order)
ax.set_ylim(-0.5, 10)
ax.axhline(1.5, color="grey", ls="--")
ax.text(-0.5, 1.55, "1.5", color="grey", fontsize=12)

prettify_axes(axs_cr_rmse)

# fig_bit_cr_rmse.suptitle("Bit truncation", fontsize=20)
fig_cr_rmse.subplots_adjust(hspace=0.3)

In [None]:
if save_fig:
    fig_cr_rmse.savefig(fig_folder / "cr_rmse.pdf")

# Traces

In [None]:
fig_bit_traces, axs = plt.subplots(nrows=2, ncols=2, sharex=False, sharey=False, figsize=(15, 10))

alpha = 0.7
channel_ids1 = ["AP241"]
channel_ids2 = ["AP21"]
nsec = 0.004
t_start1 = 30.056
t_start2 = 30.188

time_range1 = [t_start1, t_start1 + nsec]
time_range2 = [t_start2, t_start2 + nsec]

lw_gt = 3

for bit in bit_order:
    for probe in probes:
        row = res_bit.query(f"factor == {bit} and probe == '{probe}'").iloc[0]
        factor = row["factor"]
        strategy = row["strategy"]
        rec_zarr = si.read_zarr(row["rec_zarr_path"])
        rec_f = si.bandpass_filter(rec_zarr)
        if "1" in probe:
            ax_idx = 0
            channel_ids = channel_ids1
            time_range = time_range1
        else:
            ax_idx = 1
            channel_ids = channel_ids2
            time_range = time_range2
            
        fs = rec_zarr.get_sampling_frequency()
        timestamps = np.linspace(time_range[0], time_range[1], int(nsec * fs))
        start_frame = int(time_range[0] * fs)
        end_frame = int(time_range[1] * fs)

        traces = rec_zarr.get_traces(start_frame=start_frame, end_frame=end_frame,
                                     channel_ids=channel_ids, return_scaled=True)[:, 0]
        traces_f = rec_f.get_traces(start_frame=start_frame, end_frame=end_frame,
                                    channel_ids=channel_ids, return_scaled=True)[:, 0]

        axs[0, ax_idx].plot(timestamps, traces, color=bit_colors[factor], alpha=alpha, label=f"bit{factor}")
        axs[1, ax_idx].plot(timestamps, traces_f, color=bit_colors[factor], alpha=alpha, label=f"bit{factor}") 
        

axs[0, 0].set_title("Neuropixels 1.0\nRaw")
axs[0, 1].set_title("Neuropixels 2.0\nRaw")
axs[1, 0].set_title("Filtered")
axs[1, 1].set_title("Filtered")
axs[0, 0].legend(ncol=4)
axs[0, 1].legend(ncol=4)

axs[0, 0].set_ylabel("V ($\mu$V)")
axs[1, 0].set_ylabel("V ($\mu$V)")
axs[1, 0].set_xlabel("time (s)")
axs[1, 1].set_xlabel("time (s)")

prettify_axes(axs)
fig_bit_traces.suptitle("Bit truncation - traces", fontsize=15)

In [None]:
fig_wv_traces, axs = plt.subplots(nrows=2, ncols=2, sharex=False, sharey=False, figsize=(15, 10))

alpha = 0.7
channel_ids1 = ["AP241"]
channel_ids2 = ["AP21"]
nsec = 0.004
t_start1 = 30.056
t_start2 = 30.188

time_range1 = [t_start1, t_start1 + nsec]
time_range2 = [t_start2, t_start2 + nsec]

lw_gt = 3

for wv in wv_order:
    for probe in probes:
        row = res_wv.query(f"factor == {wv} and probe == '{probe}'").iloc[0]
        factor = row["factor"]
        strategy = row["strategy"]
        rec_zarr = si.read_zarr(row["rec_zarr_path"])
        rec_f = si.bandpass_filter(rec_zarr)
        if "1" in probe:
            ax_idx = 0
            channel_ids = channel_ids1
            time_range = time_range1
        else:
            ax_idx = 1
            channel_ids = channel_ids2
            time_range = time_range2
            
        fs = rec_zarr.get_sampling_frequency()
        timestamps = np.linspace(time_range[0], time_range[1], int(nsec * fs))
        start_frame = int(time_range[0] * fs)
        end_frame = int(time_range[1] * fs)

        traces = rec_zarr.get_traces(start_frame=start_frame, end_frame=end_frame,
                                     channel_ids=channel_ids, return_scaled=True)[:, 0]
        traces_f = rec_f.get_traces(start_frame=start_frame, end_frame=end_frame,
                                    channel_ids=channel_ids, return_scaled=True)[:, 0]

        axs[0, ax_idx].plot(timestamps, traces, color=wv_colors[factor], alpha=alpha, label=f"hf{factor}")
        axs[1, ax_idx].plot(timestamps, traces_f, color=wv_colors[factor], alpha=alpha, label=f"hf{factor}") 
        

axs[0, 0].set_title("Neuropixels 1.0\nRaw")
axs[0, 1].set_title("Neuropixels 2.0\nRaw")
axs[1, 0].set_title("Filtered")
axs[1, 1].set_title("Filtered")
axs[0, 0].legend(ncol=4)
axs[0, 1].legend(ncol=4)

axs[0, 0].set_ylabel("V ($\mu$V)")
axs[1, 0].set_ylabel("V ($\mu$V)")
axs[1, 0].set_xlabel("time (s)")
axs[1, 1].set_xlabel("time (s)")

prettify_axes(axs)
fig_wv_traces.suptitle("WavPack hybrid - traces", fontsize=15)

In [None]:
if save_fig:
    fig_wv_traces.savefig(fig_folder / "exp_wv_traces.pdf")

# Spike sorting

In [None]:
fig_ss, axs_ss = plt.subplots(ncols=2, nrows=2, figsize=(15, 6))

bit_labels = [int(b) for b in bit_order]
wv_labels = wv_order

for probe in probes:
    if "1" in probe:
        col = 0
    else:
        col = 1
    
    res_bit_probe = res_bit.query(f"probe == '{probe}'")
    res_wv_probe = res_wv.query(f"probe == '{probe}'")
    
    ax = axs_ss[0, col]
    df_bit = pd.melt(res_bit_probe, id_vars='factor', var_name='Type', value_name='Units', 
                     value_vars=('n_curated_good_units', 'n_curated_bad_units'))
    sns.barplot(x='factor', y='Units', hue='Type', data=df_bit,
                order=bit_order, ax=ax, palette=sns.color_palette("Set1"))
    ax.set_xticklabels(bit_labels, rotation=30, ha='right')
    if col == 0 and "1" in probe:
        ax.legend(loc=2)
    else:
        ax.legend(loc=3)
    ax.set_ylabel("")
    ax.set_xlabel("")
    n_good_lossless = res_bit_probe.query("factor == 0")["n_curated_good_units"].values[0]
    ax.axhline(n_good_lossless, color="grey", ls="--")    
    
    ax = axs_ss[1, col]
    df_wv = pd.melt(res_wv_probe, id_vars='factor', var_name='Type', value_name='Units', 
                    value_vars=('n_curated_good_units', 'n_curated_bad_units'))
    sns.barplot(x='factor', y='Units', hue='Type', data=df_wv,
                order=wv_order, ax=ax, palette=sns.color_palette("Set2"))
    ax.set_xticklabels(wv_labels, rotation=30, ha='right')
    ax.legend(loc=4)
    ax.set_ylabel("")
    ax.set_xlabel("")
    n_good_lossless = res_wv_probe.query("factor == 0")["n_curated_good_units"].values[0]
    ax.axhline(n_good_lossless, color="grey", ls="--")

axs_ss[0, 0].set_title("Neuropixels 1.0", fontsize=18)
axs_ss[0, 1].set_title("Neuropixels 2.0", fontsize=18)
axs_ss[0, 0].set_xlabel("# bits")
axs_ss[0, 1].set_xlabel("# bits")
axs_ss[1, 0].set_xlabel("hybrid factor")
axs_ss[1, 1].set_xlabel("hybrid factor")
axs_ss[0, 0].set_ylabel("Bit truncation\n# units")
axs_ss[1, 0].set_ylabel("WavPack hybrid\n# units")


prettify_axes(axs_ss)

fig_ss.suptitle("Spike sorting performance", fontsize=20)
fig_ss.subplots_adjust(hspace=0.3)

In [None]:
if save_fig:
    fig_ss.savefig(fig_folder / "exp_ss.pdf")