In [None]:
from pathlib import Path
import spikeinterface.extractors as si

In [3]:
patient = "DA047 Sorted"
session = "12-13"

base_path = Path("/Users/marco/Local_Sorting")
# session_path = base_path / patient / session
sort_path = session_path / "sort"
raw_path = session_path / "raw/micros"

# positive_dir = sort_path / "5.1"
# negative_dir = sort_path / "5.2"
# final_dir = sort_path / "final"

# raw_file = si.read_neuralynx(raw_path)
# num_channels = len(raw_file.channel_ids)
# print(num_channels)

16


In [191]:
# Written by Marco

import numpy as np
import scipy.io
from pathlib import Path


# Load spike output from .mat file, or empty dictionary if empty

def safe_load(path = sort_path):
    """Load .mat if it exists, otherwise return empty dict."""
    return scipy.io.loadmat(path) if Path(path).exists() else {}


def create_spike_summary(sort_path = sort_path, electrode_num=1, patient=patient, session=session):
    positive_file = sort_path / f"5.1/A_ss{electrode_num}_sorted_new.mat"
    negative_file = sort_path / f"5.2/A_ss{electrode_num}_sorted_new.mat"
    final_max = sort_path / f"final/A_ss{electrode_num}_Max_sorted_new.mat"
    final_min = sort_path / f"final/A_ss{electrode_num}_Min_sorted_new.mat"

    positive_spikes = safe_load(positive_file)
    negative_spikes = safe_load(negative_file)
    final_max_spikes = safe_load(final_max)
    final_min_spikes = safe_load(final_min)

    # Compile spike sorting data into single dictionary
    spike_summary = {
        "Patient": f"{patient}",
        "Session": f"{session}",
        "Channel": f"{electrode_num}",
        "Positive Data": {
            "Positive Spikes": positive_spikes.get("useNegative", np.array([])).T,
            "Positive Merged": positive_spikes.get("useNegativeMerged", np.array([])),
        },
        "Negative Data": {
            "Negative Spikes": negative_spikes.get("useNegative", np.array([])).T,
            "NegativeMerged": negative_spikes.get("useNegativeMerged", np.array([])),
        },
        "Final Max": {
            "included": final_max_spikes.get("useNegative", np.array([])),
            "excluded": final_max_spikes.get("useNegativeExcluded", np.array([])).T,
        },
        "Final Min": {
            "included": final_min_spikes.get("useNegative", np.array([])),
            "excluded": final_min_spikes.get("useNegativeExcluded", np.array([])).T,
        },
    }
     #Create a list that includes patient data
    Patient_Data = (spike_summary["Patient"], spike_summary["Session"], spike_summary["Channel"])

    # Make a python dictionary of good units
    good_units = {
        # "info": Patient_Data,
        "Positive Included": spike_summary["Final Max"]["included"].tolist(),
        "Negative Included": spike_summary["Final Min"]["included"].tolist()
    }

    # by default include the spikes listed in the final directory
    if spike_summary["Final Max"]["included"].size != 0 or spike_summary["Final Min"]["included"].size != 0:
        noise = {
            # "info": Patient_Data,
            "Positive Noise": spike_summary["Final Max"]["excluded"].tolist() ,
            "Negative Noise": spike_summary["Final Min"]["excluded"].tolist(),
        }
    # if the final directory is empty, then load the spikes identified in 5.1 / 5.2, and in that case 
    else:
        noise = {
            # "info": Patient_Data,
            "Positive Noise": spike_summary["Positive Data"]["Positive Spikes"].tolist(),
            "Negative Noise": spike_summary["Negative Data"]["Negative Spikes"].tolist()
        }
        
    return Patient_Data, good_units, noise

patient_data, good_units, noise = create_spike_summary(electrode_num=3, patient = "DA047 Sorted", session = "12-13")
print("patient_data", patient_data)
print("good units: ", good_units)
print("noise: ", noise)



