# AUDIO compression for ephys data

This notebook showcases the current write implementation of SI objects to audio formats:

- lossless
  - FLAC
  - WavPack
- lossy
  - MP3

In [None]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
import os
from pathlib import Path
import shutil
import scipy.io.wavfile as wavfile

import sys

import spikeinterface.full as si

%matplotlib widget

sys.path.append("..")

from audiocompression import write_recording_audio, AudioRecordingExtractor, _max_channels_per_stream
from utils import get_median_and_lsb

In [None]:
test_audio_folder = Path("../data/audio/")

In [None]:
n_jobs = 10
job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=True)

In [None]:
def get_dir_size(path='.'):
    total = 0
    with os.scandir(path) as it:
        for entry in it:
            if entry.is_file():
                total += entry.stat().st_size
            elif entry.is_dir():
                total += get_dir_size(entry.path)
    return total

In [None]:
np_version = 1

if np_version == 2:
    oe_folder = "/home/alessio/Documents/data/allen/npix-open-ephys/595262_2022-02-21_15-18-07/Record Node 102"
else:
    oe_folder = "/home/alessio/Documents/data/allen/npix-open-ephys/618382_2022-03-31_14-27-03/Record Node 102/"
rec_oe = si.read_openephys(oe_folder, stream_id="0")
rec_oe = si.split_recording(rec_oe)[0]
print(rec_oe)

lsb_value, median_values = get_median_and_lsb(rec_oe)

In [None]:
dur = rec_oe.get_num_samples() / rec_oe.get_sampling_frequency()
dtype = rec_oe.get_dtype()
gain = rec_oe.get_channel_gains()[0]

In [None]:
fs = rec_oe.get_sampling_frequency()

In [None]:
# median correction
rec_to_compress = si.scale(rec_oe, gain=1., offset=-median_values, dtype=dtype)
rec_to_compress = si.scale(rec_oe, gain=1. / lsb_value, dtype=dtype)

In [None]:
w = si.plot_probe_map(rec_to_compress)
w.ax.set_xlim(-20, 60)
w.ax.set_ylim(300, 600)

In [None]:
total_bytes = rec_oe.get_num_samples() * rec_oe.get_num_channels() * rec_oe.get_dtype().itemsize

In [None]:
import math

math.gcd(257, 256)

# LOSSLESS

### FLAC

In [None]:
flac_stream = test_audio_folder / "flac_stream"
flac_concat = test_audio_folder / "flac_concat"

In [None]:
rec_audio_flac_stream = write_recording_audio(rec_to_compress, flac_stream, cformat="flac",
                                              overwrite=True, mode="stream", **job_kwargs)

rec_audio_flac_concat = write_recording_audio(rec_to_compress, flac_concat, cformat="flac",
                                              overwrite=True, mode="concat", **job_kwargs)

In [None]:
from audio_numcodecs import FlacCodec, WavPackCodec

In [None]:
zarr_path = test_audio_folder / "flac-si.zarr"

if zarr_path.is_dir():
    shutil.rmtree(zarr_path)

rec_flac_zarr = rec_to_compress.save(format="zarr", zarr_path=zarr_path, 
                                     compressor=FlacCodec(compression_level=5),
                                     **job_kwargs)

In [None]:
tr_flac = rec_flac_zarr.get_traces(end_frame=30000)


In [None]:
plt.figure()

plt.plot(tr_or[:, 100])
plt.plot(tr_flac[:, 100])

In [None]:
total_bytes_flac_stream = get_dir_size(flac_stream)
cr_flac_stream = total_bytes / total_bytes_flac_stream
print(f"CR FLAC - stream mode: {cr_flac_stream}")

total_bytes_flac_concat = get_dir_size(flac_concat)
cr_flac_concat = total_bytes / total_bytes_flac_concat
print(f"CR FLAC - concat mode: {cr_flac_concat}")

print(f"CR FLAC SI - concat mode: {rec_flac_zarr.get_annotation('compression_ratio')}")

### WAVPACK

In [None]:
wv_stream = test_audio_folder / "wv_stream"
wv_concat = test_audio_folder / "wv_concat"

In [None]:
rec_audio_wv_stream = write_recording_audio(rec_to_compress, wv_stream, cformat="wavpack",
                                            overwrite=True, mode="stream", **job_kwargs)
rec_audio_wv_concat = write_recording_audio(rec_to_compress, wv_concat, cformat="wavpack",
                                     overwrite=True, mode="concat", **job_kwargs)

