## This notebook assumes that you've imported one or more NWB files into DataJoint 
## It allows you to run spikesorters on those data using the SpikeInterface package

#### Load all of the relevant modules

In [2]:
%load_ext autoreload
%autoreload 2

import os
data_dir = '/Users/loren/data/nwb_builder_test_data'  # CHANGE ME

os.environ['NWB_DATAJOINT_BASE_DIR'] = data_dir
os.environ['KACHERY_STORAGE_DIR'] = os.path.join(data_dir, 'kachery-storage')
os.environ['SPIKE_SORTING_STORAGE_DIR'] = os.path.join(data_dir, 'spikesorting')

import numpy as np
import pynwb
import os

# DataJoint and DataJoint schema
import nwb_datajoint as nd
import datajoint as dj
from ndx_franklab_novela import Probe

import warnings
warnings.simplefilter('ignore')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Connecting root@localhost:3306


### Set the nwb file name and the name of the probe file to create from DataJoint

In [3]:
nwb_file_name = (nd.common.Session() & {'session_id': 'beans_01'}).fetch1('nwb_file_name')

### Set the sort grouping by shank

In [4]:
nd.common.SortGroup().set_group_by_shank(nwb_file_name)
nd.common.SortGroup()

About to delete:
`common_spikesorting`.`sort_group__sort_group_electrode`: 256 items
`common_spikesorting`.`spike_sorting_parameters`: 1 items
`common_spikesorting`.`sort_group`: 8 items


Proceed? [yes, No]:  yes


Committed.


nwb_file_name  the name of the NWB file,sort_group_id  identifier for a group of electrodes,"sort_reference_electrode_id  the electrode to use for reference. -1: no reference, -2: common median"
beans20190718-trim.nwb,0,-1
beans20190718-trim.nwb,1,-1
beans20190718-trim.nwb,2,-1
beans20190718-trim.nwb,3,-1
beans20190718-trim.nwb,4,-1
beans20190718-trim.nwb,5,-1
beans20190718-trim.nwb,6,-1
beans20190718-trim.nwb,7,-1


Optional: Display all of the electrodes with their sort groups

In [5]:
nd.common.SortGroup.SortGroupElectrode()

nwb_file_name  the name of the NWB file,sort_group_id  identifier for a group of electrodes,electrode_group_name  electrode group name from NWBFile,electrode_id  the unique number for this electrode
beans20190718-trim.nwb,0,0,0
beans20190718-trim.nwb,0,0,1
beans20190718-trim.nwb,0,0,2
beans20190718-trim.nwb,0,0,3
beans20190718-trim.nwb,0,0,4
beans20190718-trim.nwb,0,0,5
beans20190718-trim.nwb,0,0,6
beans20190718-trim.nwb,0,0,7
beans20190718-trim.nwb,0,0,8
beans20190718-trim.nwb,0,0,9


### create the spike sorter and parameter lists 

In [6]:
nd.common.SpikeSorter().insert_from_spikeinterface()
nd.common.SpikeSorterParameters().insert_from_spikeinterface()

### create a 'franklab_mountainsort' parameter set

In [7]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'default'}).fetch1()
param = p['parameter_dict']
param['adjacency_radius'] = 100
param['curation'] = False
param['num_workers'] = 7
param['verbose'] = True
param['clip_size'] = 30
param['noise_overlap_threshold'] = 0

nd.common.SpikeSorterParameters().insert1({'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz', 'parameter_dict' : param}, skip_duplicates='True')

Display the new parameter set

In [8]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz'}).fetch1()
p

{'sorter_name': 'mountainsort4',
 'parameter_set_name': 'franklab_mountainsort_20KHz',
 'parameter_dict': {'detect_sign': -1,
  'adjacency_radius': 100,
  'freq_min': 300,
  'freq_max': 6000,
  'filter': True,
  'whiten': True,
  'curation': False,
  'num_workers': 7,
  'clip_size': 30,
  'detect_threshold': 3,
  'detect_interval': 10,
  'noise_overlap_threshold': 0,
  'verbose': True}}

In [9]:
param = p['parameter_dict']
param

{'detect_sign': -1,
 'adjacency_radius': 100,
 'freq_min': 300,
 'freq_max': 6000,
 'filter': True,
 'whiten': True,
 'curation': False,
 'num_workers': 7,
 'clip_size': 30,
 'detect_threshold': 3,
 'detect_interval': 10,
 'noise_overlap_threshold': 0,
 'verbose': True}

### Create a set of spike sorting parameters for sorting group 1

In [10]:
# create two 10 second test intervals for debugging
s1 = (nd.common.IntervalList() & {'interval_list_name' : '01_s1'}).fetch1('valid_times')
a = s1[0][0]
b = a + 10
t = np.asarray([[a,b]])
t = np.vstack((t, np.asarray([[a+120,b+120]])))
nd.common.SortIntervalList().insert1({'nwb_file_name' : nwb_file_name, 'sort_interval_list_name' : 'test', 'sort_intervals' : t}, replace='True')

print(t)

[[1.56348899e+09 1.56348900e+09]
 [1.56348911e+09 1.56348912e+09]]


