## NWB-Datajoint tutorial 1

**Note: make a copy of this notebook and run the copy to avoid git conflicts in the future**

This is the first in a multi-part tutorial on the NWB-Datajoint pipeline used in Loren Frank's lab, UCSF. It demonstrates how to run spike sorting within the pipeline.

If you have not done [tutorial 0](0_intro.ipynb) yet, make sure to do so before proceeding.

Let's start by importing the `nwb_datajoint` package, along with a few others. 

In [None]:
import os
import numpy as np
import datajoint as dj


In [None]:


#import nwb_datajoint as nd

# ignore datajoint+jupyter async warnings
import warnings
warnings.simplefilter('ignore', category=DeprecationWarning)
warnings.simplefilter('ignore', category=ResourceWarning)
os.environ['NWB_DATAJOINT_TEMP_DIR']="/stelmo/nwb/tmp"
os.environ['KACHERY_STORAGE_DIR']="/stelmo/nwb/kachery-storage"
os.environ['FIGURL_CHANNEL']="franklab2"



In [None]:
# import tables so that we can call them easily
from nwb_datajoint.common import (RawPosition, HeadDir, Speed, LinPos, StateScriptFile, VideoFile,
                                  IntervalPositionInfo, IntervalLinearizedPosition,
                                  DataAcquisitionDevice, CameraDevice, Probe,
                                  DIOEvents,
                                  ElectrodeGroup, Electrode, Raw, SampleCount,
                                  LFPSelection, LFP, LFPBandSelection, LFPBand,
                                  SortGroup, SpikeSortingFilterParameters, SpikeSortingArtifactDetectionParameters,
                                  SpikeSortingRecordingSelection, SpikeSortingRecording, 
                                  SpikeSortingWorkspace, 
                                  SpikeSorter, SpikeSorterParameters, SortingID,
                                  SpikeSortingSelection, SpikeSorting, 
                                  SpikeSortingMetricParameters,
                                  ModifySortingParameters, ModifySortingSelection, ModifySorting, 
                                  AutomaticCurationParameters, AutomaticCurationSelection,
                                  AutomaticCuration,
                                  CuratedSpikeSortingSelection, CuratedSpikeSorting,
                                  UnitInclusionParameters,
                                  FirFilter,
                                  IntervalList, SortInterval,
                                  Lab, LabMember, LabTeam, Institution,
                                  BrainRegion,
                                  SensorData,
                                  Session, ExperimenterList,
                                  Subject,
                                  Task, TaskEpoch,
                                  Nwbfile, AnalysisNwbfile, 
                                  KacheryChannel, NwbfileKacherySelection, NwbfileKachery,
                                  AnalysisNwbfileKacherySelection, AnalysisNwbfileKachery)

In [None]:
poskey = {'nwb_file_name': 'chimi20200216_new_.nwb', 'position_info_param_name':'default_decoding'}
#IntervalPositionInfo. 
#IntervalLinearizedPosition 
lposkey= {'position_info_param_name': 'default',  'nwb_file_name': 'chimi20200216_new_.nwb', 'interval_list_name': 'pos 1 valid times',  'track_graph_name': '6 arm', 'linearization_param_name': 'default'}

In [None]:
ipi = (IntervalPositionInfo & poskey).fetch1()
AnalysisNwbfileKacherySelection().insert1({'channel_name':'franklab2', 'analysis_file_name': ipi['analysis_file_name']}, skip_duplicates=True)
AnalysisNwbfileKachery.populate()
AnalysisNwbfileKachery()

In [None]:
ilp = (IntervalLinearizedPosition & lposkey).fetch1()
AnalysisNwbfileKacherySelection().insert1({'channel_name':'franklab2', 'analysis_file_name': ilp['analysis_file_name']}, skip_duplicates=True)
ilp

### nwb_file_name = 'despereaux20191125_.nwb'

In [None]:
SpikeSortingRecording & {'nwb_file_name':nwb_file_name}

In [None]:
key = {'nwb_file_name': 'chimi20200216_new_.nwb', 'position_info_param_name':'default_decoding'}

In [None]:
key = {'nwb_file_name':nwb_file_name}
(SpikeSortingRecording & key)

In [None]:
key = {'nwb_file_name':  nwb_file_name}
SpikeSortingWorkspace().url(key)

In [None]:
from nwb_datajoint.decoding import UnitMarkParameters, UnitMarks, MarkParameters

In [None]:
UnitMarks.populate()

Set up the lab members and team information for this sort

In [None]:
#Uncomment to set sort group
#SortGroup().set_group_by_electrode_group(nwb_file_name)

In [None]:
SortGroup.SortGroupElectrode & {'nwb_file_name': nwb_file_name}

#### Define sort interval
Next, we make a decision about the time interval for our spike sorting. Let's re-examine `IntervalList`.

In [None]:
IntervalList & {'nwb_file_name' : nwb_file_name}

