# Debuggin DataJoint Pipeline: Spike Sorting Computations


This notebook is designed to help advanced users to quickly debug the computations and/or investigate errors that may occur during the execution of a DataJoint `populate` function.

It provides a briefly guide on how to dissect the `make` function, enabling a deeper understanding of the pipeline's computational steps and facilitating faster issue resolution.

**Note: This notebook is intended as a supplementary tool for debugging and should not replace best practices in coding development.**

The spike sorting analysis is managed by the `ephys_sorter` schema containing three main tables in the DataJoint pipeline:

1. PreProcessing

2. SIClustering

3. PostProcessing

Please review and understand the code for each table [here](https://github.com/dj-sciops/utah_organoids_element-array-ephys/blob/main/element_array_ephys/spike_sorting/si_spike_sorting.py).


### **Key Steps**

- **Setup**

- **Step 1: Select Session of Interest**

- **Step 2: `populate` Necessary Tables before Spike Sorting**

- **Step 3: Execute Each Part of the Spike Sorting Computations to Debug**


#### **Setup**


First, import the necessary packages for the data pipeline and essential schemas.


In [1]:
import os

if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")

In [2]:
import datajoint as dj
import datetime
import pandas as pd
import numpy as np

In [3]:
from workflow.pipeline import ephys, ephys_sorter

[2024-10-07 19:08:06,552][INFO]: Connecting milagros@db.datajoint.com:3306
[2024-10-07 19:08:08,218][INFO]: Connected milagros@db.datajoint.com:3306


#### **Step 1: Select Session of Interest**


In this notebook, please ensure that the session you are inserting and using for debugging purposes has **`session_type=test`** to easily distinguish it from real ephys sessions in the pipeline.

**Note**: If you have already inserted one or more ephys sessions with `session_type=spike_sorting`, please delete those entries from the database to avoid mixing them with real sessions. This will help maintain clean management of the sessions during debugging.


In [4]:
session_info = {
    "organoid_id": "O09",
    "experiment_start_time": datetime.datetime(2023, 5, 18, 12, 25),
    "start_time": "2023-05-18 12:25:00",
    "end_time": "2023-05-18 12:26:30",
    "insertion_number": 0,
    "session_type": "test",  # Use this `session_type` for testing purposes
}

session_probe_info = dict(
    **session_info,
    probe="Q983",  # probe serial number
    port_id="A",  # Port ID ("A", "B", etc.)
    used_electrodes=[],  # electrodes used for the session; empty if all electrodes were used
)

In [5]:
# Ensure your `session_key` is inserted in the `EphysSession` and `EphysSessionProbe`
ephys.EphysSession.insert1(session_info)
ephys.EphysSessionProbe.insert1(session_probe_info, ignore_extra_fields=True)

DuplicateError: ("Duplicate entry 'O09-2023-05-18 12:25:00-0-2023-05-18 12:25:00-2023-05-18 12:26:3' for key 'ephys_session.PRIMARY'", 'To ignore duplicate entries in insert, set skip_duplicates=True')

In [6]:
session_key = (ephys.EphysSession & session_info).fetch1("KEY")
session_key

{'organoid_id': 'O09',
 'experiment_start_time': datetime.datetime(2023, 5, 18, 12, 25),
 'insertion_number': 0,
 'start_time': datetime.datetime(2023, 5, 18, 12, 25),
 'end_time': datetime.datetime(2023, 5, 18, 12, 26, 30)}

In [7]:
ephys.EphysSession * ephys.EphysSessionProbe & session_key

organoid_id  e.g. O17,experiment_start_time,insertion_number,start_time,end_time,session_type,probe  unique identifier for this model of probe (e.g. serial number),port_id,"used_electrodes  list of electrode IDs used in this session (if null, all electrodes are used)"
O09,2023-05-18 12:25:00,0,2023-05-18 12:25:00,2023-05-18 12:26:30,test,Q983,A,=BLOB=


#### **Step 2: `populate` Necessary Tables before Spike Sorting**


Populate the necessary tables:


In [8]:
ephys.EphysSessionInfo.populate(session_key)

{'success_count': 0, 'error_list': []}

In [9]:
ephys.ClusteringParamSet()

paramset_idx,clustering_method,paramset_desc,param_set_hash,params  dictionary of all applicable parameters
0,spykingcircus2,Default parameters for spyking circus2 using SpikeInterface v0.100.1,b6fb9ec2-768c-66b0-2b71-9b8ac91e94da,=BLOB=
1,spykingcircus2,Default parameter set for spyking circus2 using SpikeInterface v0.101.*,434894d0-eb7b-db6c-80e6-638a1322c568,=BLOB=
2,kilosort2,kilosort2 with SpikeInterface version 0.101+,79a731f3-f1b6-c110-5f8a-e25227464de7,=BLOB=
5,spykingcircus2,Spyking circus2 with a detection threshold 5 (neg direction),4c895afd-a1b1-5d64-b747-e8489078e2e3,=BLOB=
11,spykingcircus2,waveform>threshold: .25->2,17d41d84-067d-791c-8706-8cab83020b84,=BLOB=
12,spykingcircus2,waveform>threshold: .25->2 attempt 2,2b28cf23-2456-8202-b70f-96871b837a26,=BLOB=
13,spykingcircus2,waveform>threshold: .25->2 attempt 2,1faf6aee-71d6-fe26-74ec-6bb7cdc0f30f,=BLOB=
14,spykingcircus2,apply_preprocessing = False,ce720015-b59a-08d6-198e-def81c860f46,=BLOB=
15,spykingcircus2,"apply_preprocessing, matched_filtering, and apply_motion_correction = False",5f7a8362-c31c-061e-14b2-74ad55466546,=BLOB=
16,spykingcircus2,"default parameters, different format",0a3d0360-c0de-6c30-9c35-7c931a9a6f62,=BLOB=


**Note**: The next step is to create a `ClusteringTask` for the `ephys_sorter` to ingest the task and automatically process the spike sorting. However, in this specific notebook, the goal is to run the spike sorting individually and ensure that these `test` entries are not considered by the `ephys_sorter` schema. Therefore:

- Do not create a `ClusteringTask` for the `test` entry.
- Instead, modify the code of the `PreProcessing` make function accordingly.


In [10]:
# Task key
key = {**session_key, "paramset_idx": 101}
key

{'organoid_id': 'O09',
 'experiment_start_time': datetime.datetime(2023, 5, 18, 12, 25),
 'insertion_number': 0,
 'start_time': datetime.datetime(2023, 5, 18, 12, 25),
 'end_time': datetime.datetime(2023, 5, 18, 12, 26, 30),
 'paramset_idx': 101}

#### **Step 3: Execute Each Part of the Spike Sorting Computations to Debug**


To debug, copy and paste the code of the three `make` functions into different code cells as needed (reference [here](https://github.com/dj-sciops/utah_organoids_element-array-ephys/blob/main/element_array_ephys/spike_sorting/si_spike_sorting.py)). This allows you to, for instance, check variables or reproduce the `si_recording` and `si_sorting` objects for exploration and testing purposes.


In [11]:
import spikeinterface as si
from element_array_ephys import probe, readers
from element_interface.utils import find_full_path, memoized_result
from spikeinterface import exporters, postprocessing, qualitymetrics, sorters

# This line has been updated to import the module here
from element_array_ephys.spike_sorting import si_preprocessing

In [12]:
# ----------------- First Part of the PreProcessing Make Function Copied Here ----------------- #

# Get clustering method and output directory.

### Original code
# clustering_method, output_dir, params = (
#     ephys.ClusteringTask * ephys.ClusteringParamSet & key
# ).fetch1("clustering_method", "clustering_output_dir", "params")

### Modified code here
clustering_method, params = (ephys.ClusteringParamSet & key).fetch1(
    "clustering_method", "params"
)

acq_software = (ephys.EphysRawFile & key).fetch("acq_software", limit=1)[0]

# Get sorter method and create output directory.
sorter_name = clustering_method.replace(".", "_")

for required_key in (
    "SI_PREPROCESSING_METHOD",
    "SI_SORTING_PARAMS",
    "SI_POSTPROCESSING_PARAMS",
):
    if required_key not in params:
        raise ValueError(
            f"{required_key} must be defined in ClusteringParamSet for SpikeInterface execution"
        )

# Set directory to store recording file.
# if not output_dir:
#     output_dir = ephys.ClusteringTask.infer_output_dir(key, relative=True, mkdir=True)
#     # update clustering_output_dir
#     ephys.ClusteringTask.update1(
#         {**key, "clustering_output_dir": output_dir.as_posix()}
#     )

# create the folder if it does not exist
rel_output_dir = "O09-12_raw/202305181225_20230518122630/O09/spykingcircus2_101_example"
output_dir = ephys.get_processed_root_data_dir() / rel_output_dir
os.makedirs(output_dir, exist_ok=True)

In [13]:
############### New code here to debug
print(output_dir)

/Users/milagros/Documents/data/organoids/outbox/O09-12_raw/202305181225_20230518122630/O09/spykingcircus2_101_example


In [14]:
# ----------------- Part of the PreProcessing Make Function Copied Here ----------------- #

recording_dir = output_dir / sorter_name / "recording"
recording_dir.mkdir(parents=True, exist_ok=True)
recording_file = recording_dir / "si_recording.pkl"

# Get probe information to recording object
probe_info = (probe.Probe * ephys.EphysSessionProbe & key).fetch1()
electrode_query = probe.ElectrodeConfig.Electrode & (
    probe.ElectrodeConfig & {"probe_type": probe_info["probe_type"]}
)

# Filter for used electrodes. If probe_info["used_electrodes"] is None, it means all electrodes were used.
number_of_electrodes = len(electrode_query)
probe_info["used_electrodes"] = (
    probe_info["used_electrodes"]
    if probe_info["used_electrodes"] is not None and len(probe_info["used_electrodes"])
    else list(range(number_of_electrodes))
)
unused_electrodes = [
    elec
    for elec in range(number_of_electrodes)
    if elec not in probe_info["used_electrodes"]
]
electrodes_df = (
    (probe.ProbeType.Electrode * electrode_query)
    .fetch(format="frame", order_by="electrode")
    .reset_index()[["electrode", "x_coord", "y_coord", "shank", "channel_idx"]]
)

"""Get the row indices of the port from the data matrix."""
session_info = (ephys.EphysSessionInfo & key).fetch1("session_info")
port_indices = np.array(
    [
        ind
        for ind, ch in enumerate(session_info["amplifier_channels"])
        if ch["port_prefix"] == probe_info["port_id"]
    ]
)  # get the row indices of the port

# Create SI recording extractor object
si_extractor: si.extractors.neoextractors = (
    si.extractors.extractorlist.recording_extractor_full_dict[
        acq_software.replace(" ", "").lower()
    ]
)  # data extractor object

files, file_times = (
    ephys.EphysRawFile
    & key
    & f"file_time BETWEEN '{key['start_time']}' AND '{key['end_time']}'"
).fetch("file_path", "file_time", order_by="file_time")

si_recording = None
# Read data. Concatenate if multiple files are found.
for file_path in (find_full_path(ephys.get_ephys_root_data_dir(), f) for f in files):
    if not si_recording:
        stream_name = [
            s for s in si_extractor.get_streams(file_path)[0] if "amplifier" in s
        ][0]
        si_recording: si.BaseRecording = si_extractor(
            file_path, stream_name=stream_name
        )
    else:
        si_recording: si.BaseRecording = si.concatenate_recordings(
            [
                si_recording,
                si_extractor(file_path, stream_name=stream_name),
            ]
        )

si_recording = si_recording.channel_slice(
    si_recording.channel_ids[port_indices]
)  # select only the port data

# Create SI probe object
si_probe = readers.probe_geometry.to_probeinterface(electrodes_df)
si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values)
si_recording.set_probe(probe=si_probe, in_place=True)

# Account for additional electrodes being removed
if unused_electrodes:
    chn_ids_to_remove = [
        f"{probe_info['port_id']}-{electrodes_df.channel_idx.iloc[elec]:03d}"
        for elec in unused_electrodes
    ]
else:
    chn_ids_to_remove = []

si_recording = si_recording.remove_channels(remove_channel_ids=chn_ids_to_remove)

# Run preprocessing and save results to output folder
si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"])
si_recording = si_preproc_func(si_recording)
si_recording.dump_to_pickle(file_path=recording_file, relative_to=output_dir)

In [15]:
############### New code here to debug
# `si_recording` can be explored here
si_recording

In [16]:
# ----------------- SIClustering Make Function Copied Here ----------------- #

# Load recording object.
# clustering_method, output_dir, params = (
#     ephys.ClusteringTask * ephys.ClusteringParamSet & key
# ).fetch1("clustering_method", "clustering_output_dir", "params")
# output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
sorter_name = clustering_method.replace(".", "_")
recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
si_recording: si.BaseRecording = si.load_extractor(
    recording_file, base_folder=output_dir
)

sorting_params = params["SI_SORTING_PARAMS"]
sorting_output_dir = output_dir / sorter_name / "spike_sorting"


# Run sorting
@memoized_result(
    uniqueness_dict=sorting_params,
    output_directory=sorting_output_dir,
)
def _run_sorter():
    # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package.
    si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(
        sorter_name=sorter_name,
        recording=si_recording,
        output_folder=sorting_output_dir,
        remove_existing_folder=True,
        verbose=True,
        docker_image=sorter_name not in si.sorters.installed_sorters(),
        **sorting_params,
    )

    # Save sorting object
    sorting_save_path = sorting_output_dir / "si_sorting.pkl"
    si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir)


_run_sorter()

[2024-10-07 19:10:47,068][INFO]: No existing results found, calling '_run_sorter'
  si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(


Motion correction activated (probe geometry compatible)




detect and localize:   0%|          | 0/120 [00:00<?, ?it/s]

  warn("There is no Probe attached to this recording. Creating a dummy one with contact positions")


pairwise displacement:   0%|          | 0/15 [00:00<?, ?it/s]

write_memory_recording:   0%|          | 0/120 [00:00<?, ?it/s]

detect peaks using locally_exclusive:   0%|          | 0/120 [00:00<?, ?it/s]

extract waveforms shared_memory mono buffer:   0%|          | 0/120 [00:00<?, ?it/s]

detect peaks using matched_filtering:   0%|          | 0/1200 [00:00<?, ?it/s]

We found 3080 peaks in total
We kept 3080 peaks for clustering




extracting features:   0%|          | 0/120 [00:00<?, ?it/s]

estimate_templates:   0%|          | 0/120 [00:00<?, ?it/s]

We found 32 raw clusters, starting to clean with matching...


  warn("There is no Probe attached to this recording. Creating a dummy one with contact positions")


write_memory_recording:   0%|          | 0/1 [00:00<?, ?it/s]

We kept 32 non-duplicated clusters...


estimate_templates:   0%|          | 0/120 [00:00<?, ?it/s]

find spikes (circus-omp-svd):   0%|          | 0/1200 [00:00<?, ?it/s]

We found 116918 spikes


OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


Final merging, keeping 32 units
spykingcircus2 run time 110.30s


In [17]:
############### New code here to debug
sorting_save_path = sorting_output_dir / "si_sorting.pkl"
print(f"`si_sorting` object is saved in the `sorting_save_path` {sorting_save_path}")

`si_sorting` object is saved in the `sorting_save_path` /Users/milagros/Documents/data/organoids/outbox/O09-12_raw/202305181225_20230518122630/O09/spykingcircus2_101_example/spykingcircus2/spike_sorting/si_sorting.pkl


In [18]:
# ----------------- PostProcessing First Part of the Make Function Copied Here ----------------- #

# Load recording & sorting object.
# clustering_method, output_dir, params = (
#     ephys.ClusteringTask * ephys.ClusteringParamSet & key
# ).fetch1("clustering_method", "clustering_output_dir", "params")
# output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
# sorter_name = clustering_method.replace(".", "_")

recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl"

si_recording: si.BaseRecording = si.load_extractor(
    recording_file, base_folder=output_dir
)
si_sorting: si.sorters.BaseSorter = si.load_extractor(
    sorting_file, base_folder=output_dir
)

In [19]:
############### New code here to debug
# `si_sorting` can be explore here
si_sorting

In [29]:
# ----------------- PostProcessing Second Part of the Make Function Copied Here ----------------- #

postprocessing_params = params["SI_POSTPROCESSING_PARAMS"]

job_kwargs = postprocessing_params.get(
    "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"}
)

analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer"


@memoized_result(
    uniqueness_dict=postprocessing_params,
    output_directory=analyzer_output_dir,
)
def _sorting_analyzer_compute():
    # Sorting Analyzer
    sorting_analyzer = si.create_sorting_analyzer(
        sorting=si_sorting,
        recording=si_recording,
        format="binary_folder",
        folder=analyzer_output_dir,
        sparse=True,
        overwrite=True,
        **job_kwargs,
    )

    # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions()
    # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified)
    extensions_params = postprocessing_params.get("extensions", {})
    extensions_to_compute = {
        ext_name: extensions_params[ext_name]
        for ext_name in sorting_analyzer.get_computable_extensions()
        if ext_name in extensions_params
    }

    sorting_analyzer.compute(extensions_to_compute, **job_kwargs)

    # # Save to phy format
    # if postprocessing_params.get("export_to_phy", False):
    #     si.exporters.export_to_phy(
    #         sorting_analyzer=sorting_analyzer,
    #         output_folder=analyzer_output_dir / "phy",
    #         use_relative_path=True,
    #         **job_kwargs,
    #     )
    # # Generate spike interface report
    # if postprocessing_params.get("export_report", True):
    #     si.exporters.export_report(
    #         sorting_analyzer=sorting_analyzer,
    #         output_folder=analyzer_output_dir / "spikeinterface_report",
    #         **job_kwargs,
    #     )

    _sorting_analyzer_compute()

In [39]:
############### New code here to debug
print(
    f"Now you can explore the SpikeInterface report here: {analyzer_output_dir / 'spikeinterface_report'}\n"
    f"And the results using Phy here: {analyzer_output_dir / 'phy'}"
)

Now you can explore the SpikeInterface report here: /Users/milagros/Documents/data/organoids/outbox/O09-12_raw/202305181225_20230518122630/O09/spykingcircus2_101/spykingcircus2/sorting_analyzer/spikeinterface_report
And the results using Phy here: /Users/milagros/Documents/data/organoids/outbox/O09-12_raw/202305181225_20230518122630/O09/spykingcircus2_101/spykingcircus2/sorting_analyzer/phy


In [39]:
# Sorting Analyzer
sorting_analyzer = si.create_sorting_analyzer(
    sorting=si_sorting,
    recording=si_recording,
    format="binary_folder",
    folder=analyzer_output_dir,
    sparse=True,
    overwrite=True,
    **job_kwargs,
)

estimate_sparsity:   0%|          | 0/120 [00:00<?, ?it/s]

In [35]:
# The order of extension computation is drawn from sorting_analyzer.get_computable_extensions()
# each extension is parameterized by params specified in extensions_params dictionary (skip if not specified)
extensions_params = postprocessing_params.get("extensions", {})
extensions_to_compute = {
    ext_name: extensions_params[ext_name]
    for ext_name in sorting_analyzer.get_computable_extensions()
    if ext_name in extensions_params
}

sorting_analyzer.compute(extensions_to_compute, **job_kwargs)

# # Save to phy format
# if postprocessing_params.get("export_to_phy", False):
#     si.exporters.export_to_phy(
#         sorting_analyzer=sorting_analyzer,
#         output_folder=analyzer_output_dir / "phy",
#         use_relative_path=True,
#         **job_kwargs,
#     )
# # Generate spike interface report
# if postprocessing_params.get("export_report", True):
#     si.exporters.export_report(
#         sorting_analyzer=sorting_analyzer,
#         output_folder=analyzer_output_dir / "spikeinterface_report",
#         **job_kwargs,
#     )

compute_waveforms:   0%|          | 0/120 [00:00<?, ?it/s]

Fitting PCA:   0%|          | 0/32 [00:00<?, ?it/s]

KeyboardInterrupt: 