# Benchmark lossy strategies on GT MEArec data

In this notebook we analyze how lossy compression affects downstream analysis, inscluding spike sorting. 

We use two different strategies:

- Bit truncation
- WavPack hybrid mode

The analysis focuses on:

* compression performance
* influence on spike sorting results
* influence on waveform shapes

This notebook assumes the `ephys-compression/scripts/benchmark-lossy-gt.py` has been run and the `ephys-compression/data/results/benchmark-lossy-gt.csv` and `ephys-compression/data/results/benchmark-lossy-gt-wfs.csv` are available. Moreover, ground-truth recordings for NP1 and NP2 needs to be present in the `ephys-compression/data/mearec` folder (using the `ephys-compression/notebooks/generate-gt-neuropixels-recordings.ipynb` notebook).

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import spikeinterface as si
import spikeinterface.widgets as sw

from utils import prettify_axes

%matplotlib inline

In [None]:
figsize_rect = (10,5)
figsize_square = (10, 10)

In [None]:
data_folder = Path("../data/")
results_folder = Path("../results/")

save_fig = True
fig_folder = results_folder / "figures" / "lossy"
fig_folder.mkdir(exist_ok=True, parents=True)

results_lossy_sim_folder = data_folder / "results-lossy-sim"

In [None]:
res = pd.read_csv(results_lossy_sim_folder / "benchmark-lossy-sim.csv", index_col=False)
res_wfs = pd.read_csv(results_lossy_sim_folder / "benchmark-lossy-sim-waveforms.csv", index_col=False)

In [None]:
probe_names = {"Neuropixels1.0": "NP1",
               "Neuropixels2.0": "NP2"}
for probe, probe_name in probe_names.items():
    res.loc[res.query(f"probe == '{probe}'").index, "probe"] = probe_name
    res_wfs.loc[res_wfs.query(f"probe == '{probe}'").index, "probe"] = probe_name
probes = np.unique(res.probe)

In [None]:
res_wv = res.query("strategy == 'wavpack'")
wv_order = [0] + list(np.sort(np.unique(res_wv.factor))[::-1][:-1])
res_bit = res.query("strategy == 'bit_truncation'")
bit_order = list(np.sort(np.unique(res_bit.factor)))
bit_labels = [int(b) for b in bit_order]
wv_labels = wv_order

bit_cmap = plt.get_cmap("Purples_r")
wv_cmap = plt.get_cmap("Greens_r")

res_wv = res.query("strategy == 'wavpack'")
wv_order = [0] + list(np.sort(np.unique(res_wv.factor))[::-1][:-1])
res_bit = res.query("strategy == 'bit_truncation'")
bit_order = list(np.sort(np.unique(res_bit.factor)))
bit_labels = [int(b) for b in bit_order]
wv_labels = wv_order

Compute dataframe with relative errors

In [None]:
selected_distances = [0, 60]
metrics = ["peak_to_valley", "half_width", "peak_trough_ratio"]

In [None]:
template_metrics = metrics

res_wfs_dist = res_wfs.query(f"distance in {selected_distances}")
df_errors = None
for bit in bit_order[1:]:
    strategy = "bit_truncation"
    new_e_df = res_wfs_dist[["probe", "unit_id", "distance"]].copy()
    new_e_df.loc[:, "strategy"] = [strategy] * len(new_e_df)
    new_e_df.loc[:, "factor"] = [bit] * len(new_e_df)

    for metric in template_metrics:
        metric_gt = f"{metric}_gt"
        metric_tested = f"{metric}_{strategy}_{int(bit)}"
        error = np.abs(res_wfs[metric_tested] - res_wfs[metric_gt]) / np.abs(res_wfs[metric_gt])
        new_e_df.loc[:, f"err_{metric}"] = error
        
    if df_errors is None:
        df_errors = new_e_df
    else:
        df_errors = pd.concat([df_errors, new_e_df])
        
for wv in wv_order[1:]:
    strategy = "wavpack"
    new_e_df = res_wfs_dist[["probe", "unit_id", "distance"]].copy()
    new_e_df.loc[:, "strategy"] = [strategy] * len(new_e_df)
    new_e_df.loc[:, "factor"] = [wv] * len(new_e_df)

    for metric in template_metrics:
        metric_gt = f"{metric}_gt"
        metric_tested = f"{metric}_{strategy}_{wv}"
        error = np.abs(res_wfs[metric_tested] - res_wfs[metric_gt]) / np.abs(res_wfs[metric_gt])
        new_e_df.loc[:, f"err_{metric}"] = error
        
    if df_errors is None:
        df_errors = new_e_df
    else:
        df_errors = pd.concat([df_errors, new_e_df])