patient_data ('DA047 Sorted', '12-13', '3')
good units:  {'Positive Included': [[3634, 3804, 4042, 4093, 4143, 4158]], 'Negative Included': [[5007, 5181, 5261, 5364, 5494]]}
noise:  {'Positive Noise': [[3602, 3677, 3775, 3940, 4074]], 'Negative Noise': [[1649, 1650, 2115, 2212, 2433, 4162, 4568, 4585, 4726, 5190, 5253, 5376, 5389, 5460, 5485]]}


In [None]:
i = 0
def create_bag_of_spikes(i):
    
    # target label (e.g., from your good_units)
    lbl = int(good_units["Positive Included"][0][i])

    # flatten labels safely
    labels = np.array(positive_spikes.get("assignedNegative", np.array([]))).ravel()

    # pick the waveform matrix to slice
    W = np.array(positive_spikes.get("allSpikesCorrFree", np.array([])))

    # build mask: match label and ignore zeros
    mask = (labels == lbl) & (labels != 0)

    # slice; result is (n_selected_spikes, n_samples)
    waveforms_for_label = W[mask] if (W.size and labels.size) else np.array([])

    # optional: count selected spikes
    n_selected = int(mask.sum()) if labels.size else 0

    spike_dict = {"Patient Data": Patient_Data,
    "Spike Label": lbl,
    "Number of Spikes": n_selected,
    "Spikes": waveforms_for_label
    }

    return spike_dict


bag_0 = create_bag_of_spikes(1)
bag_0

{'Patient Data': ('DA047 Sorted', '12-13', '15'),
 'Spike Label': 4023,
 'Number of Spikes': 332,
 'Spikes': array([[-0.00766344, -0.00766344, -0.00766344, ..., -0.22314554,
          0.03320956,  0.1151993 ],
        [-0.39866404, -0.39866404, -0.39866404, ..., -1.04854944,
         -0.81387112, -0.74724279],
        [-0.38320734, -0.38320734, -0.38320734, ..., -0.27450319,
         -0.28213579, -0.32498458],
        ...,
        [ 0.40969492,  0.40969492,  0.40969492, ...,  0.49720991,
          0.38197632,  0.24522426],
        [ 0.03832915,  0.03832915,  0.03832915, ...,  0.4282705 ,
          0.61119347,  0.69317112],
        [ 1.33231828,  1.33231828,  1.33231828, ..., -0.01799642,
         -0.10343879, -0.36783007]])}

In [129]:
# There is a syntax error here due to invalid argument annotations in the function definition.
# In Python, you can use type hints like def func(arg1: type, arg2: type), but 'good_units[array]' is not valid.
# Remove '[array]' from the function argument list:

import numpy as np

def get_units(good_units: dict,
              noise: dict,
              alignment: str,
              spike_type: str):
    # pick source dict
    src = good_units if spike_type == "good_units" else noise

    # key map
    key_map = {
        ("positive", "good_units"): "Positive Included",
        ("negative", "good_units"): "Negative Included",
        ("positive", "noise"): "Positive Noise",
        ("negative", "noise"): "Negative Noise",
    }

    key = key_map.get((alignment, spike_type), None)
    if key is None:
        return []

    arr = src.get(key, [])
    if isinstance(arr, (list, np.ndarray)) and len(arr) > 0:
        return arr[0]
    return []


# example
# print(get_units(good_units, noise, "negative", "noise"))


print(get_units(good_units, noise, "positive", "noise"))

[701, 4803, 4963, 4752, 4701, 2934, 3007, 4651, 5282, 5319, 5125, 4854, 5128]


In [187]:
# Written by Marco

