# Load the datasets

The path to dataset directory and patterns to search in those directories for the HPC, PFC recordings are in loaded from the config file.

In [4]:
from phasic_tonic.detect_phasic import detect_phasic
from phasic_tonic.DatasetLoader import DatasetLoader
from phasic_tonic.helper import get_metadata
from phasic_tonic.runtime_logger import logger_setup
from phasic_tonic.utils import get_sequences, get_start_end, preprocess

import numpy as np
import pandas as pd
import pynapple as nap
import yasa

from tqdm.auto import tqdm
from scipy.io import loadmat
from mne.filter import resample

fs_cbd = 2500
fs_os = 2500
fs_rgs = 1000

targetFs = 500
n_down_cbd = fs_cbd/targetFs
n_down_rgs = fs_rgs/targetFs
n_down_os = fs_os/targetFs

logger = logger_setup()

CONFIG_DIR = "/home/nero/phasic_tonic/data/dataset_loading.yaml"

Datasets = DatasetLoader(CONFIG_DIR)
mapped_datasets = Datasets.load_datasets()

Check the number of loaded recordings

In [5]:
cbd_cnt = 0
rgs_cnt = 0
os_cnt = 0

# Count recordings belonging to CBD dataset
for name in mapped_datasets:
    metadata = get_metadata(name)
    if metadata['treatment'] == 0 or metadata['treatment'] == 1:
        cbd_cnt += 1
    elif metadata['treatment'] == 2 or metadata['treatment'] == 3:
        rgs_cnt += 1
    elif metadata['treatment'] == 4:
        os_cnt += 1

assert cbd_cnt == 170
assert rgs_cnt == 159
assert os_cnt == 210

# Waveform analysis

In [8]:
import emd
from neurodsp.filt import filter_signal

def compute_range(x):
    return x.max() - x.min()

def asc2desc(x):
    """Ascending to Descending ratio ( A / A+D )."""
    pt = emd.cycles.cf_peak_sample(x, interp=True)
    tt = emd.cycles.cf_trough_sample(x, interp=True)
    if (pt is None) or (tt is None):
        return np.nan
    asc = pt + (len(x) - tt)
    #desc = tt - pt
    return asc / len(x)

def peak2trough(x):
    """Peak to trough ratio ( P / P+T )."""
    des = emd.cycles.cf_descending_zero_sample(x, interp=True)
    if des is None:
        return np.nan
    return des / len(x)

def compute_cycles(signal, fs, metadata):
    signal = filter_signal(signal, fs, 'bandpass', (5,12), remove_edges=False)
    # Perform EMD and compute cycle metrics
    IP, IF, IA = emd.spectra.frequency_transform(signal, fs, 'hilbert', smooth_phase=3)
    C = emd.cycles.Cycles(IP.flatten())
    logger.debug("Detected cycles before extraction:")
    logger.debug(C)

    # Compute cycle metrics
    C.compute_cycle_metric('start_sample', np.arange(len(C.cycle_vect)), emd.cycles.cf_start_value)
    C.compute_cycle_metric('stop_sample', signal, emd.cycles.cf_end_value)
    C.compute_cycle_metric('peak_sample', signal, emd.cycles.cf_peak_sample)
    C.compute_cycle_metric('desc_sample', signal, emd.cycles.cf_descending_zero_sample)
    C.compute_cycle_metric('trough_sample', signal, emd.cycles.cf_trough_sample)
    C.compute_cycle_metric('duration_samples', signal, len)
    C.compute_cycle_metric('max_amp', IA, np.max)
    C.compute_cycle_metric('mean_if', IF, np.mean)
    C.compute_cycle_metric('max_if', IF, np.max)
    C.compute_cycle_metric('range_if', IF, compute_range)  # Make sure 'compute_range' is defined
    C.compute_cycle_metric('asc2desc', signal, asc2desc)  # Make sure 'asc2desc' is defined
    C.compute_cycle_metric('peak2trough', signal, peak2trough)  # Make sure 'peak2trough' is defined

    logger.debug('\nFinished computing the cycles metrics\n')

    # Extract a subset of the cycles
    amp_thresh = np.percentile(IA, 25)
    lo_freq_duration = fs / 5
    hi_freq_duration = fs / 12
    conditions = ['is_good==1',
                  f'duration_samples<{lo_freq_duration}',
                  f'duration_samples>{hi_freq_duration}',
                  f'max_amp>{amp_thresh}']

    logger.debug("Cycles after extraction:")
    df_emd = C.get_metric_dataframe(conditions=conditions)
    logger.debug(f'{len(df_emd)}')

    #Add the metadata
    df_emd["rat"]       = metadata["rat_id"]
    df_emd["study_day"] = metadata["study_day"]
    df_emd["condition"] = metadata["condition"]
    df_emd["treatment"] = metadata["treatment"]
    df_emd["trial_num"] = metadata["trial_num"]
    df_emd["state"]     = metadata["state"]
    
    start, end = metadata["interval"]
    df_emd["start"] = start
    df_emd["end"] = end

    return df_emd

