In [1]:
%load_ext autoreload
%autoreload 
import Extract_raw_data as erd
import numpy as np 
from pathlib import Path
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
import os

This function needs decompressed data, the cell below can decompress and save raw data

Decompress n sessions at a time, n can be 1

In [2]:
#GIVE a list of dirs, for where the raw compressed data is e.g .cbim, .ch and .meta files
RawDataDirPaths = [r'Path\to\rawdata\Session1', r'Path\to\rawdata\Session2']

cbinPaths, chPaths, metaPaths = erd.get_raw_data_paths(RawDataDirPaths)

In [3]:
#Decompress Data
from mtscomp import Reader

#Path to a directory where you want the decompresed data to be saved
#this is a large file, using a ssd is advised for quicker run times

#GIVE a path to a directory where the Decompressed data will be saved
#will make folder called session n for each session
#DecompDataSaveDir = r'path/to/decompressed/data/save/dir'
DecompDataSaveDir = r'c:\Users\Experiment\Data'
DecompDir = os.path.join(DecompDataSaveDir, 'DecompData')
os.mkdir(DecompDir) # Create a folder in the directory called 'DecompData'


DataPaths = []
for i in range(len(RawDataDirPaths)):
    tmpPath = os.path.join(DecompDir, f'Session{i+1}')  #+1 so starts at 1
    os.mkdir(tmpPath) # make a folder for each session called 'SessionX' 
    tmpPath = os.path.join(tmpPath, 'RawData.bin')
    DataPaths.append(tmpPath)

    # create .bin with the decompressed data

    #r = Reader() # do the mtscomp verification
    r = Reader(check_after_decompress = False) #Skip the verification check to save time
    r.open(cbinPaths[i], chPaths[i])
    r.tofile(tmpPath)
    r.close()

Decompressing:   0%|          | 0/164 [00:00<?, ?it/s]

Decompressing: 100%|██████████| 164/164 [02:31<00:00,  1.08it/s]
INFO:mtscomp:Wrote c:\Users\Experiment\Data\DecompData\Session1\RawData.bin (42.2 GB).


Extract average waveforms from decompressed data

In [4]:
#Set Up Parameters
SampleAmount = 1000 # atleast 500 per CV
SpikeWidth = 82 # assuming 30khz sampling, UM standard, covers the AP and space around needed for processing
HalfWidth = np.floor(SpikeWidth/2).astype(int)
nChannels = 384
ExtractGoodUnitsOnly = True 

#List of paths to a KS dir
KSdirs = [r'path/to/KiloSort/Dir/Session1', r'path/to/KiloSort/Dir/Session2']
nSessions = len(KSdirs) #How many session are being extracted
SpikeIds, SpikeTimes, GoodUnits = erd.extract_KSdata(KSdirs, ExtractGoodUnitsOnly = True)

KSdirs[0] = r'c:\Users\Experiment\Data\DecompData'
#give metadata + Raw data paths
#if you are NOT decompressing data here, provide a list of paths to the decompressed data and the metadata

#DataPaths = [r'c:\Users\Experiment\Data\RawData.bin']
#metaPaths = [r'']


In [5]:
#Extract the units 

