# Load the datasets

The path to dataset directory and patterns to search in those directories for the HPC, PFC recordings are in loaded from the config file.

In [1]:
from phasic_tonic.DatasetLoader import DatasetLoader
from phasic_tonic.helper import get_metadata
from phasic_tonic.runtime_logger import logger_setup
from phasic_tonic.utils import get_sequences

import numpy as np
import pandas as pd
import pynapple as nap
import yasa

from tqdm.auto import tqdm
from scipy.io import loadmat
from mne.filter import resample

fs_cbd = 2500
fs_os = 2500
fs_rgs = 1000

targetFs = 500
n_down_cbd = fs_cbd/targetFs
n_down_rgs = fs_rgs/targetFs
n_down_os = fs_os/targetFs

logger = logger_setup()

CONFIG_DIR = "/home/nero/phasic_tonic/data/dataset_loading.yaml"

Datasets = DatasetLoader(CONFIG_DIR)
mapped_datasets = Datasets.load_datasets()

def preprocess(signal: np.ndarray, n_down: int, target_fs=500) -> np.ndarray:
    """Downsample and remove artifacts."""
    
    logger.debug("STARTED: Resampling to 500 Hz.")
    # Downsample to 500 Hz
    data = resample(signal, down=n_down, method='fft', npad='auto')
    logger.debug("FINISHED: Resampling to 500 Hz.")
    logger.debug("Resampled: {0} -> {1}.".format(str(signal.shape), str(data.shape)))
    
    logger.debug("STARTED: Remove artifacts.")
    # Remove artifacts
    art_std, _ = yasa.art_detect(data, target_fs , window=1, method='std', threshold=4)
    art_up = yasa.hypno_upsample_to_data(art_std, 1, data, target_fs)
    data[art_up] = 0
    logger.debug("FINISHED: Remove artifacts.")
        
    data -= data.mean()
    return data

def get_start_end(hypno: np.ndarray, sleep_state_id: int):
    """Convert sleep states into lists of start and end time indices."""
    seq = get_sequences(np.where(hypno == sleep_state_id)[0])
    start = []
    end = []
    for s, e in seq:
        start.append(s)
        end.append(e)
    return (start, end)

Check the number of loaded recordings

In [2]:
cbd_cnt = 0
rgs_cnt = 0
os_cnt = 0

# Count recordings belonging to CBD dataset
for name in mapped_datasets:
    metadata = get_metadata(name)
    if metadata['treatment'] == 0 or metadata['treatment'] == 1:
        cbd_cnt += 1
    elif metadata['treatment'] == 2 or metadata['treatment'] == 3:
        rgs_cnt += 1
    elif metadata['treatment'] == 4:
        os_cnt += 1

assert cbd_cnt == 170
assert rgs_cnt == 159
assert os_cnt == 210

# Loop through the dataset

In [3]:
with tqdm(mapped_datasets) as t:
    for name in t:
        metadata = get_metadata(name)
        t.set_postfix_str(name)
        states_fname, hpc_fname, pfc_fname = mapped_datasets[name]
        logger.debug("Loading: {0}".format(name))

        if metadata["treatment"] == 0 or metadata["treatment"] == 1:
            n_down = n_down_cbd
        elif metadata["treatment"] == 2 or metadata["treatment"] == 3:
            n_down = n_down_rgs
        elif metadata["treatment"] == 4:
            n_down = n_down_os
        
        # Load the LFP data
        lfpHPC = loadmat(hpc_fname)['HPC'].flatten()
        lfpPFC = loadmat(pfc_fname)['PFC'].flatten()

        # Load the states
        hypno = loadmat(states_fname)['states'].flatten()
        
        # Skip if no REM epoch is detected
        if(not (np.any(hypno == 5))):
            logger.debug("No REM detected. Skipping.")
            continue

        lfpHPC_down = preprocess(lfpHPC, n_down)
        
        break


  0%|          | 0/539 [00:00<?, ?it/s]

In [25]:
start, end = get_start_end(hypno=hypno, sleep_state_id=5)
rem_interval = nap.IntervalSet(start=start, end=end)

fs = n_down*targetFs
t = np.arange(0, len(lfpHPC)/fs, 1/fs)

lfp = nap.TsdFrame(t=t, d=np.vstack([lfpHPC, lfpPFC]).T, columns=['HPC', 'PFC'])
lfpHPC_down = nap.Tsd(t=np.arange(0, len(lfpHPC_down)/targetFs, 1/targetFs), d=lfpHPC_down)
lfpHPC_down

