## Objective
Compare RT-Sort to other sorters by:
1. Pair units based on Hungarian match overlap score
2. For each pair:
   1. Store RT-Sort pre-recording spikes mean and STD latency and amplitude on all elecs
   2. Find spikes detected by both, RT-Sort only, other sorter only
   3. Plot and save mean footprints
   4. For each group
       1. For each spike:
           1. Find latency and amplitude on all elecs

Data will be stored as:\
One file per sorter, representing RT-Sort compared to that sorter. File will be list where element is a dictionary. The dictionary represents an RT-Sort sequence. It has the items:
- "idx": Index of sequence in footprint plots (which is the index of the sequence if all sequences were in one list). Sequences with less than 2 spikes are not included since 2 spikes are needed for calculating the STD
- "location": (x, y) location of sequence's root electrode
- "root_elec": Index of root electrode
- "inner_loose_elecs": Indices of inner loose electrodes (including root)
- "loose_elecs": Indices of loose electrodes (including root)
- "footprint_elecs": Indices of footprint electrodes (including root)
- "means": Numpy array with shape (num_elecs, 2) where each row represents an electrode. The columns are (starting with column 0):
  - Mean latency
  - Mean amplitude
- "stds": Same as "means" except the STD rather than the mean
  
- "other_unit_idx:" Index of the other unit
- "overlap_score": Overlap score between RT-Sort sequence and other sorter's unit.
  
- "matching_spikes": List where each element is a dictionary that represents a spike detected by both RT-Sort and other sorter.
  - "time": Time of spike in milliseconds, relative to start of the recording (not relative to the testing 5-10 minutes region)
  - "elecs": Numpy array with shape (num_elecs, 2) where each row represents an electrode. The columns are (starting with column 0):
    - Spike's latency
    - Spike's amplitude
- "rt_sort_only_spikes": Same as "matching_spikes" but spikes only detected by RT-Sort
- "other_sorter_only_spikes": Same as "matching_spikes" but spikes only detected by the other sorter
   
240220 - RT-Sort is definitely undermerging and misses many of other sorter's spikes while having consensus spikes, so using score_formula=#matches/#rt_sort and each unit from one sorter can be paired with more than one unit from other sorter

*Note: "amplitude" refers to standardized amplitude

## Globals setup

In [1]:
%load_ext autoreload

In [2]:
from multiprocessing import Pool
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from spikeinterface.extractors import NpzSortingExtractor

from tqdm import tqdm

%autoreload 2
from src.comparison import Comparison, DummySorter
from src import utils

In [4]:
SORTERS_ROOT = Path("/data/MEAprojects/spikeinterface/spiketrains/mouse412804_probeC")
TRACES = np.load("/data/MEAprojects/dandi/000034/sub-mouse412804/rt_sort/dl_model/240318/scaled_traces.npy", mmap_mode="r")  
SEQUENCES_PATH = Path("/data/MEAprojects/dandi/000034/sub-mouse412804/rt_sort/240319/tested_sequences.pickle")
ELEC_LOCS = utils.rec_si().get_channel_locations()
SAMP_FREQ = 30  # kHz
N_BEFORE = N_AFTER = round(0.5 * SAMP_FREQ)  # round(0.3 * SAMP_FREQ)
PRE_MEDIAN_FRAMES = round(50 * SAMP_FREQ)  # If None, use entire recording for SNR
MAX_FR = 1000  # for units, Hz

TIME_FRAME = (5*60*1000, min(10*60*1000, TRACES.shape[1]/SAMP_FREQ))  # start_ms, end_ms

##
SEQUENCES = utils.pickle_load(SEQUENCES_PATH)

  % (ns['name'], ns['version'], self.__namespaces.get(ns['name'])['version']))
  % (ns['name'], ns['version'], self.__namespaces.get(ns['name'])['version']))
  "Your data may be transposed." % (self.__class__.__name__, kwargs["name"]))


Recording does not have scaled traces. Setting gain to 0.195


In [5]:
class Sorter:
    # Wrapper of NpzSortingExtractor for Comparison
    def __init__(self, full_name, name):
        self.npz = NpzSortingExtractor(SORTERS_ROOT / full_name / "sorting_cached.npz")
        self.name = name
        
    def __len__(self):
        return len(self.get_spike_times()) 
    
    def get_spike_times(self):
        start_ms, end_ms = TIME_FRAME
        
        spike_times = []
        for uid in self.npz.get_unit_ids():
            times = self.npz.get_unit_spike_train(uid) / SAMP_FREQ
            times_ind = start_ms <= times
            times_ind *= times <= end_ms
            if MAX_FR * (end_ms - start_ms)/1000 >= sum(times_ind) > 0:
                spike_times.append(times[times_ind])
        return spike_times
    
HS = Sorter("herdingspikes", "herdingspikes")
KS = Sorter("kilosort2", "kilosort2")
IC = Sorter("ironclust", "ironclust")
TDC = Sorter("tridesclous", "tridesclous")
SC = Sorter("spykingcircus", "spykingcircus")
HDS = Sorter("hdsort", "hdsort")
OTHER_SORTERS = [KS, HS, IC, TDC, SC, HDS]

In [6]:
# If using SNR across entire recording, need to calculate that before running next cell
def _calc_snr(chan):
    return chan, np.clip(np.median(np.abs(TRACES[chan, :])) / 0.6745, a_min=0.5, a_max=None)  # clip to prevent divide by 0