def analysis(lfp, phasic_interval, tonic_interval, metadata):
    for state, intervals in [("phasic", phasic_interval), ("tonic", tonic_interval)]:
        metadata['state'] = state
        for i, interval in enumerate(intervals):
            metadata['interval'] = (interval['start'].item(), interval['end'].item())
            df_emd = compute_cycles(lfp.restrict(interval).to_numpy(), lfp.rate, metadata)
            combined.append(df_emd)

# Loop through the dataset

In [9]:
combined = []

with tqdm(mapped_datasets) as mapped_tqdm:
    for name in mapped_tqdm:
        metadata = get_metadata(name)
        mapped_tqdm.set_postfix_str(name)
        states_fname, hpc_fname, pfc_fname = mapped_datasets[name]
        logger.debug("Loading: {0}".format(name))

        if metadata["treatment"] == 0 or metadata["treatment"] == 1:
            n_down = n_down_cbd
        elif metadata["treatment"] == 2 or metadata["treatment"] == 3:
            n_down = n_down_rgs
        elif metadata["treatment"] == 4:
            n_down = n_down_os
        
        # Load the LFP data
        lfpHPC = loadmat(hpc_fname)['HPC'].flatten()

        # Load the states
        hypno = loadmat(states_fname)['states'].flatten()
        
        # Skip if no REM epoch is detected
        if(not (np.any(hypno == 5))):
            logger.debug("No REM detected. Skipping.")
            continue
        elif(np.sum(np.diff(get_sequences(np.where(hypno == 5)[0]))) < 10):
            logger.debug("No REM longer than 10s. Skipping.")
            continue
        
        # Detect phasic intervals
        lfpHPC_down = preprocess(lfpHPC, n_down)
        phrem = detect_phasic(lfpHPC_down, hypno, targetFs)
        
        t = np.arange(0, len(lfpHPC_down)/targetFs, 1/targetFs)
        lfp = nap.Tsd(t=t, d=lfpHPC_down)
        
        start, end = [], []
        rem_start, rem_end = [], []
        for rem_idx in phrem:
            rem_start.append(rem_idx[0])
            rem_end.append(rem_idx[1])

            for s, e in phrem[rem_idx]:
                start.append(s / targetFs)
                end.append(e / targetFs)
        
        rem_interval = nap.IntervalSet(rem_start, rem_end)
        phasic_interval = nap.IntervalSet(start, end).drop_short_intervals(0.6)
        tonic_interval = rem_interval.set_diff(phasic_interval).drop_short_intervals(0.6)
        
        #Compute waveform dynamics for each intervals
        analysis(lfp, phasic_interval, tonic_interval, metadata)
        break

  0%|          | 0/539 [00:00<?, ?it/s]

In [10]:
df = pd.concat(combined)
# df.to_csv("waveform.csv")
df

