# Using custom 'good/mua' labels

Kilosort4 determines if a unit is 'good' or 'mua' by computing a few metrics based on its correlogram (see methods in [the Kilosort paper](https://www.nature.com/articles/s41592-024-02232-7) for more details). This tutorial demonstrates how to change these labels using your own criteria.

After sorting our data, we need to load the results and determine where the labels need to be saved. You could also use the same variables as they're returned by `run_kilosort` if you're doing this in a single script using the API, but note that `st` will have 3 columns in that case instead of just 1 column as used in this notebook.

In [1]:
from pathlib import Path
import shutil

import numpy as np

from kilosort.run_kilosort import load_sorting

# Path to load existing sorting results from.
results_dir = Path('c:/users/jacob/.kilosort/.test_data/kilosort4/')
# Paths where new labels will be saved for use with Phy.
save_1 = results_dir / 'cluster_KSLabel.tsv'
save_2 = results_dir / 'cluster_group.tsv'
# Save a backup of KS4's original labels before overwriting (recommended).
shutil.copyfile(save_1, results_dir / 'cluster_KSLabel_backup.tsv')

# Load sorting results
ops, st, clu, similar_templates, is_ref, est_contam_rate, kept_spikes = \
    load_sorting(results_dir)

cluster_labels = np.unique(clu)  # integer label for each cluster
fs = ops['fs']                   # sampling rate

This will also load Kilosort4's default labels. For this tutorial we'll use those as a starting point, requiring that units satisfy the old criteria *and* the new criteria. You could also just ignore the old labels and only use your own.

In [2]:
# Option 1: Use existing labels as a starting point.
#           KS4 assigns "good" where is_ref is True, and "mua" otherwise.
label_good = is_ref.copy()

# Option 2: Ignore KS4's labels and only use your own criteria.
# label_good = np.ones(cluster_labels.size)

Some examples of other criteria you might want to use would be only labeling units 'good' if they have a firing rate above 1Hz and a contamination rate below 0.2. Whatever criteria you want to use, the process is the same: create a boolean array with shape (n_clusters,) that is True if the criteria is met, and False otherwise.

In [3]:
contam_good = est_contam_rate < 0.2   # this already has shape (n_clusters,)
fr_good = np.zeros(cluster_labels.size, dtype=bool)
for i, c in enumerate(cluster_labels):
    # Get all spikes assigned to this cluster
    spikes = st[clu == c]
    # Compute est. firing rate using your preferred method.
    # Note that this formula will not work well for units that drop in and out.
    fr = spikes.size / (spikes.max()/fs - spikes.min()/fs)
    if fr >= 1:
        fr_good[i] = True

# Update labels, requiring that all criteria hold for each cluster.
label_good = np.logical_and(label_good, contam_good, fr_good)

Another example would be to only assign "good" to units with a presence ratio above some fraction, say 0.5. This will restrict the "good" label to units that are detected for at least half of the recording.

This involves binning the data into large chunks to determine which periods of time each unit is active during. We want the bins to be large enough to not penalize units with low firing rates, but still small enough to capture periods when a unit is not detected. We recommend setting the number of bins such that each bin is around 5 minutes as a starting point.

In [4]:
# Formula adapted from https://github.com/AllenInstitute/ecephys_spike_sorting/

def presence_ratio(spike_train, num_bins, min_time, max_time, min_spike_pct=0.05):
    h, b = np.histogram(spike_train, np.linspace(min_time, max_time, num_bins))
    min_spikes = h.mean()*min_spike_pct

    # NOTE: Allen Institute formula leaves off the -1 to force the ratio to
    #       never reach 1.0. We've included it here because without it the ratio
    #       is biased too much for a small number of bins.
    return np.sum(h > min_spikes) / (num_bins - 1)

# Compute presence ratio for each cluster
presence = np.zeros(cluster_labels.size)
min_time = st.min()
max_time = st.max()
for i, c in enumerate(cluster_labels):
    spikes = st[clu == c]
    presence[i] = presence_ratio(spikes, 10, min_time, max_time)

presence_good = presence >= 0.5
# Update labels with the additional criteria.
label_good = np.logical_and(label_good, presence_good)

After we're finished changing labels, we need to save them again in the format expected by Phy.

In [None]:
# Convert True/False to 'good'/'mua'
ks_labels = ['good' if b else 'mua' for b in label_good]

# Write to two .tsv files.
with open(save_1, 'w') as f:
    f.write(f'cluster_id\tKSLabel\n')
    for i, p in enumerate(ks_labels):
        f.write(f'{i}\t{p}\n')
shutil.copyfile(save_1, save_2)