In [56]:
from pathlib import Path
import spikeinterface.extractors as si
import numpy as np
import scipy.io

In [57]:
def safe_load(path):
    """Safely load .mat file, returning empty dict if file doesn't exist."""
    return scipy.io.loadmat(path) if Path(path).exists() else {}

def create_spike_summary(sort_path, electrode_num, patient, session):
    positive_file = sort_path / f"5.1/A_ss{electrode_num}_sorted_new.mat"
    negative_file = sort_path / f"5.2/A_ss{electrode_num}_sorted_new.mat"
    final_max_file = sort_path / f"final/A_ss{electrode_num}_Max_sorted_new.mat"
    final_min_file = sort_path / f"final/A_ss{electrode_num}_Min_sorted_new.mat"

    positive_spikes = safe_load(positive_file)
    negative_spikes = safe_load(negative_file)
    final_max_spikes = safe_load(final_max_file)
    final_min_spikes = safe_load(final_min_file)

    spike_summary = {
        "patient": patient,
        "session": session,
        "channel": electrode_num,
        "positive_data": {
            "positive_spikes": positive_spikes.get("useNegative", np.array([])).T,
            "positive_merged": positive_spikes.get("useNegativeMerged", np.array([])),
        },
        "negative_data": {
            "negative_spikes": negative_spikes.get("useNegative", np.array([])).T,
            "negative_merged": negative_spikes.get("useNegativeMerged", np.array([])),
        },
        "final_max": {
            "included": final_max_spikes.get("useNegative", np.array([])),
            "excluded": final_max_spikes.get("useNegativeExcluded", np.array([])).T,
        },
        "final_min": {
            "included": final_min_spikes.get("useNegative", np.array([])),
            "excluded": final_min_spikes.get("useNegativeExcluded", np.array([])).T,
        },
    }

    patient_data = (spike_summary["patient"], spike_summary["session"], spike_summary["channel"])

    good_units = {
        "positive_included": spike_summary["final_max"]["included"].tolist(),
        "negative_included": spike_summary["final_min"]["included"].tolist()
    }

    noise = {
        "positive_noise": [],
        "negative_noise": []
    }

    # Account for case where there is final max but no final min, or vice versa
    if spike_summary["final_max"]["included"].size != 0:
        noise["positive_noise"] = spike_summary["final_max"]["excluded"].tolist()
    else:
        noise["positive_noise"] = spike_summary["positive_data"]["positive_spikes"].tolist()
    
    if spike_summary["final_min"]["included"].size != 0:
        noise["negative_noise"] = spike_summary["final_min"]["excluded"].tolist()
    else:
        noise["negative_noise"] = spike_summary["negative_data"]["negative_spikes"].tolist()

    return patient_data, good_units, noise

In [58]:
def create_bag_of_spikes(patient_data, good_units, noise,
                         sort_path, electrode_num, spike_num,
                         alignment="positive", spike_type="noise"):

    positive_file = sort_path / f"5.1/A_ss{electrode_num}_sorted_new.mat"
    negative_file = sort_path / f"5.2/A_ss{electrode_num}_sorted_new.mat"

    positive_spikes = safe_load(positive_file)
    negative_spikes = safe_load(negative_file)

    # Determine which data source to use based on alignment
    spikes_data = positive_spikes if alignment == "positive" else negative_spikes
    
    # Extract labels and waveforms
    labels = np.array(spikes_data.get("assignedNegative", np.array([]))).ravel()
    waveforms = np.array(spikes_data.get("allSpikesCorrFree", np.array([])))

    # Get label based on spike_type and alignment
    label = None
    if spike_type == "noise":
        noise_key = "positive_noise" if alignment == "positive" else "negative_noise"
        if len(noise.get(noise_key, [])) > 0:
            label = int(noise[noise_key][0][spike_num])
    elif spike_type == "good_units":
        units_key = "positive_included" if alignment == "positive" else "negative_included"
        if len(good_units.get(units_key, [])) > 0:
            label = int(good_units[units_key][0][spike_num])

    # Extract waveforms for the selected label
    if label is None or labels.size == 0:
        waveforms_for_label = np.array([])
        n_selected = 0
    else:
        mask = (labels == label) & (labels != 0)
        waveforms_for_label = waveforms[mask] if waveforms.size else np.array([])
        n_selected = int(mask.sum())

    spike_dict = {
        "patient_data": patient_data,
        "spike_label": label,
        "number_of_spikes": n_selected,
        "spike_type": spike_type,
        "alignment": alignment,
        "spikes": waveforms_for_label,
    }

    return spike_dict


In [59]:
base_path = Path("/Users/marco/Local_Sorting")

patient = "DA014"
session = "10-03"

sort_path = base_path / patient / session / "sort"

electrode_num = 5
spike_num = 0
alignment = "positive"
spike_type = "good_units"

patient_data, good_units, noise = create_spike_summary(
    sort_path=sort_path, 
    electrode_num=electrode_num, 
    patient=patient, 
    session=session
)

bag_1 = create_bag_of_spikes(
    good_units=good_units,
    sort_path=sort_path,
    spike_type=spike_type,
    noise=noise,
    electrode_num=electrode_num, 
    spike_num=spike_num, 
    patient_data=patient_data, 
    alignment=alignment
)

print(good_units)
print()
print(noise)
bag_1

{'positive_included': [[943]], 'negative_included': []}

{'positive_noise': [[1291, 1341, 1440, 1461, 1502, 1526, 1539, 1552]], 'negative_noise': [[336, 424, 1346, 1007, 1189, 1402, 996, 1376, 1304, 1018, 1513, 1519, 1484]]}


