# Load the datasets

The path to dataset directory and patterns to search in those directories for the HPC, PFC recordings are in loaded from the config file.

In [1]:
from phasic_tonic.detect_phasic import detect_phasic
from phasic_tonic.DatasetLoader import DatasetLoader
from phasic_tonic.helper import get_metadata
from phasic_tonic.runtime_logger import logger_setup
from phasic_tonic.utils import get_sequences

import numpy as np
import pandas as pd
import pynapple as nap
import yasa

from tqdm.auto import tqdm
from scipy.io import loadmat
from mne.filter import resample

fs_cbd = 2500
fs_os = 2500
fs_rgs = 1000

targetFs = 500
n_down_cbd = fs_cbd/targetFs
n_down_rgs = fs_rgs/targetFs
n_down_os = fs_os/targetFs

logger = logger_setup()

CONFIG_DIR = "/home/nero/phasic_tonic/data/dataset_loading.yaml"

Datasets = DatasetLoader(CONFIG_DIR)
mapped_datasets = Datasets.load_datasets()

def preprocess(signal: np.ndarray, n_down: int, target_fs=500) -> np.ndarray:
    """Downsample and remove artifacts."""
    
    logger.debug("STARTED: Resampling to 500 Hz.")
    # Downsample to 500 Hz
    data = resample(signal, down=n_down, method='fft', npad='auto')
    logger.debug("FINISHED: Resampling to 500 Hz.")
    logger.debug("Resampled: {0} -> {1}.".format(str(signal.shape), str(data.shape)))
    
    logger.debug("STARTED: Remove artifacts.")
    # Remove artifacts
    art_std, _ = yasa.art_detect(data, target_fs , window=1, method='std', threshold=4)
    art_up = yasa.hypno_upsample_to_data(art_std, 1, data, target_fs)
    data[art_up] = 0
    logger.debug("FINISHED: Remove artifacts.")
        
    data -= data.mean()
    return data

def get_start_end(hypno: np.ndarray, sleep_state_id: int):
    """Convert sleep states into lists of start and end time indices."""
    seq = get_sequences(np.where(hypno == sleep_state_id)[0])
    start = []
    end = []
    for s, e in seq:
        start.append(s)
        end.append(e)
    return (start, end)

Check the number of loaded recordings

In [2]:
cbd_cnt = 0
rgs_cnt = 0
os_cnt = 0

# Count recordings belonging to CBD dataset
for name in mapped_datasets:
    metadata = get_metadata(name)
    if metadata['treatment'] == 0 or metadata['treatment'] == 1:
        cbd_cnt += 1
    elif metadata['treatment'] == 2 or metadata['treatment'] == 3:
        rgs_cnt += 1
    elif metadata['treatment'] == 4:
        os_cnt += 1

assert cbd_cnt == 170
assert rgs_cnt == 159
assert os_cnt == 210

# Waveform analysis

In [3]:
import emd
from neurodsp.filt import filter_signal

def compute_range(x):
    return x.max() - x.min()

def asc2desc(x):
    """Ascending to Descending ratio ( A / A+D )."""
    pt = emd.cycles.cf_peak_sample(x, interp=True)
    tt = emd.cycles.cf_trough_sample(x, interp=True)
    if (pt is None) or (tt is None):
        return np.nan
    asc = pt + (len(x) - tt)
    #desc = tt - pt
    return asc / len(x)

def peak2trough(x):
    """Peak to trough ratio ( P / P+T )."""
    des = emd.cycles.cf_descending_zero_sample(x, interp=True)
    if des is None:
        return np.nan
    return des / len(x)