In [None]:
zarr_path = test_audio_folder / "wavpack-si.zarr"

if zarr_path.is_dir():
    shutil.rmtree(zarr_path)
    
rec_wv_zarr = rec_to_compress.save(format="zarr", zarr_path=zarr_path, 
                                   compressor=WavPackCodec(),
                                   **job_kwargs)

In [None]:
total_bytes_wv = get_dir_size(wv_stream)
cr_wv_stream = total_bytes / total_bytes_wv 
print(f"CR WV - stream: {cr_wv_stream}")

total_bytes_wv = get_dir_size(wv_concat)
cr_wv_concat = total_bytes / total_bytes_wv 
print(f"CR WV - concat: {cr_wv_concat}")

print(f"CR WV SI - concat mode: {rec_wv_zarr.get_annotation('compression_ratio')}")

# LOSSY

In [None]:
mp3_file = test_audio_folder / "mp3_stream"

In [None]:
rec_audio_mp3 = write_recording_audio(rec_to_compress, mp3_file, cformat="mp3",
                                      overwrite=True, mode="stream", **job_kwargs)

In [None]:
total_bytes_mp3 = get_dir_size(mp3_file)
cr_mp3 = total_bytes / total_bytes_mp3 
print(f"CR MP3: {cr_mp3}")

In [None]:
wv_file_lossy = test_audio_folder / "wv_stream_lossy"

In [None]:
rec_audio_wv_lossy = write_recording_audio(rec_to_compress, wv_file_lossy, cformat="wavpack",
                                           lossless=False, overwrite=True, mode="stream", **job_kwargs)

In [None]:
total_bytes_wv_lossy = get_dir_size(wv_file_lossy)
cr_wv_lossy = total_bytes / total_bytes_wv_lossy 
print(f"CR WV-HYBRID: {cr_wv_lossy}")

In [None]:
# check traces

In [None]:
#stream VS concat

snippet_durations = [0.5, 1, 2, 5]
num_chunks = 3
num_channels = [10, 384]
num_samples = rec_to_compress.get_num_samples()

for dur in snippet_durations:
    print(f"Snippet duration {dur}s")
    fig, axs = plt.subplots(nrows=num_chunks, ncols=len(num_channels))
    for ch in range(num_chunks):
        samples = int(rec_to_compress.get_sampling_frequency() * dur)
        start_frame = np.random.randint(num_samples - samples - 1)
        end_frame = start_frame + samples
        
        for inc, nc in enumerate(num_channels):
            channel_ids = rec_to_compress.channel_ids[np.random.permutation(rec_to_compress.get_num_channels())[:nc]]
            
            random_channel = np.random.randint(nc)
            tr_or = rec_to_compress.get_traces(start_frame=start_frame, end_frame=end_frame,
                                               channel_ids=channel_ids)
            t_start = time.perf_counter()
            tr_flac_str = rec_audio_flac_stream.get_traces(start_frame=start_frame, end_frame=end_frame,
                                                channel_ids=channel_ids)
            t_stop = time.perf_counter()
            elapsed_time_flac = np.round(t_stop - t_start, 2)
            print(f"FLAC {ch} - stream - num channels {nc}: {elapsed_time_flac} s")
            
            t_start = time.perf_counter()
            tr_flac_cnc = rec_audio_flac_concat.get_traces(start_frame=start_frame, end_frame=end_frame,
                                                channel_ids=channel_ids)
            t_stop = time.perf_counter()
            elapsed_time_flac = np.round(t_stop - t_start, 2)
            print(f"FLAC {ch} - concat - num channels {nc}: {elapsed_time_flac} s")
            
            axs[ch, inc].plot(tr_or[:, random_channel], "k", label="GT", lw=2, alpha=0.8)
            axs[ch, inc].plot(tr_flac_str[:, random_channel], "C0", label="FLAC", alpha=0.8)            
            axs[ch, inc].plot(tr_flac_cnc[:, random_channel], "C1", label="WV", alpha=0.8)            

In [None]:
snippet_durations = [0.5, 1, 2, 5]
num_chunks = 3
num_channels = [10, 384]
num_samples = rec_to_compress.get_num_samples()