def create_bag_of_spikes(electrode_num=0, spike_num=0, patient_data=patient_data, alignment = "positive", spike_type = "noise"):
    positive_file = sort_path / f"5.1/A_ss{electrode_num}_sorted_new.mat"
    negative_file = sort_path / f"5.2/A_ss{electrode_num}_sorted_new.mat"
    # final_max = sort_path / f"final/A_ss{electrode_num}_Max_sorted_new.mat"
    # final_min = sort_path / f"final/A_ss{electrode_num}_Min_sorted_new.mat"

    positive_spikes = safe_load(positive_file)
    negative_spikes = safe_load(negative_file)
    
    # The final spikes arent needed because in all cases the original spike file is stored in sort/5.1 and sort/5.2
    # final_max_spikes = safe_load(final_max)
    # final_min_spikes = safe_load(final_min)

    if spike_type =="noise":
        if alignment == "positive":    
            # target label (e.g., from your good_units)
            lbl = int(noise["Positive Noise"][0][spike_num]) if len(noise["Positive Noise"]) > 0 else []

            # flatten labels safely
            labels = np.array(positive_spikes.get("assignedNegative", np.array([]))).ravel()

            # pick the waveform matrix to slice
            W = np.array(positive_spikes.get("allSpikesCorrFree", np.array([])))

        else:  
            # target label (e.g., from your good_units)
            lbl = int(noise["Negative Noise"][0][spike_num]) if len(noise["Negative Noise"]) > 0 else []

            # flatten labels safely
            labels = np.array(negative_spikes.get("assignedNegative", np.array([]))).ravel()

            # pick the waveform matrix to slice
            W = np.array(negative_spikes.get("allSpikesCorrFree", np.array([])))
    else:
        if alignment == "positive":    
            # target label (e.g., from your good_units)
            lbl = int(good_units["Positive Included"][0][spike_num]) if len(good_units["Positive Included"]) > 0 else []

            # flatten labels safely
            labels = np.array(positive_spikes.get("assignedNegative", np.array([]))).ravel()

            # pick the waveform matrix to slice
            W = np.array(positive_spikes.get("allSpikesCorrFree", np.array([])))

        else:  
            # target label (e.g., from your good_units)
            lbl = int(good_units["Negative Included"][0][spike_num]) if len(good_units["Negative Included"]) > 0 else []

            # flatten labels safely
            labels = np.array(negative_spikes.get("assignedNegative", np.array([]))).ravel()

            # pick the waveform matrix to slice
            W = np.array(negative_spikes.get("allSpikesCorrFree", np.array([]))) 

    # build mask: match label and ignore zeros
    mask = (labels == lbl) & (labels != 0)

    # slice; result is (n_selected_spikes, n_samples)
    waveforms_for_label = W[mask] if (W.size and labels.size) else np.array([])

    # optional: count selected spikes
    n_selected = int(mask.sum()) if labels.size else 0

    spike_dict = {"patient data": patient_data,
    "spike label": lbl,
    "number of spikes": n_selected,
    "spike_type": spike_type,
    "alignment": alignment,
    "spikes": waveforms_for_label
    }

    return spike_dict

bag_1 = create_bag_of_spikes(electrode_num=3, spike_num=3, patient_data=patient_data, alignment = "positive")
bag_1

{'patient data': ('DA047 Sorted', '12-13', '1'),
 'spike label': 4752,
 'number of spikes': 0,
 'spike_type': 'noise',
 'alignment': 'positive',
 'spikes': array([], shape=(0, 256), dtype=float64)}

In [None]:
# Written by Chat

import numpy as np
from typing import Any, Dict, Tuple, Optional

def _as_list_first(seq: Any):
    """Return a flat Python list. If list-of-lists/ndarrays, take first inner."""
    if isinstance(seq, np.ndarray):
        seq = seq.tolist()
    if not isinstance(seq, list) or not seq:
        return []
    first = seq[0]
    if isinstance(first, np.ndarray):
        return first.tolist()
    if isinstance(first, list):
        return first
    return seq

def _safe_label(src: Dict[str, Any], key: str, idx: int) -> Optional[int]:
    """
    Safely fetch src[key] possibly shaped like [[...]] or ndarray,
    return element at idx as int, or None if missing/out-of-range.
    """
    flat = _as_list_first(src.get(key, []))
    if 0 <= idx < len(flat):
        try:
            return int(flat[idx])
        except Exception:
            return None
    return None

