## This notebook assumes that you've imported one or more NWB files into DataJoint 
## It allows you to run spikesorters on those data using the SpikeInterface package

#### Load all of the relevant modules

In [1]:
%load_ext autoreload
%autoreload 2

import os
data_dir = '/Users/loren/data/nwb_builder_test_data'
os.environ['NWB_DATAJOINT_BASE_DIR'] = data_dir

os.environ['KACHERY_STORAGE_DIR'] = os.path.join(data_dir, 'kachery-storage')
os.environ['SORTING_TEMP_DIR'] = os.path.join(data_dir, 'sort_tmp')

import numpy as np
import pynwb
import os

#DataJoint and DataJoint schema
import nwb_datajoint as nd
import datajoint as dj
from ndx_franklab_novela import Probe


Connecting root@localhost:3306


In [2]:
import warnings
warnings.simplefilter('ignore')

  and should_run_async(code)


### Set the nwb file name and the name of the probe file to create from DataJoint

In [3]:
nwb_file_name = (nd.common.Session() & {'session_id': 'beans_01'}).fetch1('nwb_file_name')

### Set the sort grouping by shank

In [4]:
nd.common.SortGroup().set_group_by_shank(nwb_file_name)
nd.common.SortGroup()

About to delete:
`common_spikesorting`.`spike_sorting_parameters`: 1 items
`common_spikesorting`.`sort_group__sort_group_electrode`: 192 items
`common_spikesorting`.`sort_group`: 8 items


Proceed? [yes, No]:  yes


Committed.


nwb_file_name  the name of the NWB file,sort_group_id  identifier for a group of electrodes,"sort_reference_electrode_id  the electrode to use for reference. -1: no reference, -2: common median"
beans20190718.nwb,0,-1
beans20190718.nwb,1,-1
beans20190718.nwb,2,-1
beans20190718.nwb,3,-1
beans20190718.nwb,4,-1
beans20190718.nwb,5,-1
beans20190718.nwb,6,-1
beans20190718.nwb,7,-1


Optional: Display all of the electrodes with their sort groups

In [5]:
nd.common.SortGroup.SortGroupElectrode().fetch()

  and should_run_async(code)


