## NWB-Datajoint Spike Sorting Tutorial

**Note: make a copy of this notebook and run the copy to avoid git conflicts in the future**

This is the second in a multi-part tutorial on the NWB-Datajoint pipeline used in Loren Frank's lab, UCSF. It demonstrates how to run spike sorting and curate units within the pipeline.

If you have not done [tutorial 0](0_intro.ipynb) yet, make sure to do so before proceeding.

Let's start by importing the `nwb_datajoint` package, along with a few others.<br>
**Note 2: Make sure you are running this within the nwb_datajoint Conda environment)**

In [None]:
import os
import numpy as np
import datajoint as dj
import nwb_datajoint as nd

# ignore datajoint+jupyter async warnings
import warnings
warnings.simplefilter('ignore', category=DeprecationWarning)
warnings.simplefilter('ignore', category=ResourceWarning)
os.environ['NWB_DATAJOINT_TEMP_DIR']="/stelmo/nwb/tmp"
os.environ['KACHERY_STORAGE_DIR']="/stelmo/nwb/kachery-storage"

#### Import tables from nwb_datajoint

In [None]:
from nwb_datajoint.common import (SortGroup, SpikeSortingFilterParameters, SpikeSortingArtifactDetectionParameters,
                                  SpikeSortingRecordingSelection, SpikeSortingRecording, 
                                  SpikeSortingWorkspace, 
                                  SpikeSorter, SpikeSorterParameters, SortingID,
                                  SpikeSortingSelection, SpikeSorting, 
                                  SpikeSortingMetricParameters,
                                  AutomaticCurationParameters, AutomaticCurationSelection,
                                  AutomaticCuration,
                                  CuratedSpikeSortingSelection, CuratedSpikeSorting,
                                  IntervalList, SortInterval, Raw,
                                  Lab, LabMember, LabTeam, Session,
                                  Nwbfile, AnalysisNwbfile)

Let's first make sure that you're a part of the LorenLab `LabTeam`, so you'll have the right permissions for this tutorial.<br>Replace `your_name`, `your_email`, and `datajoint_username`, with your information.

In [None]:
your_name = 'Daniel Gramling'
your_email = 'gmail@gmail.com'
datajoint_username = 'user'

In [None]:
lab_member_list = np.unique(LabTeam.LabTeamMember().fetch('lab_member_name')).tolist()
lorenlab_team_members = (LabTeam().LabTeamMember() & {'team_name' : 'LorenLab'}).fetch('lab_member_name').tolist()
if your_name not in lab_member_list:
    LabMember().insert1([your_name, your_name.split()[0], your_name.split()[1]], skip_duplicates=True)
    LabMember.LabMemberInfo.insert([your_name, your_email, datajoint_username])
    LabTeam.LabTeamMember.insert1({'team_name' : 'LorenLab', 
                                   'lab_member_name' : your_name}, skip_duplicates=True)
    print(f'Hi {your_name}! You have just been added to the LabMember table and the LorenLab team. Congrats!')
elif your_name not in lorenlab_team_members:
    LabTeam.LabTeamMember.insert1({'team_name' : 'LorenLab', 
                                   'lab_member_name' : your_name}, skip_duplicates=True)
    print(f'Hi {your_name}! You have just been added to the LorenLab team. Congrats!')
else:
    print(f'Hi {your_name}! You are already on the team. Congrats!')

#### Setting the NWB filename to be looked at
NWB filenames take the form of an animal name plus the date of the recording.<br>For this tutorial, we will use the nwb file `'montague20200802_.nwb'`. The animal name is `'montague'` and the date of the recording is `'20200802'` the `'_'` indicates that this is a copy of the original NWB file.

In [None]:
nwb_file_name = 'montague20200802_.nwb'

This can also be set programmatically by setting the `animal_name` and searching for a specified `date` in available NWB files.

In [None]:
animal_name = 'montague'
date = '20200802'
nwb_files = (Session() & {'subject_id': animal_name}).fetch('nwb_file_name')
nwb_file_name = [file for file in nwb_files if date in file][0]

## Setting what part of a recording we want to sort
### SortGroup()
For each NWB file there will be multiple electrodes available to sort spikes from.<br>We commonly sort over multiple electrodes at a time, also referred to as a `SortGroup`.<br>This is accomplished by grouping electrodes according to what tetrode or shank of a probe they were on.

In [None]:
# Set sort group
SortGroup().set_group_by_shank(nwb_file_name)

