## This notebook assumes that you've imported one or more NWB files into DataJoint 
## It allows you to run spikesorters on those data using the SpikeInterface package

#### Load all of the relevant modules

In [1]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import numpy as np
import pynwb
import os

data_dir = Path('/Users/loren/data/nwb_builder_test_data') # CHANGE ME TO THE BASE DIRECTORY FOR DATA STORAGE ON YOUR SYSTEM

os.environ['NWB_DATAJOINT_BASE_DIR'] = str(data_dir)
os.environ['KACHERY_STORAGE_DIR'] = str(data_dir / 'kachery-storage')
os.environ['SPIKE_SORTING_STORAGE_DIR'] = str(data_dir / 'spikesorting')

# DataJoint and DataJoint schema

import datajoint as dj

dj.config['database.host'] = 'localhost'
dj.config['database.user'] = 'root'
dj.config['database.password'] = 'tutorial'

import nwb_datajoint as nd
from ndx_franklab_novela import Probe

import spiketoolkit as st

import warnings
warnings.simplefilter('ignore')

Connecting root@localhost:3306


### Set the nwb file name and the name of the probe file to create from DataJoint

In [2]:
nd.common.Session()

nwb_file_name  the name of the NWB file,subject_id,institution_name,lab_name,session_id,session_description,session_start_time,timestamps_reference_time,experiment_description
beans20190718_.nwb,Beans,"University of California, San Francisco",Loren Frank,beans_01,Reinforcement leaarning,2019-07-18 15:29:47,1970-01-01 00:00:00,Reinforcement learning


In [3]:
#nwb_file_name = (nd.common.Session() & {'session_id': 'beans_01'}).fetch1('nwb_file_name')
nwb_file_name = 'beans20190718_.nwb'

### Set the sort grouping by shank

In [6]:
nd.common.SortGroup().set_group_by_shank(nwb_file_name)
nd.common.SortGroup()

About to delete:
`common_spikesorting`.`sort_group__sort_group_electrode`: 192 items
`common_spikesorting`.`sort_group`: 8 items


Proceed? [yes, No]:  yes


Committed.


nwb_file_name  the name of the NWB file,sort_group_id  identifier for a group of electrodes,"sort_reference_electrode_id  the electrode to use for reference. -1: no reference, -2: common median"
beans20190718_.nwb,0,-1
beans20190718_.nwb,1,-1
beans20190718_.nwb,2,-1
beans20190718_.nwb,3,-1
beans20190718_.nwb,4,-1
beans20190718_.nwb,5,-1
beans20190718_.nwb,6,-1
beans20190718_.nwb,7,-1


Optional: Display all of the electrodes with their sort groups

In [None]:
nd.common.SortGroup.SortGroupElectrode()

### create the spike sorter and parameter lists 

In [None]:
nd.common.SpikeSorter().insert_from_spikeinterface()
nd.common.SpikeSorterParameters().insert_from_spikeinterface()

In [None]:
nd.common.SpikeSorterParameters()

### create a 'franklab_mountainsort' parameter set

In [None]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'default'}).fetch1()
param = p['parameter_dict']
param['adjacency_radius'] = 100
param['curation'] = False
param['num_workers'] = 7
param['verbose'] = True
param['clip_size'] = 30
param['noise_overlap_threshold'] = 0

nd.common.SpikeSorterParameters().insert1({'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz', 'parameter_dict' : param}, skip_duplicates='True')

Display the new parameter set

In [8]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz'}).fetch1()
p

{'sorter_name': 'mountainsort4',
 'parameter_set_name': 'franklab_mountainsort_20KHz',
 'parameter_dict': {'detect_sign': -1,
  'adjacency_radius': 100,
  'freq_min': 300,
  'freq_max': 6000,
  'filter': True,
  'whiten': True,
  'curation': False,
  'num_workers': 7,
  'clip_size': 30,
  'detect_threshold': 3,
  'detect_interval': 10,
  'noise_overlap_threshold': 0,
  'verbose': True}}

In [22]:
param = p['parameter_dict']
param

{'detect_sign': -1,
 'adjacency_radius': 100,
 'freq_min': 300,
 'freq_max': 6000,
 'filter': True,
 'whiten': True,
 'curation': False,
 'num_workers': 7,
 'clip_size': 30,
 'detect_threshold': 3,
 'detect_interval': 10,
 'noise_overlap_threshold': 0,
 'verbose': True}

### Create a set of spike sorting parameters for sorting group 1

In [9]:
nd.common.IntervalList()

