## This demo notebook can be used to (optionally) decompress ephys data and create two average waveforms per session needed for Unit Match. 

In [None]:
%load_ext autoreload
%autoreload 

import sys
from pathlib import Path

import UnitMatchPy.extract_raw_data as erd
import numpy as np 
from pathlib import Path
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
import os

## Optional, decompress compressed data

In [None]:
#GIVE a list of dirs, for where the raw compressed data is e.g .cbim, .ch and .meta files
raw_data_dir_paths = [r'Path\\to\\rawdata\\Session1', r'Path\\to\\rawdata\\Session2']
#Path to a directory where you want the decompresed data to be saved
#this is a large file, using a fast ssd is advised for quicker run times

#GIVE a path to a directory where the Decompressed data will be saved
#this will make folder called session n for each session
decomp_data_save_dir = r'path/to/decompressed/data/save/dir'
cbin_paths, ch_paths, meta_paths = erd.get_raw_data_paths(raw_data_dir_paths)

In [None]:
# Decompress Data
from mtscomp import Reader

decomp_dir = os.path.join(decomp_data_save_dir, 'DecompData')
if not os.path.exists(decomp_dir):
    os.mkdir(decomp_dir)  # Create a folder in the directory called 'DecompData'

data_paths = []
for i in range(len(raw_data_dir_paths)):
    session_dir = os.path.join(decomp_dir, f'Session{i+1}')  # +1 so starts at 1
    if not os.path.exists(session_dir):
        os.mkdir(session_dir)  # Make a folder for each session called 'SessionX'
    tmpPath = os.path.join(session_dir, 'RawData.bin')
    data_paths.append(tmpPath)

    # Check if the decompressed data file already exists
    if not os.path.exists(tmpPath):
        # Create .bin with the decompressed data
        r = Reader(check_after_decompress=False)  # Skip the verification check to save time
        r.open(cbin_paths[i], ch_paths[i])
        r.tofile(tmpPath)
        r.close()
    else:
        print(f"Decompressed data for Session{i+1} already exists. Skipping decompression.")

# Continue with the rest of your code

## Give paramaters and paths needed for extraction

In [None]:
#Set Up Parameters
sample_amount = 1000 # for both CV, at least 500 per CV
spike_width = 82 # assuming 30khz sampling, 82 and 61 are common choices, covers the AP and space around needed for processing
half_width = np.floor(spike_width/2).astype(int)
max_width = np.floor(spike_width/2).astype(int) #Size of area at start and end of recording to ignore to get only full spikes
n_channels = 384 #neuropixels default
extract_good_units_only = False # bool, set to true if you want to only extract units marked as good 

KS4_data = True #bool, set to true if using Kilosort
if KS4_data:
    spike_width = 61
    samples_before = 20
    samples_after = spike_width - samples_before

#List of paths to a KS directory, can pass paths 
KS_dirs = [r'path/to/KSdir/Session1', r'Path/to/KSdir/Session2']
n_sessions = len(KS_dirs) #How many session are being extracted
spike_ids, spike_times, good_units, all_unit_ids = erd.extract_KS_data(KS_dirs, extract_good_units_only = True)

### If you have not decompressed data above

In [None]:
#give metadata + Raw data paths
#if you are NOT decompressing data here, provide a list of paths to the decompressed data and the metadata

#data_paths = [r'path/to/Decompressed/data1.bin', r'path/to/Decompressed/data2.bin']
#meta_paths = [r''path/to/data.meta', r'path/to/data.meta']


In [None]:
# Extract the units 

