# This notebook assumes that you've imported one or more NWB files into DataJoint 
# It allows you to run spikesorters on those data using the SpikeInterface package

#### Load all of the relevant modules

In [6]:
import warnings
warnings.simplefilter('ignore')

In [7]:
%env DJ_SUPPORT_FILEPATH_MANAGEMENT=TRUE
%load_ext autoreload
%autoreload 2

import os
data_dir = '/Users/loren/data/nwb_builder_test_data'
os.environ['NWB_DATAJOINT_BASE_DIR'] = data_dir

os.environ['KACHERY_STORAGE_DIR'] = os.path.join(data_dir, 'kachery-storage')
os.environ['SORTING_TEMP_DIR'] = os.path.join(data_dir, 'sort_tmp')

import numpy as np
import pynwb
import os

#DataJoint and DataJoint schema
import nwb_datajoint as nd
import datajoint as dj


env: DJ_SUPPORT_FILEPATH_MANAGEMENT=TRUE
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Set the nwb file name and the name of the probe file to create from DataJoint

In [8]:
nwb_file_name = (nd.common.Session() & {'session_id': 'beans_01'}).fetch1('nwb_file_name')

### Set the sort grouping by shank

In [3]:
nd.common.SortGroup().set_group_by_shank(nwb_file_name)
nd.common.SortGroup()

  and should_run_async(code)


About to delete:
Nothing to delete


nwb_file_name  the name of the NWB file,sort_group_id  identifier for a group of electrodes,"sort_reference_electrode_id  the electrode to use for reference. -1: no reference, -2: common median"
beans20190718.nwb,0,-1
beans20190718.nwb,1,-1
beans20190718.nwb,2,-1
beans20190718.nwb,3,-1
beans20190718.nwb,4,-1
beans20190718.nwb,5,-1
beans20190718.nwb,6,-1
beans20190718.nwb,7,-1


Optional: Display all of the electrodes with their sort groups

In [5]:
nd.common.SortGroup.SortGroupElectrode().fetch()

  and should_run_async(code)


