In [1]:
import numpy as np
from fastdtw import fastdtw
from spikeinterface import extractors as se
from spikeinterface.preprocessing import bandpass_filter
from spikeinterface.comparison import compare_sorter_to_ground_truth

In [2]:
# 1. Load MEArec simulated data
recording, sorting_true = se.read_mearec("recordings.h5")

In [3]:
# 2. Preprocessing
recording_filt = bandpass_filter(recording, freq_min=300, freq_max=6000)

In [5]:
# 3. Spike detection (simplified threshold-based)
def detect_spikes(recording, threshold=5.0):
    traces = recording.get_traces()
    std_noise = np.median(np.abs(traces)) / 0.6745  # MAD to SD conversion
    spike_frames = np.where(np.abs(traces) > threshold * std_noise)[0]
    return spike_frames

spike_times = detect_spikes(recording_filt)

In [7]:
# 4. Extract spike waveforms (fixed snippet extraction)
spike_waveforms = []
n_samples_pre = 15  # 0.5 ms pre-peak @30kHz
n_samples_post = 25  # 0.83 ms post-peak

for t in spike_times:
    start = max(0, t - n_samples_pre)
    end = t + n_samples_post
    snippet = recording_filt.get_traces(
        start_frame=start,
        end_frame=end
    )
    # Handle boundary cases with zero-padding
    if snippet.shape[0] < (n_samples_pre + n_samples_post):
        pad_before = n_samples_pre - (t - start)
        pad_after = (n_samples_pre + n_samples_post) - snippet.shape[0] - pad_before
        snippet = np.pad(
            snippet,
            ((pad_before, pad_after), (0, 0)),
            mode='constant'
        )
    spike_waveforms.append(snippet)

spike_waveforms = np.array(spike_waveforms)

In [13]:
# 5. Get ground truth templates with proper unit ID handling
template_waveforms = []
channel_locations = recording.get_channel_locations()

# Check if templates are available in recording properties
templates = recording.get_property("templates")
if templates is None:
    # Alternative: get templates directly from sorting_true if available
    print("Templates not found in recording properties. Extracting from spike trains instead.")
    
    # Extract waveforms for each unit
    for unit_id in sorting_true.unit_ids:
        # Get spike times for this unit
        unit_spike_frames = sorting_true.get_unit_spike_train(unit_id)
        
        if len(unit_spike_frames) == 0:
            continue
            
        # Extract a few waveforms and average them to create a template
        sample_spikes = unit_spike_frames[:min(10, len(unit_spike_frames))]
        waveforms = []
        
        for spike_frame in sample_spikes:
            if spike_frame < n_samples_pre or spike_frame + n_samples_post >= recording_filt.get_num_samples():
                continue
                
            wf = recording_filt.get_traces(
                start_frame=spike_frame - n_samples_pre,
                end_frame=spike_frame + n_samples_post
            )
            waveforms.append(wf)
        
        if waveforms:
            avg_waveform = np.mean(waveforms, axis=0)
            # Find best channel
            best_chan = np.argmax(np.max(np.abs(avg_waveform), axis=0))
            template_waveforms.append(avg_waveform[:, best_chan])
else:
    for unit_id in sorting_true.unit_ids:
        # Convert string unit ID to integer index (remove '#' prefix)
        unit_index = int(unit_id[1:]) if unit_id.startswith('#') else int(unit_id)
        
        # MEArec stores templates in recording properties
        unit_template = templates[unit_index]
    
    # Find best channel: peak absolute amplitude
    peak_amps = np.max(np.abs(unit_template), axis=1)
    best_chan = np.argmax(peak_amps)
    
    # Store template from best channel
    template_waveforms.append(unit_template[best_chan])

# Alternative: Handle both string and integer unit IDs robustly
def parse_unit_id(unit_id):
    """Convert unit ID to integer index, handling both '#0' and '0' formats"""
    if isinstance(unit_id, str):
        return int(unit_id.lstrip('#'))
    return int(unit_id)


# 6. Extract spike waveforms from best channel
spike_waveforms = []
n_pre = 15  # 0.5 ms @30kHz
n_post = 25  # 0.83 ms @30kHz

for t in spike_times:
    # Get multi-channel snippet
    snippet = recording_filt.get_traces(
        start_frame=max(0, t - n_pre),
        end_frame=t + n_post
    )
    
    # Pad if near boundaries
    if snippet.shape[0] < (n_pre + n_post):
        pad_before = n_pre - (t - max(0, t - n_pre))
        pad_after = (n_pre + n_post) - snippet.shape[0] - pad_before
        snippet = np.pad(snippet, ((pad_before, pad_after), (0, 0)))
    
    # Find best channel for this spike (max absolute amplitude)
    spike_peak_chan = np.argmax(np.max(np.abs(snippet), axis=0))
    spike_waveforms.append(snippet[:, spike_peak_chan])

spike_waveforms = np.array(spike_waveforms)



Templates not found in recording properties. Extracting from spike trains instead.


In [15]:

# 6. fastDTW classification (vectorized)
from scipy.spatial.distance import euclidean

unit_distances = np.zeros((len(spike_waveforms), len(template_waveforms)))

for i, spike in enumerate(spike_waveforms):
    for j, template in enumerate(template_waveforms):
        distance, _ = fastdtw(spike, template, dist=euclidean)  # Both are 1D arrays
        unit_distances[i,j] = distance

assigned_units = np.argmin(unit_distances, axis=1)

# 7. Create output sorting
sorting_fastdtw = se.NumpySorting.from_times_labels(
    spike_times,
    assigned_units,
    sampling_frequency=recording_filt.sampling_frequency
)


ValueError: Input vector should be 1-D.

