This Notebook/demo uses spike interface, [here is how to install Spike Interface](https://spikeinterface.readthedocs.io/en/stable/get_started/installation.html)
Note that for this demo we used spike interface  version  0.103.0
It is possible that spike interface changes and we don't notice. In that case please update the SI related blocks with SI's new ways of extracting raw data and waveforms from raw data according to their ReadMe. The UnitMatch side of things should remain stable. An example blogpost can also be found here: https://olebialas.github.io/posts/2025-10-02-unitmatch/

In the SI environment:

`pip install UnitMatchPy`

OR in the UM environment:

`pip install spikeinterface`


In [None]:
import sys
from pathlib import Path

import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import UnitMatchPy.extract_raw_data as erd
import numpy as np 

## Load Data & get good units

Spike Interface can load in many different types of Ephys data look [here](https://spikeinterface.readthedocs.io/en/latest/modules/extractors.html) for documentation on function to read in different data formats. [Example data can be found here.](https://figshare.com/articles/dataset/UnitMatch_Demo_-_data/24305758/1)

In [None]:
#Make list of recordings/sortings to iterate over - depends on whether you have compressed data or not - see SI documentation
# recordings = [se.read_spikeglx(r'Path/To/bin/1', stream_name="imec0.ap"), se.read_spikeglx(r'Path/To/bin/2', stream_name="imec0.ap")] # Note, only applicable when you don't have the raw waveforms yet. SKIP when using our demo data.
recordings = [se.read_cbin_ibl(r'Path/To/Cbin/1', stream_name="ap"), se.read_cbin_ibl(r'Path/To/Cbin/2/', stream_name="ap")] # you can add more sessions if you want

# Note, read_kilosort will only read good units by default. To read all units, use the argument 'load_good_only=False'
KS_dirs = [r'/Path/To/KS/1',r'/Path/To/KS/2'] # you can add more sessions if you want
sortings = [se.read_kilosort(KS_dirs[0]), se.read_kilosort(KS_dirs[1])] # you can add more sessions if you want


## Process average waveforms / templates

Beaware the spike interface method is different to the native unitmatch method in ExtractRawDemo.ipynb or in the MatLab version

In [None]:
# Pre-process the raw data
# Note, only applicable when you don't have the raw waveforms yet. SKIP when using our demo data. 
for recording in recordings:
    recording = spre.phase_shift(recording) #correct for time delay between recording channels
    recording = spre.highpass_filter(recording) #highpass

    # for motion correction, this can be very slow
    #Uncommented code below to do in session motion correction
    #recording = spre.correct_motion(recording, preset="nonrigid_fast_and_accurate")

In [None]:
#Split each recording/sorting into 2 halves               
# # Note, only applicable when you don't have the raw waveforms yet. SKIP when using our demo data.     
for i, sorting in enumerate(sortings):
    split_idx = recordings[i].get_num_samples() // 2

    split_sorting = []
    split_sorting.append(sorting.frame_slice(start_frame=0, end_frame=split_idx))
    split_sorting.append(sorting.frame_slice(start_frame=split_idx, end_frame=recordings[i].get_num_samples()))

    sortings[i] = split_sorting 

for i, recording in enumerate(recordings):
    split_idx = recording.get_num_samples() // 2

    split_recording = []
    split_recording.append(recording.frame_slice(start_frame=0, end_frame=split_idx))
    split_recording.append(recording.frame_slice(start_frame=split_idx, end_frame=recording.get_num_samples()))

    recordings[i] = split_recording
 

In [None]:
# create sorting analyzer for each pair
# Note, only applicable when you don't have the raw waveforms yet. SKIP when using our demo data. 

analysers = []
for i in range(len(recordings)):
    split_analysers = []

    split_analysers.append(si.create_sorting_analyzer(sortings[i][0], recordings[i][0], sparse=False))
    split_analysers.append(si.create_sorting_analyzer(sortings[i][1], recordings[i][1], sparse=False))
    analysers.append(split_analysers)

In [None]:
#create templates using SortingAnalyzer (SI >= 0.101)
# Note, only applicable when you don't have the raw waveforms yet. SKIP when using our demo data. 
# NOTE: We now use a per-session compute+save pipeline in the next cell
# to reduce memory pressure. This cell intentionally does not compute
# or store all sessions' waveforms/templates in memory.
USE_PER_SESSION_PIPELINE = True
print('Using per-session compute+save pipeline. Skip heavy batch compute.')

## Save extracted data in a unit match friendly folder

In [None]:
# Note, only applicable when you don't have the raw waveforms yet. SKIP when using our demo data. 

# Per-session compute + save to UnitMatch-friendly folder to keep memory footprint low
# Note: If UMInputData already exists, we reuse it.
import os, gc, shutil
from joblib import parallel_backend
from threadpoolctl import threadpool_limits

UM_input_dir = os.path.join(os.getcwd(), 'UMInputData')
os.makedirs(UM_input_dir, exist_ok=True)
all_session_paths = []

for i in range(len(analysers)):
    session_x_path = os.path.join(UM_input_dir, f'Session{i+1}')
    os.makedirs(session_x_path, exist_ok=True)

    # 1) Save good unit labels for this session
    good_units_path = os.path.join(session_x_path, 'cluster_group.tsv')
    channel_positions_path = os.path.join(session_x_path, 'channel_positions.npy')

    # Prefer existing 'good_units' if present; otherwise copy from KS output if available
    try:
        _ = good_units  # check variable exists
        save_good_units = np.vstack((np.array(('cluster_id', 'group')), good_units[i]))
        save_good_units[0,0] = 0
        np.savetxt(good_units_path, save_good_units, fmt=['%d','%s'], delimiter='	')
    except NameError:
        ks_label = os.path.join(KS_dirs[i], 'cluster_group.tsv')
        if os.path.exists(ks_label):
            shutil.copy2(ks_label, good_units_path)
        else:
            print(f'Warning: no good_units or cluster_group.tsv found for session {i+1}')

    # 2) Compute templates for both halves and save waveforms immediately
    t_halves = []
    for half in range(2):
        ana = analysers[i][half]
        # preselect spikes at low count to limit memory
        ana.compute('random_spikes', method='uniform', max_spikes_per_unit=100)
        with threadpool_limits(1), parallel_backend('threading'):
            ana.compute('waveforms', ms_before=1.0, ms_after=2.0, dtype='float32',
                       n_jobs=2, chunk_duration='100ms', total_memory='128M',
                       save=False, progress_bar=True)
            ana.compute('templates', n_jobs=2, chunk_duration='100ms', total_memory='128M',
                       progress_bar=True)
        t = ana.get_extension('templates').get_data()
        t_halves.append(t)
        # attempt to release extensions to shrink memory
        for ext in ('templates','waveforms','random_spikes'):
            try:
                ana.delete_extension(ext)
            except Exception:
                try:
                    ana.remove_extension(ext)
                except Exception:
                    pass

    avg_waves = np.stack((t_halves[0], t_halves[1]), axis=-1)
    # Channel positions (same for both halves)
    np.save(channel_positions_path, analysers[i][0].get_channel_locations())

    # 3) Save per-session average waveforms in UnitMatch format
    all_unit_ids = np.array(analysers[i][0].sorting.get_unit_ids(), dtype=int)
    extract_good_only = bool(globals().get('extract_good_units_only', False))

    # Derive good unit ids if needed
    good_ids = all_unit_ids
    if extract_good_only:
        try:
            if 'good_units' in globals():
                gu = good_units[i]
                if getattr(gu, 'ndim', 1) == 2 and gu.shape[1] >= 2:
                    good_ids = gu[gu[:, 1] == 'good', 0].astype(int)
                else:
                    good_ids = np.array(gu, dtype=int).ravel()
            else:
                # Fallback: read the saved TSV
                tbl = np.genfromtxt(good_units_path, delimiter='\t', names=True, dtype=None, encoding='utf-8')
                good_ids = tbl['cluster_id'][tbl['group'] == 'good'].astype(int)
        except Exception as e:
            print(f'Warning: could not parse good units for session {i+1}: {e}; saving all units.')
            extract_good_only = False
            good_ids = all_unit_ids

    erd.save_avg_waveforms(avg_waves, session_x_path, all_unit_ids, good_ids, extract_good_units_only=extract_good_only)


    # 4) Free memory aggressively before next session
    t_halves = None; avg_waves = None
    try:
        analysers[i] = [None, None]
    except Exception:
        pass
    try:
        recordings[i] = None; sortings[i] = None
    except Exception:
        pass
    gc.collect()

    all_session_paths.append(session_x_path)

print('Per-session saving complete. Proceed to Run UnitMatch section.')

## Run UnitMatch

In [None]:
%load_ext autoreload
%autoreload 

import UnitMatchPy.bayes_functions as bf
import UnitMatchPy.utils as util
import UnitMatchPy.overlord as ov
import numpy as np
import matplotlib.pyplot as plt
import UnitMatchPy.save_utils as su
import UnitMatchPy.GUI as gui
import UnitMatchPy.assign_unique_id as aid
import UnitMatchPy.default_params as default_params

In [None]:
#get default parameters, can add your own before or after!

# default of Spikeinterface as by default spike interface extracts waveforms in a different manner.
param = {'SpikeWidth': 90, 'waveidx': np.arange(20,50), 'PeakLoc': 35}
param = default_params.get_default_param()

# Point UnitMatch to the per-session UMInputData folders we just created
KS_dirs = all_session_paths
param['KS_dirs'] = KS_dirs
wave_paths, unit_label_paths, channel_pos = util.paths_from_KS(KS_dirs)
param = util.get_probe_geometry(channel_pos[0], param)

In [None]:
def zero_center_waveform(waveform):
    """
    Centers waveform about zero, by subtracting the mean of the first 15 time points.
    This function is useful for Spike Interface where the waveforms are not centered about 0.

    Arguments:
        waveform - ndarray (nUnits, Time Points, Channels, CV)

    Returns:
        Zero centered waveform
    """
    waveform = waveform -  np.broadcast_to(waveform[:,:15,:,:].mean(axis=1)[:, np.newaxis,:,:], waveform.shape)
    return waveform

In [None]:
#read in data and select the good units and exact metadata
waveform, session_id, session_switch, within_session, good_units, param = util.load_good_waveforms(wave_paths, unit_label_paths, param, good_units_only = True) 

#param['peak_loc'] = #may need to set as a value if the peak location is NOT ~ half the spike width

# create clus_info, contains all unit id/session related info
clus_info = {'good_units' : good_units, 'session_switch' : session_switch, 'session_id' : session_id, 
            'original_ids' : np.concatenate(good_units) }

#Extract parameters from waveform
extracted_wave_properties = ov.extract_parameters(waveform, channel_pos, clus_info, param)

#Extract metric scores
total_score, candidate_pairs, scores_to_include, predictors  = ov.extract_metric_scores(extracted_wave_properties, session_switch, within_session, param, niter  = 2)

#Probability analysis
prior_match = 1 - (param['n_expected_matches'] / param['n_units']**2 ) # freedom of choose in prior prob
priors = np.array((prior_match, 1-prior_match))

labels = candidate_pairs.astype(int)
cond = np.unique(labels)
score_vector = param['score_vector']
parameter_kernels = np.full((len(score_vector), len(scores_to_include), len(cond)), np.nan)

parameter_kernels = bf.get_parameter_kernels(scores_to_include, labels, cond, param, add_one = 1)

probability = bf.apply_naive_bayes(parameter_kernels, priors, predictors, param, cond)

output_prob_matrix = probability[:,1].reshape(param['n_units'],param['n_units'])

In [None]:
util.evaluate_output(output_prob_matrix, param, within_session, session_switch, match_threshold = 0.75)

match_threshold = param['match_threshold']
#match_threshold = try different values here!

output_threshold = np.zeros_like(output_prob_matrix)
output_threshold[output_prob_matrix > match_threshold] = 1

plt.imshow(output_threshold, cmap = 'Greys')


In [None]:
for k in ['amplitude','spatial_decay','avg_centroid','avg_waveform',
          'avg_waveform_per_tp','good_wave_idxs','max_site','max_site_mean']:
    v = extracted_wave_properties[k]
    try:
        print(k, getattr(v, 'shape', None), type(v))
    except Exception as e:
        print(k, type(v), e)


amplitude = extracted_wave_properties['amplitude']
spatial_decay = extracted_wave_properties['spatial_decay']
avg_centroid = extracted_wave_properties['avg_centroid']
avg_waveform = extracted_wave_properties['avg_waveform']
avg_waveform_per_tp = extracted_wave_properties['avg_waveform_per_tp']
wave_idx = extracted_wave_properties['good_wave_idxs']
max_site = extracted_wave_properties['max_site']
max_site_mean = extracted_wave_properties['max_site_mean']
gui.process_info_for_GUI(output_prob_matrix, match_threshold, scores_to_include, total_score, amplitude, spatial_decay,
                         avg_centroid, avg_waveform, avg_waveform_per_tp, wave_idx, max_site, max_site_mean, 
                         waveform, within_session, channel_pos, clus_info, param)

In [None]:
is_match, not_match, matches_GUI = gui.run_GUI()

In [None]:
#this function has 2 mode 'And' 'Or', which returns a matches if they appear in both or one cv pair
#then it will add all the matches selected as IsMaatch, then remove all matches in NotMatch
matches_curated = util.curate_matches(matches_GUI, is_match, not_match, mode = 'And')

In [None]:
matches = np.argwhere(match_threshold == 1)
UIDs = aid.assign_unique_id(output_prob_matrix, param, clus_info)

save_dir = r'Path/To/save/directory'
#NOTE - change to matches to matches_curated if done manual curation with the GUI
su.save_to_output(save_dir, scores_to_include, matches # matches_curated
                  , output_prob_matrix, avg_centroid, avg_waveform, avg_waveform_per_tp, max_site,
                   total_score, output_threshold, clus_info, param, UIDs = UIDs, matches_curated = None, save_match_table = True)