# This notebook assumes that you've imported one or more NWB files into DataJoint 
# It allows you to run spikesorters on those data using the SpikeInterface package

#### Load all of the relevant modules

In [1]:
%env DJ_SUPPORT_FILEPATH_MANAGEMENT=TRUE
%load_ext autoreload
%autoreload 2

import os
data_dir = '/Users/loren/data/nwb_builder_test_data'
os.environ['NWB_DATAJOINT_BASE_DIR'] = data_dir

os.environ['KACHERY_STORAGE_DIR'] = os.path.join(data_dir, 'kachery-storage')
os.environ['SORTING_TEMP_DIR'] = os.path.join(data_dir, 'sort_tmp')

import numpy as np
import pynwb
import os

#DataJoint and DataJoint schema
import nwb_datajoint as nd
import datajoint as dj


env: DJ_SUPPORT_FILEPATH_MANAGEMENT=TRUE
Connecting root@localhost:3306


### Set the nwb file name and the name of the probe file to create from DataJoint

In [2]:
nwb_file_name = (nd.common.Session() & {'session_id': 'beans_01'}).fetch1('nwb_file_name')

### Set the sort grouping by shank

In [3]:
nd.common.SortGroup().set_group_by_shank(nwb_file_name)
nd.common.SortGroup()

About to delete:
`common_ephys`.`spike_sorting_parameters`: 1 items
`common_ephys`.`sort_group__sort_group_electrode`: 256 items
`common_ephys`.`sort_group`: 8 items


Proceed? [yes, No]:  yes


Committed.


nwb_file_name  the name of the NWB file,sort_group_id  identifier for a group of electrodes,"sort_reference_electrode_id  the electrode to use for reference. -1: no reference, -2: common median"
beans20190718.nwb,0,-1
beans20190718.nwb,1,-1
beans20190718.nwb,2,-1
beans20190718.nwb,3,-1
beans20190718.nwb,4,-1
beans20190718.nwb,5,-1
beans20190718.nwb,6,-1
beans20190718.nwb,7,-1


In [4]:
nd.common.SortGroup.SortGroupElectrode().fetch()

array([('beans20190718.nwb', 0, '0',   0),
       ('beans20190718.nwb', 0, '0',   1),
       ('beans20190718.nwb', 0, '0',   2),
       ('beans20190718.nwb', 0, '0',   3),
       ('beans20190718.nwb', 0, '0',   4),
       ('beans20190718.nwb', 0, '0',   5),
       ('beans20190718.nwb', 0, '0',   6),
       ('beans20190718.nwb', 0, '0',   7),
       ('beans20190718.nwb', 0, '0',   8),
       ('beans20190718.nwb', 0, '0',   9),
       ('beans20190718.nwb', 0, '0',  10),
       ('beans20190718.nwb', 0, '0',  11),
       ('beans20190718.nwb', 0, '0',  12),
       ('beans20190718.nwb', 0, '0',  13),
       ('beans20190718.nwb', 0, '0',  14),
       ('beans20190718.nwb', 0, '0',  15),
       ('beans20190718.nwb', 0, '0',  16),
       ('beans20190718.nwb', 0, '0',  17),
       ('beans20190718.nwb', 0, '0',  18),
       ('beans20190718.nwb', 0, '0',  19),
       ('beans20190718.nwb', 0, '0',  20),
       ('beans20190718.nwb', 0, '0',  21),
       ('beans20190718.nwb', 0, '0',  22),
       ('be

### create the spike sorter and parameter lists 

In [7]:
nd.common.SpikeSorter().delete()
nd.common.SpikeSorter().insert_from_spikeinterface()
nd.common.SpikeSorterParameters().insert_from_spikeinterface()

About to delete:
`common_ephys`.`spike_sorter_parameters`: 10 items
`common_ephys`.`spike_sorter`: 10 items


Proceed? [yes, No]:  yes


Committed.


### create a 'franklab_mountainsort' parameter set

In [5]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'default'}).fetch1()
param = p['parameter_dict']
param['adjacency_radius'] = 100
param['curation'] = False
param['num_workers'] = 7
param['verbose'] = True
param['clip_size'] = 30
param['noise_overlap_threshold'] = 0

nd.common.SpikeSorterParameters().insert1({'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz', 'parameter_dict' : param}, skip_duplicates='True')

In [6]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'franklab_mountainsort_20KHz'}).fetch1()
p

