In [1]:
from pathlib import Path
import numpy as np

import spikeinterface.core as si
import spikeinterface.curation as scur
import spikeinterface.preprocessing as spre
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.widgets as sw

si.set_global_job_kwargs(n_jobs = 1)

output_folder = Path("/home/jake/Documents/ephys_analysis/code/si_dataset")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Create, compute metrics, and save SortingAnalyzer to disk

base_folder = Path("/home/jake/Documents/ephys_analysis/code/si_dataset/SpikeInterface Dataset Tutorial")
curation_dataset = base_folder / "dataset_curation"
recording = si.load_extractor(curation_dataset / "curation_recording")
sorting = si.load_extractor(curation_dataset / "curation_sorting")

print(recording)
print(sorting)

analyzer = si.create_sorting_analyzer(sorting, recording, format="memory", sparse=False)
analyzer.compute({
    'noise_levels': {},
    'random_spikes': {'max_spikes_per_unit': 1_000},
    'templates': {'ms_before': 1.5, 'ms_after': 3.5},
    'spike_amplitudes': {},
    # 'correlograms': {'bin_ms': 0.5},
    # 'waveforms': {},
    'principal_components': {},
    'spike_locations': {},
    'unit_locations': {},
    # 'template_similarity': {}
})

analyzer.compute("quality_metrics", metric_names = sqm.get_quality_metric_list())
analyzer.compute("quality_metrics", metric_names = sqm.get_quality_pca_metric_list())
analyzer.compute("template_metrics", metric_names = spost.get_template_metric_names())

analyzer.save_as("zarr", base_folder / "analyzer")

BinaryFolderRecording: 26 channels - 30.0kHz - 1 segments - 9,000,000 samples 
                       300.00s (5.00 minutes) - int16 dtype - 446.32 MiB
NumpyFolderSorting: 52 units - 1 segments - 30.0kHz


compute_waveforms: 100%|##########| 300/300 [00:00<00:00, 428.89it/s]
Fitting PCA: 100%|██████████| 52/52 [00:19<00:00,  2.66it/s]
Projecting waveforms: 100%|██████████| 52/52 [00:01<00:00, 29.10it/s]
Compute : spike_amplitudes + spike_locations: 100%|##########| 300/300 [00:03<00:00, 98.27it/s]
calculate_pc_metrics: 100%|██████████| 52/52 [15:22<00:00, 17.74s/it]
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TI

SortingAnalyzer: 26 channels - 52 units - 1 segments - zarr - has recording
Loaded 12 extensions: noise_levels, random_spikes, waveforms, templates, correlograms, principal_components, unit_locations, template_similarity, spike_amplitudes, spike_locations, quality_metrics, template_metrics

In [3]:
path_to_zarr = Path("/home/jake/Documents/ephys_analysis/code/si_dataset.zarr")
analyzer = si.load_sorting_analyzer(path_to_zarr)
quality_metrics = analyzer.extensions['quality_metrics'].data["metrics"]
template_metrics = analyzer.extensions['template_metrics'].data["metrics"]


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
Compute : spike_locations: 100%|##########| 300/300 [00:01<00:00, 225.17it/s]


In [12]:
import pickle as pkl



path_to_noise_model = Path("/home/jake/Documents/ephys_analysis/code/si_dataset/num_repetitions_1_optimized_noise_grid_classifier.pkl")
path_to_sua_mua_model = Path("/home/jake/Documents/ephys_analysis/code/si_dataset/num_repetitions_1_optimized_sua_grid_classifier.pkl")
output_folder = Path("/home/jake/Documents/ephys_analysis/code/si_dataset")

with open(path_to_noise_model, 'rb') as noise_model_file:
        noise_neuron_model = pkl.load(noise_model_file)
with open(path_to_sua_mua_model, 'rb') as sua_mua_model_file:
        sua_mua_model = pkl.load(sua_mua_model_file)

print(noise_neuron_model)
print(sua_mua_model)

