In [40]:
from src.DatasetLoader import DatasetLoader
from src.runtime_logger import logger_setup
from src.utils import get_segments, get_sequences, phasic_detect, get_tonic
from src.helper import get_metadata
from pathlib import Path

from tqdm.auto import tqdm
from scipy.io import loadmat
from mne.filter import resample

import yasa
import numpy as np
import pandas as pd

logger = logger_setup()

CONF = "/home/nero/phasic_tonic/configs/dataset_loading.yaml"
CBD_DIR = "/home/nero/datasets/CBD/"
RGS_DIR = "/home/nero/datasets/RGS14/"
OS_DIR = "/home/nero/datasets/OSbasic/"

fs_cbd = 2500
fs_os = 2500
fs_rgs = 1000

targetFs = 500
n_down_cbd = fs_cbd/targetFs
n_down_rgs = fs_rgs/targetFs
n_down_os = fs_os/targetFs

max_freq = 100
freq_columns = [f'freq_{i+1}' for i in range(max_freq)]
metadata_columns = ['rat_id', 'study_day', 'condition', 'treatment', 'trial_num', 'state']
power_columns = metadata_columns + freq_columns
power_df = pd.DataFrame(columns=power_columns)

inst_freq = {}

datasets = {
# 'dataset_name' : {'dir' : '/path/to/dataset', 'pattern_set': 'pattern_set_in_config'}
    "CBD": {"dir": CBD_DIR, "pattern_set": "CBD"},
    "RGS": {"dir": RGS_DIR, "pattern_set": "RGS"},
    "OS": {"dir": OS_DIR, "pattern_set": "OS"}
}

def partition_to_4(rem_dict):
    # rem_dict: dictionary with keys as tuples and values as numpy arrays
    keys = sorted(rem_dict.keys())
    partitions = [{} for _ in range(4)]  # Create a list of 4 empty dictionaries

    for rem_idx in keys:
        _, end = rem_idx
        if end < 2700:  # First region
            partitions[0][rem_idx] = rem_dict[rem_idx]
        elif end < 5400:  # Second region
            partitions[1][rem_idx] = rem_dict[rem_idx]
        elif end < 8100:  # Third region
            partitions[2][rem_idx] = rem_dict[rem_idx]
        else:  # Fourth region
            partitions[3][rem_idx] = rem_dict[rem_idx]

    return partitions

Datasets = DatasetLoader(datasets, CONF)
mapped_datasets = Datasets.load_datasets()

In [41]:
from neurodsp.spectral import compute_spectrum
from emd.spectra import frequency_transform

def analysis(data, rem, metadata, targetFs=targetFs):
    if len(rem) == 0:
        logger.debug("No REM given. Skipping")
        return None
    
    metadata_df = pd.DataFrame({key:pd.Series(value) for key, value in metadata.items()})
    
    phasic = []
    tonic = []
    
    # Detect phasic epochs
    phasicREM = phasic_detect(rem=rem, fs=targetFs, thr_dur=900, nfilt=11)
    logger.debug("Detected phasic: {0}.".format(phasicREM))
    
    for i, rem_idx in enumerate(phasicREM): 
        rem_start, rem_end = rem_idx
        phasic += phasicREM[rem_idx]
        # Tonic epochs are determined as everywhere that is not phasic in the REM epoch
        tonic += get_tonic(rem_start*targetFs, rem_end*targetFs, phasicREM[rem_idx])
        logger.debug("REM epoch: ({0}, {1}) ".format(rem_start*targetFs, rem_end*targetFs))
    
    if phasic:
        # Combine all phasic episodes per trial
        phasic = np.concatenate(get_segments(phasic, data))
        logger.debug(f"Phasic: {phasic.shape}")
        # Save power spectrum
        save_spectrum(phasic, targetFs, metadata_df, state="phasic")
        # Save instantaneous freq
        save_if(phasic, targetFs, metadata, state="phasic")
        
    tonic = np.concatenate(get_segments(tonic, data))
    save_spectrum(tonic, targetFs, metadata_df, state="tonic")
    save_if(tonic, targetFs, metadata, state="tonic")