array([('beans20190718.nwb', 0, '0',   0),
       ('beans20190718.nwb', 0, '0',   1),
       ('beans20190718.nwb', 0, '0',   3),
       ('beans20190718.nwb', 0, '0',   4),
       ('beans20190718.nwb', 0, '0',   5),
       ('beans20190718.nwb', 0, '0',   7),
       ('beans20190718.nwb', 0, '0',   8),
       ('beans20190718.nwb', 0, '0',   9),
       ('beans20190718.nwb', 0, '0',  11),
       ('beans20190718.nwb', 0, '0',  12),
       ('beans20190718.nwb', 0, '0',  13),
       ('beans20190718.nwb', 0, '0',  15),
       ('beans20190718.nwb', 0, '0',  16),
       ('beans20190718.nwb', 0, '0',  17),
       ('beans20190718.nwb', 0, '0',  19),
       ('beans20190718.nwb', 0, '0',  20),
       ('beans20190718.nwb', 0, '0',  21),
       ('beans20190718.nwb', 0, '0',  23),
       ('beans20190718.nwb', 0, '0',  24),
       ('beans20190718.nwb', 0, '0',  25),
       ('beans20190718.nwb', 0, '0',  27),
       ('beans20190718.nwb', 0, '0',  28),
       ('beans20190718.nwb', 0, '0',  29),
       ('be

### create the spike sorter and parameter lists 

In [6]:
nd.common.SpikeSorter().delete()
nd.common.SpikeSorter().insert_from_spikeinterface()
nd.common.SpikeSorterParameters().insert_from_spikeinterface()

About to delete:
`common_spikesorting`.`spike_sorter_parameters`: 11 items
`common_spikesorting`.`spike_sorter`: 10 items


Proceed? [yes, No]:  yes


Committed.


### create a 'franklab_mountainsort' parameter set

In [7]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'default'}).fetch1()
param = p['parameter_dict']
param['adjacency_radius'] = 100
param['curation'] = False
param['num_workers'] = 7
param['verbose'] = True
param['clip_size'] = 30
param['noise_overlap_threshold'] = 0

nd.common.SpikeSorterParameters().insert1({'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz', 'parameter_dict' : param}, skip_duplicates='True')

Display the new parameter set

In [8]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz'}).fetch1()
p

{'sorter_name': 'mountainsort4',
 'parameter_set_name': 'franklab_mountainsort_20KHz',
 'parameter_dict': {'detect_sign': -1,
  'adjacency_radius': 100,
  'freq_min': 300,
  'freq_max': 6000,
  'filter': True,
  'whiten': True,
  'curation': False,
  'num_workers': 7,
  'clip_size': 30,
  'detect_threshold': 3,
  'detect_interval': 10,
  'noise_overlap_threshold': 0,
  'verbose': True}}

In [9]:
param = p['parameter_dict']
param

{'detect_sign': -1,
 'adjacency_radius': 100,
 'freq_min': 300,
 'freq_max': 6000,
 'filter': True,
 'whiten': True,
 'curation': False,
 'num_workers': 7,
 'clip_size': 30,
 'detect_threshold': 3,
 'detect_interval': 10,
 'noise_overlap_threshold': 0,
 'verbose': True}

### Create a set of spike sorting parameters for sorting group 4

In [10]:
# create two 10 second test intervals for debugging
s1 = (nd.common.IntervalList() & {'interval_list_name' : '01_s1'}).fetch1('valid_times')
a = s1[0][0]
b = a + 10
t = np.asarray([[a,b]])
t = np.vstack((t, np.asarray([[a+120,b+120]])))
nd.common.SortIntervalList().insert1({'nwb_file_name' : nwb_file_name, 'sort_interval_list_name' : 'test', 'sort_intervals' : t}, replace='True')

print(t)

[[1.56348899e+09 1.56348900e+09]
 [1.56348911e+09 1.56348912e+09]]


In [11]:

nd.common.SpikeSortingParameters().delete()

About to delete:
Nothing to delete


In [6]:
sort_group_id = 1
key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = sort_group_id
key['sorter_name'] = 'mountainsort4'
key['parameter_set_name'] = 'franklab_mountainsort_20KHz'
key['interval_list_name'] = '01_s1'
key['sort_interval_list_name'] = 'test'
nd.common.SpikeSortingParameters().insert1(key, skip_duplicates='True')

In [13]:
(nd.common.SortIntervalList()).fetch1()

{'nwb_file_name': 'beans20190718.nwb',
 'sort_interval_list_name': 'test',
 'sort_intervals': array([[1.56348899e+09, 1.56348900e+09],
        [1.56348911e+09, 1.56348912e+09]])}

### run the sort

In [4]:
nd.common.SpikeSorting().populate()

writing new NWB file beans20190718_00000006.nwb
Sorting {'nwb_file_name': 'beans20190718.nwb', 'sort_group_id': 1, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'sort_interval_list_name': 'test', 'analysis_file_name': 'beans20190718_00000006.nwb'}...
Using 7 workers.
Using tmpdir: /tmp/tmpjv6dyvsj
Num. workers = 7
Preparing /tmp/tmpjv6dyvsj/timeseries.hdf5...
'end_frame' set to 199999
Preparing neighborhood sorters (M=24, N=199999)...
Preparing output...
Done with ms4alg.
Cleaning tmpdir::::: /tmp/tmpjv6dyvsj
mountainsort4 run time 15.91s
{'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'parameter_dict': {'detect_sign': -1, 'adjacency_radius': 100, 'freq_min': 300, 'freq_max': 6000, 'filter': True, 'whiten': True, 'curation': False, 'num_workers': 7, 'clip_size': 30, 'detect_threshold': 3, 'detect_interval': 10, 'noise_overlap_threshold': 0, 'verbose': True}}
27
Sorting {'nwb_file_name': 'beans20190718.nwb', 's


### Example: Retrieve the spike trains:

In [7]:
sorting = (nd.common.SpikeSorting & {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}).fetch()
key = {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}
units = (nd.common.SpikeSorting & key).fetch_nwb()[0]['units'].to_dataframe()


In [15]:
units.iloc[2].waveform_mean

array([[ -83.,  -73.,  -96., -100., -100., -101., -101., -116., -106.,
        -150.,  -90., -108., -108.,  -64.,  -39.,  -71.,  -23.,   15.,
           7.,   20.,   61.,   49.,   60.,   89.,   55.,   73.,   63.,
          91.],
       [ -64., -102.,  -83.,  -97.,  -91., -113., -124., -116., -123.,
        -108., -110., -116.,  -71.,  -64.,  -82.,  -12.,   13.,    3.,
          38.,   60.,   54.,   58.,   77.,   72.,   70.,   53.,   72.,
          64.],
       [ -41.,  -30.,  -40.,  -58.,  -62.,  -75., -120., -160., -151.,
        -164., -181., -208., -188., -314., -471., -349., -180., -174.,
        -161.,  -57.,  -53.,  -63.,  -28.,   -5.,    5.,   14.,   36.,
           6.],
       [ -48.,  -70., -107., -108.,  -67., -100., -108., -101., -100.,
         -91.,  -97.,  -84.,  -56.,  -65.,  -67.,  -19.,   10.,   21.,
           1.,   25.,   11.,   41.,   54.,   78.,   68.,   60.,   93.,
          53.],
       [ -74.,  -74.,  -68.,  -69.,  -77.,  -86.,  -99.,  -72.,  -73.,
         -77.

In [15]:
unit_timestamps = []
unit_labels=[]
for index, unit in units.iterrows():
    if np.ndarray.all(np.ravel(unit['sort_interval']) == sort_interval):
        unit_timestamps.extend(unit['spike_times'])
        unit_labels.extend([index]*len(unit['spike_times']))

In [10]:
s = nd.common.get_sorting_extractor(key, sort_interval)

  and should_run_async(code)


  and should_run_async(code)


[1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27]

In [11]:
!conda list | grep pynwb

pynwb                     1.3.3                    pypi_0    pypi