if PRE_MEDIAN_FRAMES is None:
    tasks = range(TRACES.shape[0])
    PRE_MEDIANS = np.zeros(len(tasks), dtype=float)
    with Pool(processes=16) as pool:
        for chan, snr in tqdm(pool.imap_unordered(_calc_snr, tasks), total=len(tasks)):
            PRE_MEDIANS[chan] = snr

100%|██████████| 248/248 [00:04<00:00, 51.80it/s]


## Obtaining data

In [15]:
def job(task):
    seq_idx, other_sorter_idx = task
    seq = SEQUENCES[seq_idx]
    seq_spike_train = seq.spike_train
    if len(seq_spike_train) < 2:
        return None
    if PRE_MEDIAN_FRAMES is None:
        pre_medians = PRE_MEDIANS

    save_data = {
        "idx": seq.idx,
        "location": ELEC_LOCS[seq.root_elec],
        "root_elec": seq.root_elec,
        "inner_loose_elecs": seq.inner_loose_elecs,
        "loose_elecs": seq.loose_elecs,
        "footprint_elecs": seq.comp_elecs,
        "means": np.vstack((
            seq.all_latencies,
            seq.all_amp_medians
        )).T,
        "stds": np.vstack((
            np.std(seq.every_latency, axis=1, ddof=1),
            np.std(seq.every_amp_median, axis=1, ddof=1)
        )).T    

    } 
    other_sorter = OTHER_SORTERS[other_sorter_idx]
    best_score = -np.inf
    best_idx = None
    best_spike_trains = None  # Will be tuple of (matching, unmatching RT-Sort, unmatching other sorter)
    for other_idx, other_spike_train in enumerate(other_sorter.get_spike_times()):
        matching, seq_only, other_only = Comparison.get_matching_events(seq_spike_train, other_spike_train)
        score = len(matching) / (len(matching) + len(seq_only) + len(other_only))
        assert 0 <= score <= 1, "overlap score is calculated wrong"
        if score > best_score:
            best_score = score
            best_idx = other_idx
            best_spike_trains = (matching, seq_only, other_only)
    save_data["other_unit_idx"] = best_idx
    save_data["overlap_score"] = best_score
    
    array_elecs = np.arange(ELEC_LOCS.shape[0])  # for indexing so latency time matches with amplitude
    for name, spike_train in zip(("matching_spikes", "rt_sort_only_spikes", "other_sorter_only_spikes"), best_spike_trains):
        spike_data = []
        for spike in spike_train:
            # Get spike window
            frame = round(spike * SAMP_FREQ)
            this_n_before = N_BEFORE if frame - N_BEFORE >= 0 else frame  # Prevents indexing negative value as start
            rec_window = np.abs(TRACES[:, frame-this_n_before:frame+N_AFTER+1])
            
            # Get latencies
            latencies = np.argmax(rec_window, axis=1) - this_n_before
            if name != "other_sorter_only_spikes":
                frame_offset = 0  # Frame of spike is frame + frame_offset
                latencies[seq.root_elec] = 0
            else:
                frame_offset = latencies[seq.root_elec]
                latencies -= latencies[seq.root_elec]
            assert latencies[seq.root_elec] == 0, "Root electrode has a non-zero latency"
            
            # Get pre-medians
            if PRE_MEDIAN_FRAMES is not None:
                pre_medians = TRACES[:, max(0, frame-this_n_before-PRE_MEDIAN_FRAMES):frame-this_n_before]
                pre_medians = np.median(np.abs(pre_medians), axis=1)
                pre_medians = np.clip(pre_medians / 0.6745, a_min=0.5, a_max=None)  # Prevents median of 0
            
            # Get amps
            # amps = rec_window[array_elecs, this_n_before + latencies]  # index out of bounds error when accounting for other sorter spikes not being centered on root elec trough
            amps = np.abs(TRACES[array_elecs, frame+frame_offset+latencies])
            amps = amps / pre_medians 
            
            # Store data
            spike_data.append({
                "time": spike,
                "elecs": np.vstack((latencies, amps)).T
            })
        save_data[name] = spike_data
    
    return save_data

In [16]:
SAVE_ROOT = Path(SEQUENCES_PATH.parent / "consistent_spikes/prop_delay_0.5ms/snr_entire_recording")
SAVE_ROOT.mkdir(exist_ok=True, parents=True)

np.seterr("ignore")
##  
for sorter_idx in range(len(OTHER_SORTERS)):
    sorter_name = OTHER_SORTERS[sorter_idx].name
    print(sorter_name)
    all_save_data = []
    tasks = [(idx, sorter_idx) for idx in range(len(SEQUENCES))]
    with Pool(processes=16) as pool:
        for save_data in tqdm(pool.imap_unordered(job, tasks), total=len(tasks)):
            if save_data is not None:
                all_save_data.append(save_data)
    np.save(SAVE_ROOT / f"{sorter_name}.npy", np.array(all_save_data, dtype=object))
np.seterr("raise")

kilosort2


100%|██████████| 168/168 [00:58<00:00,  2.86it/s]


herdingspikes


100%|██████████| 168/168 [00:26<00:00,  6.37it/s]


ironclust


100%|██████████| 168/168 [02:40<00:00,  1.05it/s]


tridesclous


100%|██████████| 168/168 [00:50<00:00,  3.32it/s]


spykingcircus


100%|██████████| 168/168 [02:12<00:00,  1.27it/s]


hdsort


100%|██████████| 168/168 [01:19<00:00,  2.12it/s]


{'divide': 'ignore', 'over': 'ignore', 'under': 'ignore', 'invalid': 'ignore'}