## This notebook assumes that you've imported one or more NWB files into DataJoint 
## It allows you to run spikesorters on those data using the SpikeInterface package

#### Load all of the relevant modules

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import numpy as np
import pynwb
import os

# CHANGE ME TO THE BASE DIRECTORY FOR DATA STORAGE ON YOUR SYSTEM
# data_dir = Path('/Users/loren/data/nwb_builder_test_data') 
data_dir = Path('/mnt/c/Users/Ryan/Documents/NWB_Data/Frank Lab Data/')

raw_dir = data_dir / 'raw'
analysis_dir = data_dir / 'analysis'

os.environ['NWB_DATAJOINT_BASE_DIR'] = str(data_dir)
os.environ['KACHERY_STORAGE_DIR'] = str(data_dir / 'kachery-storage')
os.environ['SPIKE_SORTING_STORAGE_DIR'] = str(data_dir / 'spikesorting')
os.environ['DJ_SUPPORT_FILEPATH_MANAGEMENT'] = 'TRUE'

# DataJoint and DataJoint schema
import datajoint as dj
dj.config['database.host'] = 'localhost'
dj.config['database.user'] = 'root'
dj.config['database.password'] = 'tutorial'
dj.config['stores'] = {
  'raw': {
    'protocol': 'file',
    'location': str(raw_dir),
    'stage' : str(raw_dir)
  },
  'analysis': {
    'protocol': 'file',
    'location': str(analysis_dir),
    'stage': str(analysis_dir)
  }
}

import nwb_datajoint as nd
from ndx_franklab_novela import Probe

import spiketoolkit as st

import warnings
warnings.simplefilter('ignore')

### Set the nwb file name and the name of the probe file to create from DataJoint

In [None]:
nd.common.Session()

In [None]:
#nwb_file_name = (nd.common.Session() & {'session_id': 'beans_01'}).fetch1('nwb_file_name')
# nwb_file_name = 'beans20190718_.nwb'
nwb_file_name = 'beans20190718-trim_.nwb'

### Set the sort grouping by shank

In [None]:
nd.common.SortGroup().set_group_by_shank(nwb_file_name)
nd.common.SortGroup()

Optional: Display all of the electrodes with their sort groups

In [None]:
nd.common.SortGroup.SortGroupElectrode()

### create the spike sorter and parameter lists 

In [None]:
nd.common.SpikeSorter().insert_from_spikeinterface()
nd.common.SpikeSorterParameters().insert_from_spikeinterface()

In [None]:
nd.common.SpikeSorterParameters()

### create a 'franklab_mountainsort' parameter set
#### Note that we're doing the filtering using spikeinterface, so we set filter to False

In [None]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'default'}).fetch1()
p

In [None]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name' : 'default'}).fetch1()
param = p['parameter_dict']
param['adjacency_radius'] = 100
param['curation'] = False
param['filter'] = False
param['num_workers'] = 7
param['verbose'] = True
param['clip_size'] = 30
param['noise_overlap_threshold'] = 0

nd.common.SpikeSorterParameters().insert1({'sorter_name': 'mountainsort4', 
                                           'parameter_set_name': 'franklab_mountainsort_20KHz', 
                                           'parameter_dict' : param}, 
                                          skip_duplicates='True')

Display the new parameter set

In [None]:
p = (nd.common.SpikeSorterParameters() & {'sorter_name': 'mountainsort4', 'parameter_set_name': 'franklab_mountainsort_20KHz'}).fetch1()
p

In [None]:
param = p['parameter_dict']
param

### Create a set of spike sorting parameters for sorting group 1

In [None]:
# create a 10 second test intervals for debugging
s1 = (nd.common.IntervalList() & {'interval_list_name': '01_s1'}).fetch1('valid_times')
print(s1)

a = np.asarray([s1[0][0], s1[0][0]+240])
print(a)

nd.common.SortInterval().insert1({'nwb_file_name': nwb_file_name, 'sort_interval_name': 'test', 'sort_interval': a}, replace='True')

