## This demo notebook can be used to (optionally) decompress ephys data and create two average waveforms per session needed for Unit Match. 

In [None]:
%load_ext autoreload
%autoreload 

import sys
from pathlib import Path

import UnitMatchPy.Extract_raw_data as erd
import numpy as np 
from pathlib import Path
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
import os

## Optional, decompress compressed data

In [None]:
#GIVE a list of dirs, for where the raw compressed data is e.g .cbim, .ch and .meta files
RawDataDirPaths = [r'Path\to\rawdata\Session1', r'Path\to\rawdata\Session2']

#Path to a directory where you want the decompresed data to be saved
#this is a large file, using a fast ssd is advised for quicker run times

#GIVE a path to a directory where the Decompressed data will be saved
#this will make folder called session n for each session
DecompDataSaveDir = r'path/to/decompressed/data/save/dir'

cbinPaths, chPaths, metaPaths = erd.get_raw_data_paths(RawDataDirPaths)

In [None]:
#Decompress Data
from mtscomp import Reader

DecompDir = os.path.join(DecompDataSaveDir, 'DecompData')
os.mkdir(DecompDir) # Create a folder in the directory called 'DecompData'

DataPaths = []
for i in range(len(RawDataDirPaths)):
    tmpPath = os.path.join(DecompDir, f'Session{i+1}')  #+1 so starts at 1
    os.mkdir(tmpPath) # make a folder for each session called 'SessionX' 
    tmpPath = os.path.join(tmpPath, 'RawData.bin')
    DataPaths.append(tmpPath)

    # create .bin with the decompressed data

    #r = Reader() # do the mtscomp verification
    r = Reader(check_after_decompress = False) #Skip the verification check to save time
    r.open(cbinPaths[i], chPaths[i])
    r.tofile(tmpPath)
    r.close()

## Give paramaters and paths needed for extraction

In [None]:
#Set Up Parameters
SampleAmount = 1000 # for both CV, at least 500 per CV
SpikeWidth = 82 # assuming 30khz sampling, 82 and 61 are common choices, covers the AP and space around needed for processing
HalfWidth = np.floor(SpikeWidth/2).astype(int)
nChannels = 384 #neuropixels default
ExtractGoodUnitsOnly = False # bool, set to true if you want to only extract units marked as good 

KS4data = False #bool, set to true if using Kilosort, as KS4 spiketimes refer to start of waveform not peak
if KS4data:
    SamplesBefore = 20
    SamplesAfter = SpikeWidth - SamplesBefore

#List of paths to a KS directory, can pass paths 
KSdirs = [r'path/to/KiloSort/Dir/Session1', r'path/to/KiloSort/Dir/Session2']
nSessions = len(KSdirs) #How many session are being extracted
SpikeIds, SpikeTimes, GoodUnits = erd.extract_KSdata(KSdirs, ExtractGoodUnitsOnly = True)

### If you have not decompressed data above

In [None]:
#give metadata + Raw data paths
#if you are NOT decompressing data here, provide a list of paths to the decompressed data and the metadata


#DataPaths = [r'path/to/Decompressed/data1.bin', r'path/to/Decompressed/data2.bin']
#metaPaths = [r''path/to/data.meta', r'path/to/data.meta']


In [None]:
GoodUnits = [None for s in range(nSessions)]

In [None]:
#Extract the units 