# Bit truncation

In [None]:
color = bit_cmap.name

fig_bit_cr_rmse = plt.figure(figsize=figsize_rect)

gs = GridSpec(3, 2, hspace=0.1, wspace=0.3)
ax1_top = fig_bit_cr_rmse.add_subplot(gs[0, 0])
ax1_mid = fig_bit_cr_rmse.add_subplot(gs[1, 0])
ax1_bottom = fig_bit_cr_rmse.add_subplot(gs[2, 0])
ax22 = fig_bit_cr_rmse.add_subplot(gs[:, 1])


# bit truncation
sns.pointplot(data=res_bit, x="factor", y="CR", hue="probe", ax=ax1_top, palette=color)
sns.pointplot(data=res_bit, x="factor", y="CR", hue="probe", ax=ax1_mid, palette=color)
sns.pointplot(data=res_bit, x="factor", y="CR", hue="probe", ax=ax1_bottom, palette=color)

top_lim = 600
mid_lims = [50, 322]
bottom_lims = [-0.1, 33]

ax1_top.set_ylim(bottom=600)
ax1_top.axhline(top_lim, ls="--", color="k", alpha=0.5)
ax1_top.set_xticks([])
ax1_top.set_xlabel("")
ax1_mid.set_ylim(mid_lims)
for lim in mid_lims:
    ax1_mid.axhline(lim, ls="--", color="k", alpha=0.5)
ax1_mid.set_xticks([])
ax1_mid.set_xlabel("")
ax1_bottom.set_ylim(bottom_lims)
ax1_bottom.axhline(bottom_lims[1], ls="--", color="k", alpha=0.5)


sns.despine(ax=ax1_bottom)
sns.despine(ax=ax1_top, bottom=True)
sns.despine(ax=ax1_mid, bottom=True)


ax = ax1_top
d = .015  # how big to make the diagonal lines in axes coordinates
# arguments to pass to plot, just so we don't keep repeating them
kwargs = dict(transform=ax.transAxes, color='k', clip_on=False)
ax.plot((-d, +d), (-d, +d), **kwargs)        # top-left diagonal

ax = ax1_mid
# arguments to pass to plot, just so we don't keep repeating them
kwargs = dict(transform=ax.transAxes, color='k', clip_on=False)
ax.plot((-d, +d), (-d, +d), **kwargs)        # top-left diagonal
ax.plot((-d, +d), (1 - d, 1 + d), **kwargs)  # bottom-left diagonal


ax2 = ax1_bottom
kwargs.update(transform=ax2.transAxes)  # switch to the bottom axes
ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs)  # bottom-left diagonal

#remove one of the legend
ax1_mid.legend_.remove()
ax1_bottom.legend_.remove()

ax = ax1_bottom
ax.axhline(10, color="grey", ls="--")
ax.axhline(5, color="grey", ls="-.")
ax.text(-0.1, 10.3, "10", color="grey", fontsize=12)
ax.text(-0.1, 5.3, "5", color="grey", fontsize=12)
ax.set_xlabel("# bit")

ax1_top.set_ylabel("")
ax1_mid.set_ylabel("Compression Ratio", fontsize=15)
ax1_bottom.set_ylabel("")
ax1_bottom.set_xticklabels(bit_labels)


ax = ax22 #axs_bit_cr_rmse[1]
sns.pointplot(data=res_bit, x="factor", y="rmse", hue="probe", ax=ax, palette=color)
ax.set_ylim(-0.5, 7)
ax.axhline(1.5, color="grey", ls="--")
ax.text(-0.5, 1.55, "1.5", color="grey", fontsize=12)
ax.set_ylabel("RMSE ($\mu$V)")
ax.set_xticklabels(bit_labels)
ax.set_xlabel("# bit")

prettify_axes([ax1_bottom, ax22])

fig_bit_cr_rmse.suptitle("Bit truncation", fontsize=20)
fig_bit_cr_rmse.subplots_adjust(bottom=0.15)

In [None]:
if save_fig:
    fig_bit_cr_rmse.savefig(fig_folder / "bit_cr_rmse.pdf")