Each electrode will have an `electrode_id` and be associated with an `electrode_group_name`, which will correspond with a `sort_group_id`. In this case, the data was recorded from a 32 tetrode (128 channel) drive, and thus results in 128 unique `electrode_id`, 32 unique `electrode_group_name`, and 32 unique `sort_group_id`. 

In [None]:
SortGroup.SortGroupElectrode & {'nwb_file_name': nwb_file_name}

In [None]:
sort_group_array = (SortGroup.SortGroupElectrode & {'nwb_file_name': nwb_file_name}).fetch('electrode_id',
                                                                                           'electrode_group_name',
                                                                                           'sort_group_id')
print(f"There are {len(np.unique(sort_group_array[0]))} unique electrode_id's, \
{len(np.unique(sort_group_array[1]))} unique electrode_group_name's, \
and {len(np.unique(sort_group_array[2]))} unique sort_group_id's")

### IntervalList()
Next, we make a decision about the time interval for our spike sorting. Let's re-examine `IntervalList`.

In [None]:
IntervalList & {'nwb_file_name' : nwb_file_name}

For our example, let's choose start with the first run interval (`02_r1`) as our sort interval. We first fetch `valid_times` for this interval.

In [None]:
interval_list_name = '02_r1'

In [None]:
interval_list = (IntervalList & {'nwb_file_name' : nwb_file_name,
                            'interval_list_name' : interval_list_name}).fetch1('valid_times')
print(f'IntervalList begins as a {np.round((interval_list[0][1] - interval_list[0][0]) / 60,0):g} min long epoch')

### SortInterval()
For brevity's sake, we'll select only the first 600 seconds of that 90 minute epoch as our sort interval. To do so, we first fetch `valid_times` of this interval, and then define our new sort interval as the first index of `interval_list` plus 600 seconds.

In [None]:
sort_interval = interval_list[0]
sort_interval_name = interval_list_name + '_first600'
sort_interval = np.copy(interval_list[0]) 
sort_interval[1] = sort_interval[0]+600

We can now add this `sort_interval` with the specified `sort_interval_name` `'02_r1_first600'` to the `SortInterval` table. The `SortInterval.insert()` function requires the arguments input as a dictionary with keys `nwb_file_name`, `sort_interval_name`, and `sort_interval`.

In [None]:
SortInterval.insert1({'nwb_file_name' : nwb_file_name,
                     'sort_interval_name' : sort_interval_name,
                     'sort_interval' : sort_interval}, skip_duplicates=True)

Now that we've inserted the entry into `SortInterval()` you can see that entry by querying `SortInterval()` using the `nwb_file_name` and `sort_interval_name`. 

In [None]:
SortInterval & {'nwb_file_name' : nwb_file_name, 'sort_interval_name': sort_interval_name}

Now using the `.fetch()` command, you can retrieve your user-defined sort interval from the `SortInterval` table.<br>A quick double-check will show that it is indeed a 600 second segment.

In [None]:
fetched_sort_interval = (SortInterval & {'nwb_file_name' : nwb_file_name,
                                      'sort_interval_name': sort_interval_name}).fetch('sort_interval')[0]
print(f'The sort interval goes from {fetched_sort_interval[0]} to {fetched_sort_interval[1]}, \
which is {(fetched_sort_interval[1] - fetched_sort_interval[0])} seconds. COOL!')

### SpikeSortingFilterParameters()
Let's first take a look at the `SpikeSortingFilterParameters()` table.

In [None]:
SpikeSortingFilterParameters()

Now let's set the filtering parameters. Here we insert the default parameters, and then fetch the default parameter dictionary.<br>Note the lack of `[0]` after the `fetch1()` command compared to previous uses of `fetch()`, since it will only return a single object.

In [None]:
SpikeSortingFilterParameters().insert_default()
filter_param_dict = (SpikeSortingFilterParameters() &
                     {'filter_parameter_set_name': 'default'}).fetch1('filter_parameter_dict')
print(f'{filter_param_dict}')

Adjust the `frequency_min` parameter, and insert that into `SpikeSortingFilterParameters()` as a new set of filtering parameters for hippocampal data, named `'franklab_default_hippocampus'`.

In [None]:
filter_param_dict['frequency_min'] = 600
SpikeSortingFilterParameters().insert1({'filter_parameter_set_name': 'franklab_default_hippocampus', 
                                       'filter_parameter_dict' : filter_param_dict}, skip_duplicates=True)