def save_spectrum(signal, fs, metadata_df, state):
    global power_df
    
    f, spectrum = compute_spectrum(signal, fs, method='welch', avg_type='mean')
    if spectrum.ndim == 0:
        logger.warning(f"Array is expected not scalar. Received {spectrum}")
        return None
    elif spectrum.shape[0] < 100:
        logger.warning(f"Spectrum has size {spectrum.shape}. Padding.")
        padded_spectrum = np.zeros(100)
        padded_spectrum[:len(spectrum)] = spectrum
        spec_df = pd.DataFrame([padded_spectrum], columns=freq_columns)
    else:    
        spec_df = pd.DataFrame([spectrum[:100]], columns=freq_columns)
    metadata_df["state"] = state
    # DataFrame containing metadata information, spectrum information
    comb_df = pd.concat([metadata_df, spec_df], axis=1)
    power_df = pd.concat([comb_df, power_df])

def save_if(signal, fs, m, state):
    IP, IF, IA = frequency_transform(signal, fs, 'hilbert', smooth_phase=3)
    name = f"Rat{m['rat_id']}_SD{m['study_day']}_{m['condition']}_{m['treatment']}_posttrial{m['trial_num']}_{state}"
    inst_freq[name] = IF

In [44]:
with tqdm(mapped_datasets) as t:
    for name in t:
        metadata = get_metadata(name)
        t.set_postfix_str(name)
        states_fname, hpc_fname = mapped_datasets[name]
        logger.debug("Loading: {0}".format(name))
        
        if metadata["treatment"] == 0 or metadata["treatment"] == 1:
            n_down = n_down_cbd
        elif metadata["treatment"] == 2 or metadata["treatment"] == 3:
            n_down = n_down_rgs
        elif metadata["treatment"] == 4:
            n_down = n_down_os
        
        # Load the LFP data
        lfpHPC = loadmat(hpc_fname)['HPC']
        lfpHPC = lfpHPC.flatten()

        # Load the states
        hypno = loadmat(states_fname)['states']
        hypno = hypno.flatten()
        
        # Skip if no REM epoch is detected
        if(not (np.any(hypno == 5))):
            logger.debug("No REM detected. Skipping.")
            continue

        logger.debug("STARTED: Resampling to 500 Hz.")
        # Downsample to 500 Hz
        data_resample = resample(lfpHPC, down=n_down, method='fft', npad='auto')
        logger.debug("FINISHED: Resampling to 500 Hz.")
        logger.debug("Resampled: {0} -> {1}.".format(str(lfpHPC.shape), str(data_resample.shape)))
        del lfpHPC

        logger.debug("STARTED: Remove artifacts.")
        # Remove artifacts
        art_std, _ = yasa.art_detect(data_resample, targetFs , window=1, method='std', threshold=4, verbose='info')
        art_up = yasa.hypno_upsample_to_data(art_std, 1, data_resample, targetFs)
        data_resample[art_up] = 0
        logger.debug("FINISHED: Remove artifacts.")
        del art_std, art_up

        data_resample -= data_resample.mean()

        logger.debug("STARTED: Extract REM epochs.")
        rem_seq = get_sequences(np.where(hypno == 5)[0])

        # Filter REM epochs based on a minimum duration
        min_dur = 3
        rem_seq = [(rem_start, rem_end) for rem_start, rem_end in rem_seq if (rem_end - rem_start) >= min_dur]
        if len(rem_seq) == 0:
            logger.debug("Failed min duration criteria. Skipping.")
            logger.debug(f"{rem_seq}")
            continue
        
        # get REM segments
        rem_idx = [(start * targetFs, (end+1) * targetFs) for start, end in rem_seq]
        rem_epochs = get_segments(rem_idx, data_resample)
        logger.debug("FINISHED: Extract REM epochs.")
        logger.debug(f"{rem_seq}")

        # Combine the REM indices with the corresponding downsampled segments
        rem = {seq:seg for seq, seg in zip(rem_seq, rem_epochs)}
        del rem_epochs, rem_seq

        if metadata["trial_num"] == '5':
            for i, partition in enumerate(partition_to_4(rem)):
                metadata["trial_num"] = '5-' + str(i+1)
                logger.debug(f"Partition {metadata['trial_num']}")
                analysis(data_resample, partition, metadata)
        else:
             analysis(data_resample, rem, metadata)

  0%|          | 0/539 [00:00<?, ?it/s]

  power_df = pd.concat([comb_df, power_df])


Spectrum has size (77,). Padding.




In [47]:
power_df.to_csv('power_spectra.csv', index=False)
np.save('inst_freq', inst_freq)

In [74]:
p_df = pd.read_csv('power_spectra.csv')

freq_columns = [f'freq_{i+1}' for i in range(100)]

power_spectra = p_df[p_df["rat_id"] == 6][freq_columns].to_numpy()
p_df

