# This notebook assumes that you've imported one or more NWB files into DataJoint 
# It allows you to run spikesorters on those data using the SpikeInterface package

#### Load all of the relevant modules

In [1]:
%env DJ_SUPPORT_FILEPATH_MANAGEMENT=TRUE
%load_ext autoreload
%autoreload 2

import os
data_dir = '/Users/loren/data/nwb_builder_test_data'
os.environ['NWB_DATAJOINT_BASE_DIR'] = data_dir

os.environ['KACHERY_STORAGE_DIR'] = os.path.join(data_dir, 'kachery-storage')
os.environ['SORTING_TEMP_DIR'] = os.path.join(data_dir, 'sort_tmp')

import pynwb
import os

#DataJoint and DataJoint schema
import nwb_datajoint as nd
import datajoint as dj


env: DJ_SUPPORT_FILEPATH_MANAGEMENT=TRUE
Connecting root@localhost:3306


### Set the nwb file name and the name of the probe file to create from DataJoint

In [2]:
nwb_file_name = (nd.common.Session() & {'session_id': 'beans_01'}).fetch1('nwb_file_name')
#probe_file_name = '/Users/loren/data/nwb_builder_test_data/test.prb'

### Set the sort grouping by shank

In [3]:
nd.common.SortGroup().set_group_by_shank(nwb_file_name)


About to delete:
Nothing to delete


### create the spike sorter and parameter lists 

In [5]:
nd.common.SpikeSorter().delete()
nd.common.SpikeSorter().insert_from_spikeinterface()
nd.common.SpikeSorterParameters().insert_from_spikeinterface()

About to delete:
`common_ephys`.`spike_sorter_parameters`: 10 items
`common_ephys`.`spike_sorter`: 10 items


Proceed? [yes, No]:  yes


Committed.


### create a 'franklab_mountainsort' parameter set

In [6]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'default'}).fetch1()
param = p['parameter_dict']
param['adjacency_radius'] = 100
param['curation'] = False
param['num_workers'] = 7
param['verbose'] = True
param['clip_size'] = 30
param['noise_overlap_threshold'] = 0

nd.common.SpikeSorterParameters().insert1({'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz', 'parameter_dict' : param}, skip_duplicates='True')

In [7]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz'}).fetch1()
param = p['parameter_dict']
param

{'detect_sign': -1,
 'adjacency_radius': 100,
 'freq_min': 300,
 'freq_max': 6000,
 'filter': True,
 'whiten': True,
 'curation': False,
 'num_workers': 7,
 'clip_size': 30,
 'detect_threshold': 3,
 'detect_interval': 10,
 'noise_overlap_threshold': 0,
 'verbose': True}

### Create a set of spike sorting parameters for sorting group 4

In [9]:
# create a 60 second test interval for debugging
t = (nd.common.IntervalList() & {'interval_list_name' : '01_s1'}).fetch('valid_times')
a = t[0][0,0]
b = a + 60
import numpy as np
t = np.asarray([[a,b]])
nd.common.IntervalList().insert1({'nwb_file_name' : nwb_file_name, 'interval_list_name' : 'test', 'valid_times' : t}, replace='True')

In [10]:
key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = 4
key['sorter_name'] = 'mountainsort4'
key['parameter_set_name'] = 'franklab_mountainsort_20KHz'
key['interval_list_name'] = '01_s1'
#key['interval_list_name'] = 'test'
nd.common.SpikeSortingParameters().insert1(key, skip_duplicates='True')

### run the sort

In [11]:
nd.common.SpikeSorting().populate()