array([('beans20190718.nwb', 0, '0',   0),
       ('beans20190718.nwb', 0, '0',   1),
       ('beans20190718.nwb', 0, '0',   3),
       ('beans20190718.nwb', 0, '0',   4),
       ('beans20190718.nwb', 0, '0',   5),
       ('beans20190718.nwb', 0, '0',   7),
       ('beans20190718.nwb', 0, '0',   8),
       ('beans20190718.nwb', 0, '0',   9),
       ('beans20190718.nwb', 0, '0',  11),
       ('beans20190718.nwb', 0, '0',  12),
       ('beans20190718.nwb', 0, '0',  13),
       ('beans20190718.nwb', 0, '0',  15),
       ('beans20190718.nwb', 0, '0',  16),
       ('beans20190718.nwb', 0, '0',  17),
       ('beans20190718.nwb', 0, '0',  19),
       ('beans20190718.nwb', 0, '0',  20),
       ('beans20190718.nwb', 0, '0',  21),
       ('beans20190718.nwb', 0, '0',  23),
       ('beans20190718.nwb', 0, '0',  24),
       ('beans20190718.nwb', 0, '0',  25),
       ('beans20190718.nwb', 0, '0',  27),
       ('beans20190718.nwb', 0, '0',  28),
       ('beans20190718.nwb', 0, '0',  29),
       ('be

### create the spike sorter and parameter lists 

In [4]:
nd.common.SpikeSorter().delete()
nd.common.SpikeSorter().insert_from_spikeinterface()
nd.common.SpikeSorterParameters().insert_from_spikeinterface()

  and should_run_async(code)


About to delete:
`common_ephys`.`spike_sorter_parameters`: 11 items
`common_ephys`.`spike_sorter`: 10 items


Proceed? [yes, No]:  yes


Committed.


  if isinstance(obj, collections.ByteString):
  if isinstance(obj, collections.MutableSequence):
  if isinstance(obj, collections.Sequence):
  if isinstance(obj, collections.Set):


### create a 'franklab_mountainsort' parameter set

In [7]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'default'}).fetch1()
param = p['parameter_dict']
param['adjacency_radius'] = 100
param['curation'] = False
param['num_workers'] = 7
param['verbose'] = True
param['clip_size'] = 30
param['noise_overlap_threshold'] = 0

nd.common.SpikeSorterParameters().insert1({'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz', 'parameter_dict' : param}, skip_duplicates='True')

  and should_run_async(code)


In [8]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz'}).fetch1()
p

  and should_run_async(code)


{'sorter_name': 'mountainsort4',
 'parameter_set_name': 'franklab_mountainsort_20KHz',
 'parameter_dict': {'detect_sign': -1,
  'adjacency_radius': 100,
  'freq_min': 300,
  'freq_max': 6000,
  'filter': True,
  'whiten': True,
  'curation': False,
  'num_workers': 7,
  'clip_size': 30,
  'detect_threshold': 3,
  'detect_interval': 10,
  'noise_overlap_threshold': 0,
  'verbose': True}}

In [9]:
param = p['parameter_dict']
param

  and should_run_async(code)


{'detect_sign': -1,
 'adjacency_radius': 100,
 'freq_min': 300,
 'freq_max': 6000,
 'filter': True,
 'whiten': True,
 'curation': False,
 'num_workers': 7,
 'clip_size': 30,
 'detect_threshold': 3,
 'detect_interval': 10,
 'noise_overlap_threshold': 0,
 'verbose': True}

### Create a set of spike sorting parameters for sorting group 4

In [10]:
# create two 60 second test intervals for debugging
s1 = (nd.common.IntervalList() & {'interval_list_name' : '01_s1'}).fetch1('valid_times')
a = s1[0][0]
b = a + 60
t = np.asarray([[a,b]])
t = np.vstack((t, np.asarray([[a+120,b+120]])))
nd.common.SortIntervalList().insert1({'nwb_file_name' : nwb_file_name, 'sort_interval_list_name' : 'test', 'sort_intervals' : t}, replace='True')

print(t)

[[1.56348899e+09 1.56348905e+09]
 [1.56348911e+09 1.56348917e+09]]


In [11]:
nd.common.SortIntervalList()

  and should_run_async(code)


nwb_file_name  the name of the NWB file,sort_interval_list_name  descriptive name of this interval list,sort_intervals  2D numpy array with start and end times for each interval to be used for spike sorting
beans20190718.nwb,test,=BLOB=


In [11]:
sort_group_id = 4
key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = sort_group_id
key['sorter_name'] = 'mountainsort4'
key['parameter_set_name'] = 'franklab_mountainsort_20KHz'
key['interval_list_name'] = '01_s1'
key['sort_interval_list_name'] = 'test'
nd.common.SpikeSortingParameters().insert1(key, skip_duplicates='True')

In [13]:
(nd.common.SortIntervalList()).fetch1()

  and should_run_async(code)


{'nwb_file_name': 'beans20190718.nwb',
 'sort_interval_list_name': 'test',
 'sort_intervals': array([[1.56348899e+09, 1.56348905e+09],
        [1.56348911e+09, 1.56348917e+09]])}

### run the sort

In [10]:
nd.common.SpikeSorting().delete()

About to delete:
Nothing to delete


  and should_run_async(code)


In [9]:
nd.common.SpikeSorting().populate()

writing new NWB file beans20190718_00000000.nwb
sort_intervals: [[1.56348899e+09 1.56348905e+09]
 [1.56348911e+09 1.56348917e+09]]
Sorting {'nwb_file_name': 'beans20190718.nwb', 'sort_group_id': 4, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'sort_interval_list_name': 'test', 'analysis_file_name': 'beans20190718_00000000.nwb'}...
Using 7 workers.
Using tmpdir: /tmp/tmpoa63jf0s
Num. workers = 7
Preparing /tmp/tmpoa63jf0s/timeseries.hdf5...
'end_frame' set to 1199993
Preparing neighborhood sorters (M=24, N=1199993)...
Preparing output...
Done with ms4alg.
Cleaning tmpdir::::: /tmp/tmpoa63jf0s
mountainsort4 run time 81.73s
Sorting {'nwb_file_name': 'beans20190718.nwb', 'sort_group_id': 4, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'sort_interval_list_name': 'test', 'analysis_file_name': 'beans20190718_00000000.nwb'}...
Using 7 workers.
Using tmpdir: /tmp/tmpevpkdoi3
Num. workers = 7
Preparing /tmp/tmpevpkdo

### Example: Retrieve the spike trains:

In [12]:
sorting = (nd.common.SpikeSorting & {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}).fetch_nwb()
sorting[0]['units'].to_dataframe()

Unnamed: 0_level_0,spike_times,obs_intervals,sort_interval
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,"[1563488988.9037828, 1563488988.9509833, 15634...","[[1563488988.8777807, 1563489048.8777807]]","[[1563488988.8777807, 1563489048.8777807]]"
2,"[1563488988.877833, 1563488988.9309332, 156348...","[[1563488988.8777807, 1563489048.8777807]]","[[1563488988.8777807, 1563489048.8777807]]"
3,"[1563488988.9510334, 1563488988.9952335, 15634...","[[1563488988.8777807, 1563489048.8777807]]","[[1563488988.8777807, 1563489048.8777807]]"
4,"[1563488988.893883, 1563488988.9036329, 156348...","[[1563488988.8777807, 1563489048.8777807]]","[[1563488988.8777807, 1563489048.8777807]]"
5,"[1563488988.8824327, 1563488988.9414332, 15634...","[[1563488988.8777807, 1563489048.8777807]]","[[1563488988.8777807, 1563489048.8777807]]"
6,"[1563488988.896533, 1563488988.9461331, 156348...","[[1563488988.8777807, 1563489048.8777807]]","[[1563488988.8777807, 1563489048.8777807]]"
7,"[1563488988.8768828, 1563488988.8872328, 15634...","[[1563488988.8777807, 1563489048.8777807]]","[[1563488988.8777807, 1563489048.8777807]]"
8,"[1563488989.051534, 1563488989.0542839, 156348...","[[1563488988.8777807, 1563489048.8777807]]","[[1563488988.8777807, 1563489048.8777807]]"
9,"[1563488988.8849828, 1563488988.932483, 156348...","[[1563488988.8777807, 1563489048.8777807]]","[[1563488988.8777807, 1563489048.8777807]]"
10,"[1563488988.896183, 1563488988.905333, 1563488...","[[1563488988.8777807, 1563489048.8777807]]","[[1563488988.8777807, 1563489048.8777807]]"


In [17]:
sorting = (nd.common.SpikeSorting & {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}).fetch_nwb('units_object_id')
sorting

[{'units_object_id': '10172968-4314-41e8-9eda-9c1f27f86b87',
  'units': units pynwb.misc.Units at 0x140546828303696
  Fields:
    colnames: ['spike_times']
    columns: (
      spike_times_index <class 'hdmf.common.table.VectorIndex'>,
      spike_times <class 'hdmf.common.table.VectorData'>
    )
    description: Autogenerated by NWBFile
    id: id <class 'hdmf.common.table.ElementIdentifiers'>
    waveform_unit: volts}]

[{'nwb_file_name': 'beans20190718.nwb',
  'sort_group_id': 4,
  'sorter_name': 'mountainsort4',
  'parameter_set_name': 'franklab_mountainsort_20KHz',
  'sort_interval_list_name': 'test',
  'analysis_file_name': 'beans20190718_00000000.nwb',
  'units_object_id': '10172968-4314-41e8-9eda-9c1f27f86b87',
  'units': units pynwb.misc.Units at 0x140551060536656
  Fields:
    colnames: ['spike_times']
    columns: (
      spike_times_index <class 'hdmf.common.table.VectorIndex'>,
      spike_times <class 'hdmf.common.table.VectorData'>
    )
    description: Autogenerated by NWBFile
    id: id <class 'hdmf.common.table.ElementIdentifiers'>
    waveform_unit: volts}]

In [7]:
nwbf

NameError: name 'nwbf' is not defined