In [1]:
import argparse
import json
import pickle
import numpy as np
from pathlib import Path

import spikeinterface as si
import spikeinterface.preprocessing as spre
import spikeinterface.generation as sgen
import spikeinterface.widgets as sw

In [36]:
%matplotlib widget

In [2]:
data_folder = Path("../data")
results_folder = Path("../results")

In [3]:
COMPLEXITY = "medium"
if isinstance(COMPLEXITY, str):
    COMPLEXITY = [COMPLEXITY]

NUM_UNITS = 10
NUM_CASES = 2

CORRECT_MOTION = False

In [62]:
recordings_output_folder = results_folder / "recordings"
sortings_output_folder = results_folder / "sortings"
figure_output_folder = results_folder / "figures"
templates_figures_folder = figure_output_folder / "templates"

recordings_output_folder.mkdir(exist_ok=True)
sortings_output_folder.mkdir(exist_ok=True)
figure_output_folder.mkdir(exist_ok=True)
templates_figures_folder.mkdir(exist_ok=True)

with open("params.json", "r") as f:
    params = json.load(f)

print(f"COMPLEXITY: {COMPLEXITY}")
print(f"NUM_UNITS: {NUM_UNITS}")
print(f"NUM_CASES: {NUM_CASES}")
print(f"CORRECT_MOTION: {CORRECT_MOTION}")

# input json files
# find raw data
job_json_files = [p for p in data_folder.iterdir() if p.suffix == ".json" and "job" in p.name]
job_dicts = []
for job_json_file in job_json_files:
    with open(job_json_file) as f:
        job_dict = json.load(f)
    job_dicts.append(job_dict)
print(f"Found {len(job_dicts)} JSON job files")

templates_info = sgen.fetch_templates_database_info()
# templates_info = templates_info.query("amplitude_uv > 0")

COMPLEXITY: ['medium']
NUM_UNITS: 10
NUM_CASES: 2
CORRECT_MOTION: False
Found 6 JSON job files


In [63]:
len(templates_info)

2110

In [11]:
templates_selected_indices = [
    271,
    1681,
    1476,
    934,
    1740,
    1775,
    1033,
    1682,
    1599,
    1844
]

In [72]:
templates_selected_indices_low_snr = [
    1708,
    431,
    953,
    166,
    2012,
    617,
]

In [12]:
min_amplitude, max_amplitude = 100, 200

In [32]:
def scale_template_to_range2(
    templates,
    min_amplitude: float,
    max_amplitude: float,
    amplitude_function = "ptp",
):
    """
    Scale templates to have a range with the provided minimum and maximum amplitudes.

    Parameters
    ----------
    templates : Templates
        The input templates.
    min_amplitude : float
        The minimum amplitude of the output templates after scaling.
    max_amplitude : float
        The maximum amplitude of the output templates after scaling.

    Returns
    -------
    Templates
        The scaled templates.
    """
    from spikeinterface import Templates, get_template_extremum_channel
    extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values())
    extremum_channel_indices = np.array(extremum_channel_indices, dtype=int)

    # get amplitudes
    if amplitude_function == "ptp":
        amp_fun = np.ptp
    elif amplitude_function == "min":
        amp_fun = np.min
    elif amplitude_function == "max":
        amp_fun = np.max
    amplitudes = np.zeros(templates.num_units)
    templates_array = templates.templates_array
    for i in range(templates.num_units):
        amplitudes[i] = amp_fun(templates_array[i, :, extremum_channel_indices[i]])

    print(amplitudes)

    # scale templates to meet min_amplitude and max_amplitude range
    print(np.min(amplitudes), np.max(amplitudes))
    min_scale = np.min(amplitudes) / min_amplitude
    max_scale = np.max(amplitudes) / max_amplitude
    print(min_scale, max_scale)
    m = (max_scale - min_scale) / (np.max(amplitudes) - np.min(amplitudes))
    scales = m * (amplitudes - np.min(amplitudes)) + min_scale
    print(scales)
    scaled_templates_array = templates.templates_array / scales[:, None, None]

    return Templates(
        templates_array=scaled_templates_array,
        sampling_frequency=templates.sampling_frequency,
        nbefore=templates.nbefore,
        sparsity_mask=templates.sparsity_mask,
        channel_ids=templates.channel_ids,
        unit_ids=templates.unit_ids,
        probe=templates.probe,
    )

In [73]:
templates_info

