In [1]:
from pathlib import Path
import spikeinterface.extractors as si
import numpy as np
import pandas as pd
import scipy.io
import os
from numba import njit, prange
import pyarrow as pa
import pyarrow.parquet as pq
import hashlib


In [2]:
import numba

print("Numba threads:", numba.get_num_threads())  # see how many cores it will use
numba.set_num_threads(8)  # for example, if you want to cap or change it

Numba threads: 8


In [3]:
def safe_load(path):
    """Safely load .mat file, returning empty dict if file doesn't exist."""
    return scipy.io.loadmat(path) if Path(path).exists() else {}

def create_spike_summary(sort_path, electrode_num, patient, session):
    positive_file = sort_path / f"5.1/A_ss{electrode_num}_sorted_new.mat"
    negative_file = sort_path / f"5.2/A_ss{electrode_num}_sorted_new.mat"
    final_max_file = sort_path / f"final/A_ss{electrode_num}_Max_sorted_new.mat"
    final_min_file = sort_path / f"final/A_ss{electrode_num}_Min_sorted_new.mat"

    positive_spikes = safe_load(positive_file)
    negative_spikes = safe_load(negative_file)
    final_max_spikes = safe_load(final_max_file)
    final_min_spikes = safe_load(final_min_file)

    spike_summary = {
        "patient": patient,
        "session": session,
        "channel": electrode_num,
        "positive_data": {
            "positive_spikes": positive_spikes.get("useNegative", np.array([])).T,
            "positive_merged": positive_spikes.get("useNegativeMerged", np.array([])),
        },
        "negative_data": {
            "negative_spikes": negative_spikes.get("useNegative", np.array([])).T,
            "negative_merged": negative_spikes.get("useNegativeMerged", np.array([])),
        },
        "final_max": {
            "included": final_max_spikes.get("useNegative", np.array([])),
            "excluded": final_max_spikes.get("useNegativeExcluded", np.array([])).T,
        },
        "final_min": {
            "included": final_min_spikes.get("useNegative", np.array([])),
            "excluded": final_min_spikes.get("useNegativeExcluded", np.array([])).T,
        },
    }

    patient_data = (spike_summary["patient"], spike_summary["session"], spike_summary["channel"])

    good_units = {
        "positive_included": spike_summary["final_max"]["included"].tolist(),
        "negative_included": spike_summary["final_min"]["included"].tolist()
    }

    noise = {
        "positive_noise": [],
        "negative_noise": []
    }

    # Account for case where there is final max but no final min, or vice versa
    if spike_summary["final_max"]["included"].size != 0:
        noise["positive_noise"] = spike_summary["final_max"]["excluded"].tolist()
    else:
        noise["positive_noise"] = spike_summary["positive_data"]["positive_spikes"].tolist()
    
    if spike_summary["final_min"]["included"].size != 0:
        noise["negative_noise"] = spike_summary["final_min"]["excluded"].tolist()
    else:
        noise["negative_noise"] = spike_summary["negative_data"]["negative_spikes"].tolist()

    return patient_data, good_units, noise

In [4]:
@njit(parallel=True)
def _select_waveforms_for_label(labels, waveforms, label):
    n = labels.shape[0]
    if n == 0:
        return np.empty((0, 0), np.float32), 0

    n_t = waveforms.shape[1]

    # compute match mask (serial, cheap)
    mask = np.zeros(n, np.uint8)
    for i in range(n):
        if labels[i] == label and labels[i] != 0:
            mask[i] = 1

    # prefix positions for each matching spike
    pos = np.empty(n, np.int64)
    count = 0
    for i in range(n):
        pos[i] = count
        if mask[i]:
            count += 1

    out = np.empty((count, n_t), np.float32)

    # parallel copy of selected waveforms
    for i in prange(n):
        if mask[i]:
            out[pos[i], :] = waveforms[i]

    return out, count


