In [1]:
from phasic_tonic.detect_phasic import detect_phasic
from phasic_tonic.DatasetLoader import DatasetLoader
from phasic_tonic.helper import get_metadata
from phasic_tonic.runtime_logger import logger_setup
from phasic_tonic.utils import get_sequences, get_segments

import numpy as np
import pandas as pd
import pynapple as nap
import yasa

from tqdm.auto import tqdm
from scipy.io import loadmat
from mne.filter import resample

targetFs = 500

logger = logger_setup()

CONFIG_DIR = "/home/nero/phasic_tonic/data/dataset_loading.yaml"
DATASET_DIR = "/home/nero/datasets/preprocessed"

def get_start_end(hypno: np.ndarray, sleep_state_id: int):
    """Convert sleep states into lists of start and end time indices."""
    seq = get_sequences(np.where(hypno == sleep_state_id)[0])
    start = []
    end = []
    for s, e in seq:
        start.append(s)
        end.append(e)
    return (start, end)

def str_to_tuple(string):
    string = string.strip("()")
    parts = string.split(",")
    return tuple(map(int, parts))

def load_data(fname):
    loaded_data = np.load(fname)
    loaded_dict = {str_to_tuple(key): loaded_data[key] for key in loaded_data.files}
    return loaded_dict

In [2]:
from pathlib import Path
from phasic_tonic.detect_phasic_v2 import detect_phasic_v2

datasets = Path(DATASET_DIR).glob('*.npz')
max = 1
trial = ''

for fname in datasets:
    metadata = get_metadata(str(fname.stem))
    rem_epochs = load_data(fname)
    phREM = detect_phasic_v2(rem_epochs, fs=500)
    print(phREM)
    break


{(9121, 9168): [(4578998, 4579494)], (9430, 9468): [], (9738, 9820): [(4894364, 4895363), (4904309, 4905333), (4905697, 4906490)]}


In [31]:
import numpy as np
from scipy.signal import hilbert
from neurodsp.filt import filter_signal

def get_rem_epochs(eeg, hypno, fs, min_dur=3):
    rem_seq = get_sequences(np.where(hypno == 5)[0])
    rem_idx = [(start * fs, (end + 1) * fs) for start, end in rem_seq if (end - start) > min_dur]
    
    if not rem_idx:
        raise ValueError("No REM epochs greater than min_dur.")
    
    rem_epochs = get_segments(rem_idx, eeg)
    return {seq: seg for seq, seg in zip(rem_seq, rem_epochs)}

def preprocess_rem_epoch(epoch, fs, w1=5.0, w2=12.0):
    epoch = filter_signal(epoch, fs, 'bandpass', (w1, w2), remove_edges=False)
    analytic_signal = hilbert(epoch)
    return np.angle(analytic_signal), np.abs(analytic_signal)

def detect_troughs(signal, threshold=-3):
    lidx = np.where(signal[0:-2] > signal[1:-1])[0]
    ridx = np.where(signal[1:-1] <= signal[2:])[0]
    thidx = np.where(signal[1:-1] < threshold)[0]
    return np.intersect1d(lidx, np.intersect1d(ridx, thidx)) + 1

def smooth_signal(signal, window_size=11):
    filt = np.ones(window_size) / window_size
    return np.convolve(signal, filt, 'same')

def get_phasic_candidates(sdiff, tridx, thr1, thr_dur, fs):
    cand_idx = np.where(sdiff <= thr1)[0]
    cand = get_sequences(cand_idx)
    return [(start, end) for start, end in cand if ((tridx[end] - tridx[start] + 1) / fs) * 1000 >= thr_dur]

def is_valid_phasic(start, end, sdiff, eegh, tridx, thr2, thr3):
    min_sdiff = np.min(sdiff[start:end])
    mean_amp = np.mean(eegh[tridx[start]:tridx[end]+1])
    return min_sdiff <= thr2 and mean_amp >= thr3