Unnamed: 0,probe,probe_manufacturer,brain_area,depth_along_probe,amplitude_uv,noise_level_uv,signal_to_noise_ratio,template_index,best_channel_index,spikes_per_unit,dataset,dataset_path
0,Neuropixels 1.0,IMEC,AON,1580.0,108.960100,17.474136,6.235507,0,158,1647,000409_sub-KS094_ses-6b0b5d24-bcda-4053-a59c-b...,s3://spikeinterface-template-database/000409_s...
1,Neuropixels 1.0,IMEC,AON,1660.0,189.819410,19.073212,9.952148,1,166,1540,000409_sub-KS094_ses-6b0b5d24-bcda-4053-a59c-b...,s3://spikeinterface-template-database/000409_s...
2,Neuropixels 1.0,IMEC,AON,1780.0,260.893740,17.710548,14.730980,2,178,20,000409_sub-KS094_ses-6b0b5d24-bcda-4053-a59c-b...,s3://spikeinterface-template-database/000409_s...
3,Neuropixels 1.0,IMEC,AON,1780.0,391.812620,17.710548,22.123121,3,178,1398,000409_sub-KS094_ses-6b0b5d24-bcda-4053-a59c-b...,s3://spikeinterface-template-database/000409_s...
4,Neuropixels 1.0,IMEC,AON,1900.0,74.740270,18.503231,4.039309,4,190,4386,000409_sub-KS094_ses-6b0b5d24-bcda-4053-a59c-b...,s3://spikeinterface-template-database/000409_s...
...,...,...,...,...,...,...,...,...,...,...,...,...
2105,Neuropixels 1.0,IMEC,MOs5,2940.0,143.806870,19.734700,7.287006,30,294,10192,000409_sub-KS052_ses-ac7d3064-7f09-48a3-88d2-e...,s3://spikeinterface-template-database/000409_s...
2106,Neuropixels 1.0,IMEC,MOs5,3040.0,111.190180,15.118033,7.354805,31,305,10081,000409_sub-KS052_ses-ac7d3064-7f09-48a3-88d2-e...,s3://spikeinterface-template-database/000409_s...
2107,Neuropixels 1.0,IMEC,MOs2/3,3220.0,106.443360,16.054296,6.630210,32,322,4094,000409_sub-KS052_ses-ac7d3064-7f09-48a3-88d2-e...,s3://spikeinterface-template-database/000409_s...
2108,Neuropixels 1.0,IMEC,MOs2/3,3340.0,75.831894,17.653605,4.295547,33,334,5638,000409_sub-KS052_ses-ac7d3064-7f09-48a3-88d2-e...,s3://spikeinterface-template-database/000409_s...


In [65]:
templates_selected_info = templates_info.loc[templates_selected_indices_low_snr]

# fetch templates
templates_selected = sgen.query_templates_from_database(templates_selected_info)

In [74]:
templates_selected_info

Unnamed: 0,probe,probe_manufacturer,brain_area,depth_along_probe,amplitude_uv,noise_level_uv,signal_to_noise_ratio,template_index,best_channel_index,spikes_per_unit,dataset,dataset_path
1708,Neuropixels 1.0,IMEC,SCiw,1640.0,251.24564,18.06974,13.90422,24,164,1111,000409_sub-KS042_ses-07dc4b76-5b93-4a03-82a0-b...,s3://spikeinterface-template-database/000409_s...
166,Neuropixels 1.0,IMEC,VISC5,800.0,379.2848,17.156858,22.10689,29,81,5629,000409_sub-KS046_ses-0ac8d013-b91e-4732-bc7b-a...,s3://spikeinterface-template-database/000409_s...
2012,Neuropixels 1.0,IMEC,VPMpc,1860.0,196.60464,18.606699,10.566337,18,186,5068,000409_sub-KS084_ses-1b715600-0cbc-442c-bd00-5...,s3://spikeinterface-template-database/000409_s...
617,Neuropixels 1.0,IMEC,ORBvl5,2420.0,350.6634,21.210417,16.532602,91,243,6104,000409_sub-KS046_ses-dfbe628d-365b-461c-a07f-8...,s3://spikeinterface-template-database/000409_s...


In [66]:
# scale templates
print(f"Scaling templates between {min_amplitude} and {max_amplitude}")
templates_scaled = sgen.scale_template_to_range(
    templates=templates_selected,
    min_amplitude=min_amplitude,
    max_amplitude=max_amplitude
)

Scaling templates between 100 and 200


In [69]:
sparsity = si.compute_sparsity(templates_scaled)

In [70]:
sw.plot_unit_templates(templates_scaled, sparsity=sparsity, backend="ipywidgets")