# combine two pipelines into a single pipeline
required_metrics = ['num_spikes', 'firing_rate',
    'presence_ratio', 'snr', 'isi_violations_ratio', 'isi_violations_count',
    'rp_contamination', 'rp_violations', 'sliding_rp_violation',
    'amplitude_cutoff', 'amplitude_median', 'amplitude_cv_median',
    'amplitude_cv_range', 'sync_spike_2', 'sync_spike_4', 'sync_spike_8',
    'firing_range', 'drift_ptp', 'drift_std', 'drift_mad',
    'isolation_distance', 'l_ratio', 'd_prime', 'silhouette', 'nn_hit_rate',
    'nn_miss_rate', 'peak_to_valley', 'peak_trough_ratio', 'half_width',
    'repolarization_slope', 'recovery_slope', 'num_positive_peaks',
    'num_negative_peaks', 'velocity_above', 'velocity_below', 'exp_decay',
    'spread']



Pipeline(steps=[('imputer', KNNImputer()), ('scaler', MinMaxScaler()),
                ('classifier',
                 HalvingGridSearchCV(cv=3,
                                     estimator=RandomForestClassifier(random_state=42),
                                     factor=2,
                                     param_grid={'bootstrap': [True, False],
                                                 'max_depth': [None, 10, 20],
                                                 'min_samples_leaf': [1, 2, 4],
                                                 'min_samples_split': [2, 5,
                                                                       10],
                                                 'n_estimators': [50, 100,
                                                                  200]},
                                     scoring='balanced_accuracy'))])
Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),
                ('scaler', RobustScaler()),
        

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [15]:
from sklearn.pipeline import Pipeline

classification_pipeline = Pipeline(
        steps=[
                ('noise_neuron', noise_neuron_model),
                ('sua_mua', sua_mua_model),
        ]
)

In [34]:
# use noise_neuron_model to classify noise and neuron.
classification_pipeline.named_steps['sua_mua']


In [32]:
input_data = pd.concat([quality_metrics, template_metrics], axis=1).drop(['sd_ratio'],axis =1)
input_data = input_data[required_metrics]

predictions = classification_pipeline.named_steps['sua_mua'].predict(input_data)

probabilities = classification_pipeline.named_steps['sua_mua'].predict_proba(input_data)

In [33]:
probabilities