sample indeces: [     19 1200012]
Sorting {'nwb_file_name': 'beans20190718.nwb', 'sort_group_id': 4, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'interval_list_name': 'test'}...
Using 7 workers.
Using tmpdir: /tmp/tmpmhbf9vjb
Num. workers = 7
Preparing /tmp/tmpmhbf9vjb/timeseries.hdf5...
'end_frame' set to 1199993
Preparing neighborhood sorters (M=32, N=1199993)...
Preparing output...
Done with ms4alg.
Cleaning tmpdir::::: /tmp/tmpmhbf9vjb
mountainsort4 run time 166.65s
looking up times for unit 1: 1354 spikes, sample 1: 117
looking up times for unit 2: 1620 spikes, sample 1: 4352
looking up times for unit 3: 1464 spikes, sample 1: 336
looking up times for unit 4: 1486 spikes, sample 1: 157
looking up times for unit 5: 794 spikes, sample 1: 1460
looking up times for unit 6: 92 spikes, sample 1: 2829
looking up times for unit 7: 941 spikes, sample 1: 3679
looking up times for unit 8: 58 spikes, sample 1: 22821
looking up times for unit 9: 522 spi

Exception ignored in: <function _TemporaryFileCloser.__del__ at 0x7f98bc3c5b00>
Traceback (most recent call last):
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.7/tempfile.py", line 448, in __del__
    self.close()
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.7/tempfile.py", line 444, in close
    unlink(self.name)
FileNotFoundError: [Errno 2] No such file or directory: '/var/folders/00/yfl7qb1x6rxgs725vxlb57_00000gq/T/tmpb249a00z'


In [None]:
nd.common.SpikeSorting()

In [7]:
recording = se.NwbRecordingExtractor(nd.common.Nwbfile.get_abs_path(nwb_file_name), electrical_series_name='e-series')



In [None]:
recording = recording.load_probe_file(probe_file_name)

import spikeinterface.widgets as sw
probe = sw.plot_electrode_geometry(recording)
#p = si.widgets.plot_electrode_geometry(recording)

In [9]:
import numpy as np
s1 = ((nd.common.IntervalList() & {'interval_list_name' : '01_s1'}).fetch('valid_times'))



1091.95

In [11]:
second_shank_electrodes = (nd.common.SortGroup.SortGroupElectrode() & {'nwb_file_name' : nwb_file_name} & {'sort_group_id' : 1}).fetch('electrode_id')
second_shank_electrodes

array([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
       49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])

In [None]:
valid_times = (nd.common.IntervalList() & {'nwb_file_name' : nwb_file_name, 'interval_list_name' : '01_s1'}).fetch('valid_times')
valid_times[0]

In [110]:
a = valid_times[0][0,0]
b = a + 600
new_valid_times =np.asarray([[a,b]])
new_valid_times
#nd.common.IntervalList().insert1({'nwb_file_name' : nwb_file_name, 'interval_list_name' : 'test', 'valid_times' : new_valid_times})
nd.common.IntervalList()

nwb_file_name  the name of the NWB file,interval_list_name  descriptive name of this interval list,valid_times  2D numpy array with start and end times for each interval
beans20190718.nwb,01_s1,=BLOB=
beans20190718.nwb,02_r1,=BLOB=
beans20190718.nwb,03_s2,=BLOB=
beans20190718.nwb,04_r2,=BLOB=
beans20190718.nwb,pos 0 valid times,=BLOB=
beans20190718.nwb,pos 1 valid times,=BLOB=
beans20190718.nwb,pos 2 valid times,=BLOB=
beans20190718.nwb,pos 3 valid times,=BLOB=
beans20190718.nwb,raw data valid times,=BLOB=
beans20190718.nwb,test,=BLOB=


In [32]:
second_shank_recording = se.SubRecordingExtractor(recording, channel_ids=second_shank_electrodes)
second_shank_recording.frame_to_time?

[0;31mSignature:[0m [0msecond_shank_recording[0m[0;34m.[0m[0mframe_to_time[0m[0;34m([0m[0mframe[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
This function converts a user-inputted frame index to a time with units of seconds.

Parameters
----------
frame: float
    The frame to be converted to a time

Returns
-------
time: float
    The corresponding time in seconds
[0;31mFile:[0m      ~/Src/NWB/spikeextractors/spikeextractors/subrecordingextractor.py
[0;31mType:[0m      method


In [15]:
second_shank_sort = si.sorters.run_mountainsort4(recording=second_shank_recording, **param, grouping_property='group', output_folder='/Users/loren/data/nwb_builder_test_data/tmp')

KeyboardInterrupt: 

In [35]:
a = dict()
a[1] = 'test'

In [36]:
a[1]

'test'