In [1]:
import numpy as np
from scipy import signal
from scipy.stats import zscore
from scipy.ndimage import label, uniform_filter1d
import pandas as pd
import spikeinterface as si
import spikeinterface.extractors as se
from spikeinterface import preprocessing as spp
from hdmf.data_utils import GenericDataChunkIterator
from pathlib import Path, PurePath
import mat73
import matplotlib.pyplot as plt
from itertools import groupby
from operator import itemgetter
import met_brewer as mb
import pywt

In [2]:
def get_span_start_stop(indices, offset=0):
    """Get start and stop indices of spans of consecutive indices"""
    span_inds = []
    for k, g in groupby(enumerate(indices), lambda x: x[1] - x[0]):
        group = list(map(itemgetter(1), g))
        span_inds.append((group[0] + offset, group[-1] + offset))
    return span_inds

In [3]:
## Set Data Paths
raw_data_path = Path("/mnt/born_animal/Hypothalamic_Sleep/data/raw/")
animal = "HYDO02"
date1 = "2024-07-24_05-57-05"
date2 = "2024-07-22_10-47-00"
data1_path = Path(raw_data_path, animal, date1)
data2_path = Path(raw_data_path, animal, date2)
assert data1_path.exists()

In [4]:
## filter_lfp.py

import numpy as np
import psutil
from scipy import signal
import spikeinterface as si
import spikeinterface.extractors as se
import ghostipy as gsp
from pathlib import Path, PosixPath
from typing import List, Literal


def load_rec(recording_path, concatenate=True, channels=None):
    recording = se.read_neuralynx(recording_path)
    if recording.get_num_segments() > 1:
        if concatenate:
            timestamps = []
            for seg in recording._recording_segments:
                timestamps.extend(seg.get_times())
            concat_recording = si.ConcatenateSegmentRecording([recording])
            concat_recording.set_times(np.array(timestamps))
            recording = concat_recording
    if channels:
        if channels == "all":
            recording = recording
        elif isinstance(channels, List):
            if any(channel not in recording.get_channel_ids() for channel in channels):
                raise ValueError(f">= 1 channel of {channels} not found in recording.")
            recording = recording.channel_slice(channel_ids=channels)
    return recording


def get_recording_path(base_path, animal="", date=""):
    if not isinstance(base_path, PosixPath):
        base_path = Path(base_path)
    data_path = Path(base_path, animal, date)
    if not data_path.exists():
        raise ValueError(f"Data path {data_path} does not exist.")
    return data_path


def get_valid_times(recording, atol=1e-6):
    timestamps = recording.get_times()
    dt = 1 / recording.get_time_info()["sampling_frequency"]
    time_diff = np.diff(timestamps)
    jump_times = np.concatenate(([0], np.where(time_diff - dt > atol)[0]))
    valid_times = [
        (timestamps[jump_times[i]], timestamps[jump_times[i + 1]])
        for i in range(len(jump_times) - 1)
    ]
    del timestamps
    return valid_times


def reference_recording(
    recording, reference=Literal["global", "single"], ref_channel_id=None
):
    if reference == "global":
        recording = si.preprocessing.common_reference(
            recording,
            reference=reference,
            operator="median",
            dtype=np.float64,
        )
    elif reference == "single":
        recording = si.preprocessing.common_reference(
            recording,
            reference=reference,
            ref_channel_ids=ref_channel_id,
            dtype=np.float64,
        )
    return recording


def get_filter_coeff(target_fs, band_edges):
    transition_width = (
        (band_edges[1] - band_edges[0]) + (band_edges[3] - band_edges[2])
    ) / 2.0
    numtaps = gsp.estimate_taps(target_fs, transition_width)
    desired = [0, 1, 1, 0]
    TRANS_SPLINE = 2
    filter_coeff = np.array(
        gsp.firdesign(numtaps, band_edges, desired, fs=target_fs, p=TRANS_SPLINE),
        ndmin=1,
    )
    return filter_coeff


def time_bound_check(start, stop, timestamps, n_samples):
    if start < timestamps[0]:
        start = timestamps[0]
    if stop > timestamps[-1]:
        stop = timestamps[-1]
    frm, to = np.searchsorted(timestamps, (start, stop))
    to = min(to, n_samples)
    return frm, to