### Spike sorting 

In [None]:
fig_bit_ss, axs_bit_ss = plt.subplots(ncols=2, nrows=2, figsize=figsize_rect)

bit_labels = [int(b) for b in bit_order]

for probe in probes:
    if "1" in probe:
        col = 0
    else:
        col = 1
    
    res_probe = res_bit.query(f"probe == '{probe}'")
    
    ax = axs_bit_ss[0, col]
    df_perf = pd.melt(res_probe, id_vars='factor', var_name='metric', value_name='value', 
                      value_vars=('accuracy', 'precision', 'recall'))
    sns.barplot(x='factor', y='value', hue='metric', data=df_perf,
                order=bit_order, ax=ax, palette=sns.color_palette("Accent"))
    ax.set_xticklabels(bit_labels)
    ax.legend(loc=3)
    ax.set_ylabel("")
    ax.set_xlabel("")
    
    ax = axs_bit_ss[1, col]
    df_count = pd.melt(res_probe, id_vars='factor', var_name='Type', value_name='Units', 
             value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
    sns.barplot(x='factor', y='Units', hue='Type', data=df_count,
                order=bit_order, ax=ax, palette=sns.color_palette("Set2"))
    ax.legend(loc=2)
    ax.set_ylabel("")
    ax.set_xlabel("")
    ax.set_ylim(-1, 500)
    ax.axhline(100, color="grey", ls="--")
    ax.set_xticklabels(bit_labels)

axs_bit_ss[0, 0].set_title("NP1", fontsize=18)
axs_bit_ss[0, 1].set_title("NP2", fontsize=18)
axs_bit_ss[1, 0].set_xlabel("# bit")
axs_bit_ss[1, 1].set_xlabel("# bit")
axs_bit_ss[0, 0].set_ylabel("Values")
axs_bit_ss[1, 0].set_ylabel("# Units")
axs_bit_ss[0, 1].get_legend().remove()
axs_bit_ss[1, 1].get_legend().remove()

prettify_axes(axs_bit_ss)
fig_bit_ss.subplots_adjust(hspace=0.3)

In [None]:
if save_fig:
    fig_bit_ss.savefig(fig_folder / "bit_ss.pdf")

### Waveform features

In [None]:
fig_bit_feat, ax_bit_feat = plt.subplots(ncols=len(template_metrics), nrows=2, figsize=figsize_rect)

df_errors_bit = df_errors.query("strategy == 'bit_truncation'")
color = "pastel"
ylims = (-0.05, 1)

for probe in probes:
    if "1" in probe:
        row = 0
    else:
        row = 1
    for i, metric in enumerate(template_metrics):
        ax = ax_bit_feat[row, i]
        sns.boxenplot(data=df_errors_bit.query(f"probe == '{probe}'"), 
                      x="factor", y=f"err_{metric}", hue="distance",
                      order=bit_order[1:], showfliers=False, ax=ax,
                      palette=color)
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_xticklabels(bit_labels[1:])
        if i == 0:
            ax.set_ylabel(f"Relative errors\n({probe})")
        if row == 1:
            ax.set_xlabel("# bit")
        else:
            ax.set_title(metric, fontsize=15)
        ax.axhline(0.1, color="grey", ls="--")
        ax.text(0, 0.11, "10%", fontsize=10, color="grey")
        ax.set_ylim(ylims)
        if i > 0 or row > 0:
            ax.get_legend().remove()
ax_bit_feat[0, 0].set_title("Peak-to-valley", fontsize=15)
ax_bit_feat[0, 1].set_title("Half-width", fontsize=15)
ax_bit_feat[0, 2].set_title("Peak-trough ratio", fontsize=15)

prettify_axes(ax_bit_feat, label_fs=15)
fig_bit_feat.subplots_adjust(hspace=0.3, wspace=0.2)

In [None]:
if save_fig:
    fig_bit_feat.savefig(fig_folder / "bit_feature_errors.pdf")

# WavPack Hybrid

In [None]:
color = wv_cmap.name

fig_wv_cr_rmse, axs_wv_cr_rmse = plt.subplots(ncols=2, nrows=1, figsize=figsize_rect)

# wavpack
ax = axs_wv_cr_rmse[0]
sns.pointplot(data=res_wv, x="factor", y="CR", hue="probe", ax=ax, order=wv_order, palette=color)
ax.set_ylim(0, 20)
ax.axhline(10, color="grey", ls="--")
ax.axhline(5, color="grey", ls="-.")
ax.text(-0.5, 10.2, "10", color="grey", fontsize=12)
ax.text(-0.5, 5.2, "5", color="grey", fontsize=12)
ax.set_ylabel("Compression Ratio")
ax.set_xlabel("bps")

ax = axs_wv_cr_rmse[1]
sns.pointplot(data=res_wv, x="factor", y="rmse", hue="probe", ax=ax, order=wv_order, palette=color)
ax.set_ylim(-0.5, 7)
ax.axhline(1.5, color="grey", ls="--")
ax.text(-0.5, 1.55, "1.5", color="grey", fontsize=12)
ax.set_ylabel("RMSE")
ax.set_xlabel("bps")

prettify_axes(axs_wv_cr_rmse)

fig_wv_cr_rmse.suptitle("WavPack Hybrid", fontsize=20)
fig_wv_cr_rmse.subplots_adjust(hspace=0.3)

In [None]:
if save_fig:
    fig_wv_cr_rmse.savefig(fig_folder / "wv_cr_rmse.pdf")

### Spike sorting

In [None]:
fig_wv_ss, axs_wv_ss = plt.subplots(ncols=2, nrows=2, figsize=figsize_rect)

wv_labels = wv_order

for probe in probes:
    if "1" in probe:
        col = 0
    else:
        col = 1
    
    res_probe = res_wv.query(f"probe == '{probe}'")
    
    ax = axs_wv_ss[0, col]
    df_perf = pd.melt(res_probe, id_vars='factor', var_name='metric', value_name='value', 
                      value_vars=('accuracy', 'precision', 'recall'))
    sns.barplot(x='factor', y='value', hue='metric', data=df_perf,
                order=wv_order, ax=ax, palette=sns.color_palette("Accent"))
    ax.legend(loc=3)
    ax.set_ylabel("")
    ax.set_xlabel("")
    
    ax = axs_wv_ss[1, col]
    df_count = pd.melt(res_probe, id_vars='factor', var_name='Type', value_name='Units', 
                       value_vars=('num_well_detected', 'num_false_positive', 'num_redundant', 'num_overmerged'))
    sns.barplot(x='factor', y='Units', hue='Type', data=df_count,
                order=wv_order, ax=ax, palette=sns.color_palette("Set2"))
    ax.legend(loc=2, ncol=2)
    ax.set_ylabel("# units")
    ax.set_xlabel("")
    ax.set_ylim(-1, 150)
    ax.axhline(100, color="grey", ls="--")

axs_wv_ss[0, 0].set_title("NP1", fontsize=18)
axs_wv_ss[0, 1].set_title("NP2", fontsize=18)
axs_wv_ss[1, 0].set_xlabel("bps")
axs_wv_ss[1, 1].set_xlabel("bps")
axs_wv_ss[0, 0].set_ylabel("Values")
axs_wv_ss[1, 0].set_ylabel("# Units")
axs_wv_ss[0, 1].get_legend().remove()
axs_wv_ss[1, 1].get_legend().remove()


prettify_axes(axs_wv_ss)

fig_wv_ss.subplots_adjust(hspace=0.3)

In [None]:
if save_fig:
    fig_wv_ss.savefig(fig_folder / "wv_ss.pdf")

In [None]:
fig_wv_feat, ax_wv_feat = plt.subplots(ncols=len(template_metrics), nrows=2, figsize=figsize_rect)

color = "pastel"
df_errors_wv = df_errors.query("strategy == 'wavpack'")

for probe in probes:
    if "1" in probe:
        row = 0
    else:
        row = 1
    for i, metric in enumerate(template_metrics):
        ax = ax_wv_feat[row, i]
        sns.boxenplot(data=df_errors_wv.query(f"probe == '{probe}'"), 
                      x="factor", y=f"err_{metric}", hue="distance",
                      order=wv_order[1:], showfliers=False, ax=ax, palette=color)
        ax.set_xlabel("")
        ax.set_ylabel("")
        if i == 0:
            ax.set_ylabel(f"Relative errors\n({probe})")
        if row == 1:
            ax.set_xlabel("bps")
        else:
            ax.set_title(metric, fontsize=15)
        if i > 0 or row > 0:
            ax.get_legend().remove()
        ax.axhline(0.1, color="grey", ls="--")
        ax.text(0, 0.11, "10%", fontsize=10, color="grey")
        ax.set_ylim(-0.05, 0.3)
prettify_axes(ax_wv_feat, label_fs=15)
fig_wv_feat.subplots_adjust(hspace=0.3, wspace=0.2)

In [None]:
if save_fig:
    fig_wv_feat.savefig(fig_folder / "wv_feature_errors.pdf")

### Plot templates

In [None]:
# load GT waveforms
gt_dict = {}
bit_dict = {}
wv_dict = {}
dsets = ["NP1", "NP2"]
for dset in dsets:
    wfs_folder = results_lossy_sim_folder / f"gt-{dset}" / "waveforms"
    we = si.load_waveforms(wfs_folder, with_recording=False)
    gt_dict[dset] = we
    bit_dict[dset] = {}
    wv_dict[dset] = {}

    for bit in bit_order[1:]:
        wfs_folder = results_lossy_sim_folder / f"waveforms-{dset}-bit_truncation" / f"wf_lossy_bit_truncation_{int(bit)}"
        we = si.load_waveforms(wfs_folder, with_recording=False)
        bit_dict[dset][int(bit)] = we
    for bps in wv_order[1:]:
        wfs_folder = results_lossy_sim_folder / f"waveforms-{dset}-wavpack" / f"wf_lossy_wavpack_{bps}"
        we = si.load_waveforms(wfs_folder, with_recording=False)
        wv_dict[dset][bps] = we

Plot template in the center, and single channels on subplots (main channel, suburb channel)

In [None]:
n_bits = len(bit_order) + 2
n_bps = len(wv_order) + 2

In [None]:
# unit ids
unit_ids = dict(
    NP1="#10",
    NP2="#30"
)
dist = 60
radius_um = 100

gt_lw = 3

fig_templates = {}

for dset in dsets:
    fig_templates[dset] = {}
    unit_id = unit_ids[dset]
    we_gt = gt_dict[dset]
    extremum_channels_inds = si.get_template_extremum_channel(we_gt, outputs="index")
    main_channel_ind = extremum_channels_inds[unit_id]
    locs = we_gt.get_channel_locations()
    distances = np.linalg.norm(locs - locs[main_channel_ind], axis=1)
    order = np.argsort(distances)
    idx = np.where(distances[order] > dist)[0][0]
    last_channel_ind = order[idx]
    sparsity = si.compute_sparsity(we_gt, method="radius", radius_um=radius_um)

    w_circles = sw.plot_unit_templates(we_gt, sparsity=sparsity, unit_ids=[unit_id], same_axis=True, 
                                       unit_colors={unit_id : "k"})
    w = sw.plot_unit_templates(we_gt, sparsity=sparsity, unit_ids=[unit_id], same_axis=True, 
                               unit_colors={unit_id : "k"})
    w.legend.remove()
    w.ax.set_title("")
    w.ax.axis("off")
    w_circles.ax.plot(*locs[main_channel_ind], "o", markersize=10, markerfacecolor=None)
    w_circles.ax.plot(*locs[last_channel_ind], "o", markersize=10, markerfacecolor=None)
    
    fig_bit_main, ax_bit_main = plt.subplots(ncols=1, figsize=figsize_rect)
    fig_bit_last, ax_bit_last = plt.subplots(ncols=1, figsize=figsize_rect)

    
    ts = np.arange(-we.nbefore, we.nafter) / we.sampling_frequency * 1000
    template_main_gt = we_gt.get_template(unit_id)[:, main_channel_ind]
    ax_bit_main.plot(ts, template_main_gt, color="k", label="GT", lw=gt_lw)
    template_last_gt = we_gt.get_template(unit_id)[:, last_channel_ind]
    ax_bit_last.plot(ts, template_last_gt, color="k", label="GT", lw=gt_lw)
    for i, (bit, we) in enumerate(bit_dict[dset].items()):
        color = bit_cmap(i / n_bits)
        template = we.get_template(unit_id)
        template_main = template[:, main_channel_ind]
        ax_bit_main.plot(ts, template_main, color=color, label=int(bit))
        template_last = template[:, last_channel_ind]
        ax_bit_last.plot(ts, template_last, color=color, label=int(bit))

    # legend
    ax_bit_main.legend(ncol=4, fontsize=12)
    ax_bit_last.legend(ncol=4, fontsize=12)

    # scale bar
    ax_bit_main.plot([-3, -2], [-50, -50], color="k", lw=3)
    ax_bit_main.plot([-3, -3], [-50, -40], color="k", lw=3)
    ax_bit_main.text(-2.7, -60, "1ms", fontsize=12)
    ax_bit_main.text(-2.9, -45, "10$\mu$V", fontsize=12)

    ax_bit_last.plot([-3, -2], [-20, -20], color="k", lw=3)
    ax_bit_last.plot([-3, -3], [-20, -10], color="k", lw=3)
    ax_bit_last.text(-2.7, -19, "1ms", fontsize=12)
    ax_bit_last.text(-2.9, -15, "10$\mu$V", fontsize=12)
    
    ax_bit_main.axis("off")
    ax_bit_last.axis("off")  

    
    fig_wv_main, ax_wv_main = plt.subplots(ncols=1, figsize=figsize_rect)
    fig_wv_last, ax_wv_last = plt.subplots(ncols=1, figsize=figsize_rect)

    n_bps = len(wv_dict[dset]) + 2
    ts = np.arange(-we.nbefore, we.nafter) / we.sampling_frequency * 1000
    template_main_gt = we_gt.get_template(unit_id)[:, main_channel_ind]
    ax_wv_main.plot(ts, template_main_gt, color="k", label="GT", lw=gt_lw)
    template_last_gt = we_gt.get_template(unit_id)[:, last_channel_ind]
    ax_wv_last.plot(ts, template_last_gt, color="k", label="GT", lw=gt_lw)
    for i, (bps, we) in enumerate(wv_dict[dset].items()):
        color = wv_cmap(i / n_bps)
        template = we.get_template(unit_id)
        template_main = template[:, main_channel_ind]
        ax_wv_main.plot(ts, template_main, color=color, label=bps)
        template_last = template[:, last_channel_ind]
        ax_wv_last.plot(ts, template_last, color=color, label=bps)
    
    # legend
    ax_wv_main.legend(ncol=4, fontsize=12)
    ax_wv_last.legend(ncol=4, fontsize=12)

    # scale bar
    ax_wv_main.plot([-3, -2], [-50, -50], color="k", lw=3)
    ax_wv_main.plot([-3, -3], [-50, -40], color="k", lw=3)
    ax_wv_main.text(-2.7, -60, "1ms", fontsize=12)
    ax_wv_main.text(-2.9, -45, "10$\mu$V", fontsize=12)

    ax_wv_last.plot([-3, -2], [-20, -20], color="k", lw=3)
    ax_wv_last.plot([-3, -3], [-20, -10], color="k", lw=3)
    ax_wv_last.text(-2.7, -19, "1ms", fontsize=12)
    ax_wv_last.text(-2.9, -15, "10$\mu$V", fontsize=12)

    ax_wv_main.axis("off")
    ax_wv_last.axis("off")
    
    fig_templates[dset]["gt"] = w.figure
    fig_templates[dset]["bit_main"] = fig_bit_main
    fig_templates[dset]["bit_lat"] = fig_bit_last
    fig_templates[dset]["wv_main"] = fig_wv_main
    fig_templates[dset]["wv_last"] = fig_wv_last

In [None]:
if save_fig:
    for fig_dset in fig_templates:
        for fig_name, fig in fig_templates[fig_dset].items():
            fig.savefig(fig_folder / f"{fig_dset}_{fig_name}.pdf")

In [None]:
# PLOT distributions with different colors
figs_features = {}

strategy = "bit_truncation"
for probe in probes:
    df_wfs_probe = res_wfs.query(f"probe == '{probe}'")
    fig_m, axs_m = plt.subplots(nrows=len(template_metrics), ncols=len(selected_distances),
                                figsize=figsize_square)
    for b, bit in enumerate(bit_order[1:]):
        color = bit_cmap(b / n_bits)
        for i, metric in enumerate(template_metrics):
            for j, dist in enumerate(selected_distances):
                tm_dist = df_wfs_probe.query(f"distance == {dist}")
                gt_metric_name = f"{metric}_gt"
                sns.scatterplot(data=tm_dist, x=gt_metric_name, y=f"{metric}_{strategy}_{int(bit)}", 
                                ax=axs_m[i, j], color=color, label=int(bit))
                zorder = len(bit_order[1:]) - b
                plt.setp(axs_m[i, j].collections[-1], zorder=zorder)
                axs_m[i, j].set_yticks([])
                axs_m[i, j].set_xticks([])
                axs_m[i, j].set_xlabel("")
                if j == len(distances) - 1:
                    axs_m[i, j].set_ylabel("")
                lims = [np.min(tm_dist[gt_metric_name]) - 0.2 * np.ptp(tm_dist[gt_metric_name]), 
                        np.max(tm_dist[gt_metric_name]) + 0.2 * np.ptp(tm_dist[gt_metric_name])]
                axs_m[i, j].plot(lims, lims, color="grey", alpha=0.7, ls="--")
                if i == 0:
                    axs_m[i, j].set_title(f"Distance: {int(dist)} $\mu$m")
                if i == len(template_metrics) - 1:
                    axs_m[i, j].set_xlabel("GT metric values")                       
                ext = lims[1] - lims[0]
                axs_m[i, j].set_xlim(lims[0] - 0.5*ext, lims[1] + 0.5*ext)
                axs_m[i, j].set_ylim(lims[0] - ext, lims[1] + ext)
                axs_m[i, j].legend().remove()
    axs_m[0, 0].set_ylabel("Peak-to-valley")
    axs_m[1, 0].set_ylabel("Half-width")
    axs_m[2, 0].set_ylabel("Peak-trough ratio")
    axs_m[0, 0].legend()
    prettify_axes(axs_m, label_fs=12)

    fig_m.suptitle(f"{probe} - Bit truncation", fontsize=20)
    fig_m.subplots_adjust(wspace=0.1, hspace=0.3)
    figs_features[f"features_{probe}_{strategy}"] = fig_m
    
strategy = "wavpack"
for probe in probes:
    df_wfs_probe = res_wfs.query(f"probe == '{probe}'")
    fig_m, axs_m = plt.subplots(nrows=len(template_metrics), ncols=len(selected_distances),
                                figsize=figsize_square)
    for b, bps in enumerate(wv_order[1:]):
        color = wv_cmap(b / n_bits)
        for i, metric in enumerate(template_metrics):
            for j, dist in enumerate(selected_distances):
                tm_dist = df_wfs_probe.query(f"distance == {dist}")
                gt_metric_name = f"{metric}_gt"
                sns.scatterplot(data=tm_dist, x=gt_metric_name, y=f"{metric}_{strategy}_{bps}", 
                                ax=axs_m[i, j], color=color, label=bps)
                zorder = len(wv_order[1:]) - b
                plt.setp(axs_m[i, j].collections[-1], zorder=zorder)
                axs_m[i, j].set_yticks([])
                axs_m[i, j].set_xticks([])
                axs_m[i, j].set_xlabel("")
                if j == len(distances) - 1:
                    axs_m[i, j].set_ylabel("")
                lims = [np.min(tm_dist[gt_metric_name]) - 0.2 * np.ptp(tm_dist[gt_metric_name]), 
                        np.max(tm_dist[gt_metric_name]) + 0.2 * np.ptp(tm_dist[gt_metric_name])]
                axs_m[i, j].plot(lims, lims, color="grey", alpha=0.7, ls="--")
                if i == 0:
                    axs_m[i, j].set_title(f"Distance: {int(dist)} $\mu$m")
                if i == len(template_metrics) - 1:
                    axs_m[i, j].set_xlabel("GT metric values")                    
                ext = lims[1] - lims[0]
                axs_m[i, j].set_xlim(lims[0] - 0.5*ext, lims[1] + 0.5*ext)
                axs_m[i, j].set_ylim(lims[0] - ext, lims[1] + ext)
                axs_m[i, j].legend().remove()
    axs_m[0, 0].set_ylabel("Peak-to-valley")
    axs_m[1, 0].set_ylabel("Half-width")
    axs_m[2, 0].set_ylabel("Peak-trough ratio")
    axs_m[0, 0].legend()
    prettify_axes(axs_m, label_fs=12)

    fig_m.suptitle(f"{probe} - Wavpack Hybrid", fontsize=20)
    fig_m.subplots_adjust(wspace=0.1, hspace=0.3)
    figs_features[f"features_{probe}_{strategy}"] = fig_m

In [None]:
if save_fig:
    for fig_name, fig in figs_features.items():
        fig.savefig(fig_folder / f"{fig_name}.pdf")