Time (s)
----------  ---------
0.0           16.2492
0.002        119.317
0.004         99.0067
0.006         58.1087
0.008        -32.143
0.01        -126.713
0.012       -131.46
...
2701.128    -491.824
2701.13     -473.586
2701.132    -473.871
2701.134    -434.081
2701.136    -351.645
2701.138    -271.414
2701.14     -220.062
dtype: float64, shape: (1350571,)

In [33]:
rem_interval.drop_short_intervals(threshold=3, time_units="s")

            start    end
       0     1331   1485
       1     2428   2475
shape: (2, 2), time unit: sec.

In [None]:
from scipy.signal import hilbert
from neurodsp.filt import filter_signal

def detect_phasic(eeg: nap.Tsd, rem_interval: nap.IntervalSet, pplot = False, nfilt=11, thr_dur=900, min_dur=3):
    """
    Detect phasic-like states in REM periods.
    
    Args:
    -----
    eeg: Pynapple Timeseries object (Tsd)
    rem_interval: Pynapple Intervalset
    pplot: plot the phasic/tonic states
    min_dir: minimum duration for REM epochs (s)
    
    Returns:
    --------
    phrem_interval: Pynapple IntervalSet

    Reference:
    *TO DO (Buzsaki's paper)

    Example:
    --------
    *TO DO
    """
    # Ensure minimum duration
    rem_interval = rem_interval.drop_short_intervals(threshold=min_dur, time_units="s")
    
    # Get REM epochs as a list of Tsd
    rem_epochs = [eeg.restrict(rem_interval[i]) for i in range(len(rem_interval))]

    thresholds, intertrough_intervals, amps = compute_thresholds(rem_epochs, nfilt)

    

def compute_thresholds(rem_epochs, nfilt):
    
    # Filter to smoothen the intertrough intervals
    filt = np.ones(nfilt)/nfilt

    total_amplitude = []
    trough_indices = []
    smooth_trough_diff = {}

    for rem_epoch in rem_epochs:
        amp, phase = _compute_hilbert(sig=rem_epoch.to_numpy(), fs=int(rem_epoch.rate), pass_type="bandpass", f_range=(5, 12), remove_edges=False)
        
        # amplitude of the entire REM sleep
        total_amplitude += amp

        # trough indices
        trough_idx = _detect_troughs(phase, -3)

        # differences between troughs
        trough_diff = np.diff(trough_idx)

        # smoothed trough differences
        smooth_trough_diff[idx] = np.convolve(trough_diff, filt, 'same')

    pass
    

def _compute_hilbert(sig, fs, pass_type, f_range, remove_edges=False):
    """Applies a filter and Hilbert transform."""
    sig = filter_signal(sig, fs, pass_type=pass_type, f_range=f_range, remove_edges=remove_edges)
    sig = hilbert(sig)
    return np.abs(sig), np.angle(sig)

def _compute_thresholds(rem, fs, smooth_filt):
    """
    Computes the thresholds for detecting phasic REM states.
    """
    trdiff_list = []
    rem_eeg = []
    eeg_seq = {}
    sdiff_seq = {}
    tridx_seq = {}

    for idx in rem:
        epoch = rem[idx]
        amp, phase = _compute_hilbert(epoch, fs, pass_type="bandpass", f_range=(5, 12), remove_edges=False)

        # amplitude of the entire REM sleep
        rem_eeg += amp

        # trough indices
        tridx = _detect_troughs(phase, -3)

        # differences between troughs
        trdiff = np.diff(tridx)

        # smoothed trough differences
        sdiff_seq[idx] = np.convolve(trdiff, smooth_filt, 'same')

        # dict of trough differences for each REM period
        tridx_seq[idx] = tridx

        eeg_seq[idx] = amp
    
        # differences between troughs
        trdiff_list += list(trdiff)

        
    rem_eeg = np.array(rem_eeg)
    trdiff = np.array(trdiff_list)
    trdiff_sm = np.convolve(trdiff, smooth_filt, 'same')

    # potential candidates for phasic REM:
    # the smoothed difference between troughs is less than
    # the 10th percentile:
    thr1 = np.percentile(trdiff_sm, 10)
    # the minimum smoothed difference in the candidate phREM is less than
    # the 5th percentile
    thr2 = np.percentile(trdiff_sm, 5)
    # the peak amplitude is larger than the mean of the amplitude
    # of the REM EEG.
    thr3 = rem_eeg.mean()

    return thr1, thr2, thr3, tridx_seq, eeg_seq