{'sorter_name': 'mountainsort4',
 'parameter_set_name': 'franklab_mountainsort_20KHz',
 'parameter_dict': {'detect_sign': -1,
  'adjacency_radius': 100,
  'freq_min': 300,
  'freq_max': 6000,
  'filter': True,
  'whiten': True,
  'curation': False,
  'num_workers': 7,
  'clip_size': 30,
  'detect_threshold': 3,
  'detect_interval': 10,
  'noise_overlap_threshold': 0,
  'verbose': True}}

In [7]:
param = p['parameter_dict']
param

{'detect_sign': -1,
 'adjacency_radius': 100,
 'freq_min': 300,
 'freq_max': 6000,
 'filter': True,
 'whiten': True,
 'curation': False,
 'num_workers': 7,
 'clip_size': 30,
 'detect_threshold': 3,
 'detect_interval': 10,
 'noise_overlap_threshold': 0,
 'verbose': True}

### Create a set of spike sorting parameters for sorting group 4

In [8]:
sort_group_id = 4

In [9]:
# create two 60 second test intervals for debugging
t = (nd.common.IntervalList() & {'interval_list_name' : '01_s1'}).fetch1('valid_times')
a = t[0]
b = a + 60
t = np.asarray([[a,b]])
t = np.vstack((t, np.asarray([[a+120,b+120]])))
nd.common.SortIntervalList().insert1({'nwb_file_name' : nwb_file_name, 'sort_interval_list_name' : 'test', 'sort_intervals' : t}, replace='True')

In [10]:
key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = sort_group_id
key['sorter_name'] = 'mountainsort4'
key['parameter_set_name'] = 'franklab_mountainsort_20KHz'
key['interval_list_name'] = '01_s1'
key['sort_interval_list_name'] = 'test'
nd.common.SpikeSortingParameters().insert1(key, skip_duplicates='True')

nwb_file_name  the name of the NWB file,sort_group_id  identifier for a group of electrodes,sorter_name  the name of the spike sorting algorithm,parameter_set_name  label for this set of parameters,sort_interval_list_name  descriptive name of this interval list,interval_list_name  descriptive name of this interval list
beans20190718.nwb,4,mountainsort4,franklab_mountainsort_20KHz,test,01_s1


### run the sort

In [12]:
nd.common.SpikeSorting().delete()

About to delete:
Nothing to delete


In [3]:
nd.common.SpikeSorting().populate()

writing new NWB file beans20190718_00000000.nwb
sample indeces: [      19 21839000]
Sorting {'nwb_file_name': 'beans20190718.nwb', 'sort_group_id': 4, 'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz', 'sort_interval_list_name': 'test', 'analysis_file_name': 'beans20190718_00000000.nwb'}...
Using 7 workers.
Using tmpdir: /tmp/tmp3mlpois5
Num. workers = 7
Preparing /tmp/tmp3mlpois5/timeseries.hdf5...
Cleaning tmpdir:: /tmp/tmp3mlpois5


KeyboardInterrupt: 

### Example: Retrieve the spike trains:

In [82]:
sorting = (nd.common.SpikeSorting & {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}).fetch_nwb()
sorting[0]['units'].to_dataframe()

Unnamed: 0_level_0,spike_times
id,Unnamed: 1_level_1
1,"[1563488988.8826828, 1563488988.9594333, 15634..."
2,"[1563488989.0944343, 1563488989.1491346, 15634..."
3,"[1563488988.8936327, 1563488989.0692837, 15634..."
4,"[1563488988.884683, 1563488989.0539339, 156348..."
5,"[1563488988.9498332, 1563488988.9755335, 15634..."
...,...
118,"[1563488988.920283, 1563488989.0027835, 156348..."
119,"[1563488988.9888334, 1563488989.2898355, 15634..."
120,"[1563488988.9386835, 1563488989.1438847, 15634..."
121,"[1563488988.8953328, 1563488989.0155838, 15634..."
