In [2]:
from pathlib import Path

import numpy as np
import plotly.graph_objects as go

import lussac.utils as utils
import spikeinterface.core as si
import spikeinterface.comparison as sc
import spikeinterface.qualitymetrics as sqm
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw

si.set_global_job_kwargs(n_jobs=6)

In [15]:
data_folder = Path("/media/aurelien/Main_SSD/SI_Hackathon_2024")
recording = si.load_extractor(data_folder / "recording.bin")  # drift, motion corrected
gt_sorting = si.load_extractor(data_folder / "gt_sorting")

utils.Utils.t_max = recording.get_num_frames()
utils.Utils.sampling_frequency = recording.sampling_frequency

In [4]:
# Recording is already motion corrected, so turn it to false.
if (sc2_folder := (data_folder / "sc2_default")).exists():
    sorting = si.load_extractor(sc2_folder / "sorter_output" / "sorting")
else:
    sorting = ss.run_sorter("spykingcircus2", recording, output_folder=data_folder / "sc2_default", apply_motion_correction=False,
                            merging={}, verbose=True, remove_existing_folder=True)

In [5]:
firing_rate = sorting.count_num_spikes_per_unit(outputs="array") / recording.get_num_frames() * recording.sampling_frequency
sorting = sorting.select_units(sorting.unit_ids[firing_rate > 0.8])

In [6]:
if (analyzer_folder := (data_folder / "analyzer_sc2")).exists():
    analyzer = si.load_sorting_analyzer(analyzer_folder)
else:
    analyzer = si.create_sorting_analyzer(sorting, recording, format="binary_folder", folder=analyzer_folder)
    analyzer.compute({
        'noise_levels': {},
        'random_spikes': {},
        'waveforms': {'ms_before': 1.2, 'ms_after': 2.3},
        'templates': {},
        'correlograms': {'bin_ms': 0.4},
        'spike_amplitudes': {},
        'unit_locations': {'method': "monopolar_triangulation"},
        'template_similarity': {}
    })

In [7]:
comp = sc.compare_sorter_to_ground_truth(gt_sorting, sorting, exhaustive_gt=True)

In [8]:
n_spikes = sorting.count_num_spikes_per_unit(outputs="array")
n_true_spikes = gt_sorting.count_num_spikes_per_unit(outputs="array")
n_coincidence = comp.match_event_count

precision_matrix = n_coincidence / n_spikes[None, :]
recall_matrix = n_coincidence / n_true_spikes[:, None]
agreement_matrix = n_coincidence / (n_spikes[None, :] + n_true_spikes[:, None] - n_coincidence)

In [9]:
second_best_recall = np.sort(np.array(recall_matrix), axis=0)[-2, :]
best_precision = np.max(precision_matrix, axis=0)

good_units_mask = second_best_recall < 1/3  # Merged units
good_units_mask &= best_precision > 0.5

matching = {}
matched = {}
for unit_id in sorting.unit_ids[good_units_mask]:
    neuron_id = agreement_matrix[unit_id].idxmax()

    matched[unit_id] = neuron_id
    if neuron_id in matching:
        matching[neuron_id].append(unit_id)
    else:
        matching[neuron_id] = [unit_id]

matching