In [11]:
sort_group_id = 1
key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = sort_group_id
key['sorter_name'] = 'mountainsort4'
key['parameter_set_name'] = 'franklab_mountainsort_20KHz'
key['interval_list_name'] = '01_s1'
key['sort_interval_list_name'] = 'test'
nd.common.SpikeSortingParameters().insert1(key, skip_duplicates='True')

In [12]:
(nd.common.SortIntervalList()).fetch1()

{'nwb_file_name': 'beans20190718-trim.nwb',
 'sort_interval_list_name': 'test',
 'sort_intervals': array([[1.56348899e+09, 1.56348900e+09],
        [1.56348911e+09, 1.56348912e+09]])}

### run the sort - this can take some time

In [13]:
nd.common.SpikeSorting().populate()

writing new NWB file beans20190718-trim_00000003.nwb
Sorting {'nwb_file_name': 'beans20190718-trim.nwb', 'sort_group_id': 1, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'sort_interval_list_name': 'test', 'analysis_file_name': 'beans20190718-trim_00000003.nwb'}...
Using 7 workers.
Using tmpdir: /tmp/tmps_ny5f7w
Num. workers = 7
Preparing /tmp/tmps_ny5f7w/timeseries.hdf5...
'end_frame' set to 199999
Preparing neighborhood sorters (M=32, N=199999)...
Preparing output...
Done with ms4alg.
Cleaning tmpdir::::: /tmp/tmps_ny5f7w
mountainsort4 run time 26.39s
{'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'parameter_dict': {'detect_sign': -1, 'adjacency_radius': 100, 'freq_min': 300, 'freq_max': 6000, 'filter': True, 'whiten': True, 'curation': False, 'num_workers': 7, 'clip_size': 30, 'detect_threshold': 3, 'detect_interval': 10, 'noise_overlap_threshold': 0, 'verbose': True}}
Sorting {'nwb_file_name': 'beans20190


### Example: Retrieve the spike trains:
Note that these spikes are all noise; this is for demonstration purposes only.

In [54]:
sorting = (nd.common.SpikeSorting & {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}).fetch()
key = {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}
units = (nd.common.SpikeSorting & key).fetch_nwb()[0]['units'].to_dataframe()
units

[{'nwb_file_name': 'beans20190718-trim.nwb', 'sort_group_id': 1, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'sort_interval_list_name': 'test', 'analysis_file_name': 'beans20190718-trim_00000003.nwb', 'units_object_id': '3b112528-e0ec-4739-bd97-8b84baeed433', 'units_waveforms_object_id': '', 'nwb2load_filepath': '/Users/loren/data/nwb_builder_test_data/analysis/beans20190718-trim_00000003.nwb'}]


Unnamed: 0_level_0,spike_times,obs_intervals,waveform_mean,sort_interval
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,"[1563488988.8793328, 1563488988.897183, 156348...","[[1563488988.8777807, 1563488998.8777807]]","[[-68.0, -87.0, -78.0, -73.0, -80.0, -91.0, -9...","[[1563488988.8777807, 1563488998.8777807]]"
2,"[1563488988.8909328, 1563488988.943533, 156348...","[[1563488988.8777807, 1563488998.8777807]]","[[40.5, 40.5, 11.0, 22.0, 1.5, 20.0, -15.0, 7....","[[1563488988.8777807, 1563488998.8777807]]"
3,"[1563488988.8768828, 1563488988.8803327, 15634...","[[1563488988.8777807, 1563488998.8777807]]","[[55.0, 17.0, 37.0, 45.0, 41.0, 52.0, 38.0, 31...","[[1563488988.8777807, 1563488998.8777807]]"
4,"[1563488988.888583, 1563488988.9524333, 156348...","[[1563488988.8777807, 1563488998.8777807]]","[[87.5, 85.0, 95.5, 61.0, 65.0, 65.0, 41.0, 22...","[[1563488988.8777807, 1563488998.8777807]]"
5,"[1563488988.879233, 1563488988.9779334, 156348...","[[1563488988.8777807, 1563488998.8777807]]","[[169.5, 179.5, 172.0, 159.5, 183.0, 195.5, 19...","[[1563488988.8777807, 1563488998.8777807]]"
...,...,...,...,...
71,"[1563488988.8785827, 1563488988.8977828, 15634...","[[1563489108.8777807, 1563489118.8777807]]","[[-26.0, -34.0, -28.0, -35.0, -29.0, -23.0, -2...","[[1563489108.8777807, 1563489118.8777807]]"
72,"[1563488988.911533, 1563488988.9631834, 156348...","[[1563489108.8777807, 1563489118.8777807]]","[[-41.5, -41.0, -37.0, -45.0, -43.5, -58.5, -6...","[[1563489108.8777807, 1563489118.8777807]]"
73,"[1563488988.897333, 1563488988.913633, 1563488...","[[1563489108.8777807, 1563489118.8777807]]","[[-65.0, -53.0, -69.0, -46.0, -65.0, -61.0, -5...","[[1563489108.8777807, 1563489118.8777807]]"
74,"[1563488988.8896828, 1563488988.896733, 156348...","[[1563489108.8777807, 1563489118.8777807]]","[[-54.0, -75.0, -59.5, -79.5, -71.5, -74.0, -7...","[[1563489108.8777807, 1563489118.8777807]]"