In [None]:
# create the sorting waveform parameters table
waveform_param_dict = st.postprocessing.get_waveforms_params()
waveform_param_dict['grouping_property'] = 'group'
# set the window to half of the clip size before and half after
waveform_param_dict['ms_before'] = .75
waveform_param_dict['ms_after'] = .75
waveform_param_dict['dtype'] = 'i2'
waveform_param_dict['verbose'] = False
waveform_param_dict['max_spikes_per_unit'] = 1000
nd.common.SpikeSortingWaveformParameters.insert1({'waveform_parameters_name': 'franklab default', 
                                                  'waveform_parameter_dict': waveform_param_dict}, 
                                                 replace='True')
nd.common.SpikeSortingWaveformParameters()

### create a list of metrics to be computed 

In [None]:
metric_dict = nd.common.SpikeSortingMetrics().get_metric_dict()
metric_dict

Select a set of metrics to compute

In [None]:
metric_dict['num_spikes'] = True
metric_dict['firing_rate'] = True
metric_dict['isi_violation'] = True
metric_dict['nn_hit_rate'] = True
#metric_dict['noise overlap'] = True

Set the parameters for computing the metrics. \
All of the parameters in the schema have default values, so we only need to specify the ones that we want to change. \
See spiketoolkit.validation documentation for details.

In [None]:
n_cluster_waveforms=1000

In [None]:
nd.common.SpikeSortingMetrics().insert1({'cluster_metrics_list_name': 'll_fl_probe_metrics', 
                                         'n_cluster_waveforms' : n_cluster_waveforms,
                                         'metrics_dict' : metric_dict}, 
                                        skip_duplicates='True')
nd.common.SpikeSortingMetrics()

In [None]:
sort_group_id = 1
key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = sort_group_id
key['sorter_name'] = 'mountainsort4'
key['parameter_set_name'] = 'franklab_mountainsort_20KHz'
key['waveform_parameters_name'] = 'franklab default'
key['interval_list_name'] = '01_s1'
key['sort_interval_name'] = 'test'
key['cluster_metrics_list_name'] = 'll_fl_probe_metrics'
nd.common.SpikeSortingParameters().insert1(key, skip_duplicates='True')
nd.common.SpikeSortingParameters()

### run the sort - this can take some time

In [None]:
nd.common.SpikeSorting().populate()


### Example: Retrieve the spike trains:
Note that these spikes are all noise; this is for demonstration purposes only.

In [None]:
sorting = (nd.common.SpikeSorting & {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}).fetch()
key = {'nwb_file_name' : nwb_file_name, 'sort_group_id' : sort_group_id}
units = (nd.common.SpikeSorting & key).fetch_nwb()[0]['units'].to_dataframe()
units


### Everything below here is for deleting schema or testing.

In [None]:
nd.common.SpikeSorting.delete()

In [None]:
nd.common.AnalysisNwbfile().delete()
nd.common.AnalysisNwbfile().cleanup()

In [None]:
nd.common.AnalysisNwbfile().drop()

In [None]:
nd.common.SpikeSorting().delete()
nd.common.AnalysisNwbfile().cleanup(delete_files=True)

In [None]:
sort_interval = (nd.common.SortIntervalList() & {'sort_interval_list_name' : 'test'}).fetch1('sort_intervals')[0]

key = nd.common.SpikeSorting().fetch('KEY')[0]

recording = nd.common.SpikeSorting().get_recording_extractor(key, sort_interval)[0]
sorting = nd.common.SpikeSorting().get_sorting_extractor(key, sort_interval)


In [None]:
# get the timestamps and select 1000 random times
raw_obj = (nd.common.Raw() & {'nwb_file_name' : nwb_file_name}).fetch_nwb()[0]['raw']
ts = np.asarray(raw_obj.timestamps)
ts_sort_ind = np.where(np.logical_and((ts > sort_interval[0]), (ts < sort_interval[1])))[0]


In [None]:
nd.common.Nwbfile().delete()

In [None]:
st.validation.get_quality_metrics_list()