# This notebook assumes that you've imported one or more NWB files into DataJoint 
# It allows you to run spikesorters on those data using the SpikeInterface package

#### Load all of the relevant modules

In [1]:
%env DJ_SUPPORT_FILEPATH_MANAGEMENT=TRUE
%load_ext autoreload
%autoreload 2

import os
data_dir = '/Users/loren/data/nwb_builder_test_data'
os.environ['NWB_DATAJOINT_BASE_DIR'] = data_dir

os.environ['KACHERY_STORAGE_DIR'] = os.path.join(data_dir, 'kachery-storage')
os.environ['SORTING_TEMP_DIR'] = os.path.join(data_dir, 'sort_tmp')

import numpy as np
import pynwb
import os

#DataJoint and DataJoint schema
import nwb_datajoint as nd
import datajoint as dj


env: DJ_SUPPORT_FILEPATH_MANAGEMENT=TRUE
Connecting root@localhost:3306


### Set the nwb file name and the name of the probe file to create from DataJoint

In [2]:
nwb_file_name = (nd.common.Session() & {'session_id': 'beans_01'}).fetch1('nwb_file_name')
#probe_file_name = '/Users/loren/data/nwb_builder_test_data/test.prb'

### Set the sort grouping by shank

In [3]:
nd.common.SortGroup().set_group_by_shank(nwb_file_name)
nd.common.SortGroup()

About to delete:
`common_ephys`.`sort_group__sort_group_electrode`: 256 items
`common_ephys`.`sort_group`: 8 items


Proceed? [yes, No]:  yes


Committed.


nwb_file_name  the name of the NWB file,sort_group_id  identifier for a group of electrodes,"sort_reference_electrode_id  the electrode to use for reference. -1: no reference, -2: common median"
beans20190718.nwb,0,-1
beans20190718.nwb,1,-1
beans20190718.nwb,2,-1
beans20190718.nwb,3,-1
beans20190718.nwb,4,-1
beans20190718.nwb,5,-1
beans20190718.nwb,6,-1
beans20190718.nwb,7,-1


### create the spike sorter and parameter lists 

In [4]:
nd.common.SpikeSorter().delete()
nd.common.SpikeSorter().insert_from_spikeinterface()
nd.common.SpikeSorterParameters().insert_from_spikeinterface()

About to delete:
`common_ephys`.`spike_sorter_parameters`: 11 items
`common_ephys`.`spike_sorter`: 10 items


Proceed? [yes, No]:  yes


Committed.


  if isinstance(obj, collections.ByteString):
  if isinstance(obj, collections.MutableSequence):
  if isinstance(obj, collections.Sequence):
  if isinstance(obj, collections.Set):


### create a 'franklab_mountainsort' parameter set

In [5]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'default'}).fetch1()
param = p['parameter_dict']
param['adjacency_radius'] = 100
param['curation'] = False
param['num_workers'] = 7
param['verbose'] = True
param['clip_size'] = 30
param['noise_overlap_threshold'] = 0

nd.common.SpikeSorterParameters().insert1({'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz', 'parameter_dict' : param}, skip_duplicates='True')

In [6]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz'}).fetch1()
param = p['parameter_dict']
param

{'detect_sign': -1,
 'adjacency_radius': 100,
 'freq_min': 300,
 'freq_max': 6000,
 'filter': True,
 'whiten': True,
 'curation': False,
 'num_workers': 7,
 'clip_size': 30,
 'detect_threshold': 3,
 'detect_interval': 10,
 'noise_overlap_threshold': 0,
 'verbose': True}

### Create a set of spike sorting parameters for sorting group 4

In [6]:
sort_group_id = 4

In [4]:
# create two 60 second test intervals for debugging
t = (nd.common.IntervalList() & {'interval_list_name' : '01_s1'}).fetch('valid_times')
a = t[0][0,0]
b = a + 60
t = np.asarray([[a,b]])
t = np.vstack((t, np.asarray([[a+120,b+120]])))
nd.common.SortIntervalList().insert1({'nwb_file_name' : nwb_file_name, 'sort_interval_list_name' : 'test', 'sort_intervals' : t}, replace='True')

IntegrityError: (1451, 'Cannot delete or update a parent row: a foreign key constraint fails (`common_ephys`.`spike_sorting_parameters`, CONSTRAINT `spike_sorting_parameters_ibfk_3` FOREIGN KEY (`nwb_file_name`, `sort_interval_list_name`) REFERENCES `common_interval`.`sort_interval_lis)')

In [6]:
key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = sort_group_id
key['sorter_name'] = 'mountainsort4'
key['parameter_set_name'] = 'franklab_mountainsort_20KHz'
key['interval_list_name'] = '01_s1'
key['sort_interval_list_name'] = 'test'
nd.common.SpikeSortingParameters().insert1(key, skip_duplicates='True')

### run the sort

In [3]:
nd.common.SpikeSorting().delete()

About to delete:
Nothing to delete


In [4]:
nd.common.SpikeSorting().populate()

writing new NWB file beans20190718_00000002.nwb
sample indeces: [     19 1200012]
Sorting {'nwb_file_name': 'beans20190718.nwb', 'sort_group_id': 4, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'sort_interval_list_name': 'test', 'interval_list_name': '01_s1', 'analysis_file_name': 'beans20190718_00000002.nwb'}...
Using 7 workers.
Using tmpdir: /tmp/tmpqpvs967f
Num. workers = 7
Preparing /tmp/tmpqpvs967f/timeseries.hdf5...
'end_frame' set to 1199993
Preparing neighborhood sorters (M=32, N=1199993)...
Preparing output...
Done with ms4alg.
Cleaning tmpdir::::: /tmp/tmpqpvs967f
mountainsort4 run time 142.67s
sample indeces: [2400004 3599996]
Sorting {'nwb_file_name': 'beans20190718.nwb', 'sort_group_id': 4, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'sort_interval_list_name': 'test', 'interval_list_name': '01_s1', 'analysis_file_name': 'beans20190718_00000002.nwb'}...
Using 7 workers.
Using tmpdir: /tmp/tmpt7

### Example: Retrieve the spike trains:

In [7]:
sorting = (nd.common.SpikeSorting & {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}).fetch_nwb()
sorting[0]['units'].to_dataframe()