def _detect_troughs(signal, thr):
    lidx  = np.where(signal[0:-2] > signal[1:-1])[0]
    ridx  = np.where(signal[1:-1] <= signal[2:])[0]
    thidx = np.where(signal[1:-1] < thr)[0]
    sidx = np.intersect1d(lidx, np.intersect1d(ridx, thidx))+1
    return sidx


In [105]:
phrem = nap.IntervalSet(start=[1331], end=[1400])
phrem

            start    end
       0     1331   1400
shape: (1, 2), time unit: sec.

In [None]:
from .utils import get_sequences

import numpy as np
from neurodsp.filt import filter_signal
from scipy.signal import hilbert

"""
detect_phasic
    Compute thresholds
        * Plotting (takes thresholds as argument)
        * apply thresholds
"""

def detect_phasic(rem, fs, nfilt=11, thr_dur=900):

    filt = np.ones(nfilt)/nfilt

    thr1, thr2, thr3, iti, amps = _compute_thresholds(rem, fs, smooth_filt=filt)

    phREM = {rem_idx:[] for rem_idx in rem.keys()}

    for rem_idx in iti:
        rem_start, rem_end = rem_idx
        offset = rem_start * fs

        # trough indices
        tridx = iti[rem_idx]

        # smoothed trough interval
        sdiff = np.convolve

        # amplitude of the REM epoch
        eegh = amps[rem_idx]

        # get the candidates for phREM
        cand_idx = np.where(sdiff <= thr1)[0]
        cand = get_sequences(cand_idx)

        for start, end in cand:
            # Duration of the candidate in milliseconds
            dur = ( (tridx[end]-tridx[start]+1)/fs ) * 1000
            if dur < thr_dur:
                continue # Failed Threshold 1
            
            min_sdiff = np.min(sdiff[start:end])
            if min_sdiff > thr2:
                continue # Failed Threshold 2
            
            mean_amp =  np.mean(eegh[tridx[start]:tridx[end]+1])
            if mean_amp < thr3:
                continue # Failed Threshold 3
            
            t_a = tridx[start] + offset
            t_b = np.min((tridx[end] + offset, rem_end * fs))

            ph_idx = (t_a, t_b+1)
            phREM[rem_idx].append(ph_idx)
    
    return phREM

def _compute_hilbert(sig, fs, pass_type, f_range, remove_edges=False):
    """
    Applies a filter and Hilbert transform.
    Returns instantaneous amplitude and phase of the signal.
    """
    sig = filter_signal(sig, fs, pass_type=pass_type, f_range=f_range, remove_edges=remove_edges)
    sig = hilbert(sig)
    return np.abs(sig), np.angle(sig)


def _compute_thresholds(rem, fs, smooth_filt):
    """
    Computes the thresholds for detecting phasic REM states.
    """
    trdiff_list = []
    rem_eeg = []
    eeg_seq = {}
    sdiff_seq = {}
    tridx_seq = {}

    for idx in rem:
        epoch = rem[idx]
        amp, phase = _compute_hilbert(epoch, fs, pass_type="bandpass", f_range=(5, 12), remove_edges=False)

        # amplitude of the entire REM sleep
        rem_eeg += amp

        # trough indices
        tridx = _detect_troughs(phase, -3)

        # differences between troughs
        trdiff = np.diff(tridx)

        # smoothed trough differences
        sdiff_seq[idx] = np.convolve(trdiff, smooth_filt, 'same')

        # dict of trough differences for each REM period
        tridx_seq[idx] = tridx

        eeg_seq[idx] = amp
    
        # differences between troughs
        trdiff_list += list(trdiff)

        
    rem_eeg = np.array(rem_eeg)
    trdiff = np.array(trdiff_list)
    trdiff_sm = np.convolve(trdiff, smooth_filt, 'same')

    # potential candidates for phasic REM:
    # the smoothed difference between troughs is less than
    # the 10th percentile:
    thr1 = np.percentile(trdiff_sm, 10)
    # the minimum smoothed difference in the candidate phREM is less than
    # the 5th percentile
    thr2 = np.percentile(trdiff_sm, 5)
    # the peak amplitude is larger than the mean of the amplitude
    # of the REM EEG.
    thr3 = rem_eeg.mean()

    return thr1, thr2, thr3, tridx_seq, eeg_seq

def _detect_troughs(signal, thr):
    lidx  = np.where(signal[0:-2] > signal[1:-1])[0]
    ridx  = np.where(signal[1:-1] <= signal[2:])[0]
    thidx = np.where(signal[1:-1] < thr)[0]
    sidx = np.intersect1d(lidx, np.intersect1d(ridx, thidx))+1
    return sidx