In [None]:
from pathlib import Path
import numpy as np

import spikeinterface.core as si
import spikeinterface.curation as scur
import spikeinterface.preprocessing as spre
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.widgets as sw

si.set_global_job_kwargs(n_jobs = 1)

output_folder = Path("/home/jake/Documents/ephys_analysis/code/si_dataset")

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#Create, compute metrics, and save SortingAnalyzer to disk

base_folder = Path("/home/jake/Documents/ephys_analysis/code/si_dataset/SpikeInterface Dataset Tutorial")
curation_dataset = base_folder / "dataset_curation"
recording = si.load_extractor(curation_dataset / "curation_recording")
sorting = si.load_extractor(curation_dataset / "curation_sorting")

print(recording)
print(sorting)

analyzer = si.create_sorting_analyzer(sorting, recording, format="memory", sparse=False)
analyzer.compute({
    'noise_levels': {},
    'random_spikes': {'max_spikes_per_unit': 1_000},
    'templates': {'ms_before': 1.5, 'ms_after': 3.5},
    'spike_amplitudes': {},
    'correlograms': {'bin_ms': 0.5},
    'waveforms': {},
    'principal_components': {},
    'spike_locations': {},
    'unit_locations': {},
    'template_similarity': {}
})

analyzer.compute("quality_metrics", metric_names = sqm.get_quality_metric_list())
analyzer.compute("quality_metrics", metric_names = sqm.get_quality_pca_metric_list())
analyzer.compute("template_metrics", metric_names = spost.get_template_metric_names())

analyzer.save_as("zarr", base_folder / "analyzer")

In [None]:
path_to_zarr = Path("/home/jake/Documents/ephys_analysis/code/si_dataset.zarr")
analyzer = si.load_sorting_analyzer(path_to_zarr)
quality_metrics = analyzer.extensions['quality_metrics'].data["metrics"]
template_metrics = analyzer.extensions['template_metrics'].data["metrics"]

analyzer.compute({'spike_locations': {},
    'unit_locations': {},
    'template_similarity': {}
    })

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
Compute : spike_locations: 100%|##########| 300/300 [00:01<00:00, 196.22it/s]


In [None]:
import os

