# Sorted Spikes Decoding

The mechanics of decoding with sorted spikes are largely similar to those of decoding with unsorted spikes. You should familiarize yourself with the [clusterless decoding tutorial](./42_Decoding_Clusterless.ipynb) before proceeding with this one.

The elements we will need to decode with sorted spikes are:
- `PositionGroup`
- `SortedSpikesGroup`
- `DecodingParameters`
- `encoding_interval`
- `decoding_interval`

This time, instead of extracting waveform features, we can proceed directly from the SpikeSortingOutput table to specify which units we want to decode. The rest of the decoding process is the same as before.



In [9]:
from pathlib import Path
import datajoint as dj

dj.config.load(
    Path("../dj_local_conf.json").absolute()
)  # load config for database connection info

## SortedSpikesGroup

`SortedSpikesGroup` is a child table of `SpikeSortingOutput` in the spikesorting pipeline. It allows us to group the spikesorting results from multiple 
sources (e.g. multiple terode groups or intervals) into a single entry. Here we will group together the spiking of multiple tetrode groups to use for decoding.


This table allows us filter units by their annotation labels from curation (e.g only include units labeled "good", exclude units labeled "noise") by defining parameters from `UnitSelectionParams`. When accessing data through `SortedSpikesGroup` the table will include only units with at least one label in `include_labels` and no labels in `exclude_labels`. We can look at those here:


In [1]:
from spyglass.spikesorting.analysis.v1.group import UnitSelectionParams

UnitSelectionParams().insert_default()

# look at the filter set we'll use here
unit_filter_params_name = "default_exclusion"
print(
    (
        UnitSelectionParams()
        & {"unit_filter_params_name": unit_filter_params_name}
    ).fetch1()
)
# look at full table
UnitSelectionParams()

[2024-02-02 12:06:04,725][INFO]: Connecting sambray@lmf-db.cin.ucsf.edu:3306
[2024-02-02 12:06:04,762][INFO]: Connected sambray@lmf-db.cin.ucsf.edu:3306


{'unit_filter_params_name': 'default_exclusion', 'include_labels': [], 'exclude_labels': ['noise', 'mua']}


unit_filter_params_name,include_labels,exclude_labels
all_units,=BLOB=,=BLOB=
default_exclusion,=BLOB=,=BLOB=
exclude_noise,=BLOB=,=BLOB=


Now we can make our sorted spikes group with this unit selection parameter

In [2]:
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
import spyglass.spikesorting.v1 as sgs

nwb_copy_file_name = "mediumnwb20230802_.nwb"

sorter_keys = {
    "nwb_file_name": nwb_copy_file_name,
    "sorter": "mountainsort4",
    "curation_id": 1,
}
# check the set of sorting we'll use
(sgs.SpikeSortingSelection & sorter_keys) * SpikeSortingOutput.CurationV1

sorting_id,merge_id,recording_id,sorter,sorter_param_name,nwb_file_name  name of the NWB file,interval_list_name  descriptive name of this interval list,curation_id
642242ff-5f0e-45a2-bcc1-ca681f37b4a3,75286bf3-f876-4550-f235-321f2a7badef,01c5b8e9-933d-4f1e-9a5d-c494276edb3a,mountainsort4,franklab_tetrode_hippocampus_30KHz,mediumnwb20230802_.nwb,0a6611b3-c593-4900-a715-66bb1396940e,1
a4b5a94d-ba41-4634-92d0-1d31c9daa913,143dff79-3779-c0d2-46fe-7c5040404219,a8a1d29d-ffdf-4370-8b3d-909fef57f9d4,mountainsort4,franklab_tetrode_hippocampus_30KHz,mediumnwb20230802_.nwb,3d782852-a56b-4a9d-89ca-be9e1a15c957,1
874775be-df0f-4850-8f88-59ba1bbead89,a900c1c8-909d-e583-c377-e98c4f0deebf,747f4eea-6df3-422b-941e-b5aaad7ec607,mountainsort4,franklab_tetrode_hippocampus_30KHz,mediumnwb20230802_.nwb,9cf9e3cd-7115-4b59-a718-3633725d4738,1


In [6]:
from spyglass.decoding.v1.sorted_spikes import SortedSpikesGroup

SortedSpikesGroup()

nwb_file_name  name of the NWB file,unit_filter_params_name,sorted_spikes_group_name
mediumnwb20230802_.nwb,all_units,test_group


In [7]:
# get the merge_ids for the selected sorting
spikesorting_merge_ids = (
    (sgs.SpikeSortingSelection & sorter_keys) * SpikeSortingOutput.CurationV1
).fetch("merge_id")

# create a new sorted spikes group
unit_filter_params_name = "default_exclusion"
SortedSpikesGroup().create_group(
    group_name="test_group",
    nwb_file_name=nwb_copy_file_name,
    keys=[
        {"spikesorting_merge_id": merge_id}
        for merge_id in spikesorting_merge_ids
    ],
    unit_filter_params_name=unit_filter_params_name,
)
# check the new group
SortedSpikesGroup & {
    "nwb_file_name": nwb_copy_file_name,
    "sorted_spikes_group_name": "test_group",
}

