In [None]:
import spikeinterface as si
import spikeinterface.curation as scur
import spikeinterface.widgets as sw
import numpy as np

import matplotlib.pyplot as plt

%matplotlib widget

In [None]:
seed = 2308

In [None]:
recording, sorting_original = si.generate_ground_truth_recording(durations=[300], num_channels=64, num_units=30, seed=2308)

In [None]:
units_to_split = sorting_original.unit_ids[::3]

In [None]:
partial_split_prob = 0.99
sorting_split = sorting_original.select_units(sorting_original.unit_ids)
split_units = []
original_units = []
for unit in units_to_split:
    num_spikes = len(sorting_split.get_unit_spike_train(unit))
    indices = np.zeros(num_spikes, dtype=int)
    indices[:num_spikes // 2] = (np.random.rand(num_spikes // 2) < partial_split_prob).astype(int)
    indices[num_spikes // 2:] = (np.random.rand(num_spikes - num_spikes // 2) < 1 - partial_split_prob).astype(int)
    sorting_split = scur.split_unit_sorting(sorting_split, split_unit_id=unit, indices_list=indices, properties_policy="remove")
    split_units.append(sorting_split.unit_ids[-2:])
    original_units.append(unit)

In [None]:
analyzer.unit_ids

In [None]:
sw.plot_rasters(sorting_split)

In [None]:
analyzer = si.create_sorting_analyzer(sorting_split, recording)

In [None]:
analyzer_original = si.create_sorting_analyzer(sorting_original, recording)
analyzer_original.compute(["random_spikes", "templates"])

In [None]:
bin_duration_s = 2

def get_wass_distance(analyzer, unit1, unit2, bin_duration_s=2, percentile_norm=90, bins=None):
    if bins is None:
        bin_size = bin_duration_s * analyzer.sampling_frequency
        bins = np.arange(0, analyzer.get_num_samples(), bin_size)

    st1 = analyzer.sorting.get_unit_spike_train(unit_id=unit1)
    st2 = analyzer.sorting.get_unit_spike_train(unit_id=unit2)

    h1, _ = np.histogram(st1, bins)
    h1 = h1.astype(float)
    h1 = h1 / np.percentile(h1, percentile_norm)

    h2, _ = np.histogram(st2, bins)
    h2 = h2.astype(float)
    h2 = h2 / np.percentile(h2, percentile_norm)

    d = wasserstein_distance(h1, h2)
    
    if np.isnan(d):
        d = 0
    
    return d


def get_other_distance(analyzer, unit1, unit2, bin_duration_s=2, percentile_norm=90, bins=None):
    if bins is None:
        bin_size = bin_duration_s * analyzer.sampling_frequency
        bins = np.arange(0, analyzer.get_num_samples(), bin_size)

    st1 = analyzer.sorting.get_unit_spike_train(unit_id=unit1)
    st2 = analyzer.sorting.get_unit_spike_train(unit_id=unit2)

    h1, _ = np.histogram(st1, bins)
    h1 = h1.astype(float)
    h1 = h1 / np.percentile(h1, percentile_norm)

    h2, _ = np.histogram(st2, bins)
    h2 = h2.astype(float)
    h2 = h2 / np.percentile(h2, percentile_norm)

    d = np.sum(np.abs(h1 + h2 - np.ones_like(h1))) / analyzer.get_total_duration()
    
    return d

In [None]:
analyzer.compute(["random_spikes", "templates", "template_similarity", "spike_amplitudes"])

In [None]:
distances = np.zeros((analyzer.get_num_units(), analyzer.get_num_units()))
all_templates = analyzer.get_extension("templates").get_templates()
for i, unit1 in enumerate(analyzer.unit_ids):
    for j, unit2 in enumerate(analyzer.unit_ids):    
        if i != j:
            d = get_other_distance(analyzer, unit1, unit2)
            # print(f"Distance {unit1}-{unit2}: {d}")
            distances[i, j] = d
        else:
            distances[i, j] = 1


In [None]:
for original, split in zip(original_units, split_units):
    unit_indices = analyzer.sorting.ids_to_indices(split)
    print(unit_indices)
    print(f"Units: {original}->{split} - distance: {distances[unit_indices[0], unit_indices[1]]} - similarity: {similarity[unit_indices[0], unit_indices[1]]}")

In [None]:
similarity = analyzer.get_extension("template_similarity").get_data()

In [None]:
distance_thr = distances.copy()
distance_thr = np.triu(distance_thr)
distance_thr[distance_thr == 0] = np.nan
distance_thr[similarity < 0.7] = np.nan
distance_thr[distance_thr > 0.1] = np.nan

In [None]:
plt.figure()
plt.imshow(distance_thr)

In [None]:
np.logical_not(np.isnan(distance_thr))

In [None]:
potential_merges = np.array(np.nonzero(np.logical_not(np.isnan(distance_thr)))).T

In [None]:
potential_merges

In [None]:
simi

In [None]:
analyzer.compute("spike_amplitudes")

In [None]:
sw.plot_unit_templates(analyzer, backend="ipywidgets")

In [None]:
sw.plot_unit_templates(analyzer_original, backend="ipywidgets")

In [None]:
sw.plot_amplitudes(analyzer, backend="ipywidgets")

In [None]:
distances[6, 7]

In [None]:
distances[0, 1]

In [None]:
unit1 = 16
unit2 = 17
percentile_norm = 0.99

In [None]:
st1 = analyzer.sorting.get_unit_spike_train(unit_id=unit1)
st2 = analyzer.sorting.get_unit_spike_train(unit_id=unit2)

h1, _ = np.histogram(st1, bins)
h1 = h1 / np.percentile(h1, percentile_norm)

h2, _ = np.histogram(st2, bins)
h2 = h2 / np.percentile(h2, percentile_norm)

d = wasserstein_distance(h1, h2)
    

In [None]:
bins

In [None]:
h1, _ = np.histogram(st1, bins)

In [None]:
np.percentile(h1, 0.2)

In [None]:
bins = np.arange(0, recording.get_num_samples(), bin_size)

In [None]:
unit16 = sorting_split.get_unit_spike_train(unit_id=16)
unit17 = sorting_split.get_unit_spike_train(unit_id=17)

In [None]:
plt.figure()
plt.plot(b[:-1], h1)
plt.plot(b[:-1], h2)

In [None]:
from scipy.stats import wasserstein_distance

In [None]:
wasserstein_distance(h16, h17)

In [None]:
plt.plot(b[:-1], h16 + h17, lw=3)

In [None]:
np.sum(np.abs(h16 + h17 - np.ones_like(h16))) / rec.get_total_duration()

In [None]:
np.sum(np.abs(h17 - np.ones_like(h16))) / rec.get_total_duration()

In [None]:
np.sum(np.abs(h16 - np.ones_like(h16))) / rec.get_total_duration()

In [None]:
analyzer = si.create_sorting_analyzer(sorting_split, recording)

In [None]:
scur.get_potential_auto_merge

In [None]:
def get_potential_drift_merges(analyzer, )