class AnoushkaAutoLabel():
        def __init__(self):
                self.noise_neuron_model = None
                self.sua_mua_model = None
                self.calculated_metrics = None

                # To fix later
                self.output_folder = None

        def apply_model(self):
                # Define features
                self.X_columns = self.calculated_metrics.columns.to_list()
                # Prepare input data
                input_data = self.calculated_metrics[self.X_columns]
                input_data[np.isinf(input_data)] = np.nan
                input_data = input_data.astype('float32')
                # Apply noise classifier
                noise_predictions = self.noise_neuron_model.predict(input_data[self.X_columns])
                noise_probs =self.noise_neuron_model.predict_proba(input_data[self.X_columns])
                input_data['noise_label'] = noise_predictions
                input_data['noise_probs'] = noise_probs[:,1]
                input_data['noise_label'] = input_data['noise_label'].map({1: 'noise', 0: 'neural'})
                sua_predictions = self.sua_mua_model.predict(input_data[self.X_columns])
                sua_probs = self.sua_mua_model.predict_proba(input_data[self.X_columns])
                input_data['sua_label'] = sua_predictions
                input_data['sua_probs'] = sua_probs[:,1]
                input_data['sua_label'] = input_data['sua_label'].map({1 : 'sua', 0 :'mua'})
                # Create the 'decoder_label' column
                input_data['decoder_label'] = input_data.apply(lambda row: row['noise_label'] if row['noise_label'] == 'noise' else row['sua_label'], axis=1)
                # Create the 'decoder_probs' column
                input_data['decoder_probs'] = input_data.apply(lambda row: row['noise_probs'] if row['decoder_label'] == 'noise' else row['sua_probs'], axis=1)
                # Update 'decoder_probs' for 'mua' values
                input_data.loc[input_data['decoder_label'] == 'mua', 'decoder_probs'] = 1 - input_data['decoder_probs']
                # Save the result to a CSV file
                input_data.to_csv(os.path.join(self.output_folder, 'decoder_output_dataframe.csv'))
                print('Decoder output saved to decoder_output_dataframe.csv')

        def check_required_metrics(self, calculated_metrics):

                metrics_list = calculated_metrics.columns.to_list()
                
                required_metrics = ['num_spikes', 'firing_rate',
    'presence_ratio', 'snr', 'isi_violations_ratio', 'isi_violations_count',
    'rp_contamination', 'rp_violations', 'sliding_rp_violation',
    'amplitude_cutoff', 'amplitude_median', 'amplitude_cv_median',
    'amplitude_cv_range', 'sync_spike_2', 'sync_spike_4', 'sync_spike_8',
    'firing_range', 'drift_ptp', 'drift_std', 'drift_mad',
    'isolation_distance', 'l_ratio', 'd_prime', 'silhouette', 'nn_hit_rate',
    'nn_miss_rate', 'peak_to_valley', 'peak_trough_ratio', 'half_width',
    'repolarization_slope', 'recovery_slope', 'num_positive_peaks',
    'num_negative_peaks', 'velocity_above', 'velocity_below', 'exp_decay',
    'spread']
                
                missing_metrics = [metric for metric in required_metrics if metric not in metrics_list]
                extra_metrics = [metric for metric in metrics_list if metric not in required_metrics]
                
                if len(missing_metrics) > 0:
                        raise ValueError(f"Missing metrics: {missing_metrics}")
                elif len(extra_metrics) > 0:
                        calculated_metrics = calculated_metrics.drop(extra_metrics, axis = 1)
                        print(f"Extra metrics: {extra_metrics}. Dropping before model application.")
                else:
                        print('Metric list is complete. Proceeding to model application.')

                # Reorder columns to match the model
                calculated_metrics = calculated_metrics[required_metrics]

                return calculated_metrics
                
                

In [None]:
import pickle as pkl
import pandas as pd

path_to_noise_model = Path("/home/jake/Documents/ephys_analysis/code/si_dataset/num_repetitions_1_optimized_noise_grid_classifier.pkl")
path_to_sua_mua_model = Path("/home/jake/Documents/ephys_analysis/code/si_dataset/num_repetitions_1_optimized_sua_grid_classifier.pkl")
output_folder = Path("/home/jake/Documents/ephys_analysis/code/si_dataset")

path_to_zarr = Path("/home/jake/Documents/ephys_analysis/code/si_dataset.zarr")
analyzer = si.load_sorting_analyzer(path_to_zarr)

def compute_auto_label(sorting_analyzer, method = 'anoushka', **kwargs):
    # Logic:
    # Check the analyzer has all the required metrics, else throw useful error
    # Compare against required_metrics, drop any extra metrics and print a warning
    # If any metrics are missing, throw error

    # This needs fixing when we add Robyn's method
    quality_metrics = sorting_analyzer.extensions['quality_metrics'].data["metrics"]
    template_metrics = sorting_analyzer.extensions['template_metrics'].data["metrics"]
    all_metrics = pd.concat([quality_metrics, template_metrics], axis=1)

    # If method, run AutoLabel class
    if method == 'anoushka':
        auto_label = AnoushkaAutoLabel()
        # Load models from .pkl files
        with open(path_to_noise_model, 'rb') as noise_model_file:
            auto_label.noise_neuron_model = pkl.load(noise_model_file)
        with open(path_to_sua_mua_model, 'rb') as sua_mua_model_file:
            auto_label.sua_mua_model = pkl.load(sua_mua_model_file)

        # Check metrics loaded into SortingAnalyzer against a defined list
        auto_label.calculated_metrics = auto_label.check_required_metrics(all_metrics)
        auto_label.output_folder = output_folder

        auto_label.apply_model()