nwb_file_name  the name of the NWB file,interval_list_name  descriptive name of this interval list,valid_times  numpy array with start and end times for each interval
beans20190718_.nwb,01_s1,=BLOB=
beans20190718_.nwb,02_r1,=BLOB=
beans20190718_.nwb,03_s2,=BLOB=
beans20190718_.nwb,04_r2,=BLOB=
beans20190718_.nwb,pos 0 valid times,=BLOB=
beans20190718_.nwb,pos 1 valid times,=BLOB=
beans20190718_.nwb,pos 2 valid times,=BLOB=
beans20190718_.nwb,pos 3 valid times,=BLOB=
beans20190718_.nwb,raw data valid times,=BLOB=


In [7]:
# create a 10 second test intervals for debugging
s1 = (nd.common.IntervalList() & {'interval_list_name' : '01_s1'}).fetch1('valid_times')
print(s1)


a = np.asarray([s1[0][0], s1[0][0]+240])
#t = np.vstack((t, np.asarray([[a+120,b+120]])))
nd.common.SortInterval().insert1({'nwb_file_name' : nwb_file_name, 'sort_interval_name' : 'test', 'sort_interval' : a}, replace='True')

print(a)

[[1.56348899e+09 1.56349008e+09]]
[1.56348899e+09 1.56348923e+09]


In [8]:
# create the sorting waveform parameters table
n_noise_waveforms = 1000 # the number of random noise waveforms to save
waveform_param_dict = st.postprocessing.get_waveforms_params()
waveform_param_dict['grouping_property'] = 'group'
# set the window to half of the clip size before and half after
waveform_param_dict['ms_before'] = .75
waveform_param_dict['ms_after'] = .75
waveform_param_dict['dtype'] = 'i2'
waveform_param_dict['verbose'] = False
waveform_param_dict['max_spikes_per_unit'] = 1000
nd.common.SpikeSortingWaveformParameters.insert1({'waveform_parameters_name' : 'franklab default', 'n_noise_waveforms' : n_noise_waveforms, 
                                                   'waveform_parameter_dict' : waveform_param_dict}, replace='True')

### create a list of metrics to be computed 

In [9]:
metric_dict = nd.common.SpikeSortingMetrics().get_metric_dict()
metric_dict

{'num_spikes': False,
 'firing_rate': False,
 'presence_ratio': False,
 'isi_violation': False,
 'amplitude_cutoff': False,
 'snr': False,
 'max_drift': False,
 'cumulative_drift': False,
 'silhouette_score': False,
 'isolation_distance': False,
 'l_ratio': False,
 'd_prime': False,
 'nn_hit_rate': False,
 'nn_miss_rate': False}

Select a set of metrics to compute

In [10]:
metric_dict['num_spikes'] = True
metric_dict['firing_rate'] = True
metric_dict['isi_violation'] = True
metric_dict['nn_hit_rate'] = True
#metric_dict['noise overlap'] = True

Set the parameters for computing the metrics. \
All of the parameters in the schema have default values, so we only need to specify the ones that we want to change. \
See spiketoolkit.validation documentation for details.

In [11]:
n_cluster_waveforms=1000

In [13]:
nd.common.SpikeSortingMetrics().insert1({'cluster_metrics_list_name' : 'll_fl_probe_metrics', 
                             'n_cluster_waveforms' : n_cluster_waveforms,
                             'metrics_dict' : metric_dict}, skip_duplicates='True')

In [11]:
nd.common.SpikeSortingParameters().drop()

`common_spikesorting`.`spike_sorting_parameters` (1 tuples)
`common_spikesorting`.`__spike_sorting` (0 tuples)
`common_spikesorting`.`__curated_spike_sorting` (0 tuples)
`common_spikesorting`.`__curated_spike_sorting__units` (0 tuples)


Proceed? [yes, No]:  yes


Tables dropped.  Restart kernel.


In [26]:
#nd.common.SpikeSortingParameters().delete()

In [16]:
sort_group_id = 1
key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = sort_group_id
key['sorter_name'] = 'mountainsort4'
key['parameter_set_name'] = 'franklab_mountainsort_20KHz'
key['waveform_parameters_name'] = 'franklab default'
key['interval_list_name'] = '01_s1'
key['sort_interval_name'] = 'test'
key['cluster_metrics_list_name'] = 'll_fl_probe_metrics'
nd.common.SpikeSortingParameters().insert1(key, skip_duplicates='True')

### run the sort - this can take some time

In [17]:
nd.common.SpikeSorting().populate()