def compute_cycles(signal, fs, metadata):
    signal = filter_signal(signal, fs, 'bandpass', (5,12), remove_edges=False)
    # Perform EMD and compute cycle metrics
    IP, IF, IA = emd.spectra.frequency_transform(signal, fs, 'hilbert', smooth_phase=3)
    C = emd.cycles.Cycles(IP.flatten())
    print("Detected cycles before extraction:")
    print(C)

    # Compute cycle metrics
    C.compute_cycle_metric('start_sample', np.arange(len(C.cycle_vect)), emd.cycles.cf_start_value)
    C.compute_cycle_metric('stop_sample', signal, emd.cycles.cf_end_value)
    C.compute_cycle_metric('peak_sample', signal, emd.cycles.cf_peak_sample)
    C.compute_cycle_metric('desc_sample', signal, emd.cycles.cf_descending_zero_sample)
    C.compute_cycle_metric('trough_sample', signal, emd.cycles.cf_trough_sample)
    C.compute_cycle_metric('duration_samples', signal, len)
    C.compute_cycle_metric('max_amp', IA, np.max)
    C.compute_cycle_metric('mean_if', IF, np.mean)
    C.compute_cycle_metric('max_if', IF, np.max)
    C.compute_cycle_metric('range_if', IF, compute_range)  # Make sure 'compute_range' is defined
    C.compute_cycle_metric('asc2desc', signal, asc2desc)  # Make sure 'asc2desc' is defined
    C.compute_cycle_metric('peak2trough', signal, peak2trough)  # Make sure 'peak2trough' is defined

    print('\nFinished computing the cycles metrics\n')

    # Extract a subset of the cycles
    amp_thresh = np.percentile(IA, 25)
    lo_freq_duration = fs / 5
    hi_freq_duration = fs / 12
    conditions = ['is_good==1',
                  f'duration_samples<{lo_freq_duration}',
                  f'duration_samples>{hi_freq_duration}',
                  f'max_amp>{amp_thresh}']

    print("Cycles after extraction:")
    df_emd = C.get_metric_dataframe(conditions=conditions)
    print(f'{len(df_emd)}')

    #Add the metadata
    df_emd["rat"]       = metadata["rat_id"]
    df_emd["study_day"] = metadata["study_day"]
    df_emd["condition"] = metadata["condition"]
    df_emd["treatment"] = metadata["treatment"]
    df_emd["trial_num"] = metadata["trial_num"]
    df_emd["state"]     = metadata["state"]
    
    start, end = metadata["interval"]
    df_emd["start"] = start
    df_emd["end"] = end

    return df_emd

# Loop through the dataset

In [4]:
combined = []