{'#859': [14],
 '#2693': [22],
 '#1449': [32, 33],
 '#1366': [35, 88],
 '#919': [36, 94, 96],
 '#1565': [37, 296],
 '#1399': [38, 300],
 '#6': [40, 108],
 '#1606': [42],
 '#391': [43, 44],
 '#2743': [47],
 '#2506': [48, 49],
 '#507': [51, 306],
 '#1004': [52, 129, 130],
 '#855': [53, 313, 314],
 '#2279': [55, 58, 59],
 '#1846': [56],
 '#1766': [61],
 '#1066': [63, 64],
 '#1964': [66, 137],
 '#2765': [67],
 '#2313': [69, 70, 71],
 '#1340': [73],
 '#270': [74, 141, 378],
 '#2402': [77],
 '#421': [78, 79, 142, 144],
 '#2301': [82],
 '#2461': [83],
 '#2018': [89, 387],
 '#2410': [90, 332, 333],
 '#2090': [93],
 '#1434': [98],
 '#1609': [99],
 '#2311': [100, 101],
 '#2469': [103, 105, 166, 350],
 '#1691': [110],
 '#378': [115, 117],
 '#1613': [119, 120, 121],
 '#2456': [122, 123, 407],
 '#2998': [126, 128, 308],
 '#1080': [132, 133],
 '#2421': [134, 326],
 '#2658': [138, 206, 208, 379],
 '#2519': [146, 147, 148],
 '#565': [149],
 '#1611': [151, 152, 153],
 '#2810': [154, 156, 297, 391, 392]

In [114]:
sw.plot_sorting_summary(analyzer, backend="spikeinterface_gui")

<spikeinterface.widgets.sorting_summary.SortingSummaryWidget at 0x7f0c192e2b10>

In [51]:
temp_similarity = analyzer.get_extension("template_similarity").get_data()
max_shift = 10
max_channels = 10

X = []
Y = []
pairs = []
colors = []
text = []

for unit_id1 in sorting.unit_ids[good_units_mask]:
    for unit_id2 in sorting.unit_ids[good_units_mask]:
        if unit_id2 <= unit_id1:
            continue

        unit_index1 = sorting.id_to_index(unit_id1)
        unit_index2 = sorting.id_to_index(unit_id2)
        neuron1 = matched[unit_id1]
        neuron2 = matched[unit_id2]

        template1 = analyzer.get_extension("templates").get_unit_template(unit_id1)
        template2 = analyzer.get_extension("templates").get_unit_template(unit_id2)

        best_channel_indices = np.argsort(np.max(np.abs(template1) + np.abs(template2), axis=0))[::-1][:10]
        
        max_diff = 1
        for shift in range(-max_shift, max_shift+1):
            n = len(template1)
            t1 = template1[max_shift: n-max_shift, best_channel_indices]
            t2 = template2[max_shift+shift: n-max_shift+shift, best_channel_indices]
            diff = np.sum(np.abs(t1 - t2)) / np.sum(np.abs(t1) + np.abs(t2))
            if diff < max_diff:
                max_diff = diff
        Y.append(max_diff)

        X.append(temp_similarity[unit_index1, unit_index2])
        colors.append("Crimson" if neuron1 == neuron2 else "CornflowerBlue")
        text.append(f"Units {unit_id1} - {unit_id2}")
        pairs.append([unit_id1, unit_id2])


fig = go.Figure()

fig.add_trace(go.Scatter(
    x=X,
    y=Y,
    mode="markers",
    marker_color=colors,
    text=text
))

fig.show(renderer="browser")

Opening in existing browser session.


In [52]:
n_pairs = dict(zip(*np.unique(colors, return_counts=True)))
print(f"There are {n_pairs['Crimson']} pairs to merge out of {n_pairs['CornflowerBlue']}")

There are 174 pairs to merge out of 37776


In [102]:
X = np.array(X)
Y = np.array(Y)
text = np.array(text)
colors = np.array(colors)
target = np.array(colors) == "Crimson"
pairs = np.array(pairs)

mask = Y < 0.15
new_target = target[mask]
new_pairs = pairs[mask]
new_X = X[mask]
new_Y = Y[mask]
new_colors = colors[mask]
new_text = text[mask]

print(f"There are {np.sum(new_target)} pairs to merge out of {len(new_target)}")

There are 165 pairs to merge out of 177


In [115]:
CCs = []
p_values = []

for pair in new_pairs:
    unit_id1, unit_id2 = pair
    spike_train1 = np.array(sorting.get_unit_spike_train(unit_id1))
    spike_train2 = np.array(sorting.get_unit_spike_train(unit_id2))

    CC, p_value = utils.estimate_cross_contamination(spike_train1, spike_train2, (0.4, 1.9), limit=0.20)
    CCs.append(CC)
    p_values.append(p_value)

# mask = np.array(CCs) < 0.15
mask = np.array(p_values) > 0.01

newer_target = new_target[mask]
newer_pairs = new_pairs[mask]
newer_X = new_X[mask]
newer_Y = new_Y[mask]
newer_colors = new_colors[mask]
newer_text = new_text[mask]
newer_CCs = np.array(CCs)[mask]

print(f"There are {np.sum(newer_target)} pairs to merge out of {len(newer_target)}")

There are 165 pairs to merge out of 170


In [119]:
distances = []

for pair in newer_pairs:
    unit_id1, unit_id2 = pair
    unit_index1 = analyzer.sorting.id_to_index(unit_id1)
    unit_index2 = analyzer.sorting.id_to_index(unit_id2)
    
    unit_location1 = analyzer.get_extension("unit_locations").get_data()[unit_index1]
    unit_location2 = analyzer.get_extension("unit_locations").get_data()[unit_index2]
    distance = np.sqrt(np.sum((unit_location1 - unit_location2) ** 2))

    distances.append(distance)

distances = np.array(distances)

In [120]:
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=distances,
    y=new_Y,
    mode="markers",
    marker_color=new_colors,
    text=new_text
))

fig.show(renderer="browser")

Opening in existing browser session.