if ExtractGoodUnitsOnly:
    for sid in range(nSessions):
        #load metadata
        MetaData = erd.Read_Meta(Path(metaPaths[sid]))
        nElements = int(MetaData['fileSizeBytes']) / 2
        nChannelsTot = int(MetaData['nSavedChans'])

        #create memmap to raw data, for that session
        Data = np.memmap(DataPaths[sid], dtype = 'int16', shape =(int(nElements / nChannelsTot), nChannelsTot))

        # Remove spike which won't have a full wavefunction recorded
        SpikeIdsTmp = np.delete(SpikeIds[sid], np.logical_or( (SpikeTimes[sid] < HalfWidth), ( SpikeTimes[sid] > (Data.shape[0] - HalfWidth))))
        SpikeTimesTmp = np.delete(SpikeTimes[sid], np.logical_or( (SpikeTimes[sid] < HalfWidth), ( SpikeTimes[sid] > (Data.shape[0] - HalfWidth))))


        #might be slow extracting smaple for good units only?
        SampleIdx = erd.get_sample_idx(SpikeTimesTmp, SpikeIdsTmp, SampleAmount, units = GoodUnits[sid])

        if KS4data:
            AvgWaveforms = Parallel(n_jobs = -1, verbose = 10, mmap_mode='r', max_nbytes=None )(delayed(erd.Extract_A_UnitKS4)(SampleIdx[uid], Data, SamplesBefore, SamplesAfter, SpikeWidth, nChannels, SampleAmount)for uid in range(GoodUnits[sid].shape[0]))
            AvgWaveforms = np.asarray(AvgWaveforms)           
        else:
            AvgWaveforms = Parallel(n_jobs = -1, verbose = 10, mmap_mode='r', max_nbytes=None )(delayed(erd.Extract_A_Unit)(SampleIdx[uid], Data, HalfWidth, SpikeWidth, nChannels, SampleAmount)for uid in range(GoodUnits[sid].shape[0]))
            AvgWaveforms = np.asarray(AvgWaveforms)

        #Save in file named 'RawWaveforms' in the KS Directory
        erd.Save_AvgWaveforms(AvgWaveforms, KSdirs[sid], GoodUnits = GoodUnits[sid], ExtractGoodUnitsOnly = ExtractGoodUnitsOnly)

else:
    for sid in range(nSessions):
        #Extracting ALL the Units
        nUnits = len(np.unique(SpikeIds[sid]))
        #load metadata
        MetaData = erd.Read_Meta(Path(metaPaths[sid]))
        nElements = int(MetaData['fileSizeBytes']) / 2
        nChannelsTot = int(MetaData['nSavedChans'])

        #create memmap to raw data, for that session
        Data = np.memmap(DataPaths[sid], dtype = 'int16', shape =(int(nElements / nChannelsTot), nChannelsTot))

        # Remove spike which won't have a full wavefunction recorded
        SpikeIdsTmp = np.delete(SpikeIds[sid], np.logical_or( (SpikeTimes[sid] < HalfWidth), ( SpikeTimes[sid] > (Data.shape[0] - HalfWidth))))
        SpikeTimesTmp = np.delete(SpikeTimes[sid], np.logical_or( (SpikeTimes[sid] < HalfWidth), ( SpikeTimes[sid] > (Data.shape[0] - HalfWidth))))


        SampleIdx = erd.get_sample_idx(SpikeTimesTmp, SpikeIdsTmp, SampleAmount, units= np.unique(SpikeIds[sid]))
        
        if KS4data:
            AvgWaveforms = Parallel(n_jobs = -1, verbose = 10, mmap_mode='r', max_nbytes=None )(delayed(erd.Extract_A_UnitKS4)(SampleIdx[uid], Data, SamplesBefore, SamplesAfter, SpikeWidth, nChannels, SampleAmount)for uid in range(nUnits))
            AvgWaveforms = np.asarray(AvgWaveforms)           
        else:
            AvgWaveforms = Parallel(n_jobs = -1, verbose = 10, mmap_mode='r', max_nbytes=None )(delayed(erd.Extract_A_Unit)(SampleIdx[uid], Data, HalfWidth, SpikeWidth, nChannels, SampleAmount)for uid in range(nUnits))
            AvgWaveforms = np.asarray(AvgWaveforms)

        #Save in file named 'RawWaveforms' in the KS Directory
        erd.Save_AvgWaveforms(AvgWaveforms, KSdirs[sid], GoodUnits = GoodUnits[sid], ExtractGoodUnitsOnly = ExtractGoodUnitsOnly)
del Data

#### Optional: delete the decompressed data

In [None]:
import shutil

#DELETE the decompressed data Directory/Folder ( i.e multiple sessiosn)
shutil.rmtree(DecompDir)