with tqdm(mapped_datasets) as mapped_tqdm:
    for name in mapped_tqdm:
        metadata = get_metadata(name)
        mapped_tqdm.set_postfix_str(name)
        states_fname, hpc_fname, pfc_fname = mapped_datasets[name]
        logger.debug("Loading: {0}".format(name))

        if metadata["treatment"] == 0 or metadata["treatment"] == 1:
            n_down = n_down_cbd
        elif metadata["treatment"] == 2 or metadata["treatment"] == 3:
            n_down = n_down_rgs
        elif metadata["treatment"] == 4:
            n_down = n_down_os
        
        # Load the LFP data
        lfpHPC = loadmat(hpc_fname)['HPC'].flatten()
        lfpPFC = loadmat(pfc_fname)['PFC'].flatten()

        # Load the states
        hypno = loadmat(states_fname)['states'].flatten()
        
        # Skip if no REM epoch is detected
        if(not (np.any(hypno == 5))):
            logger.debug("No REM detected. Skipping.")
            continue

        # Create Pynapple IntervalSet        
        start, end = get_start_end(hypno=hypno, sleep_state_id=5)
        rem_interval = nap.IntervalSet(start=start, end=end)
        rem_interval = rem_interval.drop_short_intervals(threshold=10, time_units='s')
        
        # Create TsdFrame for HPC and PFC signals
        fs = n_down*targetFs
        t = np.arange(0, len(lfpHPC)/fs, 1/fs)
        lfp = nap.TsdFrame(t=t, d=np.vstack([lfpHPC, lfpPFC]).T, columns=['HPC', 'PFC'])
        
        # Detect phasic intervals
        lfpHPC_down = preprocess(lfpHPC, n_down)
        phREM = detect_phasic(lfpHPC_down, hypno, targetFs)

        # Create phasic REM IntervalSet
        start, end = [], []
        for rem_idx in phREM:
            for s, e in phREM[rem_idx]:
                start.append(s/targetFs)
                end.append(e/targetFs)
        phasic_interval = nap.IntervalSet(start, end)
        tonic_interval = rem_interval.set_diff(phasic_interval)

        phasic_interval = phasic_interval.drop_short_intervals(0.6)
        tonic_interval = tonic_interval.drop_short_intervals(0.6)
        
        #Compute waveform dynamics for each intervals
        metadata['state'] = 'phasic'
        for i in range(len(phasic_interval)):
            metadata['interval'] = (phasic_interval[i]['start'].item(), phasic_interval[i]['end'].item())
            df_emd = compute_cycles(lfp['HPC'].restrict(phasic_interval[i]).to_numpy(), lfp['HPC'].rate, metadata)
            combined.append(df_emd)
            
        metadata['state'] = 'tonic'
        for i in range(len(tonic_interval)):
            metadata['interval'] = (tonic_interval[i]['start'].item(), tonic_interval[i]['end'].item())
            df_emd = compute_cycles(lfp['HPC'].restrict(tonic_interval[i]).to_numpy(), lfp['HPC'].rate, metadata)
            combined.append(df_emd)


  0%|          | 0/539 [00:00<?, ?it/s]

Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (17 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
14
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (14 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
12
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (27 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
23
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (30 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
24
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (366 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
309
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (239 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
210
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (341 cycl

  rem_interval = nap.IntervalSet(start=start, end=end)


Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (16 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
14
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (10 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
7
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (16 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
12
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (45 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
33
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (14 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
11
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (10 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
8
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (413 cycles 1 m

  rem_interval = nap.IntervalSet(start=start, end=end)
  phasic_interval = nap.IntervalSet(start, end)
  phasic_interval = nap.IntervalSet(start, end)


Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (11 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
8
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (11 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
9
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (12 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
9
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (11 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
9
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (9 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
6
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (41 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
32
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (11 cycles 1 metric

  rem_interval = nap.IntervalSet(start=start, end=end)


Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (35 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
30
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (589 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
526
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (53 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
45
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (16 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
11
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (12 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
9
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (19 cycles 1 metrics) 

Finished computing the cycles metrics

Cycles after extraction:
14
Detected cycles before extraction:
<class 'emd.cycles.Cycles'> (23 cycles 1

KeyboardInterrupt: 

In [5]:
df = pd.concat(combined)
# df.to_csv("example_cycles")
df

Unnamed: 0,index,is_good,start_sample,stop_sample,peak_sample,desc_sample,trough_sample,duration_samples,max_amp,mean_if,...,asc2desc,peak2trough,rat,study_day,condition,treatment,trial_num,state,start,end
0,1,1,73,-1.344392,74.491578,139.778779,210.05873,279,346.934723,8.965488,...,0.514096,0.500999,5,8,HC,0,3,phasic,1379.984,1381.77
1,2,1,352,-4.894912,66.294021,136.284284,204.577186,280,344.788590,8.907574,...,0.506132,0.486730,5,8,HC,0,3,phasic,1379.984,1381.77
2,3,1,632,-1.225875,77.254592,159.772773,241.03858,319,277.813918,7.853249,...,0.486571,0.500855,5,8,HC,0,3,phasic,1379.984,1381.77
3,4,1,951,-3.933588,68.30965,140.175175,200.916937,264,265.904236,9.446615,...,0.497700,0.530967,5,8,HC,0,3,phasic,1379.984,1381.77
4,5,1,1215,-0.085567,67.03187,130.621622,204.962082,273,300.707967,9.187249,...,0.494761,0.478467,5,8,HC,0,3,phasic,1379.984,1381.77
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
168,195,1,68131,-0.287040,88.620047,199.937938,291.02446,366,86.433262,6.831588,...,0.446982,0.546279,3,7,HC,1,3,tonic,2133.000,2161.00
169,196,1,68497,-0.879038,86.956669,180.034034,279.384504,372,102.770422,6.713709,...,0.482721,0.483962,3,7,HC,1,3,tonic,2133.000,2161.00
170,197,1,68869,-0.868295,77.873202,161.431431,216.971526,272,99.756428,9.186683,...,0.488609,0.593498,3,7,HC,1,3,tonic,2133.000,2161.00
171,198,1,69141,-1.748482,68.027982,127.43043,199.516114,272,103.235045,9.183427,...,0.516588,0.468494,3,7,HC,1,3,tonic,2133.000,2161.00


In [127]:
from phasic_tonic.utils import ensure_duration, get_segments
from scipy.signal import hilbert
from neurodsp.filt import filter_signal

eeg=lfpHPC_down
fs=targetFs
min_dur = 3
rem_idx = []

rem_seq = get_sequences(np.where(hypno == 5)[0])
for start, end in rem_seq:
    if ((end-start) > min_dur):
        rem_idx.append((start * fs, (end+1) * fs))

if len(rem_idx) == 0:
    raise ValueError("No REM epochs greater than min_dur.")

# get REM segments
rem_epochs = get_segments(rem_idx, eeg)

# Combine the REM indices with the corresponding downsampled segments
rem = {seq:seg for seq, seg in zip(rem_seq, rem_epochs)}

w1 = 5.0
w2 = 12.0
nfilt = 11
thr_dur = 900

trdiff_list = []
rem_eeg = np.array([])
eeg_seq = {}
sdiff_seq = {}
tridx_seq = {}
filt = np.ones((nfilt,))
filt = filt / filt.sum()

for idx in rem:
    start, end = idx

    epoch = rem[idx]
    epoch = filter_signal(epoch, fs, 'bandpass', (w1,w2), remove_edges=False)
    epoch = hilbert(epoch)

    inst_phase = np.angle(epoch)
    inst_amp = np.abs(epoch)

    # trough indices
    tridx = _detect_troughs(inst_phase, -3)

    # differences between troughs
    trdiff = np.diff(tridx)

    # smoothed trough differences
    sdiff_seq[idx] = np.convolve(trdiff, filt, 'same')

    # dict of trough differences for each REM period
    tridx_seq[idx] = tridx

    eeg_seq[idx] = inst_amp

    # differences between troughs
    trdiff_list += list(trdiff)

    # amplitude of the entire REM sleep
    rem_eeg = np.concatenate((rem_eeg, inst_amp)) 

trdiff = np.array(trdiff_list)
trdiff_sm = np.convolve(trdiff, filt, 'same')

# potential candidates for phasic REM:
# the smoothed difference between troughs is less than
# the 10th percentile:
thr1 = np.percentile(trdiff_sm, 10)
# the minimum smoothed difference in the candidate phREM is less than
# the 5th percentile
thr2 = np.percentile(trdiff_sm, 5)
# the peak amplitude is larger than the mean of the amplitude
# of the REM EEG.
thr3 = rem_eeg.mean()

phasicREM = {rem_idx:[] for rem_idx in rem.keys()}

for rem_idx in tridx_seq:
    rem_start, rem_end = rem_idx
    offset = rem_start * fs

    # trough indices
    tridx = tridx_seq[rem_idx]

    # smoothed trough interval
    sdiff = sdiff_seq[rem_idx]

    # amplitude of the REM epoch
    eegh = eeg_seq[rem_idx]

    # get the candidates for phREM
    cand_idx = np.where(sdiff <= thr1)[0]
    cand = get_sequences(cand_idx)

    for start, end in cand:
        # Duration of the candidate in milliseconds
        dur = ( (tridx[end]-tridx[start]+1)/fs ) * 1000
        if dur < thr_dur:
            continue # Failed Threshold 1
        
        min_sdiff = np.min(sdiff[start:end])
        if min_sdiff > thr2:
            continue # Failed Threshold 2
        
        mean_amp =  np.mean(eegh[tridx[start]:tridx[end]+1])
        if mean_amp < thr3:
            continue # Failed Threshold 3
        
        t_a = tridx[start] + offset
        t_b = np.min((tridx[end] + offset, rem_end * fs))
        
        ph_idx = (t_a, t_b+1)
        phasicREM[rem_idx].append(ph_idx)

def _detect_troughs(signal, thr):
    lidx  = np.where(signal[0:-2] > signal[1:-1])[0]
    ridx  = np.where(signal[1:-1] <= signal[2:])[0]
    thidx = np.where(signal[1:-1] < thr)[0]
    sidx = np.intersect1d(lidx, np.intersect1d(ridx, thidx))+1
    return sidx
