In [8]:
%matplotlib inline
import spikeinterface as si
import spikeinterface.widgets as sw
from pathlib import Path
import os

import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.curation as scur

# recompute qc
from spikeinterface.qualitymetrics.misc_metrics import isi_violations, presence_ratio, amplitude_cutoff
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import spikeinterface.comparison as sc
import spikeinterface.widgets as sw

import json

from scipy.stats import pearsonr
from scipy.stats import gaussian_kde
import pandas as pd

import spikeinterface as si
import spikeinterface.preprocessing as spre
import spikeinterface.postprocessing as spost
import spikeinterface.exporters as sexp

import numpy as np
import shutil

from pathlib import Path

%matplotlib widget

Input info

In [9]:
session = 'behavior_717121_2024-06-15_10-00-58'

In [10]:
raw_dir = f'/root/capsule/data/{session}/ecephys/'
if not os.path.exists(raw_dir):
    raw_dir = f'/root/capsule/data/{session}/'
stream_name = 'experiment1_Record Node 104#Neuropix-PXI-100.ProbeA_recording1'
data_folder = '/root/capsule/data'
for dir in os.listdir(data_folder):
    if (session in dir) and ("sorted" in dir) and ("curated" in dir):
        curated_sorting_dir = f'{data_folder}/{dir}'
        break
curated_folder = curated_sorting_dir + '/curated/' + stream_name
postprocessed_folder = curated_sorting_dir + '/postprocessed/' + stream_name

In [11]:
def qm_simple_sorting(sorting, timestamps, sample_rate=30000, bin_duration_s=60, mean_fr_ratio_thresh=0.01):
    unit_ids = sorting.get_unit_ids()
    isi_violations_ratio = [None]*len(unit_ids)
    isi_violations_rate = [None]*len(unit_ids)
    isi_violations_count = [None]*len(unit_ids)
    presence_ratio_qc = [None]*len(unit_ids)
    firing_rate = [None]*len(unit_ids) 
    recLength = timestamps[-1] - timestamps[0]
    bin_edges = np.arange(timestamps[0], timestamps[-1], bin_duration_s)


    for unitInd, unit_id in enumerate(unit_ids):
        timestampsCurr = timestamps[sorting.get_unit_spike_train(unit_id)]
        # isi_v
        isi_violations_ratio[unitInd], isi_violations_rate[unitInd], isi_violations_count[unitInd] = isi_violations([timestampsCurr], recLength, isi_threshold_s=0.0015, min_isi_s=1/sample_rate)
        # presence ratio
        unit_fr = len(timestampsCurr)/recLength
        bin_n_spikes_thres = np.floor(unit_fr * bin_duration_s * mean_fr_ratio_thresh)
        presence_ratio_qc[unitInd] = presence_ratio(timestampsCurr, recLength, bin_edges=bin_edges, bin_n_spikes_thres=bin_n_spikes_thres)
        # firing rate
        firing_rate[unitInd] = unit_fr

    qm = pd.DataFrame({
                    'unit_id': unit_ids,
                    'isi_violations_ratio': isi_violations_ratio,
                    'presence_ratio': presence_ratio_qc,
                    'firing_rate': firing_rate})
    return qm       

In [12]:
def load_and_preprocess_recording(raw_data_folder, stream_name):
    compressed_folder = raw_data_folder + "/ecephys_compressed/"
    raw_stream_name = stream_name[:stream_name.find("_recording")]
    recording = si.read_zarr(compressed_folder+f"{raw_stream_name}.zarr")
    # preprocess
    recording_processed = spre.phase_shift(recording)
    recording_processed = spre.highpass_filter(recording_processed)    
    recording_processed = spre.common_reference(recording_processed)
    return recording_processed

In [13]:
def load_qm(sorting_dir):
    qm_dir = sorting_dir + '/postprocessed'
    for root, dirs, files in os.walk(qm_dir):
        # Check if 'quality' folder is in the current directory
        if 'quality_metrics' in dirs:
            quality_folder_path = os.path.join(root, 'quality_metrics')
            print(quality_folder_path)
            break
    
    if 'quality_folder_path' in locals(): 
        qm_file = os.path.join(quality_folder_path, 'metrics.csv')
        qm = pd.read_csv(qm_file, index_col=0)
        return qm
    else:
        print('No quality metrics folder found.')


Load

In [14]:
# load timestamps
timestamps_file = Path(raw_dir + '/ecephys_clipped/Record Node 104/experiment1/recording1/continuous/Neuropix-PXI-100.ProbeA/timestamps.npy')
if Path.exists(timestamps_file):
    timestamps = np.load(timestamps_file)
else:
    timestamps_file = Path(raw_dir + '/ecephys/ecephys_clipped/Record Node 104/experiment1/recording1/continuous/Neuropix-PXI-100.ProbeA/timestamps.npy')
    timestamps = np.load(timestamps_file)
# load recording
recording_processed = load_and_preprocess_recording(raw_dir, stream_name)
# load sorting
sorting = si.load_extractor(curated_folder)

In [19]:
# load we
we = si.load_waveforms(postprocessed_folder, with_recording=False)
we.set_recording(recording_processed)
# load qm
qm = load_qm(curated_sorting_dir)

/root/capsule/data/behavior_717121_2024-06-15_10-00-58_sorted-curated_2024-07-25_06-45-59/postprocessed/experiment1_Record Node 104#Neuropix-PXI-100.ProbeA_recording1/quality_metrics


In [20]:
# sorting
counts = np.array(list(sorting.count_num_spikes_per_unit().values()))
unit_ids = sorting.unit_ids
labels = sorting.get_property('decoder_label')

# recalculate qm
sample_rate = 30000
bin_duration_s = 60.0
mean_fr_ratio_thresh = 0.01
qm_simple = qm_simple_sorting(sorting, timestamps)

pass_qc = (qm['isi_violations_ratio']<0.1) & (qm['firing_rate']>0.2) & (qm['presence_ratio']>0.95)

Export and save

In [21]:
outfolder = '/root/capsule/scratch/features/' + session
pc = we.load_extension("principal_components")
pc_sparsity = pc.get_sparsity()
max_num_channels_pc = max(len(chan_inds) for chan_inds in pc_sparsity.unit_id_to_channel_indices.values())

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [26]:
# export pc features
pc_file = outfolder + '/pc_feature.npy'
os.mkdir(outfolder)
pc.run_for_all_spikes(pc_file, n_jobs = 12)

extract PCs:   0%|          | 0/6396 [00:00<?, ?it/s]

In [None]:
# export unit_ids and channel ids
pc_feature_ind = -np.ones((len(unit_ids), max_num_channels_pc), dtype="int64")
for unit_ind, unit_id in enumerate(unit_ids):
    chan_inds = pc_sparsity.unit_id_to_channel_indices[unit_id]
    pc_feature_ind[unit_ind, : len(chan_inds)] = chan_inds
np.save(outfolder + "/pc_feature_ind.npy", pc_feature_ind)
# export cluster id and spike times
all_spikes_seg0 = sorting.to_spike_vector(concatenated=False)[0]
spike_times = all_spikes_seg0["sample_index"]
spike_labels = all_spikes_seg0["unit_index"]
spike_labels = unit_ids[spike_labels]
np.save(outfolder + "/spike_times.npy", spike_times[:, np.newaxis])
np.save(outfolder + "/spike_templates.npy", spike_labels[:, np.newaxis])
np.save(outfolder + "/spike_clusters.npy", spike_labels[:, np.newaxis])