In [None]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
import os

import sys

import spikeinterface.full as si

%matplotlib widget

sys.path.append("../ephys-compression/")

from audiocompression import write_recording_audio, AudioRecordingExtractor

In [None]:
def get_dir_size(path='.'):
    total = 0
    with os.scandir(path) as it:
        for entry in it:
            if entry.is_file():
                total += entry.stat().st_size
            elif entry.is_dir():
                total += get_dir_size(entry.path)
    return total

In [None]:
oe_folder = "/home/alessio/Documents/data/allen/npix-open-ephys/595262_2022-02-21_15-18-07/Record Node 102"
rec_oe = si.read_openephys(oe_folder, stream_id="0")
rec_oe = si.split_recording(rec_oe)[0]
print(rec_oe)

In [None]:
fs = rec_oe.get_sampling_frequency()

In [None]:
total_bytes = rec_oe.get_num_samples() * rec_oe.get_num_channels() * rec_oe.get_dtype().itemsize

In [None]:
rec_audio_flac = write_recording_audio(rec_oe, "data/test_full_flac", cformat="flac",
                                       chunk_duration="5s", n_jobs=10, overwrite=True,
                                       progress_bar=True)

In [None]:
total_bytes_flac = get_dir_size("data/test_full_flac/")
cr_flac = total_bytes / total_bytes_flac 
print(f"CR FLAC: {cr_flac}")

In [None]:
rec_audio_mp3 = write_recording_audio(rec_oe, "data/test_full_mp3", cformat="mp3",
                                      chunk_duration="5s", n_jobs=10, overwrite=True,
                                      progress_bar=True, verbose=True)

In [None]:
total_bytes_mp3 = get_dir_size("data/test_full_mp3/")
cr_mp3 = total_bytes / total_bytes_mp3 
print(f"CR MP3: {cr_mp3}")

In [None]:
rec_audio_aac = write_recording_audio(rec_oe, "data/test_full_aac", cformat="aac",
                                      chunk_duration="10s", n_jobs=10, overwrite=True,
                                      progress_bar=True)

In [None]:
total_bytes_aac = get_dir_size("data/test_full_aac/")
cr_aac = total_bytes / total_bytes_aac 
print(f"CR AAC: {cr_aac}")

In [None]:
rec_audio_flac = AudioRecordingExtractor("data/test_full_flac/")
rec_audio_flac.set_probegroup(rec_oe.get_probegroup(), in_place=True)
rec_oe.copy_metadata(rec_audio_flac)

rec_audio_mp3 = AudioRecordingExtractor("data/test_full_mp3/")
rec_audio_mp3.set_probegroup(rec_oe.get_probegroup(), in_place=True)
rec_oe.copy_metadata(rec_audio_mp3)
# rec_audio_aac = AudioRecordingExtractor("data/test_full_aac/")

In [None]:
time_range = [60, 70]
start_frame = int(time_range[0] * fs)
end_frame = int(time_range[1] * fs)

channel_ids = rec_oe.channel_ids[50:60]

In [None]:
fig, ax = plt.subplots()

si.plot_timeseries(rec_oe, time_range=time_range, channel_ids=channel_ids, color="C0",
                   alpha=0.8, ax=ax)
si.plot_timeseries(rec_audio_mp3, time_range=time_range, channel_ids=channel_ids, color="C1",
                   alpha=0.8, ax=ax)

In [None]:
rec_oe_f = si.bandpass_filter(rec_oe)
rec_audio_mp3_f = si.bandpass_filter(rec_audio_mp3)

fig, ax = plt.subplots()

si.plot_timeseries(rec_oe_f, time_range=time_range, channel_ids=channel_ids, color="C0",
                   alpha=0.8, ax=ax)
si.plot_timeseries(rec_audio_mp3_f, time_range=time_range, channel_ids=channel_ids, color="C1",
                   alpha=0.8, ax=ax)

In [None]:
channel_ids

In [None]:
plt.figure()

plt.plot(tr_oe_f[:, 0])
plt.plot(tr_mp3_f[:, 0])

In [None]:
rec = rec_audio_flac

In [None]:
# assert FLAC is lossless
chunk_sizes_s = [1, 5, 10, 20]
nchunks = 2
for i in range(nchunks):
    for chunk_s in chunk_sizes_s:
        print(f"chunk in s: {chunk_s} -- {i + 1} / {nchunks}")
        num_samples = int(chunk_s * fs)
        random_start = np.random.randint(0, rec_audio_mp3.get_num_samples() - num_samples)
        start_frame = random_start
        end_frame = random_start + num_samples
        tr_audio = rec.get_traces(start_frame=start_frame, end_frame=end_frame)
        tr_orig = rec_oe.get_traces(start_frame=start_frame, end_frame=end_frame)  
        assert np.allclose(tr_audio, tr_orig)

In [None]:
# try to spike sort mp3?

In [None]:
si.get_default_params("tridesclous")

In [None]:
nsec = 60

In [None]:
rec_sub_mp3 = rec_audio_mp3_f.frame_slice(start_frame=0, end_frame=int(nsec * fs))

In [None]:
sort_mp3 = si.run_tridesclous(rec_sub_mp3, n_jobs_bin=10, total_memory="2G", verbose=True)