compute_auto_label(analyzer, method = 'anoushka')

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Extra metrics: ['sd_ratio']. Dropping before model application.
Decoder output saved to decoder_output_dataframe.csv


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [None]:
pd.read_csv('/home/jake/Documents/ephys_analysis/code/si_dataset/decoder_output_dataframe.csv')

Unnamed: 0.1,Unnamed: 0,num_spikes,firing_rate,presence_ratio,snr,isi_violations_ratio,isi_violations_count,rp_contamination,rp_violations,sliding_rp_violation,...,velocity_above,velocity_below,exp_decay,spread,noise_label,noise_probs,sua_label,sua_probs,decoder_label,decoder_probs
0,0,3528.0,11.76,1.0,1.221621,0.361539,45.0,0.636261,36.0,0.23,...,,,0.01017,150.0,neural,0.483333,mua,0.1925,mua,0.8075
1,1,2599.0,8.663333,1.0,2.817327,0.236869,16.0,0.254432,10.0,0.07,...,412.3567,,0.012208,150.0,noise,0.600119,mua,0.2875,noise,0.600119
2,2,3026.0,10.086667,1.0,1.303826,0.404077,37.0,0.712534,28.0,0.15,...,,,0.014259,150.0,noise,0.5875,mua,0.195,noise,0.5875
3,3,343.0,1.143333,1.0,4.452133,0.0,0.0,0.0,0.0,0.295,...,,,0.024574,125.0,neural,0.467,mua,0.429167,mua,0.570833
4,4,2580.0,8.6,1.0,1.452094,0.420648,28.0,0.685975,20.0,0.28,...,153.1546,,0.021465,150.0,neural,0.416,mua,0.1775,mua,0.8225
5,5,221.0,0.736667,1.0,1.698258,2.04746,1.0,1.0,1.0,,...,,,0.01944,150.0,neural,0.425333,mua,0.188333,mua,0.811667
6,6,106.0,0.353333,1.0,2.240662,8.899964,1.0,1.0,1.0,,...,,,0.031577,150.0,noise,0.611667,mua,0.174167,noise,0.611667
7,7,2896.0,9.653334,1.0,1.563825,0.691562,58.0,1.0,47.0,,...,,,0.018391,150.0,neural,0.47,mua,0.190833,mua,0.809167
8,8,17.0,0.056667,1.0,1.929623,0.0,0.0,0.0,0.0,,...,,,0.0154,150.0,noise,0.761333,mua,0.146667,noise,0.761333
9,9,6209.0,20.696667,1.0,1.782075,0.046691,18.0,0.060174,15.0,0.06,...,,,0.022949,150.0,neural,0.4835,mua,0.480833,mua,0.519167


In [None]:
import pandas as pd

all_metrics = pd.concat([quality_metrics, template_metrics], axis=1)
metrics_list = all_metrics.columns.to_list()
all_metrics = all_metrics.drop(['sd_ratio'], axis = 1)
print(len(all_metrics.columns.to_list()))

37


In [None]:
required_metrics = ['num_spikes', 'firing_rate',
        'presence_ratio', 'snr', 'isi_violations_ratio', 'isi_violations_count',
        'rp_contamination', 'rp_violations', 'sliding_rp_violation',
        'amplitude_cutoff', 'amplitude_median', 'amplitude_cv_median',
        'amplitude_cv_range', 'sync_spike_2', 'sync_spike_4', 'sync_spike_8',
        'firing_range', 'drift_ptp', 'drift_std', 'drift_mad',
        'isolation_distance', 'l_ratio', 'd_prime', 'silhouette', 'nn_hit_rate',
        'nn_miss_rate', 'peak_to_valley', 'peak_trough_ratio', 'half_width',
        'repolarization_slope', 'recovery_slope', 'num_positive_peaks',
        'num_negative_peaks', 'velocity_above', 'velocity_below', 'exp_decay',
        'spread']

missing_metrics = [metric for metric in required_metrics if metric not in metrics_list]
extra_metrics = [metric for metric in metrics_list if metric not in required_metrics]
print(missing_metrics)
print(extra_metrics)

[]
['sd_ratio']