### SpikeSortingArtifactParameters()
Similarly, we set up the `SpikeSortingArtifactParameters` which can allow us to remove artifacts from the data.<br>
For the moment we just set up a `"none"` parameter set, which will do nothing when used

In [None]:
SpikeSortingArtifactDetectionParameters().insert_default()

#### Setting a key
Now we set up the parameters of the recording we are interested in, so we can get the recording extractor.<br>The `sort_group_id` refers back to the `SortGroup` table we populated at the beginning of the tutorial. We'll use `sort_group_id` 10 here. <br>Our `sort_interval_name` is the same as above: `'02_r1_first600'`.<br>Our `filter_param_name` and `artifact_param_name` are the same ones we just inserted into `SpikeSortingFilterParameters()` and `SpikeSortingArtifactDetectionParameters()`, respectively.<br>The `interval_list` was also set above as `'02_r1'`. Unlike `sort_interval_name`, which reflects our subsection of the recording, we keep `interval_list` unchanged from the original epoch name.

In [None]:
sort_group_id = 10
sort_interval_name = '02_r1_first600'
filter_param_name = 'franklab_default_hippocampus'
artifact_param_name = 'none'
interval_list = '02_r1'
lab_team = 'LorenLab'

Here we make a dictionary to hold all these values, which will make querying and inserting into tables all the easier moving forward.<br>We'll assign this to `ssr_key` as these values are relvant to the recording we'll use to spike sort, also referred to as the spike sorting recording **(ssr)** :-)

In [None]:
key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = sort_group_id
key['sort_interval_name'] = sort_interval_name
key['filter_parameter_set_name'] = filter_param_name
key['artifact_parameter_name'] = artifact_param_name
key['interval_list_name'] = interval_list
key['team_name'] = lab_team

ssr_key = key

### SpikeSortingRecordingSelection()
We now insert all of these parameters into the `SpikeSortingRecordingSelection()` table, which we will use to specify what time/tetrode/etc of the recording we want to extract. By specifying all of those names in the previous cell, we're identifying which entries from the `SortGroup`, `SortInterval`, `SpikeSortingFilterParameters`, `IntervalList`, `SpikeSortingArtifactDetectionParameters`, and `LabTeam` tables we want to pass into `SpikeSortingRecordingSelection`!

In [None]:
SpikeSortingRecordingSelection.insert1(ssr_key, skip_duplicates=True)
SpikeSortingRecordingSelection() & ssr_key

### SpikeSortingRecording()
And now we're ready to extract the recording! We use the `.proj()` command to pass along all of the primary keys from the `SpikeSortingRecordingSelection()` table to the `SpikeSortingRecording` table, so it knows exactly what to extract.<br>**Note**: we're using `ssr_key` to specify this exact set of parameters.<br>**Note 2**: This step might take a bit.

In [None]:
SpikeSortingRecording.populate([(SpikeSortingRecordingSelection & ssr_key).proj()])

#### Now we can see our recording in the table. _E x c i t i n g !_

In [None]:
SpikeSortingRecording() & ssr_key

### SpikeSortingWorkspace()
Now we need to populate the `SpikeSortingWorkspace` table to make this recording available via kachery (our server backend).

In [None]:
SpikeSortingWorkspace.populate()

In [None]:
SpikeSortingWorkspace() & ssr_key

A bit of an aside... you can now access the workspace using the `sortingview` package. Uncomment and run the cell below if you want to explore a bit.<br>The workspace is an object that contains the recording object and eventually the sorting object. 

In [None]:
# import sortingview as sv
# workspace = sv.load_workspace((SpikeSortingWorkspace() & ssr_key).fetch1('workspace_uri'))

### SpikeSorter() setup
For our example, we will be using `mountainsort4`. There are already some default parameters in the `SpikeSorterParameters()` table we'll `fetch`. 

In [None]:
SpikeSorter().insert_from_spikeinterface()
SpikeSorterParameters().insert_from_spikeinterface()

In [None]:
# Let's look at the default params
sorter_name='mountainsort4'
ms4_default_params = (SpikeSorterParameters & {'sorter_name' : sorter_name,
                                               'spikesorter_parameter_set_name' : 'default'}).fetch1()
print(ms4_default_params)

Now we can change these default parameters to line up more closely with our preferences. 

