This Notebook/demo uses spike interface, [here is how to install Spike Interface](https://spikeinterface.readthedocs.io/en/stable/get_started/installation.html)

In the SI environment:

`pip install UnitMatchPy`

In [None]:
import sys
from pathlib import Path

import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import unitmatchpy.extract_raw_data as erd
import numpy as np 

## Load Data & get good units

Spike Interface can load in many different types of Ephys data look [here](https://spikeinterface.readthedocs.io/en/latest/modules/extractors.html) for documentation on function to read in different data formats. [Example data can be found here.](https://figshare.com/articles/dataset/UnitMatch_Demo_-_data/24305758/1)

In [None]:
#Make list of recordings/sortings to iterate over
recordings = [se.read_spikeglx(r'path/to/SpikeGLX/data', stream_name="imec0.ap")]

sortings = [se.read_kilosort(r'Path/To/KiloSort/Directory')]

#Will only make average waveforms for good units
extract_good_units_only = True

In [None]:
#Getting good units only
sortings[0].get_property_keys() #lists keys for attached properties if 'quality' is not suitable

#Good units which will be used in Unit Match
good_units = []
units_used = []
for i, sorting in enumerate(sortings):
    unit_ids_tmp = sorting.get_property('original_cluster_id')
    is_good_tmp = sorting.get_property('quality')
    good_units.append(np.stack((unit_ids_tmp,is_good_tmp), axis = 1))

    units_used.append(unit_ids_tmp)
    if extract_good_units_only is True:
        keep = np.argwhere(is_good_tmp == 'good').squeeze()
        sortings[i] = sorting.select_units(keep)
        


## Process average waveforms / templates

Beaware the spike interface method is different to the native unitmatch method in ExtractRawDemo.ipynb or in the MatLab version

In [None]:
# Pre-process the raw data
for recording in recordings:
    recording = spre.phase_shift(recording) #correct for time delay between recording channels
    recording = spre.highpass_filter(recording) #highpass

    # for motion correction, this can be very slow
    #Uncommented code below to do in session motion correction
    #recording = spre.correct_motion(recording, preset="nonrigid_fast_and_accurate")

In [None]:
#Split each recording/sorting into 2 halves                    
for i, sorting in enumerate(sortings):
    split_idx = recordings[i].get_num_samples() // 2

    split_sorting = []
    split_sorting.append(sorting.frame_slice(start_frame=0, end_frame=split_idx))
    split_sorting.append(sorting.frame_slice(start_frame=split_idx, end_frame=recordings[i].get_num_samples()))

    sortings[i] = split_sorting 

for i, recording in enumerate(recordings):
    split_idx = recording.get_num_samples() // 2

    split_recording = []
    split_recording.append(recording.frame_slice(start_frame=0, end_frame=split_idx))
    split_recording.append(recording.frame_slice(start_frame=split_idx, end_frame=recording.get_num_samples()))

    recordings[i] = split_recording
 

In [None]:
#create sorting analyzer for each pair
analysers = []
for i in range(len(recordings)):
    split_analysers = []

    split_analysers.append(si.create_sorting_analyzer(sortings[i][0], recordings[i][0], sparse=False))
    split_analysers.append(si.create_sorting_analyzer(sortings[i][1], recordings[i][1], sparse=False))
    analysers.append(split_analysers)

In [None]:
#create the fast template extension for each sorting analyser
all_waveforms = []
for i in range(len(analysers)):
    for half in range(2):
        analysers[i][half].compute(
            "random_spikes",
            method="uniform",
            max_spikes_per_unit=500)
        
        #Analysers[i][half].compute('fast_templates', n_jobs = 0.8,  return_scaled=True)
        analysers[i][half].compute('fast_templates', n_jobs = 0.8)
    
    templates_first = analysers[i][0].get_extension('fast_templates')
    templates_second = analysers[i][1].get_extension('fast_templates')
    t1 = templates_first.get_data()
    t2 = templates_second.get_data()
    all_waveforms.append(np.stack((t1,t2), axis = -1))

#Make a channel_positions array
all_positions = []
for i in range(len(analysers)):
    #positions for first half and second half are the same
    all_positions.append(analysers[i][0].get_channel_locations())

## Save extracted data in a unit match friendly folder

In [None]:
import os
UM_input_dir = os.path.join(os.getcwd(), 'UMInputData')

os.mkdir(UM_input_dir)

all_session_paths = []
for i in range(len(recordings)):
    session_x_path = os.path.join(UM_input_dir, f'Session{i+1}') #lets start at 1
    os.mkdir(session_x_path)

    #save the GoodUnits as a .rsv first column is unit ID,second is 'good' or 'mua'
    good_units_path = os.path.join(session_x_path, 'cluster_group.tsv')
    channel_positions_path = os.path.join(session_x_path, 'channel_positions.npy')
    save_good_units = np.vstack((np.array(('cluster_id', 'group')), good_units[i])) #Title of colum one is '0000' Not 'cluster_id')
    save_good_units[0,0] = 0 # need to be int to use np.savetxt 
    np.savetxt(good_units_path, save_good_units, fmt =['%d','%s'], delimiter='\t')
    if extract_good_units_only:
        Units = np.argwhere(good_units[0][:,1] == 'good')
        erd.save_avg_waveforms(all_waveforms[i], session_x_path, Units, ExtractGoodUnitsOnly = extract_good_units_only)
    else:
        erd.save_avg_waveforms(all_waveforms[i], session_x_path, good_units[i], ExtractGoodUnitsOnly = extract_good_units_only)
    np.save(channel_positions_path, all_positions[i])

    all_session_paths.append(session_x_path)

## Run UnitMatch

In [None]:
%load_ext autoreload
%autoreload 

import UnitMatchPy.bayes_functions as bf
import UnitMatchPy.utils as util
import UnitMatchPy.overlord as ov
import numpy as np
import matplotlib.pyplot as plt
import UnitMatchPy.save_utils as su
import UnitMatchPy.GUI as gui
import UnitMatchPy.assign_unique_id as aid
import UnitMatchPy.default_params as default_params

In [None]:
#get default parameters, can add your own before or after!

# default of Spikeinterface as by default spike interface extracts waveforms in a different manner.
param = {'SpikeWidth': 90, 'waveidx': np.arange(20,50), 'PeakLoc': 35}
param = default_params.get_default_param()

KS_dirs = [r'path/to/KSdir/Session1', r'Path/to/KSdir/Session2']

param['KS_dirs'] = KS_dirs
wave_paths, unit_label_paths, channel_pos = util.paths_from_KS(KS_dirs)
param = util.get_probe_geometry(channel_pos[0], param)

In [None]:
def zero_center_waveform(waveform):
    """
    Centers waveform about zero, by subtracting the mean of the first 15 time points.
    This function is useful for Spike Interface where the waveforms are not centered about 0.

    Arguments:
        waveform - ndarray (nUnits, Time Points, Channels, CV)

    Returns:
        Zero centered waveform
    """
    waveform = waveform -  np.broadcast_to(waveform[:,:15,:,:].mean(axis=1)[:, np.newaxis,:,:], waveform.shape)
    return waveform

In [None]:
#read in data and select the good units and exact metadata
waveform, session_id, session_switch, within_session, good_units, param = util.load_good_waveforms(wave_paths, unit_label_paths, param, good_units_only = True) 

#param['peak_loc'] = #may need to set as a value if the peak location is NOT ~ half the spike width

# create clus_info, contains all unit id/session related info
clus_info = {'good_units' : good_units, 'session_switch' : session_switch, 'sessions_id' : session_id, 
            'original_ids' : np.concatenate(good_units) }

#Extract parameters from waveform
extracted_wave_properties = ov.extract_parameters(waveform, channel_pos, clus_info, param)

#Extract metric scores
total_score, candidate_pairs, scores_to_include, predictors  = ov.extract_metric_scores(extracted_wave_properties, session_switch, within_session, param, niter  = 2)

#Probability analysis
prior_match = 1 - (param['n_expected_matches'] / param['n_units']**2 ) # freedom of choose in prior prob
priors = np.array((prior_match, 1-prior_match))

labels = candidate_pairs.astype(int)
cond = np.unique(labels)
score_vector = param['score_vector']
parameter_kernels = np.full((len(score_vector), len(scores_to_include), len(cond)), np.nan)

parameter_kernels = bf.get_parameter_kernels(scores_to_include, labels, cond, param, add_one = 1)

probability = bf.apply_naive_bayes(parameter_kernels, priors, predictors, param, cond)

output_prob_matrix = probability[:,1].reshape(param['n_units'],param['n_units'])

In [None]:
util.evaluate_output(output_prob_matrix, param, within_session, session_switch, match_threshold = 0.75)

match_threshold = param['match_threshold']
#match_threshold = try different values here!

output_threshold = np.zeros_like(output_prob_matrix)
output_threshold[output_prob_matrix > match_threshold] = 1

plt.imshow(output_threshold, cmap = 'Greys')


In [None]:
amplitude = extracted_wave_properties['amplitude']
spatial_decay = extracted_wave_properties['spatial_decay']
avg_centroid = extracted_wave_properties['avg_centroid']
avg_waveform = extracted_wave_properties['avg_waveform']
avg_waveform_per_tp = extracted_wave_properties['avg_waveform_per_tp']
wave_idx = extracted_wave_properties['good_wave_idxs']
max_site = extracted_wave_properties['max_site']
max_site_mean = extracted_wave_properties['max_site_mean']
gui.process_info_for_GUI(output_prob_matrix, match_threshold, scores_to_include, total_score, amplitude, spatial_decay,
                         avg_centroid, avg_waveform, avg_waveform_per_tp, wave_idx, max_site, max_site_mean, 
                         waveform, within_session, channel_pos, clus_info, param)

In [None]:
is_match, not_match, matches_GUI = gui.run_GUI()

In [None]:
#this function has 2 mode 'And' 'Or', which returns a matches if they appear in both or one cv pair
#then it will add all the matches selected as IsMaatch, then remove all matches in NotMatch
matches_curated = util.curate_matches(matches_GUI, is_match, not_match, mode = 'And')

In [None]:
matches = np.argwhere(match_threshold == 1)
UIDs = aid.assign_unique_id(output_prob_matrix, param, clus_info)

save_dir = r'Path/to/save/directory'
#NOTE - change to matches to matches_curated if done manual curation with the GUI
su.save_to_output(save_dir, scores_to_include, matches # matches_curated
                  , output_prob_matrix, avg_centroid, avg_waveform, avg_waveform_per_tp, max_site,
                   total_score, output_threshold, clus_info, param, UIDs = UIDs, matches_curated = None, save_match_table = True)