# ZARR compression

This notebook investigaes the use of Zarr compressors and filters to compress raw Neuropixels data.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import spikeinterface.full as si
import probeinterface as pi
import numpy as np
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
import os
import time
from pathlib import Path
import time
import shutil

import zarr
import numcodecs
from numcodecs import Blosc, blosc


%matplotlib widget

## Load data in SpikeInterface

- Open Ephys data (`read_openephys()`)
- MEArec simulated data (`read_mearec()`)

In [None]:
# # mearec_file = "/home/alessio/Documents/data/mearec/recordings/recording_Neuronexus-32_1800_int16.h5"
# rec, sort = si.read_mearec(mearec_file)
# print(rec)

In [None]:
data_folder = Path("/home/alessio/Documents/data/allen/npix-open-ephys/")

In [None]:
oe_folder = data_folder / "595262_2022-02-22_16-47-26/Record Node 102/"
rec = si.read_openephys(oe_folder, stream_id="0")
print(rec)

In [None]:
# w = si.plot_timeseries(rec, time_range=[100, 130])

### Truncation filter 

Optional truncation filter function (see [here](https://github.com/AllenNeuralDynamics/lightsheet-compression-tests/blob/main/compress_zarr.py#L18))

In [None]:
def trunc_filter(bits, recording):
    scale = 1.0 / (2 ** bits)
    dtype = recording.get_dtype()
    if bits == 0:
        return []
    else:
        return [numcodecs.fixedscaleoffset.FixedScaleOffset(offset=0, scale=scale, dtype=dtype)]

### Choose zarr compressor

`zstd` seems to provide good compression ratios and comp-/decomp speeds. 
Let's choose a middle level (e.g. 5)

In [None]:
compressor = Blosc(cname='zstd', clevel=5, shuffle=Blosc.BITSHUFFLE,)

### Define output folder abd configs

In [None]:
overwrite = False

In [None]:
zarr_output_folder = Path("zarr_tests")

if overwrite:
    if zarr_output_folder.is_dir():
        shutil.rmtree(zarr_output_folder)

    zarr_output_folder.mkdir()

In [None]:
zarr_root = "test_npix_full_zarr"

In [None]:
n_jobs = 4
chunksize = 100000
blosc.use_threads = True

# blosc.set_nthreads(8)
# print(blosc.get_nthreads())

In [None]:
# define truncation bits
trunc_bits = [0, 1, 2, 3, 4, 5, 6, 7, 8]

# optionally define stub as the number of seconds to cut the recording (e.g. stub = 30)
stub = 30

In [None]:
rec_zarr_dict = {}


for trunc_bit in trunc_bits:
    rec_zarr_dict[trunc_bit] = {}
    zarr_path = zarr_output_folder / f"{zarr_root}_trunc{trunc_bit}.zarr"
    
    if overwrite:
        if zarr_path.is_dir():
            shutil.rmtree(zarr_path)
    
    if stub is not None:
        end_frame = int(stub * rec.get_sampling_frequency())
        rec_stub = rec.frame_slice(start_frame=0, end_frame=end_frame)
    else:
        rec_stub = rec
    
    if zarr_path.is_dir():
        rec_zarr = si.read_zarr(zarr_path)
        elapsed_time = 0
    else:
        t_start = time.perf_counter()
        filters = trunc_filter(trunc_bit, rec_stub)
        rec_zarr = rec_stub.save(format="zarr", zarr_path=zarr_path, 
                                 compressor=compressor, filters=filters, n_jobs=n_jobs,
                                 chunk_size=10000, progress_bar=True)
        t_stop = time.perf_counter()
        elapsed_time = np.round(t_stop - t_start, 2)
    cr = np.round(rec_zarr._root['traces_seg0'].nbytes / rec_zarr._root['traces_seg0'].nbytes_stored, 2)
    
    rec_zarr_dict[trunc_bit]["CR"] = cr
    rec_zarr_dict[trunc_bit]["rec"] = rec_zarr
    
    print(f"Elapsed time truncation: {trunc_bit}: {elapsed_time}s - CR: {cr}")

### Visualization

In [None]:
time_range = [15, 16]
channel_ids = rec.get_channel_ids()[100:110]

In [None]:
fig, axs = plt.subplots(nrows=2, sharex=True)
for i, (trunc_bit, trunc_dict) in enumerate(rec_zarr_dict.items()):
    rec_trunc = trunc_dict["rec"]
    _ = si.plot_timeseries(rec_trunc, time_range=time_range, channel_ids=channel_ids,
                           mode="line", color=f"C{i}", ax=axs[0])
    
    rec_f = si.bandpass_filter(rec_trunc)
    rec_zarr_dict[trunc_bit]["rec_filt"] = rec_f
    l = axs[0].get_lines()[-1].set_label(f"trunc {trunc_bits[i]}")
    
    _ = si.plot_timeseries(rec_f, time_range=time_range, channel_ids=channel_ids,
                           mode="line", color=f"C{i}", ax=axs[1])
axs[0].legend()
axs[0].set_title("Raw", fontsize=15)
axs[0].set_xlabel("", fontsize=15)
axs[1].set_title("Filtered", fontsize=15)
fig.subplots_adjust(hspace=0.3)

### RMS error on filtered traces

In [None]:
time_range = [15, 18]
frames = np.array(time_range) * rec.get_sampling_frequency()
frames = frames.astype(int)

In [None]:
rec_orig_f = si.bandpass_filter(rec_stub)
traces_orig = rec_orig_f.get_traces(start_frame=frames[0], end_frame=frames[1], return_scaled=True)
errors_rms = []
for i, (trunc_bit, trunc_dict) in enumerate(rec_zarr_dict.items()):
    rec_f = trunc_dict["rec_filt"]
    traces_trunc_f = rec_f.get_traces(start_frame=frames[0], end_frame=frames[1], return_scaled=True)
    
    error_rms = np.sqrt(((traces_trunc_f.ravel() - traces_orig.ravel()) ** 2).mean())
    rec_zarr_dict[trunc_bit]["rmse"] = error_rms
    print(f"RMS for truncation {trunc_bit}: {error_rms}")

In [None]:
fig_e, ax_e = plt.subplots()

errors_rms = []
crs = []

for (trunc_bit, trunc_dict) in rec_zarr_dict.items():
    crs.append(trunc_dict["CR"])
    errors_rms.append(trunc_dict["rmse"])


ax_e.plot(trunc_bits, errors_rms, "d",
          ls="--", label="RMSE")
ax_e.plot(trunc_bits, crs, "o",
          ls="-", label="CR")
ax_e.set_title("Error VS CR")
ax_e.set_xlabel("# truncation bits")
ax_e.set_ylabel("")
ax_e.axhline(3, color="grey", ls="--", alpha=0.4)
ax_e.legend()

## Spike Sorting

Test whether compression affects spike sorting results.

In [None]:
si.installed_sorters()

In [None]:
sorter = "tridesclous"

In [None]:
for (trunc_bit, trunc_dict) in rec_zarr_dict.items():
    print(f"Running {sorter} for truncation bits {trunc_bit}")
    
    sorter_folder = zarr_output_folder / f"{sorter}_trunc{trunc_bit}"
    if sorter_folder.is_dir() and not overwrite:
        sort = si.load_extractor(sorter_folder)
    else:
        rec_f = trunc_dict["rec_filt"]
        t_start = time.perf_counter()
        sort = si.run_sorter(sorter, rec_f, verbose=True)
        sort = sort.save(folder=zarr_output_folder / f"{sorter}_trunc{trunc_bit}")
        t_stop = time.perf_counter()
        elapsed_time = np.round(t_stop - t_start, 2)
    rec_zarr_dict[trunc_bit]["sort"] = sort
    
    
    print(f"Elapsed {sorter} - truncation {trunc_bit}: {elapsed_time}s")

In [None]:
import spikeinterface.sortingcomponents as scp

In [None]:
peaks = scp.detect_peaks(rec_f, n_jobs=4, progress_bar=True, chunk_duration="2s")

In [None]:
si.plot_drift_over_time(rec_f, peaks, mode="scatter")