def create_bag_of_spikes(patient_data,
                         spike_label,
                         alignment,
                         spike_type,
                         labels,
                         waveforms):
    labels = np.asarray(labels, dtype=np.int32)
    waveforms = np.asarray(waveforms, dtype=np.float32)

    if spike_label is None or labels.size == 0 or waveforms.size == 0:
        n_t = waveforms.shape[1] if waveforms.ndim == 2 else 0
        waveforms_for_label = np.empty((0, n_t), dtype=np.float32)
        n_selected = 0
    else:
        waveforms_for_label, n_selected = _select_waveforms_for_label(
            labels, waveforms, int(spike_label)
        )

    spike_dict = {
        "patient_data": patient_data,
        "spike_label": int(spike_label) if spike_label is not None else -1,
        "number_of_spikes": int(n_selected),
        "spike_type": spike_type,
        "alignment": alignment,
        "spikes": waveforms_for_label,
    }

    return spike_dict


def load_spike_arrays(sort_path, electrode_num):
    """Load label and waveform arrays once per channel to avoid repeated .mat I/O."""
    positive_file = sort_path / f"5.1/A_ss{electrode_num}_sorted_new.mat"
    negative_file = sort_path / f"5.2/A_ss{electrode_num}_sorted_new.mat"

    positive_spikes = safe_load(positive_file)
    negative_spikes = safe_load(negative_file)

    pos_labels = np.array(positive_spikes.get("assignedNegative", np.array([]))).ravel()
    pos_waves = np.array(positive_spikes.get("allSpikesCorrFree", np.array([])))

    neg_labels = np.array(negative_spikes.get("assignedNegative", np.array([]))).ravel()
    neg_waves = np.array(negative_spikes.get("allSpikesCorrFree", np.array([])))

    return {
        "positive": (pos_labels, pos_waves),
        "negative": (neg_labels, neg_waves),
    }


In [5]:
def list_subdirectories(dir_path, print_count = True):
    file_list = [name for name in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, name))]
    if print_count == False:
        return file_list
    else:
        print(dir_path, "contains", file_list)
        return file_list


In [6]:
def get_units(good_units: dict,
              noise: dict,
              alignment: str,
              spike_type: str):
    # pick source dict
    src = good_units if spike_type == "good_units" else noise

    # key map - updated to match new snake_case keys
    key_map = {
        ("positive", "good_units"): "positive_included",
        ("negative", "good_units"): "negative_included",
        ("positive", "noise"): "positive_noise",
        ("negative", "noise"): "negative_noise",
    }

    key = key_map.get((alignment, spike_type), None)
    if key is None:
        return []

    arr = src.get(key, [])
    if isinstance(arr, (list, np.ndarray)) and len(arr) > 0:
        return arr[0]
    return []


def create_bag_of_bags(num_channels, patient_path, patient):
    base_path = Path(patient_path).parent
    alignments = ("positive", "negative")
    spike_types = ("noise", "good_units")

    sessions = list_subdirectories(patient_path)

    bag_of_bags = []

    for session_dir in sessions:
        sort_path = base_path / patient / session_dir / "sort"
        for channel in range(int(num_channels)):
            patient_data, good_units, noise = create_spike_summary(
                sort_path=sort_path,
                electrode_num=channel,
                patient=patient,
                session=session_dir,
            )

            # Load spike label/waveform arrays once per channel
            spike_data = load_spike_arrays(sort_path, channel)

            for spike_type in spike_types:
                for alignment in alignments:
                    labels, waveforms = spike_data[alignment]
                    units = np.asarray(
                        get_units(good_units, noise, alignment, spike_type),
                        dtype=np.int32,
                    )
                    for label in units:
                        bag = create_bag_of_spikes(
                            patient_data=patient_data,
                            spike_label=int(label),
                            alignment=alignment,
                            spike_type=spike_type,
                            labels=labels,
                            waveforms=waveforms,
                        )
                        bag_of_bags.append(bag)
                        print(
                            "Neuron: ",
                            bag["spike_label"],
                            " has ",
                            bag["number_of_spikes"],
                            "spikes",
                            f"Session={session_dir} ",
                            f"Channel={channel} ",
                        )
        print()

    return bag_of_bags        

In [14]:
def count_channels(patient_path):
    subdirectories = list_subdirectories(patient_path)
    raw_path = patient_path / subdirectories[0] / "raw/micros"
    # Count all files in raw_path that end with .ncs
    ncs_count = sum(1 for file in raw_path.iterdir() if file.is_file() and file.suffix == ".ncs")
    print(f"Number of .ncs files in {raw_path}: {ncs_count}")
    return str(ncs_count)