Unnamed: 0,index,is_good,start_sample,stop_sample,peak_sample,desc_sample,trough_sample,duration_samples,max_amp,mean_if,...,asc2desc,peak2trough,rat,study_day,condition,treatment,trial_num,state,start,end
0,1,1,16,-33.137473,14.066244,27.172172,41.198818,55,347.900971,8.966730,...,0.506680,0.494039,5,8,HC,0,3,phasic,1379.984,1381.77
1,2,1,71,-28.156040,13.254916,27.254254,40.920858,56,345.523611,8.913864,...,0.505965,0.486683,5,8,HC,0,3,phasic,1379.984,1381.77
2,3,1,127,-17.992928,15.455898,31.957958,48.208199,64,278.147014,7.854380,...,0.488245,0.499343,5,8,HC,0,3,phasic,1379.984,1381.77
3,4,1,191,-13.043188,13.465908,27.83984,39.987533,53,265.900736,9.442066,...,0.499592,0.525280,5,8,HC,0,3,phasic,1379.984,1381.77
4,5,1,244,-0.207466,13.009458,25.725726,40.592979,55,300.754510,9.185402,...,0.498481,0.467740,5,8,HC,0,3,phasic,1379.984,1381.77
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
275,318,1,20944,-3.167717,16.628014,33.861862,50.386043,75,201.657838,6.681162,...,0.549893,0.451491,5,8,HC,0,3,tonic,2432.308,2475.00
276,319,1,21019,-4.319788,20.325787,36.572573,51.165314,67,156.163254,7.474672,...,0.539709,0.545859,5,8,HC,0,3,tonic,2432.308,2475.00
277,320,1,21086,-10.509578,11.853206,25.162162,40.007809,55,128.469464,9.010428,...,0.488098,0.457494,5,8,HC,0,3,tonic,2432.308,2475.00
278,321,1,21141,-9.672022,18.187644,35.531532,53.924764,72,176.936658,6.969526,...,0.503651,0.493493,5,8,HC,0,3,tonic,2432.308,2475.00


In [None]:
combined = []

with tqdm(mapped_datasets) as mapped_tqdm:
    for name in mapped_tqdm:
        metadata = get_metadata(name)
        mapped_tqdm.set_postfix_str(name)
        states_fname, hpc_fname, pfc_fname = mapped_datasets[name]
        logger.debug("Loading: {0}".format(name))

        if metadata["treatment"] == 0 or metadata["treatment"] == 1:
            n_down = n_down_cbd
        elif metadata["treatment"] == 2 or metadata["treatment"] == 3:
            n_down = n_down_rgs
        elif metadata["treatment"] == 4:
            n_down = n_down_os
        
        # Load the LFP data
        lfpHPC = loadmat(hpc_fname)['HPC'].flatten()

        # Load the states
        hypno = loadmat(states_fname)['states'].flatten()
        
        # Skip if no REM epoch is detected
        if(not (np.any(hypno == 5))):
            logger.debug("No REM detected. Skipping.")
            continue
        elif(np.sum(np.diff(get_sequences(np.where(hypno == 5)[0]))) < 10):
            logger.debug("No REM longer than 10s. Skipping.")
            continue
        
        # Detect phasic intervals
        lfpHPC_down = preprocess(lfpHPC, n_down)
        phrem = detect_phasic(lfpHPC_down, hypno, targetFs)
        
        t = np.arange(0, len(lfpHPC_down)/targetFs, 1/targetFs)
        lfp = nap.Tsd(t=t, d=lfpHPC_down)
        
        start, end = [], []
        rem_start, rem_end = [], []
        for rem_idx in phrem:
            rem_start.append(rem_idx[0])
            rem_end.append(rem_idx[1])

            for s, e in phrem[rem_idx]:
                start.append(s / targetFs)
                end.append(e / targetFs)
        
        rem_interval = nap.IntervalSet(rem_start, rem_end)
        phasic_interval = nap.IntervalSet(start, end).drop_short_intervals(0.6)
        tonic_interval = rem_interval.set_diff(phasic_interval).drop_short_intervals(0.6)
        
        #Compute waveform dynamics for each intervals
        analysis(lfp, phasic_interval, tonic_interval, metadata)