In [73]:
n_seconds = recording.get_num_frames() / recording.sampling_frequency
n_bins = int(n_seconds // 20)

for pair, target in zip(newer_pairs, newer_target):
    unit_id1, unit_id2 = pair
    spike_train1 = np.array(sorting.get_unit_spike_train(unit_id1))
    spike_train2 = np.array(sorting.get_unit_spike_train(unit_id2))
    merged_spike_train = np.sort(np.concatenate((spike_train1, spike_train2)))
    
    hist1, _ = np.histogram(spike_train1, bins=n_bins)
    hist2, _ = np.histogram(spike_train2, bins=n_bins)
    hist_merged, _ = np.histogram(merged_spike_train, bins=n_bins)

    cv1 = np.std(hist1) / np.mean(hist1)
    cv2 = np.std(hist2) / np.mean(hist2)
    cv_merged = np.std(hist_merged) / np.mean(hist_merged)

    metric = max(cv1, cv2) / (cv_merged)

    if metric < 1.1:
        print(target, metric, pair)

True 0.9631350152745939 [103 105]
True 1.0882664476338895 [103 166]
True 0.844491401680653 [103 350]
False 1.032358511431898 [356 408]
True 1.025321956868236 [389 437]


In [78]:
# Function for Pierre

def aurelien_merge(analyzer: si.SortingAnalyzer, refractory_period, template_threshold: float = 0.12, CC_threshold: float = 0.15,
                   max_shift: int = 10, max_channels: int = 10) -> list[tuple]:
    """
    Looks at a sorting analyzer, and returns a list of potential pairwise merges.

    Parameters
    ----------
    analyzer: SortingAnalyzer
        The analyzer to look at
    refractory_period: array/list/tuple of 2 floats
        (censored_period_ms, refractory_period_ms)
    template_threshold: float
        The threshold on the template difference.
        Any pair above this threshold will not be considered.
    CC_treshold: float
        The threshold on the cross-contamination.
        Any pair above this threshold will not be considered.
    max_shift: int
        The maximum shift when comparing the templates (in number of time samples).
    max_channels: int
        The maximum number of channels to consider when comparing the templates.
    """
    
    pairs = []

    for unit_id1 in analyzer.unit_ids:
        for unit_id2 in analyzer.unit_ids:
            if unit_id2 <= unit_id1:
                continue

            # Computing template difference
            template1 = analyzer.get_extension("templates").get_unit_template(unit_id1)
            template2 = analyzer.get_extension("templates").get_unit_template(unit_id2)
    
            best_channel_indices = np.argsort(np.max(np.abs(template1) + np.abs(template2), axis=0))[::-1][:10]
            
            max_diff = 1
            for shift in range(-max_shift, max_shift+1):
                n = len(template1)
                t1 = template1[max_shift: n-max_shift, best_channel_indices]
                t2 = template2[max_shift+shift: n-max_shift+shift, best_channel_indices]
                diff = np.sum(np.abs(t1 - t2)) / np.sum(np.abs(t1) + np.abs(t2))
                if diff < max_diff:
                    max_diff = diff

            if max_diff > template_threshold:
                continue

            # Compuyting the cross-contamination difference
            spike_train1 = np.array(sorting.get_unit_spike_train(unit_id1))
            spike_train2 = np.array(sorting.get_unit_spike_train(unit_id2))
            CC, p_value = utils.estimate_cross_contamination(spike_train1, spike_train2, refractory_period, limit=CC_threshold)

            if p_value < 0.05:
                continue

            pairs.append((unit_id1, unit_id2))

    return pairs

aurelien_merge(analyzer, (0.4, 1.9))

[(1, 13),
 (32, 33),
 (35, 88),
 (36, 94),
 (36, 96),
 (37, 296),
 (38, 300),
 (43, 44),
 (48, 49),
 (51, 306),
 (52, 129),
 (52, 130),
 (53, 313),
 (53, 314),
 (55, 58),
 (55, 59),
 (56, 132),
 (58, 59),
 (63, 64),
 (66, 137),
 (69, 70),
 (69, 71),
 (70, 71),
 (74, 141),
 (74, 378),
 (78, 79),
 (78, 142),
 (78, 144),
 (79, 142),
 (79, 144),
 (87, 388),
 (89, 387),
 (94, 96),
 (100, 101),
 (103, 105),
 (103, 166),
 (103, 350),
 (105, 166),
 (105, 350),
 (115, 117),
 (119, 120),
 (120, 121),
 (122, 123),
 (123, 407),
 (126, 128),
 (126, 308),
 (128, 308),
 (129, 130),
 (132, 133),
 (134, 326),
 (138, 206),
 (138, 208),
 (138, 379),
 (141, 378),
 (142, 144),
 (146, 147),
 (146, 148),
 (147, 148),
 (151, 152),
 (151, 153),
 (152, 153),
 (154, 156),
 (154, 391),
 (156, 297),
 (156, 391),
 (156, 392),
 (158, 159),
 (163, 342),
 (166, 350),
 (168, 171),
 (169, 173),
 (170, 235),
 (170, 449),
 (180, 405),
 (181, 245),
 (181, 450),
 (182, 183),
 (183, 184),
 (187, 188),
 (189, 453),
 (189, 456

In [101]:
agreement_matrix[[374, 376]].loc[['#164', '#1798']]

Unnamed: 0,374,376
#164,0.002065,0.292969
#1798,0.69913,0.001075


In [99]:
agreement_matrix[374].idxmax()

'#1798'

In [100]:
agreement_matrix[376].idxmax()

'#164'