if ExtractGoodUnitsOnly:
    for sid in range(nSessions):
        #load metadata
        MetaData = erd.Read_Meta(Path(metaPaths[sid]))
        nElements = int(MetaData['fileSizeBytes']) / 2
        nChannelsTot = int(MetaData['nSavedChans'])

        #create memmap to raw data, for that session
        Data = np.memmap(DataPaths[sid], dtype = 'int16', shape =(int(nElements / nChannelsTot), nChannelsTot))

        # Remove spike which won't have a full wavefunction recorded
        SpikeIdsTmp = np.delete(SpikeIds[sid], np.logical_or( (SpikeTimes[sid] < HalfWidth), ( SpikeTimes[sid] > (Data.shape[0] - HalfWidth))))
        SpikeTimesTmp = np.delete(SpikeTimes[sid], np.logical_or( (SpikeTimes[sid] < HalfWidth), ( SpikeTimes[sid] > (Data.shape[0] - HalfWidth))))


        #might be slow extracting smaple for good units only?
        SampleIdx = erd.get_sample_idx(SpikeTimesTmp, SpikeIdsTmp, SampleAmount, units = GoodUnits[sid])

        AvgWaveforms = Parallel(n_jobs = -1, verbose = 10, mmap_mode='r', max_nbytes=None )(delayed(erd.Extract_A_Unit)(SampleIdx[uid], Data, HalfWidth, SpikeWidth, nChannels, SampleAmount)for uid in range(GoodUnits[sid].shape[0]))
        AvgWaveforms = np.asarray(AvgWaveforms)

        #Save in file named 'RawWaveforms' in the KS Directory
        erd.Save_AvgWaveforms(AvgWaveforms, KSdirs[sid], GoodUnits = GoodUnits[sid], ExtractGoodUnitsOnly = ExtractGoodUnitsOnly)

else:
    for sid in range(nSessions):
        #Extracting ALL the Units
        nUnits = len(np.unique(SpikeIds[sid]))
        #load metadata
        MetaData = erd.Read_Meta(Path(metaPaths[sid]))
        nElements = int(MetaData['fileSizeBytes']) / 2
        nChannelsTot = int(MetaData['nSavedChans'])

        #create memmap to raw data, for that session
        Data = np.memmap(DataPaths[sid], dtype = 'int16', shape =(int(nElements / nChannelsTot), nChannelsTot))

        # Remove spike which won't have a full wavefunction recorded
        SpikeIdsTmp = np.delete(SpikeIds[sid], np.logical_or( (SpikeTimes[sid] < HalfWidth), ( SpikeTimes[sid] > (Data.shape[0] - HalfWidth))))
        SpikeTimesTmp = np.delete(SpikeTimes[sid], np.logical_or( (SpikeTimes[sid] < HalfWidth), ( SpikeTimes[sid] > (Data.shape[0] - HalfWidth))))


        SampleIdx = erd.get_sample_idx(SpikeTimesTmp, SpikeIdsTmp, SampleAmount, units= np.unique(SpikeIds[sid]))
        AvgWaveforms = Parallel(n_jobs = -1, verbose = 10, mmap_mode='r', max_nbytes=None )(delayed(erd.Extract_A_Unit)(SampleIdx[uid], Data, HalfWidth, SpikeWidth, nChannels, SampleAmount)for uid in range(nUnits))
        AvgWaveforms = np.asarray(AvgWaveforms)

        #Save in file named 'RawWaveforms' in the KS Directory
        erd.Save_AvgWaveforms(AvgWaveforms, KSdirs[sid])

del Data

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:    4.6s
[Parallel(n_jobs=-1)]: Done   8 tasks      | elapsed:    4.7s
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:    6.6s
[Parallel(n_jobs=-1)]: Done  26 tasks      | elapsed:    8.5s
[Parallel(n_jobs=-1)]: Done  37 tasks      | elapsed:   10.3s
[Parallel(n_jobs=-1)]: Done  48 tasks      | elapsed:   11.0s
[Parallel(n_jobs=-1)]: Done  61 tasks      | elapsed:   13.5s
[Parallel(n_jobs=-1)]: Done  74 tasks      | elapsed:   15.3s
[Parallel(n_jobs=-1)]: Done  89 tasks      | elapsed:   18.0s
[Parallel(n_jobs=-1)]: Done 104 tasks      | elapsed:   20.0s
[Parallel(n_jobs=-1)]: Done 121 tasks      | elapsed:   22.6s
[Parallel(n_jobs=-1)]: Done 138 tasks      | elapsed:   25.5s
[Parallel(n_jobs=-1)]: Done 157 tasks      | elapsed:   28.4s
[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:   31.5s
[Parallel(n_jobs=-1)]: Done 197 tasks      | elapsed:  

The cells below will delete the decomressed data, can delete in file explorer as well

In [6]:
import shutil

#DELETE the decompressed data Directory/Folder ( i.e multiple sessiosn)
shutil.rmtree(DecompDir)