## This demo notebook can be used to (optionally) decompress ephys data and create two average waveforms per session needed for Unit Match. 

In [None]:
%load_ext autoreload
%autoreload 

import sys
from pathlib import Path

import UnitMatchPy.extract_raw_data as erd
import numpy as np 
from pathlib import Path
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
import os
import json

## Give paramaters and paths needed for extraction

In [None]:
#Set Up Parameters
sample_amount = 1000 # for both CV, at least 500 per CV
spike_width = 82 # assuming 30khz sampling, 82 and 61 are common choices (KS1-3, KS4), covers the AP and space around needed for processing
half_width = np.floor(spike_width/2).astype(int)
max_width = np.floor(spike_width/2).astype(int) #Size of area at start and end of recording to ignore to get only full spikes
n_channels = 384 #neuropixels default, the number of channels EXCLUDING sync channels
extract_good_units_only = False # bool, set to true if you want to only extract units marked as good 

KS4_data = False #bool, set to true if using Kilosort, as KS4 spike times refer to start of waveform not peak
if KS4_data:
    samples_before = 20
    samples_after = spike_width - samples_before
    max_width = samples_after #Number of samples on either side of the 

#List of paths to a KS directory, can pass paths 
KS_dirs = [r'path/to/KiloSort/Dir/Session1', r'path/to/KiloSort/Dir/Session2']
n_sessions = len(KS_dirs) #How many session are being extracted
spike_ids, spike_times, good_units, all_unit_ids = erd.extract_KS_data(KS_dirs, extract_good_units_only = True)

Need decompressed data

In [None]:
#give metadata + Raw data paths

data_paths = [r'path/to/Decompressed/data1.dat', r'path/to/Decompressed/data2.dat']
meta_paths = [r'path/to/data/structure.oebin', r'path/to/data/structure.oebin']

In [None]:
#Extract the units 

if extract_good_units_only:
    for sid in range(n_sessions):
        #load metadata
        with open(meta_paths[sid], 'r') as file:
            meta = json.load(file)
        n_bytes = os.path.getsize(data_paths[sid])
        n_channels_tot = int(meta['continuous'][0]['num_channels'])
        n_samples = int(n_bytes / (2*n_channels_tot))

        #create memmap to raw data, for that session
        data = np.memmap(data_paths[sid], dtype = 'int16', shape =(n_samples, n_channels_tot))

        # Remove spike which won't have a full waveform recorded
        spike_ids_tmp = np.delete(spike_ids[sid], np.logical_or( (spike_times[sid] < max_width), ( spike_times[sid] > (data.shape[0] - max_width))))
        spike_times_tmp = np.delete(spike_times[sid], np.logical_or( (spike_times[sid] < max_width), ( spike_times[sid] > (data.shape[0] - max_width))))


        #might be slow extracting sample for good units only?
        sample_idx = erd.get_sample_idx(spike_times_tmp, spike_ids_tmp, sample_amount, units = good_units[sid])

        if KS4_data:
            avg_waveforms = Parallel(n_jobs = -1, verbose = 10, mmap_mode='r', max_nbytes=None )(delayed(erd.extract_a_unit_KS4)(sample_idx[uid], data, samples_before, samples_after, spike_width, n_channels, sample_amount)for uid in range(good_units[sid].shape[0]))
            avg_waveforms = np.asarray(avg_waveforms)           
        else:
            avg_waveforms = Parallel(n_jobs = -1, verbose = 10, mmap_mode='r', max_nbytes=None )(delayed(erd.extract_a_unit)(sample_idx[uid], data, half_width, spike_width, n_channels, sample_amount)for uid in range(good_units[sid].shape[0]))
            avg_waveforms = np.asarray(avg_waveforms)

        #Save in file named 'RawWaveforms' in the KS Directory
        erd.save_avg_waveforms(avg_waveforms, KS_dirs[sid], all_unit_ids[sid], GoodUnits = good_units[sid], extract_good_units_only = extract_good_units_only)

else:
    for sid in range(n_sessions):
        #Extracting ALL the Units
        n_units = len(np.unique(spike_ids[sid]))
        #load metadata
        with open(meta_paths[sid], 'r') as file:
            meta = json.load(file)
        n_bytes = os.path.getsize(data_paths[sid])
        n_channels_tot = int(meta['continuous'][0]['num_channels'])
        n_samples = int(n_bytes / (2*n_channels_tot))

        #create memmap to raw data, for that session
        data = np.memmap(data_paths[sid], dtype = 'int16', shape =(n_samples, n_channels_tot))

        # Remove spikes which won't have a full waveform recorded
        spike_ids_tmp = np.delete(spike_ids[sid], np.logical_or( (spike_times[sid] < max_width), ( spike_times[sid] > (data.shape[0] - max_width))))
        spike_times_tmp = np.delete(spike_times[sid], np.logical_or( (spike_times[sid] < max_width), ( spike_times[sid] > (data.shape[0] - max_width))))


        sample_idx = erd.get_sample_idx(spike_times_tmp, spike_ids_tmp, sample_amount, units= np.unique(spike_ids[sid]))
        
        if KS4_data:
            avg_waveforms = Parallel(n_jobs = -1, verbose = 10, mmap_mode='r', max_nbytes=None )(delayed(erd.extract_a_unit_KS4)(sample_idx[uid], data, samples_before, samples_after, spike_width, n_channels, sample_amount)for uid in range(n_units))
            avg_waveforms = np.asarray(avg_waveforms)           
        else:
            avg_waveforms = Parallel(n_jobs = -1, verbose = 10, mmap_mode='r', max_nbytes=None )(delayed(erd.extract_a_unit)(sample_idx[uid], data, half_width, spike_width, n_channels, sample_amount)for uid in range(n_units))
            avg_waveforms = np.asarray(avg_waveforms)

        #Save in file named 'RawWaveforms' in the KS Directory
        erd.save_avg_waveforms(avg_waveforms, KS_dirs[sid], all_unit_ids[sid], GoodUnits = good_units[sid], extract_good_units_only = extract_good_units_only)
del data