def create_bag_of_spikes(
    electrode_num: int = 0,
    spike_num: int = 0,
    patient_data: Any = None,
    alignment: str = "positive",       # "positive" | "negative"
    spike_type: str = "noise"          # "noise" | "good_units"
):
    # expects globals: sort_path, safe_load, noise, good_units
    pos_file = sort_path / f"5.1/A_ss{electrode_num}_sorted_new.mat"
    neg_file = sort_path / f"5.2/A_ss{electrode_num}_sorted_new.mat"

    pos = safe_load(pos_file)
    neg = safe_load(neg_file)

    # choose spike dict by alignment
    spikes = pos if alignment == "positive" else neg

    # choose label source dict and key
    if spike_type == "noise":
        label_src: Dict[str, Any] = noise
        label_key = "Positive Noise" if alignment == "positive" else "Negative Noise"
    else:
        label_src = good_units
        label_key = "Positive Included" if alignment == "positive" else "Negative Included"

    # safe label extraction
    lbl = _safe_label(label_src, label_key, spike_num)

    # flatten labels and pick waveform matrix
    labels = np.array(spikes.get("assignedNegative", np.array([]))).ravel()
    W = np.array(spikes.get("allSpikesCorrFree", np.array([])))

    # build mask only if we have a label and arrays are non-empty
    if lbl is not None and labels.size and W.size:
        # guard dimension mismatch (labels should match W rows)
        n_rows = W.shape[0]
        if labels.shape[0] != n_rows:
            # trim to the shorter length to avoid index errors
            n = min(n_rows, labels.shape[0])
            W = W[:n]
            labels = labels[:n]
        mask = (labels == lbl) & (labels != 0)
        waveforms_for_label = W[mask]
        n_selected = int(mask.sum())
    else:
        waveforms_for_label = np.array([])
        n_selected = 0

    return {
        "patient data": patient_data,
        "spike label": lbl,
        "number of spikes": n_selected,
        "spike_type": spike_type,
        "alignment": alignment,
        "spikes": waveforms_for_label,
    }

# example call (uses your globals noise/good_units/sort_path/safe_load/patient_data):
bag_1 = create_bag_of_spikes(electrode_num=1, spike_num=0, patient_data=patient_data, alignment="positive", spike_type="noise")


In [169]:
bag_1

{'patient data': ('DA047 Sorted', '12-13', '1'),
 'spike label': 701,
 'number of spikes': 79,
 'spike_type': 'noise',
 'alignment': 'positive',
 'spikes': array([[-0.52701887, -0.52701887, -0.52701887, ..., -0.39468965,
         -0.58021059, -0.74705469],
        [ 0.22877798,  0.22877798,  0.22877798, ..., -0.58857936,
         -0.54955398, -0.55416127],
        [ 0.57100215,  0.57100215,  0.57100215, ..., -0.00713703,
          0.19888277,  0.34133471],
        ...,
        [-0.6481804 , -0.6481804 , -0.6481804 , ..., -1.36424797,
         -1.42281428, -1.45733443],
        [ 0.65367604,  0.65367604,  0.65367604, ..., -0.13056104,
          0.01292183,  0.12577971],
        [-0.24687845, -0.24687845, -0.24687845, ..., -1.15873753,
         -1.22594925, -1.34878511]])}

In [None]:
# Written by Marco

import os

patient_data, good_units, noise = create_spike_summary(electrode_num=1, patient = "DA047 Sorted", session = "12-13")

def list_subdirectories(dir_path):
    file_list = [name for name in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, name))]
    print(dir_path, "contains", file_list)
    return file_list

def create_bag_of_bags(num_channels, patient_path, patient = "DA047 Sorted"):
    # get positive and negative allignment
    alignments = ["positive", "negative"]
    spike_type = ["noise", "good_units"]
    # include all sessions
    subdirs = list_subdirectories(patient_path)
    bag_of_bags = []
    for dir in subdirs:
        for channel in range(int(num_channels)+1):
            print(dir, channel)
            patient_data, good_units, noise = create_spike_summary(electrode_num=channel, patient = patient, session = dir)
            for type in spike_type:
                for alignment in alignments:
                    units = get_units(good_units, noise, alignment = alignment, spike_type = type)
                    print(units)
                    num_units = range(len(units))
                    for unit in num_units:
                        bag = create_bag_of_spikes(electrode_num=channel, spike_num=unit, patient_data=patient_data, alignment = alignment, spike_type=type)
                        bag_of_bags.append(bag)
                        return unit
                        print(" Unit: ", unit, "from channel: ", channel, "Is Done!")
    return bag_of_bags