In [8]:
def make_unit_hash(patient: str, session: str, label_id: int) -> str:
    key = f"{patient}|{session}|{label_id}".encode("utf-8")
    # 4 bytes â†’ 8 hex characters
    return hashlib.blake2b(key, digest_size=4).hexdigest()

In [9]:
def bags_and_spikes_from_list(bags):
    bag_rows = []
    spike_rows = []

    for bag_id, bag in enumerate(bags):
        patient, session, channel = bag["patient_data"]
        label_id = int(bag["spike_label"])
        spike_type = bag["spike_type"]
        alignment = bag["alignment"]
        waves = np.asarray(bag["spikes"], dtype="float32")
        n_spikes = waves.shape[0]
        mean_wave = waves.mean(axis=0) if n_spikes > 0 else np.zeros(256, np.float32)

        unit_id = make_unit_hash(patient, session, label_id)

        bag_rows.append(
            (bag_id, patient, session, int(channel), label_id,
             spike_type, alignment, n_spikes, mean_wave, unit_id)
        )

        for i in range(n_spikes):
            spike_rows.append((bag_id, i, waves[i], unit_id))  # optional: attach unit_id to each spike
    return bag_rows, spike_rows



In [32]:
def write_to_parquat(bag_rows, spike_rows, patient_path, write_df=False):
    bag_ids, patients, sessions, channels, labels, types, aligns, n_spikes, means, unit_ids = zip(*bag_rows)

    bags_table = pa.table({
        "bag_id":      pa.array(bag_ids, type=pa.int32()),
        "patient":     pa.array(patients),
        "session":     pa.array(sessions),
        "channel":     pa.array(channels, type=pa.int32()),
        "spike_label": pa.array(labels, type=pa.int32()),
        "spike_type":  pa.array(types),
        "alignment":   pa.array(aligns),
        "n_spikes":    pa.array(n_spikes, type=pa.int32()),
        "mean_waveform": pa.FixedSizeListArray.from_arrays(
            pa.array(np.stack(means).reshape(-1)), 256
        ),
        "unit_id": pa.array(unit_ids),
    })

    bag_id_s, spike_idxs, wf_list, spike_unit_ids = zip(*spike_rows)
    wf_stack = np.stack(wf_list).astype("float32")

    spikes_table = pa.table({
        "bag_id":   pa.array(bag_id_s, type=pa.int32()),
        "spike_idx": pa.array(spike_idxs, type=pa.int32()),
        "waveform": pa.FixedSizeListArray.from_arrays(
            pa.array(wf_stack.reshape(-1)), 256
        ),
        "unit_id": pa.array(spike_unit_ids),
    })

    bags_table_path = Path(patient_path) / "spike_bags.parquet"
    spikes_table_path = Path(patient_path) / "spikes.parquet"

    pq.write_table(bags_table, bags_table_path)
    pq.write_table(spikes_table, spikes_table_path)

    if write_df:
        bags_df = pd.read_parquet(bags_table_path)
        spikes_df = pd.read_parquet(spikes_table_path)
        return bags_df, spikes_df

In [33]:
def main(patient):
    base_path = "/Users/marco/Local_Sorting"
    # patient = "DA047 Sorted"
    patient_path = Path(base_path) / patient
    num_channels = count_channels(patient_path)
    A047_Bag = create_bag_of_bags(num_channels, patient_path, patient)
    bag_rows, spike_rows = bags_and_spikes_from_list(A047_Bag)
    write_to_parquat(bag_rows, spike_rows, patient_path)

In [None]:
sorted_paths = ["DA014", "DA047 Sorted", "DA050 Sorted", "DA054 Sorted", "DA056 Sorted", "DA057 Sorted", "DA061"]

for path in sorted_paths:
    main(path)

