# Clusterless Decoding

## Overview

_Developer Note:_ if you may make a PR in the future, be sure to copy this
notebook, and use the `gitignore` prefix `temp` to avoid future conflicts.

This is one notebook in a multi-part series on Spyglass.

- To set up your Spyglass environment and database, see
  [the Setup notebook](./00_Setup.ipynb)
- This tutorial assumes you've already 
  [extracted waveforms](./41_Extracting_Clusterless_Waveform_Features.ipynb), as well as loaded 
  [position data](./20_Position_Trodes.ipynb). If 1D decoding, this data should also be
  [linearized](./24_Linearization.ipynb).

Clusterless decoding can be performed on either 1D or 2D data. We will start with 2D data.

## Elements of Clusterless Decoding
- **Position Data**: This is the data that we want to decode. It can be 1D or 2D.
- **Spike Waveform Features**: These are the features that we will use to decode the position data.
- **Decoding Model Parameters**: This is how we define the model that we will use to decode the position data.

## Grouping Data
An important concept will be groups. Groups are tables that allow use to specify collections of data. We will use groups in two situations here:
1. Because we want to decode from more than one tetrode (or probe), so we will create a group that contains all of the tetrodes that we want to decode from. 
2. Similarly, we will create a group for the position data that we want to decode, so that we can decode from position data from multiple sessions.

### Grouping Waveform Features
Let's start with grouping the Waveform Features. We will first inspect the waveform features that we have extracted to figure out the primary keys of the data that we want to decode from. We need to use the tables `SpikeSortingSelection` and `SpikeSortingOutput` to figure out the `merge_id` associated with `nwb_file_name` to get the waveform features associated with the NWB file of interest.


In [5]:
from pathlib import Path
import datajoint as dj

dj.config.load(
    Path("../dj_local_conf.json").absolute()
)  # load config for database connection info

In [6]:
from spyglass.spikesorting.merge import SpikeSortingOutput
import spyglass.spikesorting.v1 as sgs
from spyglass.decoding.v1.waveform_features import UnitWaveformFeaturesSelection


nwb_copy_file_name = "mediumnwb20230802_.nwb"

sorter_keys = {
    "nwb_file_name": nwb_copy_file_name,
    "sorter": "clusterless_thresholder",
    "sorter_param_name": "default_clusterless",
}

feature_key = {"features_param_name": "amplitude"}

(sgs.SpikeSortingSelection & sorter_keys) * SpikeSortingOutput.CurationV1 * (
    UnitWaveformFeaturesSelection & feature_key
)

[2023-12-25 16:02:39,751][INFO]: Connecting root@localhost:3306
[2023-12-25 16:02:39,838][INFO]: Connected root@localhost:3306