In [None]:
param_dict = ms4_default_params['parameter_dict']
# Detect downward going spikes (1 is for upward, 0 is for both up and down)
param_dict['detect_sign'] = -1 
# We will sort electrodes together that are within 100 microns of each other
param_dict['adjacency_radius'] = 100
param_dict['curation'] = False
# Turn filter off since we will filter it prior to starting sort
param_dict['filter'] = False
param_dict['freq_min'] = 0
param_dict['freq_max'] = 0
# Turn whiten off since we will whiten it prior to starting sort
param_dict['whiten'] = False
# set num_workers to be the same number as the number of electrodes
param_dict['num_workers'] = 4
param_dict['verbose'] = True
# set clip size as number of samples for 1.33 millisecond based on the sampling rate
param_dict['clip_size'] = np.int(1.33e-3 * (Raw & {'nwb_file_name' : nwb_file_name}).fetch1('sampling_rate'))
param_dict['noise_overlap_threshold'] = 0
param_dict

This set of parameters has already been inserted into the table as `'franklab_tetrode_hippocampus_30KHz'`.<br>We can take a look at these and insert a new `spikesorter_parameter_set_name` and `parameter_dict` into the `SpikeSorterParameters()` table if need be. 

In [None]:
SpikeSorterParameters() & {'sorter_name' : sorter_name}

In [None]:
(SpikeSorterParameters() & {'sorter_name' : sorter_name,
                           'spikesorter_parameter_set_name' : 'franklab_tetrode_hippocampus_30KHz'}).fetch1()

In [None]:
# Give a unique name here if parameters different than default
parameter_set_name = 'franklab_tetrode_hippocampus_30KHz'

Now we insert our parameters for use by the spike sorter into `SpikeSorterParameters()` and double-check that it made it in to the table. 

In [None]:
SpikeSorterParameters.insert1({'sorter_name': sorter_name,
                               'spikesorter_parameter_set_name': parameter_set_name,
                               'parameter_dict': param_dict}, skip_duplicates=True)
# Check that insert was successful
p = (SpikeSorterParameters & {'sorter_name': sorter_name, 'spikesorter_parameter_set_name': parameter_set_name}).fetch1()
p

### Gearing up to Spike Sort by adding to `SpikeSortingSelection()`

We now collect all the decisions we made up to here and put it into the `SpikeSortingSelection` table, which is specific to this recording and eventual sorting segment.<br>We'll add in a few parameters to our key and call it `ss_key` for spike sorting key now.<br>(**note**: the spike *sorter* parameters defined above are for the sorter, `mountainsort4` in this case.)

In [None]:
key = (SpikeSortingWorkspace & ssr_key).fetch1("KEY")
key['sorter_name'] = sorter_name
key['spikesorter_parameter_set_name'] = 'franklab_tetrode_hippocampus_30KHz'
ss_key = key
SpikeSortingSelection.insert1(ss_key, skip_duplicates=True)
(SpikeSortingSelection & ss_key)

### Running Spike Sorting
Now we can run spike sorting. It's nothing more than populating a table (`SpikeSorting`) based on the entries of `SpikeSortingSelection`.<br>**Note**: This will take a little bit

In [None]:
# `proj` gives you primary key"
SpikeSorting.populate([(SpikeSortingSelection & ss_key).proj()])

#### Check to make sure the table populated

In [None]:
SpikeSorting() & ss_key

### SpikeSortingMetricParameters()
#### Define quality metric parameters for curation with `SpikeSortingMetricParameters()` table

We're almost done. There are more parameters related to how to compute the quality metrics for curation. We just use the default options here. The default has already been inserted into the table as `'franklab_cluster_metrics_09-19-2021'`.<br>For this tutorial we'll go through the motions of adding it, regardless.

In [None]:
SpikeSortingMetricParameters()

Below we'll take a look at what the default set of metrics are. 

In [None]:
metric_dict = SpikeSortingMetricParameters().get_metric_dict()
metric_param_dict = SpikeSortingMetricParameters().get_metric_parameter_dict()
for k in metric_dict:
    print(f"'{k}': {metric_dict[k]}\n")

And now set the ones we want to calculate to `True`

In [None]:
metric_dict['noise_overlap'] = True
metric_dict['firing_rate'] = True
metric_dict['num_spikes'] = True
for k in metric_dict:
    print(f"'{k}': {metric_dict[k]}\n")

In [None]:
cluster_metrics_list_name = 'franklab_cluster_metrics_09-19-2021'