/Users/marco/Local_Sorting/DA014 contains ['10-03', '09-26', '09-28', '09-29', '09-27']
Number of .ncs files in /Users/marco/Local_Sorting/DA014/10-03/raw/micros: 24
/Users/marco/Local_Sorting/DA014 contains ['10-03', '09-26', '09-28', '09-29', '09-27']
Neuron:  934  has  75 spikes Session=10-03  Channel=1 
Neuron:  1020  has  1672 spikes Session=10-03  Channel=1 
Neuron:  877  has  103 spikes Session=10-03  Channel=1 
Neuron:  963  has  125 spikes Session=10-03  Channel=1 
Neuron:  976  has  1143 spikes Session=10-03  Channel=1 
Neuron:  785  has  61 spikes Session=10-03  Channel=2 
Neuron:  942  has  276 spikes Session=10-03  Channel=2 
Neuron:  1168  has  170 spikes Session=10-03  Channel=2 
Neuron:  1224  has  331 spikes Session=10-03  Channel=2 
Neuron:  1264  has  202 spikes Session=10-03  Channel=2 
Neuron:  954  has  70 spikes Session=10-03  Channel=2 
Neuron:  1263  has  74 spikes Session=10-03  Channel=2 
Neuron:  546  has  82 spikes Session=10-03  Channel=2 
Neuron:  1232  h

In [41]:
DA061 = pd.read_parquet("/Users/marco/Local_Sorting/DA061/spike_bags.parquet")

In [47]:
DA061


Unnamed: 0,bag_id,patient,session,channel,spike_label,spike_type,alignment,n_spikes,mean_waveform,unit_id
0,0,DA061,8-27,1,1253,noise,positive,67,"[0.53085613, 0.53085613, 0.53085613, 0.5308561...",40b44db5
1,1,DA061,8-27,1,1380,noise,positive,1726,"[0.040133964, 0.03995413, 0.03972575, 0.039479...",0f4ce97f
2,2,DA061,8-27,1,990,noise,negative,60,"[-0.46930152, -0.46930152, -0.46930152, -0.469...",eccffffb
3,3,DA061,8-27,1,1101,noise,negative,222,"[-0.034568544, -0.035344493, -0.036510855, -0....",c5c87b35
4,4,DA061,8-27,1,1218,noise,negative,1177,"[0.025253, 0.025265744, 0.025265457, 0.0253143...",890ba89c
...,...,...,...,...,...,...,...,...,...,...
259,259,DA061,8-28 No Neurons,7,2165,noise,positive,428,"[0.018016724, 0.017854564, 0.018029535, 0.0183...",567abba1
260,260,DA061,8-28 No Neurons,7,2389,noise,positive,541,"[-0.07594722, -0.07594722, -0.07594722, -0.075...",90a263b6
261,261,DA061,8-28 No Neurons,7,1396,noise,negative,62,"[-0.21880156, -0.21661927, -0.21449585, -0.212...",7c56603d
262,262,DA061,8-28 No Neurons,7,2081,noise,negative,360,"[0.031126153, 0.031616647, 0.032165922, 0.0328...",57953849


In [1]:
import pyarrow.dataset as ds
import pyarrow.parquet as pq

bag_paths = [
    "/Users/marco/Local_Sorting/DA014/spike_bags.parquet",
    "/Users/marco/Local_Sorting/DA047 Sorted/spike_bags.parquet",
    "/Users/marco/Local_Sorting/DA050 Sorted/spike_bags.parquet",
    "/Users/marco/Local_Sorting/DA054 Sorted/spike_bags.parquet",
    "/Users/marco/Local_Sorting/DA056 Sorted/spike_bags.parquet",
    "/Users/marco/Local_Sorting/DA057 Sorted/spike_bags.parquet",
    "/Users/marco/Local_Sorting/DA061/spike_bags.parquet"
]

spikes_paths = [
    "/Users/marco/Local_Sorting/DA014/spikes.parquet",
    "/Users/marco/Local_Sorting/DA047 Sorted/spikes.parquet",
    "/Users/marco/Local_Sorting/DA050 Sorted/spikes.parquet",
    "/Users/marco/Local_Sorting/DA054 Sorted/spikes.parquet",
    "/Users/marco/Local_Sorting/DA056 Sorted/spikes.parquet",
    "/Users/marco/Local_Sorting/DA057 Sorted/spikes.parquet",
    "/Users/marco/Local_Sorting/DA061/spikes.parquet"
]


bag_dataset = ds.dataset(bag_paths, format="parquet")
bag_table = bag_dataset.to_table()
pq.write_table(bag_table, "/Users/marco/Local_Sorting/Classifier/data/bags_merged.parquet")

In [2]:
spike_dataset = ds.dataset(spikes_paths, format="parquet")
spike_table = spike_dataset.to_table()
pq.write_table(spike_table, "/Users/marco/Local_Sorting/Classifier/data/spikes_merged.parquet")