## Objective
Determine most accurate way to count matching spikes between sorters

In [2]:
import numpy as np
import sys

In [40]:
def spikeinterface(times1, times2, delta=0.4):
    # The way spikeinterface does it
    times_concat = np.concatenate((times1, times2))
    membership = np.concatenate((np.ones(len(times1)) * 1, np.ones(len(times2)) * 2))
    indices = times_concat.argsort()
    times_concat_sorted = times_concat[indices]
    membership_sorted = membership[indices]
    # diffs = times_concat_sorted[1:] - times_concat_sorted[:-1]
    # inds = np.where((diffs <= delta) & (membership_sorted[:-1] != membership_sorted[1:]))[0]
    diffs = times_concat_sorted[1:] - times_concat_sorted[:-1] - delta
    inds = np.where((diffs <= 1e-4) & (membership_sorted[:-1] != membership_sorted[1:]))[0]
    if len(inds) == 0:
        return 0
    inds2 = np.where(inds[:-1] + 1 != inds[1:])[0]  # Prevents a spike being matched with more than one other spike
    return len(inds2) + 1


def searchsorted(times1, times2, delta=0.4):
    matched = []
    unmatched1 = []

    already_matched = set()
    for st1 in times1:
        idx = np.searchsorted(times2, st1)
        idx_left = idx - 1
        while idx_left in already_matched:
            idx_left -= 1
        if idx_left >= 0:
            left = times2[idx_left]
        else:
            left = -np.inf

        idx_right = idx
        while idx_right in already_matched:
            idx_right += 1
        if idx_right < len(times2):
            right = times2[idx_right]
        else:
            right = np.inf

        if right - st1 < st1 - left:
            if right - st1 - delta <= 1e-4:
                matched.append(st1)
                already_matched.add(idx_right)
            else:
                unmatched1.append(st1)
        else:
            if st1 - left - delta <= 1e-4:
                matched.append(st1)
                already_matched.add(idx_left)
            else:
                unmatched1.append(st1)

    unmatched2 = [times2[i] for i in range(len(times2)) if i not in already_matched]

    return len(matched)


def merge_count(times1, times2, delta=0.4):
    count = 0
    ptr1 = ptr2 = 0
    while ptr1 < len(times1) and ptr2 < len(times2):
        diff = abs(times1[ptr1] - times2[ptr2]) - delta
        if diff <= 1e-4:
            count += 1
            ptr1 += 1
            ptr2 += 1
        elif times1[ptr1] < times2[ptr2]:
            ptr1 += 1
        else:
            ptr2 += 1
    
    return count

In [13]:
def test_case(times1, times2, num_matches):
    print(f"Expected: {num_matches}")
    print(f"Spikeinterface: {spikeinterface(times1, times2)/num_matches*100}%")
    print(f"Search sorted: {searchsorted(times1, times2)/num_matches*100}%")
    print(f"Merge count: {merge_count(times1, times2)/num_matches*100}%")

In [69]:
times1 = [0, 1, 2, 3, 4, 5]
times2 = [0, 1, 2, 3, 4, 5]
num_matches = 6
test_case(times1, times2, num_matches)

Expected: 6


Spikeinterface: 100.0%
Search sorted: 100.0%
Merge count: 100.0%


In [30]:
test_case(
    [1, 1.8],
    [1.3, 2.1],
    2
)

Expected: 2
Spikeinterface: 100.0%
Search sorted: 100.0%
Merge count: 100.0%


In [31]:
test_case(
    [1, 1.8],
    [1.4, 2.2],
    2
)

Expected: 2
Spikeinterface: 50.0%
Search sorted: 100.0%
Merge count: 100.0%


In [32]:
test_case(
    [1, 1.2],
    [1, 1.2],
    2
)

Expected: 2
Spikeinterface: 50.0%
Search sorted: 100.0%
Merge count: 100.0%


In [71]:
test_case(
    [1, 1.4],
    [0.7, 1.1],
    2
)

"""
spikeinterface counts [1, 0.7]
searchsorted counts [1, 1.1]
merge_count matches [1, 0.7] and [1.4, 1.1]
"""

Expected: 2
Spikeinterface: 50.0%
Search sorted: 50.0%
Merge count: 100.0%


'\nspikeinterface counts [1, 0.7]\nsearchsorted counts [1, 1.1]\nmerge_count matches [1, 0.7] and [1.4, 1.1]\n'

In [73]:
test_case(
    [1, 2],
    [0.7, 1.1, 2.2],
    2
)

Expected: 2
Spikeinterface: 100.0%
Search sorted: 100.0%
Merge count: 100.0%


In [64]:
N = 10000
NUM_MATCHES = 3000

DELTA = 0.4
np.random.seed(56)
##
times1 = np.arange(N)
times1 = times1 + np.random.random(N)
times1 = np.sort(times1)

times2 = np.random.choice(times1, NUM_MATCHES, replace=False)
times2 = times2 + np.random.random(NUM_MATCHES) * DELTA
times2 = np.concatenate((times2, np.arange(N*3, N*3+(N-NUM_MATCHES))))
times2 = np.sort(times2)

test_case(times1, times2, NUM_MATCHES)

Expected: 3000
Spikeinterface: 94.06666666666666%
Search sorted: 100.0%
Merge count: 100.0%