In [11]:
def partition_to_4(rem_dict):
    # rem_dict: dictionary with keys as tuples and values as numpy arrays
    keys = sorted(rem_dict.keys())
    partitions = [{} for _ in range(4)]  # Create a list of 4 empty dictionaries

    for rem_idx in keys:
        _, end = rem_idx
        if end < 2700:  # First region
            partitions[0][rem_idx] = rem_dict[rem_idx]
        elif end < 5400:  # Second region
            partitions[1][rem_idx] = rem_dict[rem_idx]
        elif end < 8100:  # Third region
            partitions[2][rem_idx] = rem_dict[rem_idx]
        else:  # Fourth region
            partitions[3][rem_idx] = rem_dict[rem_idx]

    return partitions

In [None]:
combined = []

with tqdm(mapped_datasets) as mapped_tqdm:
    for name in mapped_tqdm:
        metadata = get_metadata(name)
        mapped_tqdm.set_postfix_str(name)
        states_fname, hpc_fname, pfc_fname = mapped_datasets[name]
        logger.debug("Loading: {0}".format(name))

        if metadata["treatment"] == 0 or metadata["treatment"] == 1:
            n_down = n_down_cbd
        elif metadata["treatment"] == 2 or metadata["treatment"] == 3:
            n_down = n_down_rgs
        elif metadata["treatment"] == 4:
            n_down = n_down_os
        
        # Load the LFP data
        lfpHPC = loadmat(hpc_fname)['HPC'].flatten()

        # Load the states
        hypno = loadmat(states_fname)['states'].flatten()
        
        # Skip if no REM epoch is detected
        if(not (np.any(hypno == 5))):
            logger.debug("No REM detected. Skipping.")
            continue
        elif(np.sum(np.diff(get_sequences(np.where(hypno == 5)[0]))) < 10):
            logger.debug("No REM longer than 10s. Skipping.")
            continue
        
        # Detect phasic intervals
        lfpHPC_down = preprocess(lfpHPC, n_down)
        phrem = detect_phasic(lfpHPC_down, hypno, targetFs)
        
        t = np.arange(0, len(lfpHPC_down)/targetFs, 1/targetFs)
        lfp = nap.Tsd(t=t, d=lfpHPC_down)
        
        start, end = [], []
        rem_start, rem_end = [], []
        for rem_idx in phrem:
            rem_start.append(rem_idx[0])
            rem_end.append(rem_idx[1])

            for s, e in phrem[rem_idx]:
                start.append(s / targetFs)
                end.append(e / targetFs)
        
        rem_interval = nap.IntervalSet(rem_start, rem_end)
        phasic_interval = nap.IntervalSet(start, end).drop_short_intervals(0.6)
        tonic_interval = rem_interval.set_diff(phasic_interval).drop_short_intervals(0.6)
        
        rem_seq = get_sequences(np.where(hypno == 5)[0])

        # Filter REM epochs based on a minimum duration
        min_dur = 3
        rem_seq = [(rem_start, rem_end) for rem_start, rem_end in rem_seq if (rem_end - rem_start) >= min_dur]
        
        # get REM segments
        rem_idx = [(start * targetFs, (end+1) * targetFs) for start, end in rem_seq]
        rem_epochs = get_segments(rem_idx, data_resample)
        logger.debug("FINISHED: Extract REM epochs.")
        logger.debug(f"{rem_seq}")

        # Combine the REM indices with the corresponding downsampled segments
        rem = {seq:seg for seq, seg in zip(rem_seq, rem_epochs)}
        del rem_epochs, rem_seq
        
        #Compute waveform dynamics for each intervals
        if metadata["trial_num"] == '5':
            for i, partition in enumerate(partition_to_4(rem)):
                metadata["trial_num"] = '5.' + str(i+1)

                logger.debug("Partition: {0}".format(str(partition)))
                # Detect phasic & save phasic/tonic percentage, rem_epoch durations
                analysis(partition, hypno[i*2700:(i+1)*2700], metadata)
        else:
            # Detect phasic & save phasic/tonic percentage, rem_epoch durations
            analysis(rem, hypno, metadata)