Unnamed: 0,rat_id,study_day,condition,treatment,trial_num,state,freq_1,freq_2,freq_3,freq_4,...,freq_91,freq_92,freq_93,freq_94,freq_95,freq_96,freq_97,freq_98,freq_99,freq_100
0,11,4,OR,4,3,tonic,2736.026822,16421.768553,4135.474851,1666.474326,...,7.378674,6.911788,7.005151,6.399858,6.078365,6.181880,6.290018,6.066528,6.198129,5.915563
1,11,4,OR,4,3,phasic,1523.036778,9045.967523,5188.735581,2225.896928,...,10.285460,9.803339,12.455225,16.528076,14.520866,9.508602,8.905631,7.758662,9.971283,12.324690
2,11,4,OR,4,5-4,tonic,1274.414783,6161.926361,2508.813358,1840.071068,...,9.535300,10.151998,9.709683,7.429483,7.541275,7.264755,8.134707,7.181261,6.820787,6.950894
3,11,4,OR,4,5-4,phasic,79.950623,2229.675086,493.966703,582.934758,...,7.108655,9.693376,3.242763,8.852094,22.640716,24.825705,44.586460,23.414907,5.669849,10.135948
4,11,4,OR,4,5-3,tonic,1275.173231,7247.088122,2398.905923,1365.076419,...,10.122217,9.634638,9.225120,8.261307,8.228207,7.962016,8.041525,7.597038,7.652500,7.154726
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1230,5,8,HC,0,4,phasic,519.976419,2093.722335,1472.769526,756.644446,...,10.055235,14.824464,10.189933,8.270498,7.600562,10.495390,12.724941,8.645372,10.104771,14.694172
1231,5,8,HC,0,2,tonic,1016.690579,6036.690405,4230.969049,2803.580158,...,7.924306,7.385913,7.013693,6.323671,5.760065,6.017093,5.926854,5.890788,5.736669,13.233448
1232,5,8,HC,0,2,phasic,629.374771,3792.447321,3227.017071,2555.540524,...,13.549006,13.425325,12.322300,9.709928,7.776854,10.120088,13.035475,6.050749,7.007161,12.877313
1233,5,8,HC,0,3,tonic,1145.375884,6085.499219,5307.484668,3268.170008,...,6.959264,7.429401,7.553319,6.469456,6.463410,6.150589,5.568496,5.575029,5.283055,12.060076


In [7]:
ifreq = np.load('inst_freq.npy', allow_pickle=True).item()

for name in ifreq:
    metadata = {}
    metaname  = name.split('_')
    metadata["rat_id"]    = int(metaname[0][3:])
    metadata["study_day"] = int(metaname[1][2:])
    metadata["condition"] = metaname[2]
    metadata["treatment"] = int(metaname[3])
    metadata["trial_num"] = metaname[4][9:]
    metadata["state"] = metaname[5]

    print(name)
    print(metadata)

Rat5_SD8_HC_0_posttrial3_phasic
{'rat_id': 5, 'study_day': 8, 'condition': 'HC', 'treatment': 0, 'trial_num': '3', 'state': 'phasic'}
Rat5_SD8_HC_0_posttrial3_tonic
{'rat_id': 5, 'study_day': 8, 'condition': 'HC', 'treatment': 0, 'trial_num': '3', 'state': 'tonic'}
Rat5_SD8_HC_0_posttrial2_phasic
{'rat_id': 5, 'study_day': 8, 'condition': 'HC', 'treatment': 0, 'trial_num': '2', 'state': 'phasic'}
Rat5_SD8_HC_0_posttrial2_tonic
{'rat_id': 5, 'study_day': 8, 'condition': 'HC', 'treatment': 0, 'trial_num': '2', 'state': 'tonic'}
Rat5_SD8_HC_0_posttrial4_phasic
{'rat_id': 5, 'study_day': 8, 'condition': 'HC', 'treatment': 0, 'trial_num': '4', 'state': 'phasic'}
Rat5_SD8_HC_0_posttrial4_tonic
{'rat_id': 5, 'study_day': 8, 'condition': 'HC', 'treatment': 0, 'trial_num': '4', 'state': 'tonic'}
Rat5_SD8_HC_0_posttrial5-1_phasic
{'rat_id': 5, 'study_day': 8, 'condition': 'HC', 'treatment': 0, 'trial_num': '5-1', 'state': 'phasic'}
Rat5_SD8_HC_0_posttrial5-1_tonic
{'rat_id': 5, 'study_day': 8, '