#### And now add the cluster metrics to the `SpikeSortingMetricParameters()` table.
**Note** we have `skip_duplicates=True`, so if an entry with the same name already exists in the table, a new one won't get inserted. 

In [None]:
SpikeSortingMetricParameters.insert1({'cluster_metrics_list_name' : cluster_metrics_list_name,
                            'metric_dict' : metric_dict, 
                            'metric_parameter_dict' : metric_param_dict}, skip_duplicates=True)


### Automatic Curation: AutomaticCurationParameters(), AutomaticCurationSelection()
#### Retrieve the default automatic curation parameters and add to `AutomaticCurationParameters()` table

In [None]:
param = AutomaticCurationParameters().get_default_parameters()
AutomaticCurationParameters().insert1({'automatic_curation_parameter_set_name':'none', 
                                      'automatic_curation_parameter_dict': param}, skip_duplicates=True)

#### Add an entry to `AutomaticCurationSelection()` to select those parameters for automatic curation of this sorting.
First we'll get the sorting-id from the `SpikeSorting` table. And then identify which entries from `AutomaticCurationParameters()` and `SpikeSortingMetricParameters()` we want to use during automatic curation.<br>This will all get added into a new automatic curation selection key (`acs_key`).<br>**Note** This is similar to how we added parameters to `SpikeSortingSelection()` prior to populating `SpikeSorting()`

In [None]:
acs_key = (SpikeSortingRecording & ssr_key).fetch1('KEY')
acs_key['sorting_id'] = (SpikeSorting & ss_key).fetch1('sorting_id')
acs_key['automatic_curation_parameter_set_name'] = 'none'
acs_key['cluster_metrics_list_name'] = cluster_metrics_list_name
AutomaticCurationSelection.insert1(acs_key, skip_duplicates=True)

In [None]:
(AutomaticCurationSelection() & acs_key)

### We'll stop here for now... there are some bugs below that need to be worked out :-)

Now we populate the `AutomaticCuration()` table, which in this case just computes the metrics and does not add labels.

In [None]:
AutomaticCuration.populate(acs_key)

In [None]:
AutomaticCuration() & acs_key

### Revisiting `SpikeSortingWorkspace()`.
To peform manual curation, we use the `figurl` interface.<br>`figurl` will load more quickly if we run `SpikeSortingWorkspace().precalculate()` beforehand.

In [None]:
SpikeSortingWorkspace().precalculate(ssr_key)

Here we'll use the sortingview backend to access the figurl *url*. We do this by loading the workspace that we set up while populating `SpikeSortingWorkspace`.<br>This workspace in tandem with the ids of the spike sorting and recording segment we extracted during this tutorial, will allow us to retrieve the url. We'll also use this opportunity to enable permissions for everyone to curate this sorting. 

In [None]:
import sortingview as sv
workspace = sv.load_workspace((SpikeSortingWorkspace() & ssr_key).fetch1('workspace_uri'))
sorting_id = acs_key['sorting_id']
recording_id = workspace.recording_ids[0]
url = workspace.experimental_spikesortingview(recording_id=recording_id, sorting_id=sorting_id,
                                                  label=workspace.label, include_curation=True)
member_emails = LabMember().LabMemberInfo().fetch('lab_member_name','google_user_name')
member_dict = [email for name, email in zip(member_emails[0], member_emails[1]) if name in lorenlab_team_members]
workspace.set_sorting_curation_authorized_users(sorting_id=sorting_id, user_ids=member_dict)
print(f'{url}')

Once you're done with manual curation through figurl, you can click 'Close Curation' or uncomment and run the cell below to close your curation.

In [None]:
# workspace.add_sorting_curation_action(sorting_id=sorting_id, action={'type':'CLOSE_CURATION'})

### CuratedSpikeSorting()
Now you can add the units (with the option for a new set of metrics) to the `CuratedSpikeSorting` table, which includes only accepted units.<br>This is accomplished by first adding an entry to `CuratedSpikeSortingSelection()` and then populating `CuratedSpikeSorting` from this.

In [None]:
css_key = (AutomaticCuration & acs_key).fetch1('KEY')
css_key['sorting_id']
css_key['final_cluster_metrics_list_name'] = cluster_metrics_list_name
CuratedSpikeSortingSelection.insert1(css_key, skip_duplicates=True)
CuratedSpikeSorting.populate(css_key)

And now you can see all your accepted units in the `Unit` table with `CuratedSpikeSorting`

In [None]:
CuratedSpikeSorting().Unit() & css_key