For our example, let's choose the first 600 seconds of the first run interval (`02_r1`) as our sort interval. To do so, we first fetch `valid_times` of this interval, define our new sort interval, and add this to the `SortInterval` table.

In [None]:
interval_list_name = '02_r1'

In [None]:
interval_list = (IntervalList & {'nwb_file_name' : nwb_file_name,
                            'interval_list_name' : interval_list_name}).fetch1('valid_times')
print(interval_list)

In [None]:
sort_interval = interval_list[0]
sort_interval_name = interval_list_name
sort_interval = np.copy(interval_list[0]) 
sort_interval[1] = sort_interval[0]+300
sort_interval_name = 'test'

In [None]:
# Check out SortInterval
(SortInterval & {'nwb_file_name' : nwb_file_name})

In [None]:
# Specify the required attributes.
# This time, the entries take the form of a dictionary.
#SortInterval.insert1({'nwb_file_name' : nwb_file_name,
#                      'sort_interval_name' : sort_interval_name,
#                      'sort_interval' : sort_interval}, replace=True)

In [None]:
# See results
SortInterval & {'nwb_file_name' : nwb_file_name, 'sort_interval_name': sort_interval_name}

Now set the filtering parameters. Here we insert the default parameters and a new set of filtering parameters for hippocampal data
|

In [None]:
SpikeSortingFilterParameters().insert_default()
filter_param_dict = SpikeSortingFilterParameters.fetch('filter_parameter_dict')
filter_param_dict = filter_param_dict[0]

In [None]:
filter_param_dict['frequency_min'] = 600
SpikeSortingFilterParameters().insert1({'filter_parameter_set_name': 'franklab_default_hippocampus', 
                                       'filter_parameter_dict' : filter_param_dict}, skip_duplicates=True)

Similarly, we set up the SpikeSortingArtifactParameters which can allow us to remove artifacts from the data
For the moment we just set up a "none" parameter set which will do nothing when used

In [None]:
SpikeSortingArtifactDetectionParameters().insert_default()

Now we set up the recording parameters so we can get the recording extractor

In [None]:
sort_group_id = 2 # use sort group 2
sort_interval_name = 'test'
filter_param_name = 'franklab_default_hippocampus'
artifact_param_name = 'none'
interval_list = '02_r1'
lab_team = 'Loren Frank'

In [None]:
# collect the params
key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = sort_group_id
key['filter_parameter_set_name'] = filter_param_name
key['sort_interval_name'] = sort_interval_name
key['artifact_parameter_name'] = artifact_param_name
key['interval_list_name'] = interval_list
key['team_name'] = 'Loren Frank'

ssr_key = key

In [None]:
SpikeSortingRecordingSelection()

In [None]:
SpikeSortingRecordingSelection.insert1(key, skip_duplicates=True)

In [None]:
SpikeSortingRecording.populate()

Now we need to populate the SpikeSortingWorkspace table to make this recording available via kachery

In [None]:
SpikeSortingRecording()

In [None]:
SpikeSortingWorkspace.populate()


For our example, we will be using `mountainsort4`.

In [None]:
#SpikeSortingWorkspace().url(key)

In [None]:
SpikeSorter().insert_from_spikeinterface()
SpikeSorterParameters().insert_from_spikeinterface()

In [None]:
sorter_name='mountainsort4'

In [None]:
# Let's look at the default params
ms4_default_params = (SpikeSorterParameters & {'sorter_name' : sorter_name,
                                               'spikesorter_parameter_set_name' : 'default'}).fetch1()
print(ms4_default_params)

In [None]:
# Change the default params
param_dict = ms4_default_params['parameter_dict']
# Detect upward downward going spikes
param_dict['detect_sign'] = -1 
#We will sort electrodes together that are within 100 microns of each other
param_dict['adjacency_radius'] = 100
param_dict['curation'] = False
# Turn filter off since we will filter it prior to starting sort
param_dict['filter'] = False
param_dict['freq_min'] = 0
param_dict['freq_max'] = 0
# Turn whiten off since we will whiten it prior to starting sort
param_dict['whiten'] = False
# set num_workers to be the same number as the number of electrodes
param_dict['num_workers'] = 4
param_dict['verbose'] = True
# set clip size as number of samples for 1.33 millisecond
param_dict['clip_size'] = np.int(1.33e-3 * (Raw & {'nwb_file_name' : nwb_file_name}).fetch1('sampling_rate'))
param_dict['noise_overlap_threshold'] = 0



In [None]:
param_dict

In [None]:
# Give a unique name here
parameter_set_name = 'franklab_tetrode_hippocampus_30KHz'
SpikeSorterParameters()

In [None]:
# Insert
SpikeSorterParameters.insert1({'sorter_name': sorter_name,
                               'spikesorter_parameter_set_name': parameter_set_name,
                               'parameter_dict': param_dict}, skip_duplicates=True)