sorting_id,merge_id,features_param_name  a name for this set of parameters,recording_id,sorter,sorter_param_name,nwb_file_name  name of the NWB file,interval_list_name  descriptive name of this interval list,curation_id
0cbd8579-6c48-4506-a116-e27e9b89f174,86acdb0f-84f0-73a2-a851-1f8305cd2e41,amplitude,458ba3c2-3a08-4291-af10-c74d823330d4,clusterless_thresholder,default_clusterless,mediumnwb20230802_.nwb,b52d303b-12b1-4584-8f35-63dde543836c,0
22283413-433c-4c6f-b2fa-f82a10327df7,46829e10-1984-99a1-65a3-2b485a2f037f,amplitude,8e3daa47-2ee6-435f-892b-f095b1c5aa1a,clusterless_thresholder,default_clusterless,mediumnwb20230802_.nwb,1536b082-018d-4674-b562-cac09d298b7f,0
22e6bc74-e755-440c-a507-f9292fd494c9,ec308784-2bfb-dd90-147c-e4d44e5f649b,amplitude,7fb98af6-d486-439f-ae1b-7abdfddae56b,clusterless_thresholder,default_clusterless,mediumnwb20230802_.nwb,b1a34880-c87b-403f-8a3b-c346e614c782,0
32fa3502-7fa9-469b-a7a1-0e0e670fe28e,4b3065e5-76c2-bd48-32a1-ae62484f9314,amplitude,ef989f5a-3cf4-488d-be1f-660970fdfd69,clusterless_thresholder,default_clusterless,mediumnwb20230802_.nwb,4fa14f8e-14a2-49ce-a1b4-6c447fdc3a1e,0
338442ef-821c-401e-91ba-8eec27490701,609aeb54-dc2e-52d3-91bf-1728e0a2cf09,amplitude,86d39675-d6b0-4697-b336-9b2b1766d8f3,clusterless_thresholder,default_clusterless,mediumnwb20230802_.nwb,3824e250-27e5-4cc5-a49d-56d9e37b3ad8,0
43495249-ab6b-4067-b04a-11401b998215,88492b1c-f4a9-9669-bb5b-7f1573015187,amplitude,ded4b85c-a2f8-465d-ab21-504905c06403,clusterless_thresholder,default_clusterless,mediumnwb20230802_.nwb,7f38783a-215f-47c1-853b-2e1ddc941d7f,0
43a6942c-668e-44a1-aa5b-a7aebc5c424a,f515c07f-fc80-b28a-750d-d0d5491259f4,amplitude,078776e3-1b9c-4755-bef8-b9201bcdd717,clusterless_thresholder,default_clusterless,mediumnwb20230802_.nwb,ac7875e6-a370-4cf3-a74e-263f0d98a17a,0
4986cd16-515f-441a-8653-36cf3a312ca0,f4e29a80-ec96-dbe8-7081-425ac311b74c,amplitude,db9d73cf-f9e2-46b4-8eb7-a8d059d99bf6,clusterless_thresholder,default_clusterless,mediumnwb20230802_.nwb,d630e3bb-10b2-4466-9c20-1db14565bcf4,0
59e06873-aae3-438a-8bc1-2988315b3d7e,d7754d5f-af01-19f4-3fdc-c9635081667a,amplitude,aeda79a6-8442-4a39-93b7-bce6da6fcacd,clusterless_thresholder,default_clusterless,mediumnwb20230802_.nwb,6bea1980-8ea0-4160-afc3-aef93743fb9d,0
67b0fafd-693f-4a26-a20b-100c0a4731a7,2567bf67-bc67-47a5-aa2a-2bce19da232d,amplitude,47337655-182c-4c9d-b79d-ea0c6ce51b34,clusterless_thresholder,default_clusterless,mediumnwb20230802_.nwb,b7fc2304-9cbf-4d85-8028-39cab674273a,0


In [91]:
from spyglass.decoding.v1.waveform_features import UnitWaveformFeaturesSelection

spikesorting_merge_id = (
    (sgs.SpikeSortingSelection & sorter_keys)
    * SpikeSortingOutput.CurationV1
    * (UnitWaveformFeaturesSelection & feature_key)
).fetch("merge_id")

waveform_selection_keys = [
    {"merge_id": merge_id, "features_param_name": "amplitude"}
    for merge_id in spikesorting_merge_id
]

UnitWaveformFeaturesSelection & waveform_selection_keys

merge_id,features_param_name  a name for this set of parameters
00763b68-d663-c446-0555-1f2622d7da50,amplitude
03954edd-f8fd-3dd9-cd10-f0eee47d6b3d,amplitude
0720e5f2-625e-09d2-b522-ca2652c09f2a,amplitude
153954b2-b230-cb1f-749d-f977a22eaae9,amplitude
189fb8c6-f964-00a9-f392-a9dbb138ea63,amplitude
2567bf67-bc67-47a5-aa2a-2bce19da232d,amplitude
26310ce7-9ac3-4159-99f8-a3ad17037235,amplitude
411dff13-44f0-3e03-e867-689ae275e418,amplitude
43a98eab-1fa6-184b-1f09-2e923984b03a,amplitude
46829e10-1984-99a1-65a3-2b485a2f037f,amplitude


In [74]:
# from spyglass.common import BrainRegion, Electrode

# (
#     (
#         sgs.SpikeSortingRecordingSelection
#         * sgs.SortGroup.SortGroupElectrode
#         * Electrode
#         * BrainRegion
#     )
#     & [
#         {
#             "recording_id": recording_id,
#         }
#         for recording_id in (
#             SpikeSortingOutput.CurationV1() * sgs.SpikeSortingSelection()
#         ).fetch("recording_id")
#     ]
# )

We will create a group called `test_group` that contains all of the tetrodes that we want to decode from. We will use the `create_group` function to create this group. This function takes two arguments: the name of the group, and the keys of the tables that we want to include in the group.