{'patient_data': ('DA014', '10-03', 5),
 'spike_label': 943,
 'number_of_spikes': 571,
 'spike_type': 'good_units',
 'alignment': 'positive',
 'spikes': array([[ 0.36140452,  0.36140452,  0.36140452, ..., -0.06617414,
          0.03611035,  0.15354787],
        [ 0.25503776,  0.25503776,  0.25503776, ..., -0.47555287,
         -0.39148501, -0.2468079 ],
        [ 0.19377394,  0.19377394,  0.19377394, ..., -0.83741608,
         -0.63000237, -0.45082429],
        ...,
        [-0.19905316, -0.19905316, -0.19905316, ...,  0.86036245,
          0.96206012,  1.03951206],
        [-0.081159  , -0.081159  , -0.081159  , ..., -0.71670151,
         -0.64508366, -0.6218226 ],
        [-0.48106327, -0.48106327, -0.48106327, ..., -0.50197335,
         -0.35344196, -0.27902289]])}

positive noise - Works!

negative noise - Works!

positive good - Works!

negative good - Works! 

In [63]:
import os

def list_subdirectories(dir_path):
    file_list = [name for name in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, name))]
    print(dir_path, "contains", file_list)
    return file_list

dir_path = "/Users/marco/Local_Sorting/DA047 Sorted"

list_subdirectories(dir_path)

/Users/marco/Local_Sorting/DA047 Sorted contains ['12-15', '12-13', '12-16']


['12-15', '12-13', '12-16']

In [67]:
def get_units(good_units: dict,
              noise: dict,
              alignment: str,
              spike_type: str):
    # pick source dict
    src = good_units if spike_type == "good_units" else noise

    # key map - updated to match new snake_case keys
    key_map = {
        ("positive", "good_units"): "positive_included",
        ("negative", "good_units"): "negative_included",
        ("positive", "noise"): "positive_noise",
        ("negative", "noise"): "negative_noise",
    }

    key = key_map.get((alignment, spike_type), None)
    if key is None:
        return []

    arr = src.get(key, [])
    if isinstance(arr, (list, np.ndarray)) and len(arr) > 0:
        return arr[0]
    return []


def create_bag_of_bags(num_channels, patient_path, patient):
    base_path = Path(patient_path).parent
    alignments = ("positive", "negative")
    spike_types = ("noise", "good_units")

    sessions = list_subdirectories(patient_path)

    bag_of_bags = []

    for session_dir in sessions:
        sort_path = base_path / patient / session_dir / "sort"
        for channel in range(int(num_channels)):
            patient_data, good_units, noise = create_spike_summary(
                sort_path=sort_path,
                electrode_num=channel,
                patient=patient,
                session=session_dir
            )
            for spike_type in spike_types:
                for alignment in alignments:
                    units = get_units(good_units, noise, alignment, spike_type)
                    for unit_idx in range(len(units)):
                        bag = create_bag_of_spikes(
                            patient_data=patient_data,
                            good_units=good_units,
                            noise=noise,
                            sort_path=sort_path,
                            electrode_num=channel, 
                            spike_num=unit_idx,
                            alignment=alignment,
                            spike_type=spike_type,
                        )
                        bag_of_bags.append(bag)
                        print(
                            f"Session={session_dir} "
                            f"Channel={channel} "
                            f"Kind={spike_type} "
                            f"Align={alignment} "
                            f"Unit={unit_idx} done"
                        )
        print()

    return bag_of_bags        

In [68]:
num_channels = 3
patient = "DA047 Sorted"
patient_path = "/Users/marco/Local_Sorting/DA047 Sorted"

A047_Bag = create_bag_of_bags(num_channels, patient_path, patient)

/Users/marco/Local_Sorting/DA047 Sorted contains ['12-15', '12-13', '12-16']
Session=12-15 Channel=1 Kind=noise Align=positive Unit=0 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=1 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=2 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=3 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=4 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=5 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=6 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=7 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=8 done
Session=12-15 Channel=1 Kind=noise Align=negative Unit=0 done
Session=12-15 Channel=1 Kind=noise Align=negative Unit=1 done
Session=12-15 Channel=1 Kind=noise Align=negative Unit=2 done
Session=12-15 Channel=1 Kind=noise Align=negative Unit=3 done
Session=12-15 Channel=1 Kind=noise Align=negative Unit=4 done
Session=12-15 Channel=2 Kind=noise Align=positive Unit=

In [69]:
A047_Bag

[{'patient_data': ('DA047 Sorted', '12-15', 1),
  'spike_label': 1022,
  'number_of_spikes': 63,
  'spike_type': 'noise',
  'alignment': 'positive',
  'spikes': array([[-0.00198035, -0.00198035, -0.00198035, ...,  0.48905609,
           0.80071638,  1.07602585],
         [ 0.03300863,  0.03300863,  0.03300863, ...,  0.34135452,
           0.32218732,  0.37301812],
         [ 1.13324126,  1.13324126,  1.13324126, ..., -0.32588126,
          -0.33251722, -0.12589569],
         ...,
         [ 0.18445288,  0.18445288,  0.18445288, ..., -0.0368876 ,
          -0.09347259, -0.12625635],
         [-0.9313494 , -0.9313494 , -0.9313494 , ...,  1.41648887,
           1.43442458,  1.40231297],
         [ 0.59270997,  0.59270997,  0.59270997, ..., -0.17114422,
          -0.1932632 , -0.21598504]])},
 {'patient_data': ('DA047 Sorted', '12-15', 1),
  'spike_label': 1023,
  'number_of_spikes': 69,
  'spike_type': 'noise',
  'alignment': 'positive',
  'spikes': array([[ 0.17733422,  0.13761052,  0.13