nwb_file_name  name of the NWB file,unit_filter_params_name,sorted_spikes_group_name
mediumnwb20230802_.nwb,all_units,test_group
mediumnwb20230802_.nwb,default_exclusion,test_group


In [8]:
# look at the sorting within the group we just made
SortedSpikesGroup.Units & {
    "nwb_file_name": nwb_copy_file_name,
    "sorted_spikes_group_name": "test_group",
    "unit_filter_params_name": unit_filter_params_name,
}

nwb_file_name  name of the NWB file,unit_filter_params_name,sorted_spikes_group_name,spikesorting_merge_id
mediumnwb20230802_.nwb,default_exclusion,test_group,143dff79-3779-c0d2-46fe-7c5040404219
mediumnwb20230802_.nwb,default_exclusion,test_group,75286bf3-f876-4550-f235-321f2a7badef
mediumnwb20230802_.nwb,default_exclusion,test_group,a900c1c8-909d-e583-c377-e98c4f0deebf


## Model parameters

As before we can specify the model parameters. The only difference is that we will use the `ContFragSortedSpikesClassifier` instead of the `ContFragClusterlessClassifier`.

In [9]:
from spyglass.decoding.v1.core import DecodingParameters
from non_local_detector.models import ContFragSortedSpikesClassifier


DecodingParameters.insert1(
    {
        "decoding_param_name": "contfrag_sorted",
        "decoding_params": ContFragSortedSpikesClassifier(),
        "decoding_kwargs": dict(),
    },
    skip_duplicates=True,
)

DecodingParameters()

decoding_param_name  a name for this set of parameters,decoding_params  initialization parameters for model,decoding_kwargs  additional keyword arguments
contfrag_clusterless,=BLOB=,=BLOB=
contfrag_clusterless_0.5.13,=BLOB=,=BLOB=
contfrag_clusterless_6track,=BLOB=,=BLOB=
contfrag_sorted,=BLOB=,=BLOB=
contfrag_sorted_0.5.13,=BLOB=,=BLOB=
j1620210710_contfrag_clusterless_1D,=BLOB=,=BLOB=
j1620210710_test_contfrag_clusterless,=BLOB=,=BLOB=
MS2220180629_contfrag_sorted,=BLOB=,=BLOB=
ms_lineartrack_2023_contfrag_sorted,=BLOB=,=BLOB=
ms_lineartrack_contfrag_clusterless,=BLOB=,=BLOB=


### 1D Decoding

As in the clusterless notebook, we can decode 1D position if we specify the `track_graph`, `edge_order`, and `edge_spacing` parameters in the `Environment` class constructor. See the [clusterless decoding tutorial](./42_Decoding_Clusterless.ipynb) for more details.

## Decoding

Now we can decode the position using the sorted spikes using the `SortedSpikesDecodingSelection` table. Here we assume that `PositionGroup` has been specified as in the clusterless decoding tutorial.

In [2]:
selection_key = {
    "sorted_spikes_group_name": "test_group",
    "unit_filter_params_name": "default_exclusion",
    "position_group_name": "test_group",
    "decoding_param_name": "contfrag_sorted",
    "nwb_file_name": "mediumnwb20230802_.nwb",
    "encoding_interval": "pos 0 valid times",
    "decoding_interval": "test decoding interval",
    "estimate_decoding_params": False,
}

from spyglass.decoding import SortedSpikesDecodingSelection

SortedSpikesDecodingSelection.insert1(
    selection_key,
    skip_duplicates=True,
)

In [14]:
from spyglass.decoding.v1.sorted_spikes import SortedSpikesDecodingV1

SortedSpikesDecodingV1.populate(selection_key)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Encoding models:   0%|          | 0/54 [00:00<?, ?cell/s]

Non-Local Likelihood:   0%|          | 0/54 [00:00<?, ?cell/s]

We verify that the results have been inserted into the `DecodingOutput` merge table.

In [3]:
from spyglass.decoding.decoding_merge import DecodingOutput

DecodingOutput.SortedSpikesDecodingV1 & selection_key

merge_id,nwb_file_name  name of the NWB file,unit_filter_params_name,sorted_spikes_group_name,position_group_name,decoding_param_name  a name for this set of parameters,encoding_interval  descriptive name of this interval list,decoding_interval  descriptive name of this interval list,estimate_decoding_params  whether to estimate the decoding parameters
42e9e7f9-a6f2-9242-63ce-94228bc72743,mediumnwb20230802_.nwb,default_exclusion,test_group,test_group,contfrag_sorted,pos 0 valid times,test decoding interval,0


We can load the results as before:

In [6]:
results = (SortedSpikesDecodingV1 & selection_key).fetch_results()
results