In [93]:
from spyglass.decoding.v1.clusterless import UnitWaveformFeaturesGroup

UnitWaveformFeaturesGroup().create_group("test_group", waveform_selection_keys)
UnitWaveformFeaturesGroup & {"waveform_features_group_name": "test_group"}

waveform_features_group_name
test_group


We can see that we successfully associated "test_group" with the tetrodes that we want to decode from by using the `get_group` function.

In [94]:
UnitWaveformFeaturesGroup.UnitFeatures & {
    "waveform_features_group_name": "test_group"
}

waveform_features_group_name,merge_id,features_param_name  a name for this set of parameters
test_group,00763b68-d663-c446-0555-1f2622d7da50,amplitude
test_group,03954edd-f8fd-3dd9-cd10-f0eee47d6b3d,amplitude
test_group,0720e5f2-625e-09d2-b522-ca2652c09f2a,amplitude
test_group,153954b2-b230-cb1f-749d-f977a22eaae9,amplitude
test_group,189fb8c6-f964-00a9-f392-a9dbb138ea63,amplitude
test_group,2567bf67-bc67-47a5-aa2a-2bce19da232d,amplitude
test_group,26310ce7-9ac3-4159-99f8-a3ad17037235,amplitude
test_group,411dff13-44f0-3e03-e867-689ae275e418,amplitude
test_group,43a98eab-1fa6-184b-1f09-2e923984b03a,amplitude
test_group,46829e10-1984-99a1-65a3-2b485a2f037f,amplitude


### Grouping Position Data

We will now create a group called `02_r1` that contains all of the position data that we want to decode from. As before, we will use the `create_group` function to create this group. This function takes two arguments: the name of the group, and the keys of the tables that we want to include in the group.

We use the the `PositionOutput` table to figure out the `merge_id` associated with `nwb_file_name` to get the position data associated with the NWB file of interest. In this case, we only have one position to insert, but we could insert multiple positions if we wanted to decode from multiple sessions.

Note that the position data sampling frequency is what determines the time step of the decoding. In this case, the position data sampling frequency is 30 Hz, so the time step of the decoding will be 1/30 seconds. In practice, you will want to use a smaller time step such as 500 Hz. This will allow you to decode at a finer time scale. To do this, you will want to interpolate the position data to a higher sampling frequency as shown in the [position trodes notebook](./20_Position_Trodes.ipynb).

You will also want to specify the name of the position variables if they are different from the default names. The default names are `position_x` and `position_y`.

In [97]:
from spyglass.position import PositionOutput

PositionOutput.TrodesPosV1 & {"nwb_file_name": nwb_copy_file_name}

merge_id,nwb_file_name  name of the NWB file,interval_list_name  descriptive name of this interval list,trodes_pos_params_name  name for this set of parameters
a95d0105-de87-9c2f-85cf-f940e0490bee,mediumnwb20230802_.nwb,pos 0 valid times,default


In [104]:
from spyglass.decoding.v1.clusterless import PositionGroup

position_merge_ids = (
    PositionOutput.TrodesPosV1
    & {
        "nwb_file_name": nwb_copy_file_name,
        "interval_list_name": "pos 0 valid times",
        "trodes_pos_params_name": "default",
    }
).fetch("merge_id")

PositionGroup().create_group(
    "test_group",
    [{"pos_merge_id": merge_id} for merge_id in position_merge_ids],
)

PositionGroup & {"position_group_name": "test_group"}

position_group_name,position_variables  list of position variables to decode
test_group,=BLOB=


In [106]:
(PositionGroup & {"position_group_name": "test_group"}).fetch1(
    "position_variables"
)

['position_x', 'position_y']

In [107]:
PositionGroup.Position & {"position_group_name": "test_group"}

position_group_name,pos_merge_id
test_group,a95d0105-de87-9c2f-85cf-f940e0490bee


## Decoding Model Parameters