in spike sorting
writing new NWB file beans20190718_000000.nwb
Sorting {'nwb_file_name': 'beans20190718_.nwb', 'sort_group_id': 1, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'sort_interval_name': 'test', 'analysis_file_name': 'beans20190718_000000.nwb'}...
Using 7 workers.
Using tmpdir: /tmp/tmpt7gilkpr
Num. workers = 7
Preparing /tmp/tmpt7gilkpr/timeseries.hdf5...
'end_frame' set to 4799970
Preparing neighborhood sorters (M=24, N=4799970)...
Preparing output...
Done with ms4alg.
Cleaning tmpdir::::: /tmp/tmpt7gilkpr
mountainsort4 run time 50.28s
Subsampling sorting
Finding unit peak channels
Processing chunk 1 of 1; chunk-range: 0 2543997; num-frames: 2543997
Retrieving traces for chunk
Collecting waveforms for chunk
Finding unit neighborhoods
Getting unit waveforms for 33 units
Processing chunk 1 of 1; chunk-range: 0 4799970; num-frames: 4799970
Retrieving traces for chunk
Collecting waveforms for chunk
Subsampling sorting
Finding unit peak c


### Example: Retrieve the spike trains:
Note that these spikes are all noise; this is for demonstration purposes only.

In [15]:
nd.common.SpikeSorting.delete()

About to delete:
Nothing to delete


In [9]:
nd.common.AnalysisNwbfile().delete()
nd.common.AnalysisNwbfile().cleanup()

About to delete:
`common_lab`.`analysis_nwbfile`: 1 items


Proceed? [yes, No]:  yes


Committed.


In [4]:
nd.common.AnalysisNwbfile().drop()

