## Download Required Files

To get started, download the following two files:

1. **quality_metrics.csv**
2. **cluster_group.tsv**

Save these files to your computer.



In [19]:
import warnings
warnings.filterwarnings("ignore")
import pandas as pd 
import os

import spikeinterface.core as si
import spikeinterface.widgets as sw


# Note, you can set the number of cores you use using e.g.
# si.set_global_job_kwargs(n_jobs = 8)

In [20]:
# In the cell below If you're getting an import error, try adding the root path manually in your notebook:

## how to add root path manually is shown below 
# import sys
# sys.path.append(r"C:\Users\jain\Documents\GitHub\UnitRefine") # path to UnitRefine in locally directory

In [None]:
# Import necessary libraries  
from UnitRefine.scripts.train_manual_curation import train_model

In [15]:
# Load the data from 'quality_metrics.csv' into a DataFrame
metrics = pd.read_csv('quality_metrics.csv')

# Define a list of column names that correspond to various quality metrics
metrics_cols = [
    'num_spikes', 'firing_rate', 'presence_ratio', 'snr',
    'isi_violations_ratio', 'isi_violations_count', 'rp_contamination',
    'rp_violations', 'sliding_rp_violation', 'amplitude_cutoff',
    'amplitude_median', 'amplitude_cv_median', 'amplitude_cv_range',
    'sync_spike_2', 'sync_spike_4', 'sync_spike_8', 'firing_range',
    'drift_ptp', 'drift_std', 'drift_mad', 'isolation_distance', 'l_ratio',
    'd_prime', 'silhouette', 'nn_hit_rate', 'nn_miss_rate'
]


In [16]:
# Load human-curated labels from 'cluster_group.tsv' file
# This file contains information about the quality of each cluster

labels_to_train = pd.read_csv('cluster_group.tsv', sep='\t')  # Read the file, specifying tab as the separator

# Map label names to numerical values:
# 'good' clusters are labeled as 1, while 'mua' (multi-unit activity) and 'noise' are labeled as 0
labels_to_train = labels_to_train['group'].map({'good': 1, 'mua': 0, 'noise': 0}).to_list()


In [17]:
current_directory = os.getcwd()
print("Current working directory:", current_directory)

Current working directory: c:\Users\jain\Documents\GitHub\UnitRefine\UnitRefine\tutorial\train_model


### Step: Train the Classifier with Your Data

Load your dataset and the corresponding curated labels to train the classifier. Adjust parameters such as training data size, feature selection, or classifier type to observe their impact on model performance.

**Note:** For improved generalizability, train the model on multiple labeled recordings from varied conditions or sessions.

After training, the model will be saved as `best_model`, which can be used to make predictions on other recordings. The best-performing parameters can be seen using `best_model`.



In [18]:
# We will use a list of two (identical) csv here, we would advise using more than one to improve model performance
trainer = train_model(
    mode = "csv",
    labels = [labels_to_train,labels_to_train],
    metrics_paths = ['quality_metrics.csv','quality_metrics.csv'], # List of paths to the metrics files
    folder = current_directory, # Optional, can be set to save the model and model_info.json file
    metric_names = metrics_cols, # Can be set to specify which metrics to use for training
    imputation_strategies =  ["median"], # Default to all
    scaling_techniques =  ["standard_scaler"], # Default to all
    classifiers = None, # Default to Random Forest only. Other classifiers you can try [ "AdaBoostClassifier","GradientBoostingClassifier",
    test_size=0.2,                                                            # "LogisticRegression","MLPClassifier"]
    overwrite = True
)





Running RandomForestClassifier with imputation median and scaling StandardScaler()
BayesSearchCV from scikit-optimize not available, using RandomizedSearchCV


  values = values.astype(str)


In [None]:
trainer

In [7]:
best_model = trainer.best_pipeline


In [None]:
# Load and disply top pipelines and accuracies
accuracies = pd.read_csv("model_accuracies.csv", index_col = 0)
accuracies.head()

In [None]:
# Plot feature importances
import numpy as np
import matplotlib.pyplot as plt

importances = best_model.named_steps['classifier'].feature_importances_
indices = np.argsort(importances)[::-1]
features = best_model.feature_names_in_
n_features = best_model.n_features_in_

plt.figure(figsize=(12, 6))
plt.title("Feature Importances")
plt.bar(range(n_features), importances[indices], align="center")
plt.xticks(range(n_features), features, rotation=90)
plt.xlim([-1, n_features])
plt.show()

#Lets apply best_model on new data!

## Load new_data.csv from the UnitRefine folder and apply your best_model on it

In [None]:
# read new csv file
new_data = pd.read_csv('new_data.csv')

In [None]:
new_data.head()

In [11]:
output_labels = best_model.predict(new_data[metrics_cols])

In [12]:
output_probs = best_model.predict_proba(new_data[metrics_cols])

In [None]:
output_labels

In [None]:
output_probs[:,1]# certainty of predictions for SUA