In [40]:
import sys
import os
import shutil
import time

import traceback
import signal as sg

from pathlib import Path
import h5py
import json
import pickle

import scipy
import numpy as np
import pandas as pd

import seaborn as sns
sns.set(style='whitegrid', palette='muted')

## added TreeMazeanalyses folder using the following command
## conda develop /home/alexgonzalez/Documents/TreeMazeAnalyses2
import TreeMazeAnalyses2.Utils.robust_stats as rs
import TreeMazeAnalyses2.Pre_Processing.pre_process_functions as pp
import TreeMazeAnalyses2.Sorting.sort_functions as sf

import spikeextractors as se
import spikesorters as ss

sns.set(style= 'whitegrid', palette= 'muted')

import signal as sg

from importlib import reload

class timeout:
    def __init__(self, seconds=1, error_message='Timeout'):
        self.seconds = seconds
        self.error_message = error_message
    def handle_timeout(self, signum, frame):
        raise TimeoutError(self.error_message)
    def __enter__(self):
        sg.signal(sg.SIGALRM, self.handle_timeout)
        sg.alarm(self.seconds)
    def __exit__(self, type, value, traceback):
        sg.alarm(0)


In [14]:
print(os.path.dirname(os.path.abspath('__file__')))
print(os.getcwd())


/home/alexgonzalez/Documents
/home/alexgonzalez/Documents


In [15]:
subject_id = 'Li'
sorter = 'KS2'
data_folder = Path('/Data_SSD2T/Data/PreProcessed/', subject_id)

task_table_file = data_folder / 'TasksDir' / ('sort_{}_{}.json'.format(subject_id,sorter))
with task_table_file.open(mode='r') as f:
    task_table = json.load(f)
    

In [17]:
task_table['1'][]
    

{'session_name': 'Li_T3g_052818',
 'session_path': '/Data_SSD2T/Data/PreProcessed/Li/Li_T3g_052818',
 'files': {'1': {'session': 'Li_T3g_052818',
   'task_type': 'KS2',
   'file_path': '/Data_SSD2T/Data/PreProcessed/Li/Li_T3g_052818/tt_2.npy',
   'file_header_path': '/Data_SSD2T/Data/PreProcessed/Li/Li_T3g_052818/tt_2_info.pickle',
   'tt_id': '2',
   'save_path': '/Data_SSD2T/Data/Sorted/Li/Li_T3g_052818/tt_2'},
  '2': {'session': 'Li_T3g_052818',
   'task_type': 'KS2',
   'file_path': '/Data_SSD2T/Data/PreProcessed/Li/Li_T3g_052818/tt_5.npy',
   'file_header_path': '/Data_SSD2T/Data/PreProcessed/Li/Li_T3g_052818/tt_5_info.pickle',
   'tt_id': '5',
   'save_path': '/Data_SSD2T/Data/Sorted/Li/Li_T3g_052818/tt_5'},
  '3': {'session': 'Li_T3g_052818',
   'task_type': 'KS2',
   'file_path': '/Data_SSD2T/Data/PreProcessed/Li/Li_T3g_052818/tt_6.npy',
   'file_header_path': '/Data_SSD2T/Data/PreProcessed/Li/Li_T3g_052818/tt_6_info.pickle',
   'tt_id': '6',
   'save_path': '/Data_SSD2T/Data/S

In [18]:
task_num = 1
task_num_str=str(task_num)
tasks_info = task_table[task_num_str]
session_name = tasks_info['session_name']
n_files = tasks_info['n_files']
task_list = tasks_info['files']
print("Processing Session {}".format(session_name))

Processing Session Li_T3g_052818


In [23]:
# load task data
data = np.load(task['file_path'])
with open(task['file_header_path'],'rb') as f:
    data_info = pickle.load(f)

In [35]:
print('Data Information Fields:')
print(list(data_info.keys()))


Data Information Fields:
['data_dir', 'session', 'fs', 'tt_num', 'n_chans', 'chan_files', 'a_ds', 'ref_chan', 'chan_ids', 'input_range', 'tt_geom', 'bad_chan_thr', 'n_samps', 'tB', 'tE', 'Raw', 'chan_code', 'bad_chans']


In [22]:
# filter data to correct range and save
SOS,_ = pp.get_sos_filter_bank(['Sp'],fs=info['fs'])
spk_data = np.zeros_like(data)

assert info['n_chans']==spk_data.shape[0], "Inconsistent formating in the data files. Aborting."

t0 = time.time()
for ch in range(info['n_chans']):
    spk_data[ch] = scipy.signal.sosfiltfilt(SOS, data[ch])
    print('',end='.')
t1 = time.time()

print('\nTime to spk filter data {0:0.2f}s'.format(t1-t0))

....
Time to spk filter data 4.16s


In [36]:
pp = reload(pp)
chan_masks = pp.create_chan_masks(data_info['Raw']['ClippedSegs'],data_info['n_samps'])
chan_mad = pp.get_signals_mad(spk_data,chan_masks)
print('Channels Median Absolute Deviation in the Spike Frequency Range:')
print(chan_mad)

Channels Median Absolute Deviation in the Spike Frequency Range:
[7.21095753 2.83736539 2.82375073 2.82144165]


In [42]:
spk_masked_signals = se.NumpyRecordingExtractor(timeseries=spk_data*chan_masks, geom=data_info['tt_geom'], sampling_frequency=data_info['fs'])