In [23]:
import numpy as np
from fastdtw import fastdtw
import spikeinterface.extractors as se
from spikeinterface.preprocessing import bandpass_filter
import h5py

# 1. Load MEArec data
recording, sorting_true = se.read_mearec("recordings.h5")

# 2. Preprocessing
recording_filt = bandpass_filter(recording, freq_min=300, freq_max=6000)

# 3. Load templates and inspect shape
with h5py.File("recordings.h5", 'r') as f:
    templates = f['templates'][:]
    template_ids = f['template_ids'][:]
    
print(f"Templates shape: {templates.shape}")

# 4. Spike detection
def detect_spikes(recording, threshold=4.5):
    traces = recording.get_traces()
    std_noise = np.median(np.abs(traces)) / 0.6745
    spike_frames = np.where(np.abs(traces) > threshold * std_noise)[0]
    return spike_frames

spike_times = detect_spikes(recording_filt)

# 5. Extract spike waveforms
spike_waveforms = []
n_pre, n_post = 15, 25

for t in spike_times:
    snippet = recording_filt.get_traces(
        start_frame=max(0, t - n_pre),
        end_frame=t + n_post
    )
    if snippet.shape[0] < (n_pre + n_post):
        pad_before = n_pre - (t - max(0, t - n_pre))
        pad_after = (n_pre + n_post) - snippet.shape[0] - pad_before
        snippet = np.pad(snippet, ((pad_before, pad_after), (0, 0)))
    
    spike_peak_chan = np.argmax(np.max(np.abs(snippet), axis=0))
    spike_waveforms.append(snippet[:, spike_peak_chan])

spike_waveforms = np.array(spike_waveforms)

# 6. Extract templates with CORRECTED dimension handling
template_waveforms = []
for i, unit_id in enumerate(sorting_true.unit_ids):
    template_idx = int(unit_id.lstrip('#'))
    unit_template = templates[template_idx]  # Shape: (n_jitters, n_channels, n_samples)
    
    print(f"Unit {unit_id}: template shape = {unit_template.shape}")
    
    # CORRECTED: Handle 3D template structure (jitters, channels, samples)
    # Shape is (10, 4, 416) = (n_jitters, n_channels, n_samples)
    n_jitters, n_channels, n_samples = unit_template.shape
    
    # Find best channel by max amplitude over jitters and samples
    max_per_channel = np.max(np.abs(unit_template), axis=(0, 2))  # Max over jitters and samples
    best_chan = np.argmax(max_per_channel)  # Best channel index (0-3)
    
    print(f"Best channel: {best_chan} out of {n_channels} channels")
    
    # Extract template from best channel using median over jitters
    template_waveform = np.median(unit_template[:, best_chan, :], axis=0)  # Shape: (n_samples,)
    template_waveforms.append(template_waveform)

print(f"Extracted {len(template_waveforms)} templates")

# 7. fastDTW classification
unit_distances = np.zeros((len(spike_waveforms), len(template_waveforms)))

for i, spike in enumerate(spike_waveforms):
    for j, template in enumerate(template_waveforms):
        distance, _ = fastdtw(spike, template, dist=2)
        unit_distances[i,j] = distance

assigned_units = np.argmin(unit_distances, axis=1)

print(f"Classified {len(spike_times)} spikes into {len(template_waveforms)} units")


Templates shape: (5, 10, 4, 416)
Unit #0: template shape = (10, 4, 416)
Best channel: 3 out of 4 channels
Unit #1: template shape = (10, 4, 416)
Best channel: 3 out of 4 channels
Unit #2: template shape = (10, 4, 416)
Best channel: 2 out of 4 channels
Unit #3: template shape = (10, 4, 416)
Best channel: 3 out of 4 channels
Unit #4: template shape = (10, 4, 416)
Best channel: 2 out of 4 channels
Extracted 5 templates


KeyboardInterrupt: 

In [20]:
import spikeinterface.extractors as se
import h5py

# 1. Load the recording and check what's available
recording, sorting_true = se.read_mearec("recordings.h5")


# 2. Check recording properties
print("Recording properties:")
print(recording.get_property_keys())

# 3. Check if templates are stored in the recording
if 'templates' in recording.get_property_keys():
    templates = recording.get_property('templates')
    print(f"Templates shape: {templates.shape}")
else:
    print("No templates found in recording properties")

# 4. Inspect the H5 file directly
with h5py.File("recordings.h5", 'r') as f:
    print("H5 file structure:")
    def print_structure(name, obj):
        print(name)
    f.visititems(print_structure)

# 5. Check sorting extractor properties
print("Sorting properties:")
print(sorting_true.get_property_keys())

Recording properties:
['gain_to_uV', 'offset_to_uV', 'physical_unit', 'gain_to_physical_unit', 'offset_to_physical_unit', 'channel_name', 'contact_vector', 'location', 'group']
No templates found in recording properties
H5 file structure:
channel_positions
info
info/cell_types
info/cell_types/excitatory
info/cell_types/inhibitory
info/electrodes
info/electrodes/description
info/electrodes/dim
info/electrodes/electrode_name
info/electrodes/pitch
info/electrodes/plane
info/electrodes/shape
info/electrodes/size
info/electrodes/sortlist
info/electrodes/type
info/recordings
info/recordings/adc_bit_depth
info/recordings/angle_tol
info/recordings/bursting
info/recordings/bursting_units
info/recordings/chunk_duration
info/recordings/color_noise_floor
info/recordings/color_peak
info/recordings/color_q
info/recordings/drift_fs
info/recordings/drift_mode_probe
info/recordings/drift_mode_speed
info/recordings/drifting
info/recordings/dtype
info/recordings/duration
info/recordings/exp_decay
info/re