def detect_phasic(eeg, hypno, fs, thr_dur=900):
    rem_epochs = get_rem_epochs(eeg, hypno, fs)
    
    trough_difference_list = []
    rem_eeg = np.array([])
    eeg_seq, smooth_difference_seq, trough_idx_seq = {}, {}, {}
    
    for idx, epoch in rem_epochs.items():
        inst_phase, inst_amp = preprocess_rem_epoch(epoch, fs)
        trough_idx = detect_troughs(inst_phase)
        trough_difference = np.diff(trough_idx)
        
        smooth_difference_seq[idx] = smooth_signal(trough_difference)
        trough_idx_seq[idx] = trough_idx
        eeg_seq[idx] = inst_amp
        
        trough_difference_list.extend(trough_difference)
        rem_eeg = np.concatenate((rem_eeg, inst_amp))
    
    trough_difference_smooth = smooth_signal(np.array(trough_difference_list))
    thr1, thr2, thr3 = np.percentile(trough_difference_smooth, 10), np.percentile(trough_difference_smooth, 5), rem_eeg.mean()
    
    phasicREM = {rem_idx: [] for rem_idx in rem_epochs.keys()}
    
    for rem_idx, trough_idx in trough_idx_seq.items():
        rem_start, rem_end = rem_idx
        offset = rem_start * fs
        smooth_difference, eegh = smooth_difference_seq[rem_idx], eeg_seq[rem_idx]
        
        candidates = get_phasic_candidates(smooth_difference, trough_idx, thr1, thr_dur, fs)
        
        for start, end in candidates:
            if is_valid_phasic(start, end, smooth_difference, eegh, trough_idx, thr2, thr3):
                t_a = trough_idx[start] + offset
                t_b = min(trough_idx[end] + offset, rem_end * fs)
                phasicREM[rem_idx].append((t_a, t_b + 1))
    
    return phasicREM

In [33]:
def detect_phasic_v2(rem_epochs, fs, thr_dur=900):
    trough_difference_list = []
    rem_eeg = np.array([])
    eeg_seq, smooth_difference_seq, trough_idx_seq = {}, {}, {}
    
    for idx, epoch in rem_epochs.items():
        inst_phase, inst_amp = preprocess_rem_epoch(epoch, fs)
        trough_idx = detect_troughs(inst_phase)
        trough_difference = np.diff(trough_idx)
        
        smooth_difference_seq[idx] = smooth_signal(trough_difference)
        trough_idx_seq[idx] = trough_idx
        eeg_seq[idx] = inst_amp
        
        trough_difference_list.extend(trough_difference)
        rem_eeg = np.concatenate((rem_eeg, inst_amp))
    
    trough_difference_smooth = smooth_signal(np.array(trough_difference_list))
    thr1, thr2, thr3 = np.percentile(trough_difference_smooth, 10), np.percentile(trough_difference_smooth, 5), rem_eeg.mean()
    
    phasicREM = {rem_idx: [] for rem_idx in rem_epochs.keys()}
    
    for rem_idx, trough_idx in trough_idx_seq.items():
        rem_start, rem_end = rem_idx
        offset = rem_start * fs
        smooth_difference, eegh = smooth_difference_seq[rem_idx], eeg_seq[rem_idx]
        
        candidates = get_phasic_candidates(smooth_difference, trough_idx, thr1, thr_dur, fs)
        
        for start, end in candidates:
            if is_valid_phasic(start, end, smooth_difference, eegh, trough_idx, thr2, thr3):
                t_a = trough_idx[start] + offset
                t_b = min(trough_idx[end] + offset, rem_end * fs)
                phasicREM[rem_idx].append((t_a, t_b + 1))
    
    return phasicREM

In [23]:
trial

'5-3'

Check the number of loaded recordings

In [33]:
cbd_cnt = 0
rgs_cnt = 0
os_cnt = 0

# Count recordings belonging to CBD dataset
for name in mapped_datasets:
    metadata = get_metadata(name)
    if metadata['treatment'] == 0 or metadata['treatment'] == 1:
        cbd_cnt += 1
    elif metadata['treatment'] == 2 or metadata['treatment'] == 3:
        rgs_cnt += 1
    elif metadata['treatment'] == 4:
        os_cnt += 1

assert cbd_cnt == 170
assert rgs_cnt == 159
assert os_cnt == 210

# Loop through the dataset

