# This notebook assumes that you've imported one or more NWB files into DataJoint 
# It allows you to run spikesorters on those data using the SpikeInterface package

#### Load all of the relevant modules

In [5]:
import warnings
warnings.simplefilter('ignore')

In [6]:
%env DJ_SUPPORT_FILEPATH_MANAGEMENT=TRUE
%load_ext autoreload
%autoreload 2

import os
data_dir = '/Users/loren/data/nwb_builder_test_data'
os.environ['NWB_DATAJOINT_BASE_DIR'] = data_dir

os.environ['KACHERY_STORAGE_DIR'] = os.path.join(data_dir, 'kachery-storage')
os.environ['SORTING_TEMP_DIR'] = os.path.join(data_dir, 'sort_tmp')

import numpy as np
import pynwb
import os

#DataJoint and DataJoint schema
import nwb_datajoint as nd
import datajoint as dj


env: DJ_SUPPORT_FILEPATH_MANAGEMENT=TRUE
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Set the nwb file name and the name of the probe file to create from DataJoint

In [7]:
nwb_file_name = (nd.common.Session() & {'session_id': 'beans_01'}).fetch1('nwb_file_name')

### Set the sort grouping by shank

In [8]:
nd.common.SortGroup().set_group_by_shank(nwb_file_name)
nd.common.SortGroup()

About to delete:
`common_ephys`.`sort_group__sort_group_electrode`: 192 items
`common_ephys`.`__spike_sorting`: 1 items
`common_ephys`.`spike_sorting_parameters`: 1 items
`common_ephys`.`sort_group`: 8 items


Proceed? [yes, No]:  yes


Committed.


nwb_file_name  the name of the NWB file,sort_group_id  identifier for a group of electrodes,"sort_reference_electrode_id  the electrode to use for reference. -1: no reference, -2: common median"
beans20190718.nwb,0,-1
beans20190718.nwb,1,-1
beans20190718.nwb,2,-1
beans20190718.nwb,3,-1
beans20190718.nwb,4,-1
beans20190718.nwb,5,-1
beans20190718.nwb,6,-1
beans20190718.nwb,7,-1


Optional: Display all of the electrodes with their sort groups

In [5]:
nd.common.SortGroup.SortGroupElectrode().fetch()

  and should_run_async(code)