In [None]:
# Check that insert was successful
#p = (SpikeSorterParameters & {'sorter_name': sorter_name, 'parameter_set_name': parameter_set_name}).fetch1()
p = (SpikeSorterParameters & {'sorter_name': sorter_name}).fetch()
p

#### Bringing everything together

We now collect all the decisions we made up to here and put it into `SpikeSortingSelection` table (note: this is different from spike sor*ter* parameters defined above).

In [None]:
key = (SpikeSortingWorkspace & ssr_key).fetch1("KEY")
key['sorter_name'] = sorter_name
key['spikesorter_parameter_set_name'] = 'franklab_tetrode_hippocampus_30KHz'
ss_key = key

In [None]:
# insert
SpikeSortingSelection.insert1(key, skip_duplicates=True)

In [None]:
#(SpikeSortingParameters & {'nwb_file_name' : nwb_file_name, 'sort_interval_name' : sort_interval_name}).delete()

In [None]:
# inspect
(SpikeSortingSelection & {'nwb_file_name' : nwb_file_name})

#### Running spike sorting
Now we can run spike sorting. As we said it's nothing more than populating another table (`SpikeSorting`) from the entries of `SpikeSortingParameters`.

In [None]:
# Specify entry (otherwise runs everything in SpikeSortingParameters)
# `proj` gives you primary key"
SpikeSorting.populate([(SpikeSortingSelection & {'nwb_file_name' : nwb_file_name, 'sort_interval_name' : sort_interval_name}).proj()])

In [None]:
#SpikeSortingWorkspace().url(key)

#### Define quality metric parameters

We're almost done. There are more parameters related to how to compute the quality metrics for curation. We just use the default options here. 

In [None]:
SpikeSortingMetricParameters()

In [None]:
metric_dict = SpikeSortingMetricParameters().get_metric_dict()
metric_param_dict = SpikeSortingMetricParameters().get_metric_parameter_dict()

In [None]:
for k in metric_dict:
    print(f"'{k}': {metric_dict[k]}\n")

In [None]:
metric_dict['noise_overlap'] = True
metric_dict['firing_rate'] = True
metric_dict['num_spikes'] = True
for k in metric_dict:
    print(f"'{k}': {metric_dict[k]}\n")

In [None]:
cluster_metrics_list_name = 'franklab_cluster_metrics_09-19-2021'

In [None]:
#(SpikeSortingMetricParameters & {'cluster_metrics_list_name' : cluster_metrics_list_name}).delete()

Add the cluster metrics to the table if they are not there already.

In [None]:
SpikeSortingMetricParameters.insert1({'cluster_metrics_list_name' : cluster_metrics_list_name,
                            'metric_dict' : metric_dict, 
                            'metric_parameter_dict' : metric_param_dict}, skip_duplicates=True)


Add the default Automatic curation parameters

In [None]:
param = AutomaticCurationParameters().get_default_parameters()
AutomaticCurationParameters().insert1({'automatic_curation_parameter_set_name':'none', 
                                      'automatic_curation_parameter_dict': param}, skip_duplicates=True)

Now add an entry to select those parameters for automatic curation of this recording

In [None]:
# first get the sorting ID
acs_key = (SpikeSortingRecording & ssr_key).fetch1('KEY')
acs_key['sorting_id'] = (SpikeSorting & ss_key).fetch1('sorting_id')
acs_key['automatic_curation_parameter_set_name'] = 'none'
acs_key['cluster_metrics_list_name'] = cluster_metrics_list_name
AutomaticCurationSelection.insert1(acs_key, skip_duplicates=True)

Now we populate the Autocuration table, which in this case just computes the metrics and does not add labels.

In [None]:
#AutomaticCuration.delete()

In [None]:
AutomaticCuration.populate(acs_key)

We can now curate the recording using the figurl interface. To do so, we get the figurl link for this recording

In [None]:
SpikeSortingWorkspace().url(ssr_key)

Once you're done with manual curation, you can add the units (with an optional new set of metrics) to the final CuratedSortingTable which includes only accepted units.

In [None]:
css_key = (AutomaticCuration & acs_key).fetch1('KEY')
css_key['sorting_id']
css_key['final_cluster_metrics_list_name'] = cluster_metrics_list_name
CuratedSpikeSortingSelection.insert1(css_key, skip_duplicates=True)

In [None]:
CuratedSpikeSorting.populate(css_key)

In [None]:
CuratedSpikeSorting.Unit()

In [None]:
sort_groups = (SortGroup & {'nwb_file_name' : nwb_file_name}).fetch('sort_group_id')
sort_groups

In [None]:
SpikeSorting()

In [None]:
dj.ERD(SpikeSorting)+5-6

In [None]:
dj.ERD(ModifySorting)+3-3

In [None]:
units = CuratedSpikeSorting().Unit().fetch()

In [None]:
units['noise_overlap']