In [189]:
DA047_Bag = create_bag_of_bags(num_channels=3, patient = "DA047 Sorted", patient_path = "/Users/marco/Local_Sorting/DA047 Sorted")
print(DA047_Bag)

/Users/marco/Local_Sorting/DA047 Sorted contains ['12-15', '12-13', '12-16']
12-15 0
[]
[]
[]
[]
12-15 1
[701, 4803, 4963, 4752, 4701, 2934, 3007, 4651, 5282, 5319, 5125, 4854, 5128]
0


In [None]:
# Written by Chat

import os

# assume: create_spike_summary, get_units, create_bag_of_spikes exist

def list_subdirectories(dir_path: str):
    names = os.listdir(dir_path)
    subdirs = [n for n in names if os.path.isdir(os.path.join(dir_path, n))]
    print(dir_path, "contains", subdirs)
    return subdirs

def create_bag_of_bags(
    num_channels: int,
    patient_path: str,
    patient: str = "DA047 Sorted",
):
    alignments = ("positive", "negative")
    spike_kinds = ("noise", "good_units")

    sessions = list_subdirectories(patient_path)
    bag_of_bags = []

    for session_dir in sessions:
        for channel in range(int(num_channels)):  # 0..num_channels-1
            patient_data, gu, nz = create_spike_summary(
                electrode_num=channel, patient=patient, session=session_dir
            )

            for spike_kind in spike_kinds:
                for alignment in alignments:
                    units = get_units(gu, nz, alignment=alignment, spike_type=spike_kind)
                    for unit_idx in range(len(units)+1):
                        bag = create_bag_of_spikes(
                            electrode_num=channel,
                            spike_num=unit_idx,
                            patient_data=patient_data,
                            alignment=alignment,
                            spike_type=spike_kind
                            # pass dicts explicitly to avoid stale globals
                            # good_units=gu,
                            # noise=nz,
                        )
                        bag_of_bags.append(bag)
                        print(
                            f"Session={session_dir} "
                            f"Channel={channel} "
                            f"Kind={spike_kind} "
                            f"Align={alignment} "
                            f"Unit={unit_idx} done"
                        )
        print()  # spacer per session

    return bag_of_bags


In [186]:
DA047_Bag = create_bag_of_bags(num_channels=3, patient = "DA047 Sorted", patient_path = "/Users/marco/Local_Sorting/DA047 Sorted")
print(DA047_Bag)

/Users/marco/Local_Sorting/DA047 Sorted contains ['12-15', '12-13', '12-16']
Session=12-15 Channel=0 Kind=noise Align=positive Unit=0 done
Session=12-15 Channel=0 Kind=noise Align=negative Unit=0 done
Session=12-15 Channel=0 Kind=good_units Align=positive Unit=0 done
Session=12-15 Channel=0 Kind=good_units Align=negative Unit=0 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=0 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=1 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=2 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=3 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=4 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=5 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=6 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=7 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=8 done
Session=12-15 Channel=1 Kind=noise Align=positive Unit=9 done
Session=12-15 Channel=1 Kind=noise Align=posi

In [164]:
print(DA047_Bag)

0


In [72]:
patient = "DA047 Sorted"
session = "12-13"

base_path = Path("/Users/marco/Local_Sorting")
patient_path = base_path / patient
session_path = patient_path / session
sort_path = session_path / "sort"
raw_path = session_path / "raw/micros"

positive_dir = sort_path / "5.1"
negative_dir = sort_path / "5.2"
final_dir = sort_path / "final"



raw_file = si.read_neuralynx(raw_path)
num_channels = len(raw_file.channel_ids)
print(num_channels)

16


In [76]:
import os

def list_subdirectories(dir_path):
    file_list = [name for name in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, name))]
    print(dir_path, "contains", file_list)
    return file_list

# Example usage:
subdirs = list_subdirectories(patient_path)


/Users/marco/Local_Sorting/DA047 Sorted contains ['12-15', '12-13', '12-16']