array([('beans20190718.nwb', 0, '0',   0),
       ('beans20190718.nwb', 0, '0',   1),
       ('beans20190718.nwb', 0, '0',   3),
       ('beans20190718.nwb', 0, '0',   4),
       ('beans20190718.nwb', 0, '0',   5),
       ('beans20190718.nwb', 0, '0',   7),
       ('beans20190718.nwb', 0, '0',   8),
       ('beans20190718.nwb', 0, '0',   9),
       ('beans20190718.nwb', 0, '0',  11),
       ('beans20190718.nwb', 0, '0',  12),
       ('beans20190718.nwb', 0, '0',  13),
       ('beans20190718.nwb', 0, '0',  15),
       ('beans20190718.nwb', 0, '0',  16),
       ('beans20190718.nwb', 0, '0',  17),
       ('beans20190718.nwb', 0, '0',  19),
       ('beans20190718.nwb', 0, '0',  20),
       ('beans20190718.nwb', 0, '0',  21),
       ('beans20190718.nwb', 0, '0',  23),
       ('beans20190718.nwb', 0, '0',  24),
       ('beans20190718.nwb', 0, '0',  25),
       ('beans20190718.nwb', 0, '0',  27),
       ('beans20190718.nwb', 0, '0',  28),
       ('beans20190718.nwb', 0, '0',  29),
       ('be

### create the spike sorter and parameter lists 

In [9]:
nd.common.SpikeSorter().delete()
nd.common.SpikeSorter().insert_from_spikeinterface()
nd.common.SpikeSorterParameters().insert_from_spikeinterface()

About to delete:
`common_ephys`.`spike_sorter_parameters`: 11 items
`common_ephys`.`spike_sorter`: 10 items


Proceed? [yes, No]:  yes


Committed.


### create a 'franklab_mountainsort' parameter set

In [10]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'default'}).fetch1()
param = p['parameter_dict']
param['adjacency_radius'] = 100
param['curation'] = False
param['num_workers'] = 7
param['verbose'] = True
param['clip_size'] = 30
param['noise_overlap_threshold'] = 0

nd.common.SpikeSorterParameters().insert1({'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz', 'parameter_dict' : param}, skip_duplicates='True')

Display the new parameter set

In [13]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz'}).fetch1()
p

{'sorter_name': 'mountainsort4',
 'parameter_set_name': 'franklab_mountainsort_20KHz',
 'parameter_dict': {'detect_sign': -1,
  'adjacency_radius': 100,
  'freq_min': 300,
  'freq_max': 6000,
  'filter': True,
  'whiten': True,
  'curation': False,
  'num_workers': 7,
  'clip_size': 30,
  'detect_threshold': 3,
  'detect_interval': 10,
  'noise_overlap_threshold': 0,
  'verbose': True}}

In [14]:
param = p['parameter_dict']
param

{'detect_sign': -1,
 'adjacency_radius': 100,
 'freq_min': 300,
 'freq_max': 6000,
 'filter': True,
 'whiten': True,
 'curation': False,
 'num_workers': 7,
 'clip_size': 30,
 'detect_threshold': 3,
 'detect_interval': 10,
 'noise_overlap_threshold': 0,
 'verbose': True}

### Create a set of spike sorting parameters for sorting group 4

In [15]:
# create two 60 second test intervals for debugging
s1 = (nd.common.IntervalList() & {'interval_list_name' : '01_s1'}).fetch1('valid_times')
a = s1[0][0]
b = a + 60
t = np.asarray([[a,b]])
t = np.vstack((t, np.asarray([[a+120,b+120]])))
nd.common.SortIntervalList().insert1({'nwb_file_name' : nwb_file_name, 'sort_interval_list_name' : 'test', 'sort_intervals' : t}, replace='True')

print(t)

[[1.56348899e+09 1.56348905e+09]
 [1.56348911e+09 1.56348917e+09]]


In [16]:
nd.common.SortIntervalList()

nwb_file_name  the name of the NWB file,sort_interval_list_name  descriptive name of this interval list,sort_intervals  2D numpy array with start and end times for each interval to be used for spike sorting
beans20190718.nwb,test,=BLOB=


In [17]:
sort_group_id = 4
key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = sort_group_id
key['sorter_name'] = 'mountainsort4'
key['parameter_set_name'] = 'franklab_mountainsort_20KHz'
key['interval_list_name'] = '01_s1'
key['sort_interval_list_name'] = 'test'
nd.common.SpikeSortingParameters().insert1(key, skip_duplicates='True')

In [18]:
(nd.common.SortIntervalList()).fetch1()

{'nwb_file_name': 'beans20190718.nwb',
 'sort_interval_list_name': 'test',
 'sort_intervals': array([[1.56348899e+09, 1.56348905e+09],
        [1.56348911e+09, 1.56348917e+09]])}

### run the sort

In [19]:
nd.common.SpikeSorting().populate()

writing new NWB file beans20190718_00000002.nwb
Sorting {'nwb_file_name': 'beans20190718.nwb', 'sort_group_id': 4, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'sort_interval_list_name': 'test', 'analysis_file_name': 'beans20190718_00000002.nwb'}...
Using 7 workers.
Using tmpdir: /tmp/tmpbtdity6m
Num. workers = 7
Preparing /tmp/tmpbtdity6m/timeseries.hdf5...
'end_frame' set to 1199993
Preparing neighborhood sorters (M=24, N=1199993)...
Preparing output...
Done with ms4alg.
Cleaning tmpdir::::: /tmp/tmpbtdity6m
mountainsort4 run time 98.85s
Sorting {'nwb_file_name': 'beans20190718.nwb', 'sort_group_id': 4, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'sort_interval_list_name': 'test', 'analysis_file_name': 'beans20190718_00000002.nwb'}...
Using 7 workers.
Using tmpdir: /tmp/tmpxudge30m
Num. workers = 7
Preparing /tmp/tmpxudge30m/timeseries.hdf5...
'end_frame' set to 1199992
Preparing neighborhood sorters (M=

### Example: Retrieve the spike trains:

In [11]:
sorting = (nd.common.SpikeSorting & {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}).fetch_nwb()
sorting[0]['units'].to_dataframe()

Unnamed: 0_level_0,spike_times
id,Unnamed: 1_level_1
1,"[1563488988.8826828, 1563488988.9594333, 15634..."
2,"[1563488989.0944343, 1563488989.1491346, 15634..."
3,"[1563488988.8936327, 1563488989.0692837, 15634..."
4,"[1563488988.884683, 1563488989.0539339, 156348..."
5,"[1563488988.9498332, 1563488988.9755335, 15634..."
...,...
59,"[1563488988.909683, 1563488988.9649334, 156348..."
60,"[1563488988.9439833, 1563488988.9715836, 15634..."
61,"[1563488989.3086355, 1563488993.8108146, 15634..."
62,"[1563488988.891583, 1563488988.9447331, 156348..."


In [17]:
sorting = (nd.common.SpikeSorting & {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}).fetch_nwb('units_object_id')
sorting

[{'units_object_id': '10172968-4314-41e8-9eda-9c1f27f86b87',
  'units': units pynwb.misc.Units at 0x140546828303696
  Fields:
    colnames: ['spike_times']
    columns: (
      spike_times_index <class 'hdmf.common.table.VectorIndex'>,
      spike_times <class 'hdmf.common.table.VectorData'>
    )
    description: Autogenerated by NWBFile
    id: id <class 'hdmf.common.table.ElementIdentifiers'>
    waveform_unit: volts}]