def filter_data(recording, filter_coeff, valid_times, decimation=None, channels=None):

    if channels is None:
        channels = recording.get_channel_ids()
    elecs = [int(ch) for ch in channels]
    timestamps = recording.get_times()
    n_samples = recording.get_num_samples()
    ram_capacity = psutil.virtual_memory().available / (1024**3) * 0.9
    rec_disk_mem = recording.get_memory_size() / (1024**3)
    data_on_disk = recording.get_traces(channel_ids=channels)
    n_dim = len(data_on_disk.shape)
    input_dim_restrictions = [None] * n_dim
    input_dim_restrictions[1] = np.s_[elecs]
    indices = []
    output_shape_list = [0] * 2
    output_shape_list[1] = len(channels)
    output_offsets = [0]
    filter_delay = (len(filter_coeff) - 1) // 2
    if rec_disk_mem > ram_capacity:
        for start, stop in valid_times:
            frm, to = time_bound_check(start, stop, timestamps, n_samples)
            if np.isclose(frm, to, rtol=0, atol=1e-8):
                continue
            indices.append((frm, to))
            shape, _ = gsp.filter_data_fir(
                data_on_disk,
                filter_coeff,
                axis=0,
                input_index_bounds=[frm, to],
                output_index_bounds=[filter_delay, filter_delay + to - frm],
                describe_dims=True,
                ds=decimation,
                input_dim_restrictions=input_dim_restrictions,
            )
            output_offsets.append(output_offsets[-1] + shape[0])
            output_shape_list[0] += shape[0]
        filtered_data = np.empty(tuple(output_shape_list), dtype=data_on_disk.dtype)
        new_timestamps = np.empty((output_shape_list[0],), timestamps.dtype)
        indices = np.array(indices, ndmin=2)
        ts_offset = 0
        for i, (start, stop) in enumerate(indices):
            extracted_ts = timestamps[start:stop:decimation]
            new_timestamps[ts_offset : ts_offset + len(extracted_ts)] = extracted_ts
            ts_offset += len(extracted_ts)
            gsp.filter_data_fir(
                data_on_disk,
                filter_coeff,
                axis=0,
                input_index_bounds=[start, stop],
                output_index_bounds=[filter_delay, filter_delay + stop - start],
                outarray=filtered_data,
                ds=decimation,
                input_dim_restrictions=input_dim_restrictions,
                output_offset=output_offsets[i],
            )
    # filtered_data_tmp = dask.compute(*results)
    return filtered_data, new_timestamps

In [5]:
## Set Data Paths
# raw_data_path = Path("/mnt/born_animal/Hypothalamic_Sleep/data/raw/")
raw_data_path = Path("/home/born-animal/Desktop/data/")

# animal = "HYDO02"
date1 = "2024-07-24_05-57-05"
rec_path = get_recording_path(base_path=raw_data_path, date=date1)
channels = ["37", "39"]
recording = load_rec(recording_path=rec_path, concatenate=True, channels=channels)
valid_times = get_valid_times(recording)
target_sampling_rate = 1000
spi_band = np.array([10, 16])  # Hz
low_cutoff, high_cutoff = spi_band  # 2 * spi_band / target_sampling_rate
order = 4  # Order of the filter
down_rec = spp.resample(recording, resample_rate=target_sampling_rate)
# down_rec_refd = reference_recording(down_rec, reference="global")
down_filt_rec = spp.bandpass_filter(
    down_rec, freq_min=low_cutoff, freq_max=high_cutoff, **{"filter_order": order}
)
filt_timestamps = down_filt_rec.get_times()
# filt_rec_traces = down_filt_rec.get_traces(channel_ids=channels)

