# Model-Based Curation Tutorial

We will download a pre-trained model and apply on metrics files using Spikeinterface. 

In [29]:
import warnings
warnings.filterwarnings("ignore")
import pandas as pd

import spikeinterface.core as si

# note: you can use more cores using e.g.
# si.set_global_jobs_kwargs(n_jobs = 8)


## Download a pretrained model

Let's download a pretrained model from [Hugging Face](https://huggingface.co/) (HF). The
``load_model`` function allows us to download directly from HF, or use a model in a local
folder. The function downloads the model and saves it in a temporary folder and returns a
model and some metadata about the model.

In [30]:
# In the cell below If you're getting an import error, try adding the root path manually in your notebook:

## how to add root path manually is shown below 
# import sys
# sys.path.append(r"C:\Users\jain\Documents\GitHub\UnitRefine") # path to UnitRefine in locally directory

In [38]:
from UnitRefine.scripts.model_based_curation import load_model

## import from Spikeinterface package
#from spikeinterface.curation import load_model

In [32]:
model, model_info = load_model(
    repo_id = "SpikeInterface/UnitRefine_noise_neural_classifier",
    trusted = ['numpy.dtype']
)

In [33]:
model

The model object (an sklearn Pipeline) contains information about which metrics
were used to compute the model. We can access it from the model (or from the model_info)

# If you only have metrics files

If you don't have access to to sorting analyzer for a particular recording, you can still use the pretrained model to predict on a new data

In [34]:
# Load your files
metrics =  pd.read_csv('all_metrics.csv')

In [35]:
columns = model.feature_names_in_

classifier_preds = model.predict(metrics[columns])
classifier_probabs = model.predict_proba(metrics[columns])

In [36]:
print(pd.Series(classifier_preds).map({1: 'noise', 0: 'neural'}))


0       noise
1       noise
2       noise
3       noise
4       noise
        ...  
555    neural
556    neural
557     noise
558    neural
559    neural
Length: 560, dtype: object
