## This notebook assumes that you've imported one or more NWB files into DataJoint 
## It allows you to run spikesorters on those data using the SpikeInterface package

#### Load all of the relevant modules

In [5]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import numpy as np
import pynwb
import os

data_dir = Path('/Users/loren/data/nwb_builder_test_data') # CHANGE ME TO THE BASE DIRECTORY FOR DATA STORAGE ON YOUR SYSTEM

os.environ['NWB_DATAJOINT_BASE_DIR'] = str(data_dir)
os.environ['KACHERY_STORAGE_DIR'] = str(data_dir / 'kachery-storage')
os.environ['SPIKE_SORTING_STORAGE_DIR'] = str(data_dir / 'spikesorting')

# DataJoint and DataJoint schema

import datajoint as dj

dj.config['database.host'] = 'localhost'
dj.config['database.user'] = 'root'
dj.config['database.password'] = 'tutorial'

import nwb_datajoint as nd
from ndx_franklab_novela import Probe

import spiketoolkit as st

import warnings
warnings.simplefilter('ignore')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Set the nwb file name and the name of the probe file to create from DataJoint

In [6]:
nwb_file_name = (nd.common.Session() & {'session_id': 'beans_01'}).fetch1('nwb_file_name')

### Set the sort grouping by shank

In [7]:
nd.common.SortGroup().set_group_by_shank(nwb_file_name)
nd.common.SortGroup()

About to delete:
`common_spikesorting`.`sort_group__sort_group_electrode`: 256 items
`common_spikesorting`.`__spike_sorting`: 1 items
`common_spikesorting`.`spike_sorting_parameters`: 1 items
`common_spikesorting`.`sort_group`: 8 items


Proceed? [yes, No]:  no


Cancelled deletes.


DuplicateError: ("Duplicate entry 'beans20190718-trim.nwb-0' for key 'PRIMARY'", 'To ignore duplicate entries in insert, set skip_duplicates=True')

Optional: Display all of the electrodes with their sort groups

In [22]:
nd.common.SortGroup.SortGroupElectrode()

nwb_file_name  the name of the NWB file,sort_group_id  identifier for a group of electrodes,electrode_group_name  electrode group name from NWBFile,electrode_id  the unique number for this electrode
beans20190718-trim.nwb,0,0,0
beans20190718-trim.nwb,0,0,1
beans20190718-trim.nwb,0,0,2
beans20190718-trim.nwb,0,0,3
beans20190718-trim.nwb,0,0,4
beans20190718-trim.nwb,0,0,5
beans20190718-trim.nwb,0,0,6
beans20190718-trim.nwb,0,0,7
beans20190718-trim.nwb,0,0,8
beans20190718-trim.nwb,0,0,9


### create the spike sorter and parameter lists 

In [4]:
nd.common.SpikeSorter().insert_from_spikeinterface()
nd.common.SpikeSorterParameters().insert_from_spikeinterface()

### create a 'franklab_mountainsort' parameter set

In [5]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'default'}).fetch1()
param = p['parameter_dict']
param['adjacency_radius'] = 100
param['curation'] = False
param['num_workers'] = 7
param['verbose'] = True
param['clip_size'] = 30
param['noise_overlap_threshold'] = 0

nd.common.SpikeSorterParameters().insert1({'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz', 'parameter_dict' : param}, skip_duplicates='True')

Display the new parameter set

In [6]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz'}).fetch1()
p

{'sorter_name': 'mountainsort4',
 'parameter_set_name': 'franklab_mountainsort_20KHz',
 'parameter_dict': {'detect_sign': -1,
  'adjacency_radius': 100,
  'freq_min': 300,
  'freq_max': 6000,
  'filter': True,
  'whiten': True,
  'curation': False,
  'num_workers': 7,
  'clip_size': 30,
  'detect_threshold': 3,
  'detect_interval': 10,
  'noise_overlap_threshold': 0,
  'verbose': True}}

In [7]:
param = p['parameter_dict']
param

{'detect_sign': -1,
 'adjacency_radius': 100,
 'freq_min': 300,
 'freq_max': 6000,
 'filter': True,
 'whiten': True,
 'curation': False,
 'num_workers': 7,
 'clip_size': 30,
 'detect_threshold': 3,
 'detect_interval': 10,
 'noise_overlap_threshold': 0,
 'verbose': True}

### Create a set of spike sorting parameters for sorting group 1

In [9]:
# create a 100 second test intervals for debugging
s1 = (nd.common.IntervalList() & {'interval_list_name' : '01_s1'}).fetch1('valid_times')
a = np.asarray([s1[0][0], s1[0][0]+10])
#t = np.vstack((t, np.asarray([[a+120,b+120]])))
nd.common.SortInterval().insert1({'nwb_file_name' : nwb_file_name, 'sort_interval_name' : 'test', 'sort_interval' : a}, replace='True')