if extract_good_units_only:
    for sid in range(n_sessions):
        # Load metadata
        meta_data = erd.read_meta(Path(meta_paths[sid]))
        n_elements = int(meta_data['fileSizeBytes']) / 2
        n_channels_tot = int(meta_data['nSavedChans'])

        # Create memmap to raw data, for that session
        data = np.memmap(data_paths[sid], dtype='int16', shape=(int(n_elements / n_channels_tot), n_channels_tot))

        # Remove spikes which won't have a full waveform recorded
        spike_ids_tmp = np.delete(spike_ids[sid], np.logical_or((spike_times[sid] < max_width), (spike_times[sid] > (data.shape[0] - max_width))))
        spike_times_tmp = np.delete(spike_times[sid], np.logical_or((spike_times[sid] < max_width), (spike_times[sid] > (data.shape[0] - max_width))))

        # Might be slow extracting sample for good units only?
        sample_idx = erd.get_sample_idx(spike_times_tmp, spike_ids_tmp, sample_amount, units=good_units[sid])

        if KS4_data:
            avg_waveforms = Parallel(n_jobs=-1, verbose=10, mmap_mode='r', max_nbytes=None)(
                delayed(erd.extract_a_unit_KS4)(sample_idx[uid], data, samples_before, samples_after, spike_width, n_channels, sample_amount)
                for uid in range(good_units[sid].shape[0])
            )
            avg_waveforms = np.asarray(avg_waveforms)
        else:
            avg_waveforms = Parallel(n_jobs=-1, verbose=10, mmap_mode='r', max_nbytes=None)(
                delayed(erd.extract_a_unit)(sample_idx[uid], data, half_width, spike_width, n_channels, sample_amount)
                for uid in range(good_units[sid].shape[0])
            )
            avg_waveforms = np.asarray(avg_waveforms)

        # Save in file named 'RawWaveforms' in the KS Directory
        erd.save_avg_waveforms(avg_waveforms, KS_dirs[sid], all_unit_ids[sid], good_units=good_units[sid], extract_good_units_only=extract_good_units_only)

else:
    for sid in range(n_sessions):
        # Extracting ALL the Units
        n_units = len(np.unique(spike_ids[sid]))
        # Load metadata
        meta_data = erd.read_meta(Path(meta_paths[sid]))
        n_elements = int(meta_data['fileSizeBytes']) / 2
        n_channels_tot = int(meta_data['nSavedChans'])

        # Create memmap to raw data, for that session
        data = np.memmap(data_paths[sid], dtype='int16', shape=(int(n_elements / n_channels_tot), n_channels_tot))

        # Remove spikes which won't have a full waveform recorded
        spike_ids_tmp = np.delete(spike_ids[sid], np.logical_or((spike_times[sid] < max_width), (spike_times[sid] > (data.shape[0] - max_width))))
        spike_times_tmp = np.delete(spike_times[sid], np.logical_or((spike_times[sid] < max_width), (spike_times[sid] > (data.shape[0] - max_width))))

        # Extract sample indices for all units
        sample_idx = erd.get_sample_idx(spike_times_tmp, spike_ids_tmp, sample_amount, units=np.unique(spike_ids[sid]))

        if KS4_data:
            avg_waveforms = Parallel(n_jobs=-1, verbose=10, mmap_mode='r', max_nbytes=None)(
                delayed(erd.extract_a_unit_KS4)(sample_idx[uid], data, samples_before, samples_after, spike_width, n_channels, sample_amount)
                for uid in range(n_units)
            )
            avg_waveforms = np.asarray(avg_waveforms)
        else:
            avg_waveforms = Parallel(n_jobs=-1, verbose=10, mmap_mode='r', max_nbytes=None)(
                delayed(erd.extract_a_unit)(sample_idx[uid], data, half_width, spike_width, n_channels, sample_amount)
                for uid in range(n_units)
            )
            avg_waveforms = np.asarray(avg_waveforms)

        # Save in file named 'RawWaveforms' in the KS Directory
        erd.save_avg_waveforms(avg_waveforms, KS_dirs[sid], all_unit_ids[sid], good_units=good_units[sid], extract_good_units_only=extract_good_units_only)

del data

#### Optional: delete the decompressed data

In [None]:
import shutil

#DELETE the decompressed data Directory/Folder ( i.e multiple sessions)
shutil.rmtree(decomp_dir)