In [7]:
import matplotlib.pyplot as plt
from pprint import pprint
import MEArec as mr
import numpy as np
from cus_sort.custom_sorter import *
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import sys

In [10]:
def test_custom_sorter():
    print("=== Testing Custom SpikeInterface Sorter ===")
    
    # Create a toy recording for testing
    recording, sorting_true = se.toy_example(
        duration=10, 
        num_channels=4, 
        num_units=5,
        num_segments=1,
        seed=42
    )
    
    print(f"Test recording: {recording}")
    print(f"Available sorters: {ss.available_sorters()}")
    
    # Verify your sorter is registered
    if 'myspikesorter' not in ss.available_sorters():
        print("ERROR: myspikesorter not found in available sorters!")
        return
    
    # Run your custom sorter
    print("\nRunning custom sorter...")
    sorting = ss.run_sorter(
        sorter_name='myspikesorter',
        recording=recording,
        output_folder='sorting_output/test_output',
        detect_threshold=5.0,
        clustering_method='kmeans',
        remove_existing_folder=True,
        verbose=True
    )
    
    print(f"\nSorting completed! Found {len(sorting.get_unit_ids())} units")
    print(f"Unit IDs: {sorting.get_unit_ids()}")
    
    # Display spike counts for each unit
    for unit_id in sorting.get_unit_ids():
        spike_train = sorting.get_unit_spike_train(unit_id)
        print(f"Unit {unit_id}: {len(spike_train)} spikes")
    
    print("\n=== Test completed successfully! ===")

In [11]:
test_custom_sorter()

=== Testing Custom SpikeInterface Sorter ===
Test recording: GroundTruthRecording (InjectTemplatesRecording): 4 channels - 30.0kHz - 1 segments 
                      300,000 samples - 10.00s - float32 dtype - 4.58 MiB
Available sorters: ['combinato', 'hdsort', 'herdingspikes', 'ironclust', 'kilosort', 'kilosort2', 'kilosort2_5', 'kilosort3', 'kilosort4', 'klusta', 'mountainsort4', 'mountainsort5', 'myspikesorter', 'pykilosort', 'simple', 'spykingcircus', 'spykingcircus2', 'tridesclous', 'tridesclous2', 'waveclus', 'waveclus_snippets', 'yass']

Running custom sorter...
Generated 1000 spikes across 5 units.
Spike times range: 982.74 to 1799655.44 samples
myspikesorter run time 0.00s

Sorting completed! Found 5 units
Unit IDs: [0 1 2 3 4]
Unit 0: 185 spikes
Unit 1: 201 spikes
Unit 2: 198 spikes
Unit 3: 223 spikes
Unit 4: 193 spikes

=== Test completed successfully! ===


  sorting = ss.run_sorter(