for dur in snippet_durations:
    print(f"Snippet duration {dur}s")
    fig, axs = plt.subplots(nrows=num_chunks, ncols=len(num_channels))
    for ch in range(num_chunks):
        samples = int(rec_to_compress.get_sampling_frequency() * dur)
        start_frame = np.random.randint(num_samples - samples - 1)
        end_frame = start_frame + samples
        
        for inc, nc in enumerate(num_channels):
            channel_ids = rec_to_compress.channel_ids[np.random.permutation(rec_to_compress.get_num_channels())[:nc]]
            
            random_channel = np.random.randint(nc)
            tr_or = rec_to_compress.get_traces(start_frame=start_frame, end_frame=end_frame,
                                               channel_ids=channel_ids)
            t_start = time.perf_counter()
            tr_flac = rec_audio_flac.get_traces(start_frame=start_frame, end_frame=end_frame,
                                                channel_ids=channel_ids)
            t_stop = time.perf_counter()
            elapsed_time_flac = np.round(t_stop - t_start, 2)
            print(f"FLAC {ch} - dur{dur} - num channels {nc}: {elapsed_time_flac} s")
            
            t_start = time.perf_counter()
            tr_wv = rec_audio_wv.get_traces(start_frame=start_frame, end_frame=end_frame,
                                            channel_ids=channel_ids)
            t_stop = time.perf_counter()
            elapsed_time_wv = np.round(t_stop - t_start, 2)
            print(f"WV test{ch} - num channels {nc}: {elapsed_time_wv} s")
            
            axs[ch, inc].plot(tr_or[:, random_channel], "k", label="GT", lw=2, alpha=0.8)
            axs[ch, inc].plot(tr_flac[:, random_channel], "C0", label="FLAC", alpha=0.8)            
            axs[ch, inc].plot(tr_wv[:, random_channel], "C1", label="WV", alpha=0.8)            

In [None]:
snippet_durations = [0.5, 1, 2, 5]
num_chunks = 3
num_channels = [10, 384]
num_samples = rec_to_compress.get_num_samples()

for dur in snippet_durations:
    print(f"Snippet duration {dur}s")
    fig, axs = plt.subplots(nrows=num_chunks, ncols=len(num_channels))
    for ch in range(num_chunks):
        samples = int(rec_to_compress.get_sampling_frequency() * dur)
        start_frame = np.random.randint(num_samples - samples - 1)
        end_frame = start_frame + samples
        
        for inc, nc in enumerate(num_channels):
            channel_ids = rec_to_compress.channel_ids[np.random.permutation(rec_to_compress.get_num_channels())[:nc]]
            
            random_channel = np.random.randint(nc)
            tr_or = rec_to_compress.get_traces(start_frame=start_frame, end_frame=end_frame,
                                               channel_ids=channel_ids)
            t_start = time.perf_counter()
            tr_mp3 = rec_audio_mp3.get_traces(start_frame=start_frame, end_frame=end_frame,
                                                channel_ids=channel_ids)
            t_stop = time.perf_counter()
            elapsed_time_mp3 = np.round(t_stop - t_start, 2)
            print(f"MP3 {ch} - dur{dur} - num channels {nc}: {elapsed_time_mp3} s")
            
            axs[ch, inc].plot(tr_or[:, random_channel], "k", label="GT", lw=2, alpha=0.8)
            axs[ch, inc].plot(tr_mp3[:, random_channel], "C0", label="MP3", alpha=0.8)            


### Spike sort MP3

In [None]:
sorter = "kilosort2_5"
sorter_params = {"n_jobs_bin": 10, "total_memory": "4G"}

In [None]:
sort_mp3_KS = si.run_sorter(sorter, rec_audio_mp3, output_folder=test_audio_folder / "mp3_ks25", verbose=True,
                            **sorter_params)

In [None]:
print(sort_mp3_KS)

In [None]:
sort_flac_KS = si.run_sorter(sorter, rec_audio_flac, output_folder=test_audio_folder / "flac_ks25", 
                             verbose=True, **sorter_params)

In [None]:
print(sort_flac_KS)

In [None]:
mcmp = si.compare_multiple_sorters([sort_flac_KS, sort_mp3_KS], name_list=["FLAC", "MP3"], verbose=True)

In [None]:
si.plot_multicomp_agreement_by_sorter(mcmp)

In [None]:
si.plot_agreement_matrix(mcmp.comparisons[0])

In [None]:
cmp_gt = si.compare_sorter_to_ground_truth(sort_flac_KS, sort_mp3_KS)

In [None]:
len(cmp_gt.get_well_detected_units(well_detected_score=0.9))