print(a)

[1.56348899e+09 1.56348900e+09]


In [12]:
# create the sorting waveform parameters table
n_noise_waveforms = 1000 # the number of random noise waveforms to save
waveform_param_dict = st.postprocessing.get_waveforms_params()
waveform_param_dict['grouping_property'] = 'group'
# set the window to half of the clip size before and half after
waveform_param_dict['ms_before'] = .75
waveform_param_dict['ms_after'] = .75
waveform_param_dict['dtype'] = 'i2'
waveform_param_dict['verbose'] = False
waveform_param_dict['max_spikes_per_unit'] = 1000
nd.common.SpikeSortingWaveformParameters.insert1({'waveform_parameters_name' : 'franklab default', 'n_noise_waveforms' : n_noise_waveforms, 
                                                   'waveform_parameter_dict' : waveform_param_dict}, replace='True')

In [35]:
#nd.common.SpikeSortingParameters().delete()

  and should_run_async(code)


About to delete:
`common_spikesorting`.`spike_sorting_parameters`: 1 items


Proceed? [yes, No]:  yes


Committed.


In [14]:
sort_group_id = 1
key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = sort_group_id
key['sorter_name'] = 'mountainsort4'
key['parameter_set_name'] = 'franklab_mountainsort_20KHz'
key['waveform_parameters_name'] = 'franklab default'
key['interval_list_name'] = '01_s1'
key['sort_interval_name'] = 'test'
nd.common.SpikeSortingParameters().insert1(key, skip_duplicates='True')

In [15]:
(nd.common.SortInterval()).fetch()

array([('beans20190718-trim.nwb', 'test', array([1.56348899e+09, 1.56348900e+09]))],
      dtype=[('nwb_file_name', 'O'), ('sort_interval_name', 'O'), ('sort_interval', 'O')])

### run the sort - this can take some time

In [3]:
nd.common.SpikeSorting().populate()

in spike sorting
writing new NWB file beans20190718-trim_00000017.nwb
Sorting {'nwb_file_name': 'beans20190718-trim.nwb', 'sort_group_id': 1, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'sort_interval_name': 'test', 'analysis_file_name': 'beans20190718-trim_00000017.nwb'}...
Using 7 workers.
Using tmpdir: /tmp/tmp4fh8yynb
Num. workers = 7
Preparing /tmp/tmp4fh8yynb/timeseries.hdf5...
'end_frame' set to 199999
Preparing neighborhood sorters (M=32, N=199999)...
Preparing output...
Done with ms4alg.
Cleaning tmpdir::::: /tmp/tmp4fh8yynb
mountainsort4 run time 18.91s
Subsampling sorting
Finding unit peak channels
Processing chunk 1 of 1; chunk-range: 0 193638; num-frames: 193638
Retrieving traces for chunk
Collecting waveforms for chunk
Finding unit neighborhoods
Getting unit waveforms for 37 units
Processing chunk 1 of 1; chunk-range: 0 199999; num-frames: 199999
Retrieving traces for chunk
Collecting waveforms for chunk
Subsampling sorting
Finding


### Example: Retrieve the spike trains:
Note that these spikes are all noise; this is for demonstration purposes only.

In [4]:
nd.common.SpikeSorting()

nwb_file_name  the name of the NWB file,sort_group_id  identifier for a group of electrodes,sorter_name  the name of the spike sorting algorithm,parameter_set_name  label for this set of parameters,sort_interval_name  descriptive name for this interval,analysis_file_name  the name of the file,units_object_id  the object ID for the units for this sort group,units_waveforms_object_id  the object ID for the unit waveforms,noise_waveforms_object_id  the object ID for the noise waveforms
beans20190718-trim.nwb,1,mountainsort4,franklab_mountainsort_20KHz,test,beans20190718-trim_00000017.nwb,886bbfd0-047b-4b97-9a28-5cdb2357bccb,,


In [84]:
sort_interval = (nd.common.SortIntervalList() & {'sort_interval_list_name' : 'test'}).fetch1('sort_intervals')[0]

key = nd.common.SpikeSorting().fetch('KEY')[0]

recording = nd.common.SpikeSorting().get_recording_extractor(key, sort_interval)[0]
sorting = nd.common.SpikeSorting().get_sorting_extractor(key, sort_interval)


In [71]:
# get the timestamps and select 1000 random times
raw_obj = (nd.common.Raw() & {'nwb_file_name' : nwb_file_name}).fetch_nwb()[0]['raw']
ts = np.asarray(raw_obj.timestamps)
ts_sort_ind = np.where(np.logical_and((ts > sort_interval[0]), (ts < sort_interval[1])))[0]


array([     19,      20,      21, ..., 2000004, 2000005, 2000006])