We will use the `non_local_detector` package to decode the data. This package is highly flexible and allows several different types of models to be used. In this case, we will use the `ContFragClusterlessClassifier` to decode the data. This has two discrete states: Continuous and Fragmented, which correspond to different types of movement models. To read more about this model, see:
> Denovellis, E.L., Gillespie, A.K., Coulter, M.E., Sosa, M., Chung, J.E., Eden, U.T., and Frank, L.M. (2021). Hippocampal replay of experience at real-world speeds. eLife 10, e64505. [10.7554/eLife.64505](https://doi.org/10.7554/eLife.64505).

Let's first look at the model and the default parameters:


In [110]:
from non_local_detector.models import ContFragClusterlessClassifier

ContFragClusterlessClassifier()

You can change these parameters like so: 

In [112]:
from non_local_detector.models import ContFragClusterlessClassifier

ContFragClusterlessClassifier(
    clusterless_algorithm_params={
        "block_size": 10000,
        "position_std": 12.0,
        "waveform_std": 24.0,
    },
)

To insert these parameters into the database, we need to use the following syntax, we need to convert the initialized model into a dictionary, and then insert the dictionary into the database.

In [114]:
vars(ContFragClusterlessClassifier())

{'discrete_initial_conditions': array([0.5, 0.5]),
 'continuous_initial_conditions_types': [UniformInitialConditions(),
  UniformInitialConditions()],
 'discrete_transition_concentration': 1.1,
 'discrete_transition_stickiness': array([0., 0.]),
 'discrete_transition_regularization': 1e-10,
 'discrete_transition_type': DiscreteStationaryDiagonal(diagonal_values=array([0.98, 0.98])),
 'continuous_transition_types': [[RandomWalk(environment_name='', movement_var=6.0, movement_mean=0.0, use_manifold_distance=False, direction=None),
   Uniform(environment_name='', environment2_name=None)],
  [Uniform(environment_name='', environment2_name=None),
   Uniform(environment_name='', environment2_name=None)]],
 'environments': (Environment(environment_name='', place_bin_size=2.0, track_graph=None, edge_order=None, edge_spacing=None, is_track_interior=None, position_range=None, infer_track_interior=True, fill_holes=False, dilate=False, bin_count_threshold=0),),
 'infer_track_interior': True,
 'obs

In [116]:
from spyglass.decoding.v1.core import DecodingParameters


DecodingParameters.insert1(
    {
        "decoding_param_name": "contfrag_clusterless",
        "decoding_params": vars(ContFragClusterlessClassifier()),
        "decoding_kwargs": dict(),
    },
    skip_duplicates=True,
)

DecodingParameters & {"decoding_param_name": "contfrag_clusterless"}

decoding_param_name  a name for this set of parameters,decoding_params  initialization parameters for model,decoding_kwargs  additional keyword arguments
contfrag_clusterless,=BLOB=,=BLOB=


We can retrieve these parameters and rebuild the model like so:

In [126]:
model_params = (
    DecodingParameters & {"decoding_param_name": "contfrag_clusterless"}
).fetch1()

ContFragClusterlessClassifier(**model_params["decoding_params"])

## Decoding

Now that we have grouped the data and defined the model parameters, we have finally set up the elements in tables that we need to decode the data. We now need to use the `ClusterlessDecodingSelection` to fully specify all the parameters and data that we want.

This has:
- `waveform_features_group_name`: the name of the group that contains the waveform features that we want to decode from
- `position_group_name`: the name of the group that contains the position data that we want to decode from
- `decoding_param_name`: the name of the decoding parameters that we want to use
- `nwb_file_name`: the name of the NWB file that we want to decode from
- `encoding_interval`: the interval of time that we want to train the initial model on
- `decoding_interval`: the interval of time that we want to decode from
- `estimate_decoding_params`: whether or not we want to estimate the decoding parameters


Notice the last three parameters. The `encoding_interval` is the interval of time that we want to train the initial model on. The `decoding_interval` is the interval of time that we want to decode from. These two intervals can be the same, but they do not have to be. For example, we may want to train the model on a long interval of time, but only decode from a short interval of time. This is useful if we want to decode from a short interval of time that is not representative of the entire session. In this case, we will train the model on a longer interval of time that is representative of the entire session.

These keys come from the `IntervalList` table. We can see that the `IntervalList` table contains the `nwb_file_name` and `interval_name` that we need to specify the `encoding_interval` and `decoding_interval`.

The last parameter is `estimate_decoding_params`. This is a boolean that specifies whether or not we want to estimate the decoding parameters. If this is `True`, then we will estimate the initial conditions and discrete transition matrix from the data.

In [128]:
from spyglass.decoding.v1.clusterless import ClusterlessDecodingSelection

ClusterlessDecodingSelection()

waveform_features_group_name,position_group_name,decoding_param_name  a name for this set of parameters,nwb_file_name  name of the NWB file,encoding_interval  descriptive name of this interval list,decoding_interval  descriptive name of this interval list,estimate_decoding_params  whether to estimate the decoding parameters
test,02_r1,contfrag_clusterless,mediumnwb20230802_.nwb,pos 0 valid times,decoding interval,0
test,02_r1,contfrag_clusterless,mediumnwb20230802_.nwb,pos 0 valid times,decoding interval5,0


In [132]:
from spyglass.common import IntervalList

IntervalList & {"nwb_file_name": nwb_copy_file_name}

nwb_file_name  name of the NWB file,interval_list_name  descriptive name of this interval list,valid_times  numpy array with start/end times for each interval
mediumnwb20230802_.nwb,02_r1,=BLOB=
mediumnwb20230802_.nwb,03143dcd-d09a-4216-8000-631d346875ad,=BLOB=
mediumnwb20230802_.nwb,078776e3-1b9c-4755-bef8-b9201bcdd717,=BLOB=
mediumnwb20230802_.nwb,078a7847-23a6-4820-a71e-e0f4fc5b31b8,=BLOB=
mediumnwb20230802_.nwb,1536b082-018d-4674-b562-cac09d298b7f,=BLOB=
mediumnwb20230802_.nwb,23e7dd2c-b24f-4bd3-b769-b3dfbcc9dfbd,=BLOB=
mediumnwb20230802_.nwb,2ce6a87c-2c6b-4fd9-af00-35f181c3fd2f,=BLOB=
mediumnwb20230802_.nwb,333b230a-14d8-45c0-bd0d-1eec9797152e,=BLOB=
mediumnwb20230802_.nwb,3824e250-27e5-4cc5-a49d-56d9e37b3ad8,=BLOB=
mediumnwb20230802_.nwb,3a5e3bf4-8bdb-4050-afb9-c3034f204ff7,=BLOB=


Once we have figured out the keys that we need, we can insert the `ClusterlessDecodingSelection` into the database.

In [134]:
selection_key = {
    "waveform_features_group_name": "test",
    "position_group_name": "02_r1",
    "decoding_param_name": "contfrag_clusterless",
    "nwb_file_name": "mediumnwb20230802_.nwb",
    "encoding_interval": "pos 0 valid times",
    "decoding_interval": "decoding interval5",
    "estimate_decoding_params": False,
}

ClusterlessDecodingSelection.insert1(
    selection_key,
    skip_duplicates=True,
)

ClusterlessDecodingSelection & selection_key

waveform_features_group_name,position_group_name,decoding_param_name  a name for this set of parameters,nwb_file_name  name of the NWB file,encoding_interval  descriptive name of this interval list,decoding_interval  descriptive name of this interval list,estimate_decoding_params  whether to estimate the decoding parameters
test,02_r1,contfrag_clusterless,mediumnwb20230802_.nwb,pos 0 valid times,decoding interval5,0


To run decoding, we simply populate the `ClusterlessDecodingOutput` table. This will run the decoding and insert the results into the database. We can then retrieve the results from the database.

In [135]:
from spyglass.decoding.v1.clusterless import ClusterlessDecodingV1

ClusterlessDecodingV1.populate(selection_key)

We can now see it as an entry in the `DecodingOutput` table.

In [138]:
from spyglass.decoding.decoding_merge import DecodingOutput

DecodingOutput.ClusterlessDecodingV1 & selection_key

merge_id,waveform_features_group_name,position_group_name,decoding_param_name  a name for this set of parameters,nwb_file_name  name of the NWB file,encoding_interval  descriptive name of this interval list,decoding_interval  descriptive name of this interval list,estimate_decoding_params  whether to estimate the decoding parameters
abf0dacc-472c-bf7c-041a-c00b9098e463,test,02_r1,contfrag_clusterless,mediumnwb20230802_.nwb,pos 0 valid times,decoding interval5,0


We can load the results of the decoding:

In [141]:
decoding_results = (ClusterlessDecodingV1 & selection_key).load_results()
decoding_results



Finally, if we deleted the results, we can use the `cleanup` function to delete the results from the file system:

In [143]:
DecodingOutput().cleanup()

[12:51:35][INFO] Spyglass: Cleaning up decoding outputs