AppLayout(children=(HBox(children=(Checkbox(value=False, description='same axis'), Checkbox(value=True, descri…

<spikeinterface.widgets.unit_templates.UnitTemplatesWidget at 0x7f54323f6640>

In [39]:
import numpy as np

In [40]:
np.all(templates_selected.templates_array[7] == 0)

True

In [14]:
templates_scaled.templates_array

array([[[-1.65166614e-01, -2.10923701e-01,  4.72816540e-01, ...,
          1.42642163e-01, -9.12499081e-02, -9.73552235e-03],
        [-6.11263491e-01, -2.68284150e-01,  2.23890167e-01, ...,
          1.45937587e-01,  1.18205402e-01,  1.36762345e-02],
        [-4.04854830e-01, -4.97854111e-01,  4.58918723e-01, ...,
          2.01813494e-01,  2.68021757e-01,  1.21053156e-01],
        ...,
        [ 2.54353677e-01, -4.88021079e-01,  3.06056694e-01, ...,
         -9.29555717e-02, -4.36194070e-02,  3.86796308e-01],
        [ 7.03663551e-01, -2.63253038e-01, -1.61745150e-01, ...,
          5.18626285e-01,  2.46275653e-02, -2.44266787e-02],
        [ 2.56641211e-01, -2.57185181e-01, -4.50205615e-02, ...,
          2.62353290e-01,  9.69671435e-02,  2.63127018e-01]],

       [[-7.31666039e-03,  3.05356217e-02,  1.32057956e-02, ...,
          1.06024563e-01, -1.78024271e-01, -5.24797657e-02],
        [-1.79900631e-01, -4.65382305e-02, -3.93628685e-02, ...,
          1.07270142e-01,  1.83883714e

In [None]:
# for each JSON file, we now create hybrid recordings
for job_dict in job_dicts:
    recording_name = job_dict["recording_name"]
    print(f"Creating hybrid recordings for {recording_name}")
    recording = si.load_extractor(job_dict["recording_dict"], base_folder=data_folder)
    print(f"\t{recording}")

    # preprocess
    recording_preproc = spre.highpass_filter(recording)
    recording_preproc = spre.common_reference(recording_preproc)

    motion = None
    if CORRECT_MOTION:
        print("Estimating motion")
        motion_figures_folder = figure_output_folder / "motion"
        motion_figures_folder.mkdir(exist_ok=True)

        _, motion_info = spre.correct_motion(
            recording_preproc, preset="dredge_fast", n_jobs=-1, progress_bar=True, output_motion_info=True
        )
        motion = motion_info["motion"]
        w = sw.plot_motion_info(
            motion_info,
            recording_preproc,
            color_amplitude=True,
            scatter_decimate=10,
            amplitude_cmap="Greys_r"
        )
        w.figure.savefig(motion_figures_folder / f"recording_name.png", dpi=300)

    for complexity in COMPLEXITY:
        print(f"\tGenerating complexity: {complexity}")
        min_amplitude, max_amplitude = params["amplitudes"][complexity]
        for case in range(NUM_CASES):
            print(f"\t\tGenerating case: {case}")
            case_name = f"{recording_name}_{complexity}_{case}"

            # sample templates
            print(f"\t\t\tSelecting and fetching templates")
            templates_selected_indices = np.random.choice(templates_info.index, size=NUM_UNITS, replace=False)
            print(f"\t\t\tSelected indices: {list(templates_selected_indices)}")
            templates_selected_info = templates_info.loc[templates_selected_indices]

            # fetch templates
            templates_selected = sgen.query_templates_from_database(templates_selected_info)

            # scale templates
            print(f"\t\t\tScaling templates between {min_amplitude} and {max_amplitude}")
            templates_scaled = sgen.scale_template_to_range(
                templates=templates_selected,
                min_amplitude=min_amplitude,
                max_amplitude=max_amplitude
            )

            print(f"\t\t\tConstructing hybrid recording")
            recording_hybrid, sorting_hybrid = sgen.generate_hybrid_recording(
                recording=recording_preproc,
                templates=templates_scaled,
                motion=motion,
                seed=None,
            )
            print(recording_hybrid)

            # rename hybrid units with selected indices for provenance
            sorting_hybrid = sorting_hybrid.rename_units(templates_selected_indices)

            # we construct here a pkl version of the job json because 
            # it needs to be compatible with the preprocessing capsule
            recording_dict = recording_hybrid.to_dict(
                include_annotations=True,
                include_properties=True,
                relative_to=data_folder,
                recursive=True,
            )
            dump_dict = {
                "session_name": job_dict["session_name"],
                "recording_name": case_name,
                "recording_dict": recording_dict
            }
            file_path = recordings_output_folder / f"job_{case_name}.pkl"
            file_path.write_bytes(pickle.dumps(dump_dict))

            sorting_hybrid.dump_to_pickle(
                sortings_output_folder / f"{case_name}.pkl",
                relative_to=data_folder
            )

            # generate some plots!
            templates_obj = si.Templates(
                recording_hybrid.templates,
                channel_ids=templates_selected.channel_ids,
                unit_ids=sorting_hybrid.unit_ids,
                probe=recording.get_probe(), 
                sampling_frequency=templates_selected.sampling_frequency,
                nbefore=templates_selected.nbefore
            )
            sparsity = si.compute_sparsity(templates_obj)

            figsize = (7, 3*NUM_UNITS)
            w = sw.plot_unit_templates(
                templates_obj,
                sparsity=sparsity,
                figsize=figsize,
                ncols=2,
            )
            w.figure.savefig(templates_figures_folder / f"{case_name}.pdf")