In [None]:
with tqdm(mapped_datasets) as mapped_tqdm:
    for name in mapped_tqdm:
        metadata = get_metadata(name)
        mapped_tqdm.set_postfix_str(name)
        states_fname, hpc_fname, pfc_fname = mapped_datasets[name]
        logger.debug("Loading: {0}".format(name))

        if metadata["treatment"] == 0 or metadata["treatment"] == 1:
            n_down = n_down_cbd
        elif metadata["treatment"] == 2 or metadata["treatment"] == 3:
            n_down = n_down_rgs
        elif metadata["treatment"] == 4:
            n_down = n_down_os
        
        # Load the LFP data
        lfpHPC = loadmat(hpc_fname)['HPC'].flatten()
        lfpPFC = loadmat(pfc_fname)['PFC'].flatten()

        # Load the states
        hypno = loadmat(states_fname)['states'].flatten()
        
        # Skip if no REM epoch is detected
        if(not (np.any(hypno == 5))):
            logger.debug("No REM detected. Skipping.")
            continue
        elif(np.sum(np.diff(get_sequences(np.where(hypno == 5)[0]))) < 10):
            logger.debug("No REM longer than 10s. Skipping.")
            continue
        
        # Create Pynapple IntervalSet        
        start, end = get_start_end(hypno=hypno, sleep_state_id=5)
        rem_interval = nap.IntervalSet(start=start, end=end)
        
        # Create TsdFrame for HPC and PFC signals
        fs = n_down*targetFs
        t = np.arange(0, len(lfpHPC)/fs, 1/fs)
        lfp = nap.TsdFrame(t=t, d=np.vstack([lfpHPC, lfpPFC]).T, columns=['HPC', 'PFC'])
        
        # Detect phasic intervals
        lfpHPC_down = preprocess(lfpHPC, n_down)
        phREM = detect_phasic(lfpHPC_down, hypno, targetFs)
        
        # Create phasic REM IntervalSet
        start, end = [], []
        for rem_idx in phREM:
            for s, e in phREM[rem_idx]:
                start.append(s/targetFs)
                end.append(e/targetFs)
        phasic_interval = nap.IntervalSet(start, end)

        tonic_interval = rem_interval.set_diff(phasic_interval)
        

In [66]:
hpc_fname = mapped_datasets['Rat5_SD8_HC_0_posttrial4'][1]
pfc_fname = mapped_datasets['Rat5_SD8_HC_0_posttrial4'][2]
hypno = mapped_datasets['Rat5_SD8_HC_0_posttrial4'][0]

# Load the LFP data
lfpHPC = loadmat(hpc_fname)['HPC'].flatten()
lfpPFC = loadmat(pfc_fname)['PFC'].flatten()

# Load the states
hypno = loadmat(hypno)['states'].flatten()

# Skip if no REM epoch is detected
if(not (np.any(hypno == 5))):
    print("No REM detected. Skipping.")
elif(np.sum(np.diff(get_sequences(np.where(hypno == 5)[0]))) < 10):
    logger.debug("No REM longer than 10s. Skipping.")
    print(np.sum(np.diff(get_sequences(np.where(hypno == 5)[0]))))

# Create Pynapple IntervalSet        
start, end = get_start_end(hypno=hypno, sleep_state_id=5)
rem_interval = nap.IntervalSet(start=start, end=end)

# Create TsdFrame for HPC and PFC signals
fs = int(n_down*targetFs)
t = np.arange(0, len(lfpHPC)/fs, 1/fs)
lfp = nap.TsdFrame(t=t, d=np.vstack([lfpHPC, lfpPFC]).T, columns=['HPC', 'PFC'])

# Detect phasic intervals
lfpHPC_down = preprocess(lfpHPC, n_down)
phREM = detect_phasic(lfpHPC_down, hypno, targetFs)

  rem_interval = nap.IntervalSet(start=start, end=end)


## Access the HPC and PFC signals during phasic REM

In [None]:
lfp.restrict(phasic_interval)

In [5]:
phrem_hpc = [lfp["HPC"].restrict(phasic_interval[i]).to_numpy() for i in range(len(phasic_interval))]
tonic_hpc = [lfp["HPC"].restrict(tonic_interval[i]).to_numpy() for i in range(len(tonic_interval))]

In [25]:
np.savez(name+'_phasic', *phrem_hpc)
np.savez(name+'_tonic', *tonic_hpc)

In [None]:
data = np.load('Rat5_SD8_HC_0_posttrial3_phasic.npz')
phrem = [data[key] for key in data]
phrem

## Access HPC and PFC signals during tonic REM

In [None]:
lfp.restrict(tonic_interval)