In [81]:
n_random_snippets = 1000
# generate a list of random times and get snippets for those times
rng = np.random.default_rng()
noise_indeces = np.sort(rng.choice(ts_sort_ind, n_random_snippets))
noise_sorting=se.NumpySortingExtractor()
noise_sorting.set_times_labels(times=noise_indeces,labels=np.zeros(noise_indeces.shape))

[   2372    3238    5280   10319   12819   17293   17485   17677   18460
   18480   21556   22520   23313   26127   29066   30371   30604   32355
   40270   42190   47176   48161   55720   59438   60787   61329   62426
   62873   64101   65649   67058   72391   74226   81095   82199   84472
   86057   87812   88760   91252   94648   94713   96691   97239  100988
  101801  104249  104621  106990  111278  121511  125590  126504  129303
  130006  131704  131717  132235  135002  136458  140085  142260  145351
  145657  146885  149081  150091  153295  153696  162761  167047  169256
  170378  176024  177927  178131  178222  180584  185076  187366  187404
  191455  192311  192656  197217  199148  205133  208741  209980  211279
  216413  216796  218254  221083  222767  222940  223482  224560  226358
  226485  226925  226966  228314  229097  230243  232736  235865  236303
  238904  239363  242016  242713  243171  245037  249725  251657  252296
  257655  262073  263269  264907  265418  265468  2

In [82]:
noise_sorting.get_unit_spike_train(0)

array([   2372,    3238,    5280,   10319,   12819,   17293,   17485,
         17677,   18460,   18480,   21556,   22520,   23313,   26127,
         29066,   30371,   30604,   32355,   40270,   42190,   47176,
         48161,   55720,   59438,   60787,   61329,   62426,   62873,
         64101,   65649,   67058,   72391,   74226,   81095,   82199,
         84472,   86057,   87812,   88760,   91252,   94648,   94713,
         96691,   97239,  100988,  101801,  104249,  104621,  106990,
        111278,  121511,  125590,  126504,  129303,  130006,  131704,
        131717,  132235,  135002,  136458,  140085,  142260,  145351,
        145657,  146885,  149081,  150091,  153295,  153696,  162761,
        167047,  169256,  170378,  176024,  177927,  178131,  178222,
        180584,  185076,  187366,  187404,  191455,  192311,  192656,
        197217,  199148,  205133,  208741,  209980,  211279,  216413,
        216796,  218254,  221083,  222767,  222940,  223482,  224560,
        226358,  226

In [83]:
import spikeextractors as se
import labbox_ephys as le
import numpy as np
import h5py

#add a second snippets file with 1000 random spike times across all channels

# Specify the output path
output_h5_path = 'real_snippets.h5'

# Prepare the snippets h5 file
le.prepare_snippets_h5_from_extractors(
    recording=recording,
    sorting=sorting,
    output_h5_path=output_h5_path,
    start_frame=None,
    end_frame=None,
    snippet_len = (10,20),
    max_events_per_unit=None,
    max_neighborhood_size=6
)

output_h5_path = 'noise_snippets.h5'
le.prepare_snippets_h5_from_extractors(
    recording=recording,
    sorting=noise_sorting,
    output_h5_path=output_h5_path,
    start_frame=None,
    end_frame=None,
    snippet_len = (10,20),
    max_events_per_unit=None,
    max_neighborhood_size=10000
)



# Example display some contents of the file
with h5py.File(output_h5_path, 'r') as f:
    unit_ids = np.array(f.get('unit_ids'))
    sampling_frequency = np.array(f.get('sampling_frequency'))[0]
    print(f'Unit IDs: {unit_ids}')
    print(f'Sampling freq: {sampling_frequency}')
    for unit_id in unit_ids:
        unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
        unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))
        unit_waveforms_channel_ids = np.array(f.get(f'unit_waveforms/{unit_id}/channel_ids'))
        print(f'Unit {unit_id} | Tot num events: {len(unit_spike_train)} | shape of subsampled snippets: {unit_waveforms.shape}')

Subsampling sorting
Finding unit peak channels
Processing chunk 1 of 1; chunk-range: 0 704643; num-frames: 704643
Retrieving traces for chunk
Collecting waveforms for chunk
Finding unit neighborhoods
Getting unit waveforms for 54 units
Processing chunk 1 of 1; chunk-range: 0 1999988; num-frames: 1999988
Retrieving traces for chunk
Collecting waveforms for chunk
Subsampling sorting
Finding unit peak channels
Processing chunk 1 of 1; chunk-range: 0 42191; num-frames: 42191
Retrieving traces for chunk
Collecting waveforms for chunk
Finding unit neighborhoods
Getting unit waveforms for 1 units
Processing chunk 1 of 1; chunk-range: 0 1999988; num-frames: 1999988
Retrieving traces for chunk
Collecting waveforms for chunk
Unit IDs: [0]
Sampling freq: nan
Unit 0 | Tot num events: 1000 | shape of subsampled snippets: (1000, 32, 30)


In [45]:
cd .

/Users/loren/Src/NWB/nwb_datajoint/notebooks


  and should_run_async(code)


In [8]:
nd.common.SpikeSorting().delete()

About to delete:
`common_spikesorting`.`__spike_sorting`: 1 items


Proceed? [yes, No]:  yes


Committed.


In [40]:
sorting = (nd.common.SpikeSorting & {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}).fetch()
key = {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}
units = (nd.common.SpikeSorting & key).fetch_nwb()[0]['units'].to_dataframe()
units

  and should_run_async(code)


Unnamed: 0_level_0,spike_times,obs_intervals,waveform_mean,sort_interval
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,"[1563488988.8793328, 1563488988.880833, 156348...","[[1563488988.8777807, 1563489088.8777807]]","[[-76.0, -72.5, -75.0, -92.0, -87.5, -83.5, -9...","[[1563488988.8777807, 1563489088.8777807]]"
2,"[1563488988.943533, 1563488988.969233, 1563488...","[[1563488988.8777807, 1563489088.8777807]]","[[22.0, 24.0, 17.0, 13.5, 10.5, -9.0, -5.0, -7...","[[1563488988.8777807, 1563489088.8777807]]"
3,"[1563488988.8768828, 1563488988.8803327, 15634...","[[1563488988.8777807, 1563489088.8777807]]","[[22.0, 15.5, 19.5, 12.5, 20.0, 17.0, 13.5, 25...","[[1563488988.8777807, 1563489088.8777807]]"
4,"[1563488988.888583, 1563488988.8907828, 156348...","[[1563488988.8777807, 1563489088.8777807]]","[[48.0, 60.5, 60.5, 55.0, 53.0, 63.0, 53.5, 37...","[[1563488988.8777807, 1563489088.8777807]]"
5,"[1563488988.879233, 1563488989.9093895, 156348...","[[1563488988.8777807, 1563489088.8777807]]","[[-225.5, -207.0, -219.0, -206.5, -225.0, -206...","[[1563488988.8777807, 1563489088.8777807]]"
6,"[1563488988.906633, 1563488988.9317832, 156348...","[[1563488988.8777807, 1563489088.8777807]]","[[-18.0, -12.0, -23.5, -27.0, -23.5, -41.0, -4...","[[1563488988.8777807, 1563489088.8777807]]"
7,"[1563488988.9779334, 1563488989.0894341, 15634...","[[1563488988.8777807, 1563489088.8777807]]","[[203.0, 199.0, 204.0, 207.0, 210.5, 222.0, 22...","[[1563488988.8777807, 1563489088.8777807]]"
8,"[1563488989.6780376, 1563488991.5644002, 15634...","[[1563488988.8777807, 1563489088.8777807]]","[[-184.0, -181.0, -187.0, -187.0, -204.0, -227...","[[1563488988.8777807, 1563489088.8777807]]"
9,"[1563488988.9461832, 1563488988.9508832, 15634...","[[1563488988.8777807, 1563489088.8777807]]","[[53.0, 51.0, 42.5, 54.5, 37.5, 34.5, 14.5, 17...","[[1563488988.8777807, 1563489088.8777807]]"
10,"[1563488988.8768828, 1563488988.9551833, 15634...","[[1563488988.8777807, 1563489088.8777807]]","[[5.5, 11.0, 14.0, 9.0, 9.0, 17.0, 20.0, 21.0,...","[[1563488988.8777807, 1563489088.8777807]]"


In [10]:
dj.conn

  and should_run_async(code)
[autoreload of nwb_datajoint.common.common_spikesorting failed: Traceback (most recent call last):
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 410, in superreload
    update_generic(old_obj, new_obj)
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 302, in update_class
    if update_generic(old_obj, new_obj): continue
  File "/Users/loren/opt/anaconda3/envs/nwb_datajoint/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a

<function datajoint.connection.conn(host=None, user=None, password=None, *, init_fun=None, reset=False, use_tls=None)>

In [11]:
dj.config

  and should_run_async(code)


{   'connection.charset': '',
    'connection.init_function': None,
    'database.host': 'localhost',
    'database.password': 'tutorial',
    'database.port': 3306,
    'database.reconnect': True,
    'database.use_tls': None,
    'database.user': 'root',
    'databasse.password': 'tutorial',
    'display.limit': 12,
    'display.show_tuple_count': True,
    'display.width': 14,
    'enable_python_native_blobs': True,
    'fetch_format': 'array',
    'loglevel': 'INFO',
    'password': 'tutorial',
    'safemode': True,
    'user': 'root'}