array([[0.8075    , 0.1925    ],
       [0.7125    , 0.2875    ],
       [0.805     , 0.195     ],
       [0.57083333, 0.42916667],
       [0.82583333, 0.17416667],
       [0.81166667, 0.18833333],
       [0.82583333, 0.17416667],
       [0.80916667, 0.19083333],
       [0.85333333, 0.14666667],
       [0.52416667, 0.47583333],
       [0.77083333, 0.22916667],
       [0.6       , 0.4       ],
       [0.66916667, 0.33083333],
       [0.36083333, 0.63916667],
       [0.5575    , 0.4425    ],
       [0.39666667, 0.60333333],
       [0.7925    , 0.2075    ],
       [0.63      , 0.37      ],
       [0.47666667, 0.52333333],
       [0.36916667, 0.63083333],
       [0.77      , 0.23      ],
       [0.845     , 0.155     ],
       [0.62083333, 0.37916667],
       [0.835     , 0.165     ],
       [0.83666667, 0.16333333],
       [0.7775    , 0.2225    ],
       [0.76166667, 0.23833333],
       [0.53      , 0.47      ],
       [0.5825    , 0.4175    ],
       [0.66333333, 0.33666667],
       [0.

In [None]:

class ModelBasedClassification:
    # TODO docstring
    
    def __init__(self, pipeline, sorting_analyzer):
        self.pipeline = pipeline
        self.classified_units = None

        self.sorting_analyzer = sorting_analyzer

        # TODO: split pipeline into model, required_metrics, else?

    def predict_labels(self):
        # TODO: make general predict_labels function, then allow for calling multiple times to use different models

        # Get metrics DataFrame for classification
        input_data = self._get_metrics_for_classification()

        # Prepare input data
        input_data[np.isinf(input_data)] = np.nan
        input_data = input_data.astype('float32')

        # Apply classifier
        predictions = self.pipeline.predict(input_data)
        probabilities = self.pipeline.predict_proba(input_data)
        
        return predictions, probabilities

        # TODO: return DataFrame? and set as SortingAnalyzer.sorting property
        # Maybe in separate function??


    def _get_metrics_for_classification(self):
        try:
            # TODO: check if required_metrics are ALL computed
            quality_metrics = self.sorting_analyzer.extensions['quality_metrics'].data["metrics"]
            template_metrics = self.sorting_analyzer.extensions['template_metrics'].data["metrics"]
        except:
            # TODO split into separate quality and template metrics computation so it's not recomputing unnecessarily
            # Calculate metrics
            self.sorting_analyzer.compute({
                'noise_levels': {},
                'random_spikes': {'max_spikes_per_unit': 1_000},
                'templates': {'ms_before': 1.5, 'ms_after': 3.5},
                'spike_amplitudes': {},
                # 'correlograms': {'bin_ms': 0.5},
                # 'waveforms': {},
                'principal_components': {},
                'spike_locations': {},
                'unit_locations': {},
                # 'template_similarity': {}
            })

            self.sorting_analyzer.compute("quality_metrics", metric_names = sqm.get_quality_metric_list())
            self.sorting_analyzer.compute("quality_metrics", metric_names = sqm.get_quality_pca_metric_list())
            self.sorting_analyzer.compute("template_metrics", metric_names = spost.get_template_metric_names())
                                          
            quality_metrics = self.sorting_analyzer.extensions['quality_metrics'].data["metrics"]
            template_metrics = self.sorting_analyzer.extensions['template_metrics'].data["metrics"]
        
        metrics_list = quality_metrics.columns.to_list() + template_metrics.columns.to_list()

        #TODO: make dynamic (pull from self.pipeline)
        required_metrics = ['num_spikes', 'firing_rate',
    'presence_ratio', 'snr', 'isi_violations_ratio', 'isi_violations_count',
    'rp_contamination', 'rp_violations', 'sliding_rp_violation',
    'amplitude_cutoff', 'amplitude_median', 'amplitude_cv_median',
    'amplitude_cv_range', 'sync_spike_2', 'sync_spike_4', 'sync_spike_8',
    'firing_range', 'drift_ptp', 'drift_std', 'drift_mad',
    'isolation_distance', 'l_ratio', 'd_prime', 'silhouette', 'nn_hit_rate',
    'nn_miss_rate', 'peak_to_valley', 'peak_trough_ratio', 'half_width',
    'repolarization_slope', 'recovery_slope', 'num_positive_peaks',
    'num_negative_peaks', 'velocity_above', 'velocity_below', 'exp_decay',
    'spread']
                
        missing_metrics = [metric for metric in required_metrics if metric not in metrics_list]
        
        if len(missing_metrics) > 0:
            raise ValueError(f"Missing metrics: {missing_metrics}")
        
            # TODO: recalculate metrics which are missing
            # TODO: set properties of self.sorting_analyzer

        # Create DataFrame of all metrics and reorder columns to match the model
        calculated_metrics = pd.concat([quality_metrics, template_metrics], axis = 1)
        calculated_metrics = calculated_metrics[required_metrics]

        return calculated_metrics
                

In [5]:
import pickle as pkl
import pandas as pd

path_to_noise_model = Path("/home/jake/Documents/ephys_analysis/code/si_dataset/num_repetitions_1_optimized_noise_grid_classifier.pkl")
path_to_sua_mua_model = Path("/home/jake/Documents/ephys_analysis/code/si_dataset/num_repetitions_1_optimized_sua_grid_classifier.pkl")
output_folder = Path("/home/jake/Documents/ephys_analysis/code/si_dataset")

path_to_zarr = Path("/home/jake/Documents/ephys_analysis/code/si_dataset.zarr")
analyzer = si.load_sorting_analyzer(path_to_zarr)

def compute_auto_label(sorting_analyzer, paths_to_model, method = 'anoushka', **kwargs): #change to loaded .pkl file?
    # TODO: docstring
    # TODO: paths_to_model/ actual models loaded from pkl?? make decision & implement flexibly
    # Logic:
    # Check the analyzer has all the required metrics, else throw useful error
    # Compare against required_metrics, drop any extra metrics and print a warning
    # If any metrics are missing, throw error

    # This needs fixing when we add Robyn's method
    # TODO: handle case where no metrics of either type have not been computed
    quality_metrics = sorting_analyzer.extensions['quality_metrics'].data["metrics"]
    template_metrics = sorting_analyzer.extensions['template_metrics'].data["metrics"]
    all_metrics = pd.concat([quality_metrics, template_metrics], axis=1)

    # If method, run AutoLabel class
    # TODO: think of good names

    auto_label = AnoushkaAutoLabel()
    # TODO: combine to single pkl
    # Load models from .pkl files
    with open(path_to_noise_model, 'rb') as noise_model_file:
        auto_label.noise_neuron_model = pkl.load(noise_model_file)
    with open(path_to_sua_mua_model, 'rb') as sua_mua_model_file:
        auto_label.sua_mua_model = pkl.load(sua_mua_model_file)

    # Check metrics loaded into SortingAnalyzer against a defined list
    auto_label.calculated_metrics = auto_label.check_required_metrics(all_metrics)
    auto_label.output_folder = output_folder

    # TODO: set output dataframe (when implemented) as SortingAnalyzer.sorting property
    # SortingAnalyzer.sorting property should retain only a dict of unit_IDs, label, and confidence
    auto_label.predict_labels()


compute_auto_label(analyzer, method = 'anoushka')

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


TypeError: compute_auto_label() missing 1 required positional argument: 'paths_to_model'

In [7]:
pd.read_csv('/home/jake/Documents/ephys_analysis/code/si_dataset/decoder_output_dataframe.csv').head()

Unnamed: 0.1,Unnamed: 0,num_spikes,firing_rate,presence_ratio,snr,isi_violations_ratio,isi_violations_count,rp_contamination,rp_violations,sliding_rp_violation,...,velocity_above,velocity_below,exp_decay,spread,noise_label,noise_probs,sua_label,sua_probs,decoder_label,decoder_probs
0,0,3528.0,11.76,1.0,1.221621,0.361539,45.0,0.636261,36.0,0.23,...,,,0.01017,150.0,neural,0.483333,mua,0.1925,mua,0.8075
1,1,2599.0,8.663333,1.0,2.817327,0.236869,16.0,0.254432,10.0,0.07,...,412.3567,,0.012208,150.0,noise,0.600119,mua,0.2875,noise,0.600119
2,2,3026.0,10.086667,1.0,1.303826,0.404077,37.0,0.712534,28.0,0.15,...,,,0.014259,150.0,noise,0.5875,mua,0.195,noise,0.5875
3,3,343.0,1.143333,1.0,4.452133,0.0,0.0,0.0,0.0,0.295,...,,,0.024574,125.0,neural,0.467,mua,0.429167,mua,0.570833
4,4,2580.0,8.6,1.0,1.452094,0.420648,28.0,0.685975,20.0,0.28,...,153.1546,,0.021465,150.0,neural,0.416,mua,0.1775,mua,0.8225


In [None]:
import pandas as pd

all_metrics = pd.concat([quality_metrics, template_metrics], axis=1)
metrics_list = all_metrics.columns.to_list()
all_metrics = all_metrics.drop(['sd_ratio'], axis = 1)
print(len(all_metrics.columns.to_list()))

37


In [None]:
required_metrics = ['num_spikes', 'firing_rate',
        'presence_ratio', 'snr', 'isi_violations_ratio', 'isi_violations_count',
        'rp_contamination', 'rp_violations', 'sliding_rp_violation',
        'amplitude_cutoff', 'amplitude_median', 'amplitude_cv_median',
        'amplitude_cv_range', 'sync_spike_2', 'sync_spike_4', 'sync_spike_8',
        'firing_range', 'drift_ptp', 'drift_std', 'drift_mad',
        'isolation_distance', 'l_ratio', 'd_prime', 'silhouette', 'nn_hit_rate',
        'nn_miss_rate', 'peak_to_valley', 'peak_trough_ratio', 'half_width',
        'repolarization_slope', 'recovery_slope', 'num_positive_peaks',
        'num_negative_peaks', 'velocity_above', 'velocity_below', 'exp_decay',
        'spread']

missing_metrics = [metric for metric in required_metrics if metric not in metrics_list]
extra_metrics = [metric for metric in metrics_list if metric not in required_metrics]
print(missing_metrics)
print(extra_metrics)

[]
['sd_ratio']


In [None]:
import os

class AnoushkaAutoLabel():
        def __init__(self):
                # TODO: make dict of models
                self.noise_neuron_model = None
                self.sua_mua_model = None

                self.calculated_metrics = None

                # TODO: fix such that predict_labels returns a DataFrame
                self.output_folder = None

        def predict_labels(self):
                # Define features
                self.X_columns = self.calculated_metrics.columns.to_list()
                
                # Prepare input data
                input_data = self.calculated_metrics[self.X_columns]
                input_data[np.isinf(input_data)] = np.nan
                input_data = input_data.astype('float32')

                # Apply noise classifier
                noise_predictions = self.noise_neuron_model.predict(input_data[self.X_columns])
                noise_probs =self.noise_neuron_model.predict_proba(input_data[self.X_columns])
                input_data['noise_label'] = noise_predictions
                input_data['noise_probs'] = noise_probs[:,1]
                input_data['noise_label'] = input_data['noise_label'].map({1: 'noise', 0: 'neural'})
                sua_predictions = self.sua_mua_model.predict(input_data[self.X_columns])
                sua_probs = self.sua_mua_model.predict_proba(input_data[self.X_columns])
                input_data['sua_label'] = sua_predictions
                input_data['sua_probs'] = sua_probs[:,1]
                input_data['sua_label'] = input_data['sua_label'].map({1 : 'sua', 0 :'mua'})
                # Create the 'decoder_label' column
                input_data['decoder_label'] = input_data.apply(lambda row: row['noise_label'] if row['noise_label'] == 'noise' else row['sua_label'], axis=1)
                # Create the 'decoder_probs' column
                input_data['decoder_probs'] = input_data.apply(lambda row: row['noise_probs'] if row['decoder_label'] == 'noise' else row['sua_probs'], axis=1)
                # Update 'decoder_probs' for 'mua' values
                input_data.loc[input_data['decoder_label'] == 'mua', 'decoder_probs'] = 1 - input_data['decoder_probs']
                # Save the result to a CSV file
                input_data.to_csv(os.path.join(self.output_folder, 'decoder_output_dataframe.csv'))
                print('Decoder output saved to decoder_output_dataframe.csv')

                # TODO: return DataFrame and set as SortingAnalyzer.sorting property

        def check_required_metrics(self, calculated_metrics):

                metrics_list = calculated_metrics.columns.to_list()
                
                # TODO: make dynamic (read from pkl)
                required_metrics = ['num_spikes', 'firing_rate',
    'presence_ratio', 'snr', 'isi_violations_ratio', 'isi_violations_count',
    'rp_contamination', 'rp_violations', 'sliding_rp_violation',
    'amplitude_cutoff', 'amplitude_median', 'amplitude_cv_median',
    'amplitude_cv_range', 'sync_spike_2', 'sync_spike_4', 'sync_spike_8',
    'firing_range', 'drift_ptp', 'drift_std', 'drift_mad',
    'isolation_distance', 'l_ratio', 'd_prime', 'silhouette', 'nn_hit_rate',
    'nn_miss_rate', 'peak_to_valley', 'peak_trough_ratio', 'half_width',
    'repolarization_slope', 'recovery_slope', 'num_positive_peaks',
    'num_negative_peaks', 'velocity_above', 'velocity_below', 'exp_decay',
    'spread']
                
                missing_metrics = [metric for metric in required_metrics if metric not in metrics_list]
                extra_metrics = [metric for metric in metrics_list if metric not in required_metrics]
                
                if len(missing_metrics) > 0:
                        raise ValueError(f"Missing metrics: {missing_metrics}")
                elif len(extra_metrics) > 0:
                        calculated_metrics = calculated_metrics.drop(extra_metrics, axis = 1)
                        print(f"Extra metrics: {extra_metrics}. Dropping before model application.")
                else:
                        print('Metric list is complete. Proceeding to model application.')

                # Reorder columns to match the model
                calculated_metrics = calculated_metrics[required_metrics]

                return calculated_metrics
                
                