[autoreload of nwb_datajoint.common.common_spikesorting failed: Traceback (most recent call last):
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 410, in superreload
    update_generic(old_obj, new_obj)
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 302, in update_class
    if update_generic(old_obj, new_obj): continue
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/Users/loren/opt

`common_lab`.`analysis_nwbfile` (0 tuples)
`common_lab`.`__analysis_nwbfile_kachery` (0 tuples)
`common_spikesorting`.`__spike_sorting` (0 tuples)
`common_spikesorting`.`__curated_spike_sorting` (0 tuples)
`common_spikesorting`.`__curated_spike_sorting__units` (0 tuples)
`common_ephys`.`_l_f_p` (0 tuples)
`common_ephys`.`l_f_p_band_selection` (0 tuples)
`common_ephys`.`l_f_p_band_selection__l_f_p_band_electrode` (0 tuples)
`common_ephys`.`__l_f_p_band` (0 tuples)


Proceed? [yes, No]:  yes


Tables dropped.  Restart kernel.


In [12]:
nd.common.SpikeSorting().delete()
nd.common.AnalysisNwbfile().cleanup(delete_files=True)

[autoreload of nwb_datajoint.common.common_spikesorting failed: Traceback (most recent call last):
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 410, in superreload
    update_generic(old_obj, new_obj)
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 302, in update_class
    if update_generic(old_obj, new_obj): continue
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/Users/loren/opt

About to delete:
`common_spikesorting`.`__spike_sorting`: 1 items


Proceed? [yes, No]:  yes


0it [00:00, ?it/s]

Committed.





In [84]:
sort_interval = (nd.common.SortIntervalList() & {'sort_interval_list_name' : 'test'}).fetch1('sort_intervals')[0]

key = nd.common.SpikeSorting().fetch('KEY')[0]

recording = nd.common.SpikeSorting().get_recording_extractor(key, sort_interval)[0]
sorting = nd.common.SpikeSorting().get_sorting_extractor(key, sort_interval)


In [71]:
# get the timestamps and select 1000 random times
raw_obj = (nd.common.Raw() & {'nwb_file_name' : nwb_file_name}).fetch_nwb()[0]['raw']
ts = np.asarray(raw_obj.timestamps)
ts_sort_ind = np.where(np.logical_and((ts > sort_interval[0]), (ts < sort_interval[1])))[0]


array([     19,      20,      21, ..., 2000004, 2000005, 2000006])

In [7]:
sorting = (nd.common.SpikeSorting & {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}).fetch()
key = {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}
units = (nd.common.SpikeSorting & key).fetch_nwb()[0]['units'].to_dataframe()
units


Unnamed: 0_level_0,spike_times,obs_intervals,waveform_mean,sort_interval,num_spikes,firing_rate,isi_violation,nn_hit_rate
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
1,"[1563488988.9327333, 1563488988.9815834, 15634...","[[1563488988.8777807, 1563488998.8777807]]","[[-14.0, -15.0, -13.5, -12.0, -9.5, -6.0, -13....","[[1563488988.8777807, 1563488998.8777807]]",401,40.045817,0.429112,0.388889
2,"[1563488988.920283, 1563488988.9500833, 156348...","[[1563488988.8777807, 1563488998.8777807]]","[[-10.0, -2.0, -3.0, -1.0, -5.0, 2.0, 3.0, -2....","[[1563488988.8777807, 1563488998.8777807]]",211,21.07149,0.378016,0.256554
3,"[1563488988.9864836, 1563488989.1257844, 15634...","[[1563488988.8777807, 1563488998.8777807]]","[[-12.0, -12.0, -6.0, -14.0, -18.0, -14.0, -9....","[[1563488988.8777807, 1563488998.8777807]]",195,19.473652,0.575373,0.323741
4,"[1563488988.905783, 1563488988.911983, 1563488...","[[1563488988.8777807, 1563488998.8777807]]","[[-6.0, -3.0, 3.0, 0.0, -9.0, -10.0, -12.0, -1...","[[1563488988.8777807, 1563488998.8777807]]",207,20.67203,0.196383,0.252252
5,"[1563488988.8783827, 1563488988.901983, 156348...","[[1563488988.8777807, 1563488998.8777807]]","[[3.5, 0.0, 6.0, 3.5, 3.0, -3.5, -5.0, 5.0, -5...","[[1563488988.8777807, 1563488998.8777807]]",270,26.963518,0.230859,0.329949
6,"[1563488988.8863328, 1563488988.9429333, 15634...","[[1563488988.8777807, 1563488998.8777807]]","[[-10.5, -6.0, -17.5, -18.5, -2.0, -11.0, -10....","[[1563488988.8777807, 1563488998.8777807]]",214,21.371084,0.367492,0.257246
7,"[1563488988.8768828, 1563488988.8978329, 15634...","[[1563488988.8777807, 1563488998.8777807]]","[[0.0, 2.0, 7.5, 3.5, -1.5, -4.5, -5.0, -9.0, ...","[[1563488988.8777807, 1563488998.8777807]]",274,27.362977,0.403503,0.181572
8,"[1563488988.951233, 1563488988.9756334, 156348...","[[1563488988.8777807, 1563488998.8777807]]","[[0.0, 0.0, -4.0, 4.0, 3.5, -6.0, -1.0, -1.5, ...","[[1563488988.8777807, 1563488998.8777807]]",272,27.163247,0.318468,0.436275
9,"[1563488988.8793328, 1563488988.897183, 156348...","[[1563488988.8777807, 1563488998.8777807]]","[[0.0, -1.5, -9.5, -5.5, -11.5, -0.5, 6.5, 3.5...","[[1563488988.8777807, 1563488998.8777807]]",414,41.34406,0.314213,0.196429
10,"[1563488988.8909328, 1563488988.922383, 156348...","[[1563488988.8777807, 1563488998.8777807]]","[[-14.0, -9.5, -17.5, -21.0, -19.0, -10.0, -29...","[[1563488988.8777807, 1563488998.8777807]]",244,24.367031,1.385133,0.217125


In [15]:
nd.common.Nwbfile().delete()

About to delete:
`common_lab`.`__nwbfile_kachery`: 1 items
`common_spikesorting`.`__spike_sorting`: 1 items
`file_lock`.`analysis_nwbfile_lock`: 1 items
`common_lab`.`analysis_nwbfile`: 1 items
`common_ephys`.`l_f_p_selection__l_f_p_electrode`: 16 items
`common_spikesorting`.`sort_group__sort_group_electrode`: 256 items
`common_ephys`.`_electrode`: 256 items
`common_ephys`.`_electrode_group`: 2 items
`common_ephys`.`l_f_p_selection`: 1 items
`common_behav`.`_raw_position`: 1 items
`common_dio`.`_d_i_o_events`: 19 items
`common_ephys`.`_raw`: 1 items
`common_sensors`.`_sensor_data`: 1 items
`common_spikesorting`.`spike_sorting_parameters`: 1 items
`common_behav`.`_state_script_file`: 1 items
`common_behav`.`_video_file`: 4 items
`common_task`.`_task_epoch`: 4 items
`common_interval`.`interval_list`: 9 items
`common_interval`.`sort_interval`: 1 items
`common_session`.`_experimenter_list__experimenter`: 1 items
`common_session`.`_experimenter_list`: 1 items
`common_spikesorting`.`sort_gro

Proceed? [yes, No]:  yes


Committed.


In [18]:
st.validation.get_quality_metrics_list()

[autoreload of nwb_datajoint.common.common_spikesorting failed: Traceback (most recent call last):
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 410, in superreload
    update_generic(old_obj, new_obj)
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 302, in update_class
    if update_generic(old_obj, new_obj): continue
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/Users/loren/opt

['num_spikes',
 'firing_rate',
 'presence_ratio',
 'isi_violation',
 'amplitude_cutoff',
 'snr',
 'max_drift',
 'cumulative_drift',
 'silhouette_score',
 'isolation_distance',
 'l_ratio',
 'd_prime',
 'nn_hit_rate',
 'nn_miss_rate']