# decimation = int(32_000 // target_sampling_rate)
# band_edges = [9, 12, 16, 20]  # low stop, low pass, high pass, high stop (Hz)
# filter_coeff = get_filter_coeff(target_fs=target_sampling_rate, band_edges=band_edges)
# filt_data, filt_timestamps1 = filter_data(recording, filter_coeff, valid_times, decimation=decimation, channels=channels)

  warn(


In [6]:
# start = 0 * target_sampling_rate
# dur = 5 * target_sampling_rate
# # plt.plot(down_rec.get_times()[start:start+dur], down_rec.get_traces(channel_ids=["0"])[start:start+dur, 0])
# plt.plot(filt_timestamps[start : start + dur], filt_rec_traces[start : start + dur, 0])
# plt.plot(filt_timestamps[start : start + dur], spi_amp[start : start + dur, 0])

In [7]:
scoring_path = Path("/mnt/born_animal/DanielG/hypothalamic_sleep_scoring/")

In [8]:
scoring = mat73.loadmat(
    PurePath(scoring_path, "2024-05-21_10-28-00_EEGEMG_25Hz.mat"), use_attrdict=True
)["SlStNew"]
hypno = scoring["codes"][:, 0].astype(float)

In [9]:
cfg = {
    "scoring": {
        "name": "Hypothalamus Animal 1",
        "scoring": hypno,
        "scoring_epoch_length": 10,  # scoring epoch changed to 2 seconds
        "code_NREM": [2, 4],
        "code_REM": [3],
        "code_WAKE": [1],
    },
    "spectrum": {
        "Fs": down_filt_rec.get_sampling_frequency(),
        "artfctpad": 0,
        "spectrum": 1,
        "spec_freq": [1, 45],
        "invertdata": 0,
        "slo": {
            "slo": 1,
            "slo_dur_min": [0.5, 0.25],
            "slo_dur_max": [2.5, 2.5],
            "slo_thr": 1.5,
            "slo_peak2peak_min": 70,  # rec is in uV
            "slo_freq": [0.1, 4],
            "slo_filt_ord": 3,
            "slo_rel_thr": 33,  # online threshold: 20 | offline analysis: 33
            "slo_dur_max_down": 0.300,  # in s
        },
        "spi": {
            "spi": 1,
            "spi_dur_min": [0.5, 0.25],
            "spi_dur_max": [2.5, 2.5],
            "spi_thr": [1.5, 2, 2.5],
            "spi_thr_chan": [],
            "spi_freq": [10, 16],
            "spi_peakdist_max": 0.125,
            "spi_filt_ord": 6,
            "spi_indiv": 0,
        },
        "rip": 1,
    },
}

In [10]:
NREM_mask = np.where(
    np.logical_or.reduce(
        [cfg["scoring"]["scoring"] == code for code in cfg["scoring"]["code_NREM"]]
    ),
    1,
    0,
)
REM_mask = np.where(
    np.logical_or.reduce(
        [cfg["scoring"]["scoring"] == code for code in cfg["scoring"]["code_REM"]]
    ),
    1,
    0,
)
WAKE_mask = np.where(
    np.logical_or.reduce(
        [cfg["scoring"]["scoring"] == code for code in cfg["scoring"]["code_WAKE"]]
    ),
    1,
    0,
)

In [11]:
state_dict = {
    "NREM": {
        "mask": None,
        "onset": None,
        "offset": None,
    },
    "REM": {
        "mask": None,
        "onset": None,
        "offset": None,
    },
    "WAKE": {
        "mask": None,
        "onset": None,
        "offset": None,
    },
}
for (key, val), mask in zip(state_dict.items(), [NREM_mask, REM_mask, WAKE_mask]):
    val["mask"] = np.repeat(mask, int(10 * down_filt_rec.get_sampling_frequency()))[
        : down_filt_rec.get_num_samples()
    ].astype(int)
    val["onset"] = np.where(np.diff(val["mask"]) > 0)[0]
    val["offset"] = np.where(np.diff(val["mask"]) < 0)[0]
    val["mask"] = val["mask"].astype(bool)

In [12]:
if cfg["scoring"]["scoring"][0] in cfg["scoring"]["code_NREM"]:
    state_dict["NREM"]["onset"] = np.concatenate(([0], state_dict["NREM"]["onset"]))
if cfg["scoring"]["scoring"][0] in cfg["scoring"]["code_REM"]:
    state_dict["REM"]["onset"] = np.concatenate(([0], state_dict["REM"]["onset"]))
if cfg["scoring"]["scoring"][0] in cfg["scoring"]["code_WAKE"]:
    state_dict["WAKE"]["onset"] = np.concatenate(([0], state_dict["WAKE"]["onset"]))
if cfg["scoring"]["scoring"][-1] in cfg["scoring"]["code_NREM"]:
    state_dict["NREM"]["offset"] = np.concatenate(
        (state_dict["NREM"]["offset"], [len(cfg["scoring"]["scoring"])])
    )
if cfg["scoring"]["scoring"][-1] in cfg["scoring"]["code_REM"]:
    state_dict["REM"]["offset"] = np.concatenate(
        (state_dict["REM"]["offset"], [len(cfg["scoring"]["scoring"])])
    )
if cfg["scoring"]["scoring"][-1] in cfg["scoring"]["code_WAKE"]:
    state_dict["WAKE"]["offset"] = np.concatenate(
        (state_dict["WAKE"]["offset"], [len(cfg["scoring"]["scoring"])])
    )

In [13]:
for key, val in state_dict.items():
    val["times"] = np.array(
        (
            val["onset"] * cfg["scoring"]["scoring_epoch_length"] + 1,
            val["offset"] * cfg["scoring"]["scoring_epoch_length"],
        )
    )

In [14]:
min_dur_1 = cfg["spectrum"]["spi"]["spi_dur_min"][0] * cfg["spectrum"]["Fs"]
max_dur_1 = cfg["spectrum"]["spi"]["spi_dur_max"][0] * cfg["spectrum"]["Fs"]

min_dur_2 = cfg["spectrum"]["spi"]["spi_dur_min"][1] * cfg["spectrum"]["Fs"]
max_dur_2 = cfg["spectrum"]["spi"]["spi_dur_max"][1] * cfg["spectrum"]["Fs"]
buff = 5 * cfg["spectrum"]["Fs"]
# Find where smoothed envelope above threshold 0
ch_dict = {
    ch: {"spindles": [], "rejects": 0, "trace": None, "spi_amp_smooth": None}
    for ch in channels
}
for ch_ind, channel in enumerate(channels):
    print("Processing channel", channel)
    trace = down_filt_rec.get_traces(
        channel_ids=[channel], return_scaled=True
    ).flatten()
    spi_amp = np.abs(signal.hilbert(trace, axis=0))
    spi_amp_mean = spi_amp[state_dict["NREM"]["mask"]].mean()
    spi_amp_std = spi_amp[state_dict["NREM"]["mask"]].std()
    spi_amp_smooth = uniform_filter1d(
        spi_amp,
        int(0.1 * cfg["spectrum"]["Fs"]),
        axis=0,
        mode="constant",
        cval=0,
    )
    ch_dict[channel]["trace"] = trace
    ch_dict[channel]["spi_amp_smooth"] = spi_amp_smooth
    thr = np.zeros((3,))
    if not cfg["spectrum"]["spi"]["spi_thr_chan"]:
        # Determine channel-specific thresholds
        thr[:] = np.asarray(cfg["spectrum"]["spi"]["spi_thr"]) * spi_amp_std
    # else:
    #     # One threshold for all channspi_amp_smoothcfg["spectrum"]["spi"]["spi_thr"][2] * mean_std
    for onset, offset in zip(state_dict["NREM"]["onset"], state_dict["NREM"]["offset"]):
        mask = np.zeros_like(state_dict["NREM"]["mask"], dtype=bool)
        mask[onset : offset + 1] = True
        # Find threshold crossings
        above_thr = np.where(spi_amp_smooth[mask] > thr[0])[0]
        if len(above_thr) == 0:
            continue
        span_inds = get_span_start_stop(above_thr, onset)
        # Check span duration against min and max duration 0
        good_spans_1 = [
            span
            for span in span_inds
            if (min_dur_1 < span[1] - span[0] < max_dur_1)
            and (onset not in np.arange(span[0], span[1]))
            and ((spi_amp_smooth.size - span[1]) > buff)
        ]
        if len(good_spans_1) == 0:
            continue
        for spans in good_spans_1:
            mask_2 = np.zeros_like(state_dict["NREM"]["mask"], dtype=bool)
            mask_2[spans[0] : spans[1] + 1] = True
            # Check second threshold crossing set against second min and max duration
            above_thr_2 = np.where(spi_amp_smooth[mask_2] > thr[1])[0]
            if len(above_thr_2) == 0:
                ch_dict[channel]["rejects"] += 1
                continue
            span_inds_2 = get_span_start_stop(above_thr_2, offset=onset)
            good_spans_2 = [
                span
                for span in span_inds_2
                if min_dur_2 < span[1] - span[0] < max_dur_2
            ]
            if len(good_spans_2) == 0:
                ch_dict[channel]["rejects"] += 1
                continue
            if len(good_spans_2) > 1:
                good_span_2 = (good_spans_2[0][0], good_spans_2[-1][1])
            else:
                good_span_2 = good_spans_2[0]
            mask_3 = np.zeros_like(state_dict["NREM"]["mask"], dtype=bool)
            mask_3[good_span_2[0] : good_span_2[1]] = True
            above_thr_3 = np.where(spi_amp_smooth[mask_3] > thr[2])[0]
            if above_thr_3.size == 0:
                ch_dict[channel]["rejects"] += 1
                continue
            # Find peaks within spindle and check against peak-to-peak distance threshold
            peaks = signal.find_peaks(trace[spans[0] : spans[1]], prominence=thr[0])[0]
            if len(peaks) == 0:
                ch_dict[channel]["rejects"] += 1
                continue
            peak_diffs = np.diff(peaks)
            if np.all(
                peak_diffs
                < cfg["spectrum"]["spi"]["spi_peakdist_max"] * cfg["spectrum"]["Fs"]
            ):
                central_peak_ind = peaks[
                    np.argmin(np.abs(peaks - ((spans[1] - spans[0]) // 2)))
                ]
                ch_dict[channel]["spindles"].append(
                    spans
                    + (
                        central_peak_ind,
                        thr,
                    )
                )


# For each crossing, find threshold 1 crossing
# Check second threshold crossing set against second min and max duration
# Find peaks within spindle and check against peak-to-peak distance threshold

Processing channel 37
Processing channel 39


In [16]:
# colors = ["#880000", "#1177ed", "#7bb622", "#d027a4", "#551199"]
colors = mb.met_brew(name="Thomas", n=8)
plot_dir = Path("/home/born-animal/Desktop/plots/")
if not plot_dir.exists():
    plot_dir.mkdir()
plot_kwargs = {"raw": True}
win = int(1.25 * cfg["spectrum"]["Fs"])
for ch, vals in ch_dict.items():
    if "raw" in plot_kwargs.keys():
        if plot_kwargs["raw"]:
            raw_trace = down_rec.get_traces(
                channel_ids=[ch], return_scaled=True
            ).flatten()
    trace = ch_dict[ch]["trace"]
    spi_amp_smooth = ch_dict[ch]["spi_amp_smooth"]
    for ind, (spi_start, spi_stop, central_peak_ind, thrs) in enumerate(
        vals["spindles"]
    ):
        inds_to_plot = np.arange(
            spi_start + central_peak_ind - win, spi_start + central_peak_ind + win + 1
        )
        zero_point = filt_timestamps[spi_start + central_peak_ind]
        times = filt_timestamps[inds_to_plot] - zero_point
        fig, ax = plt.subplots(figsize=(9, 6))
        if "raw" in plot_kwargs.keys():
            if plot_kwargs["raw"]:
                ax.plot(
                    times,
                    # filt_timestamps[spi_start - win : spi_stop + win],
                    raw_trace[inds_to_plot],
                    label="Raw Trace",
                    c=colors[-1],
                )
        ax.plot(
            times,
            # np.arange(spi_start - win, spi_stop + win),
            trace[inds_to_plot],
            lw=2,
            label="Filtered Trace",
            c=colors[0],
        )
        ax.plot(
            times,
            # np.arange(spi_start - win, spi_stop + win),
            spi_amp_smooth[inds_to_plot],
            label="Envelope",
            lw=2,
            c=colors[4],
        )
        ax.vlines(
            x=np.asarray([filt_timestamps[spi_start], filt_timestamps[spi_stop]])
            - zero_point,
            ymin=-100,
            ymax=100,
            colors=["g", "r"],
            ls="--",
            lw=2,
            alpha=0.6,
            label="On/Offset",
        )
        # ax.vlines(
        #     x = filt_timestamps[spi_stop] - zero_point,
        #     ymin=-100,
        #     ymax=100,
        #     c="r",
        #     ls="--",
        #     lw=2,
        #     alpha=0.6,
        #     label="Offset",
        # )
        # ax.plot(
        #     filt_timestamps[spi_start] - zero_point,
        #     spi_amp_smooth[spi_start],
        #     ls="",
        #     marker="X",
        #     c="g",
        #     ms=7,
        #     alpha=0.6,
        #     label="onset",
        # )
        # ax.plot(
        #     filt_timestamps[spi_stop] - zero_point,
        #     spi_amp_smooth[spi_stop],
        #     ls="",
        #     marker="X",
        #     c="r",
        #     ms=7,
        #     alpha=0.6,
        #     label="offset",
        # )
        ax.hlines(
            y=thrs,
            xmin=filt_timestamps[spi_start] - zero_point,
            xmax=filt_timestamps[spi_stop] - zero_point,
            color="k",
            label="Thresholds",
            ls="--",
        )
        ax.set_title(
            f"Channel: {ch} - Duration: {filt_timestamps[spi_stop] - filt_timestamps[spi_start]:.2f}s, ID: {ind:03d}"
        )
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("Amplitude (uV) [10-16 Hz]")
        ax.legend(loc="upper right")
        ax.set_ylim([-250, 250])
        ax.set_xlim([-win / 1000, win / 1000])
        fig.tight_layout()
        fig.savefig(
            Path(plot_dir, f"spindle_{ch}_{ind:03d}.png"),
            dpi=450,
            facecolor="w",
            transparent=False,
        )
        plt.close()