In [1]:
from phasic_tonic.DatasetLoader import DatasetLoader
from phasic_tonic.runtime_logger import logger_setup
from phasic_tonic.utils import get_segments, get_sequences, phasic_detect, get_tonic, create_hypnogram
from phasic_tonic.helper import get_metadata
from pathlib import Path

from tqdm.auto import tqdm
from scipy.io import loadmat
from mne.filter import resample

import yasa
import numpy as np
import pandas as pd

logger = logger_setup()

CONF = "/home/nero/phasic_tonic/configs/dataset_loading.yaml"
CBD_DIR = "/home/nero/datasets/CBD/"
RGS_DIR = "/home/nero/datasets/RGS14/"
OS_DIR = "/home/nero/datasets/OSbasic/"

fs_cbd = 2500
fs_os = 2500
fs_rgs = 1000

targetFs = 500
n_down_cbd = fs_cbd/targetFs
n_down_rgs = fs_rgs/targetFs
n_down_os = fs_os/targetFs

per_trial_percentage = {
    'rat_id': [],
    'study_day': [],
    'condition': [],
    'treatment': [],
    'trial_num': [],
    'phasic': [],
    'tonic': []
    }

nrem_rem_percentage = {
    'rat_id': [],
    'study_day': [],
    'condition': [],
    'treatment': [],
    'trial_num': [],
    'nrem': [],
    'rem': []
}
per_rem_epoch_percentage = {
    'rat_id': [],
    'study_day': [],
    'condition': [],
    'treatment': [],
    'trial_num': [],
    'state': [],
    'epoch_id': [],
    'duration': [],
    'count': [],
    'rem_epoch_dur': []
    }

phasic_tonic_idx = {}

datasets = {
# 'dataset_name' : {'dir' : '/path/to/dataset', 'pattern_set': 'pattern_set_in_config'}
    "CBD": {"dir": CBD_DIR, "pattern_set": "CBD"},
    "RGS": {"dir": RGS_DIR, "pattern_set": "RGS"},
    "OS": {"dir": OS_DIR, "pattern_set": "OS"}
}

def partition_to_4(rem_dict):
    # rem_dict: dictionary with keys as tuples and values as numpy arrays
    keys = sorted(rem_dict.keys())
    partitions = [{} for _ in range(4)]  # Create a list of 4 empty dictionaries

    for rem_idx in keys:
        _, end = rem_idx
        if end < 2700:  # First region
            partitions[0][rem_idx] = rem_dict[rem_idx]
        elif end < 5400:  # Second region
            partitions[1][rem_idx] = rem_dict[rem_idx]
        elif end < 8100:  # Third region
            partitions[2][rem_idx] = rem_dict[rem_idx]
        else:  # Fourth region
            partitions[3][rem_idx] = rem_dict[rem_idx]

    return partitions

def get_name(m):
    return f"Rat{m['rat_id']}_SD{m['study_day']}_{m['condition']}_{m['treatment']}_posttrial{m['trial_num']}"

Datasets = DatasetLoader(datasets, CONF)
mapped_datasets = Datasets.load_datasets()

ModuleNotFoundError: No module named 'neurodsp'

In [5]:
def analysis(rem, hypno, metadata, targetFs=targetFs):
    if len(rem) == 0:
        logger.debug("No REM given.")
        return None
    
    _save_nrem_rem(hypno, metadata)
    idx = {}
    name = get_name(metadata)

    # Detect phasic epochs
    phasicREM = phasic_detect(rem=rem, fs=targetFs, thr_dur=900, nfilt=11)
    logger.debug("Detected phasic: {0}.".format(phasicREM))
    
    phasic_total = []
    tonic_total = []
    
    for i, rem_idx in enumerate(phasicREM):
        rem_start, rem_end = rem_idx
        
        logger.debug("REM epoch: ({0}, {1}) ".format(rem_start*targetFs, rem_end*targetFs))

        epoch_dur = rem_end - rem_start

        phasic = phasicREM[rem_idx]
        # Save total phasic duration for each REM epoch
        _save_epoch(durations=phasic, metadata=metadata, state="phasic", epoch_id=i, rem_epoch_dur=epoch_dur)

        # Tonic epochs are determined as everywhere that is not phasic in the REM epoch
        tonic = get_tonic(rem_start*targetFs, rem_end*targetFs, phasic)
        _save_epoch(durations=tonic, metadata=metadata, state="tonic", epoch_id=i, rem_epoch_dur=epoch_dur)

        phasic_total += phasic
        tonic_total += tonic
    
    idx['phasic'] = phasic_total
    idx['tonic'] = tonic_total
    phasic_tonic_idx[name] = idx

    # Save total duration of phasic/tonic per trial
    for condition in metadata.keys():
        per_trial_percentage[condition].append(metadata[condition])
    per_trial_percentage['phasic'].append(np.sum(np.diff(phasic_total)/targetFs))
    per_trial_percentage['tonic'].append(np.sum(np.diff(tonic_total)/targetFs))

def _save_nrem_rem(hypno, metadata):
    for condition in metadata.keys():
        nrem_rem_percentage[condition].append(metadata[condition])
    
    nrem_dur = np.sum(np.diff(get_sequences(np.where(hypno==3)[0])))
    logger.debug("NREM Duration: {0}".format(nrem_dur))
    nrem_rem_percentage["nrem"].append(nrem_dur)
    
    rem_dur = np.sum(np.diff(get_sequences(np.where(hypno==5)[0])))
    logger.debug("REM Duration: {0}".format(rem_dur))
    nrem_rem_percentage["rem"].append(rem_dur)

def _save_epoch(durations, metadata, state, epoch_id, rem_epoch_dur):
    # Add metadata
    for condition in metadata.keys():
        per_rem_epoch_percentage[condition].append(metadata[condition])
    
    per_rem_epoch_percentage['state'].append(state)
    per_rem_epoch_percentage["epoch_id"].append(epoch_id)
    per_rem_epoch_percentage['duration'].append(np.sum(np.diff(durations))/targetFs)
    per_rem_epoch_percentage['count'].append(len(durations))
    per_rem_epoch_percentage['rem_epoch_dur'].append(rem_epoch_dur)

In [6]:
with tqdm(mapped_datasets) as t:
    for name in t:
        metadata = get_metadata(name)
        t.set_postfix_str(name)
        states_fname, hpc_fname = mapped_datasets[name]
        logger.debug("Loading: {0}".format(name))
        
        if metadata["treatment"] == 0 or metadata["treatment"] == 1:
            n_down = n_down_cbd
        elif metadata["treatment"] == 2 or metadata["treatment"] == 3:
            n_down = n_down_rgs
        elif metadata["treatment"] == 4:
            n_down = n_down_os
        
        # Load the LFP data
        lfpHPC = loadmat(hpc_fname)['HPC']
        lfpHPC = lfpHPC.flatten()

        # Load the states
        hypno = loadmat(states_fname)['states']
        hypno = hypno.flatten()
        
        # Skip if no REM epoch is detected
        if(not (np.any(hypno == 5))):
            logger.debug("No REM detected. Skipping.")
            continue

        logger.debug("STARTED: Resampling to 500 Hz.")
        # Downsample to 500 Hz
        data_resample = resample(lfpHPC, down=n_down, method='fft', npad='auto')
        logger.debug("FINISHED: Resampling to 500 Hz.")
        logger.debug("Resampled: {0} -> {1}.".format(str(lfpHPC.shape), str(data_resample.shape)))
        del lfpHPC

        logger.debug("STARTED: Remove artifacts.")
        # Remove artifacts
        art_std, _ = yasa.art_detect(data_resample, targetFs , window=1, method='std', threshold=4, verbose='info')
        art_up = yasa.hypno_upsample_to_data(art_std, 1, data_resample, targetFs)
        data_resample[art_up] = 0
        logger.debug("FINISHED: Remove artifacts.")
        del art_std, art_up

        data_resample -= data_resample.mean()

        logger.debug("STARTED: Extract REM epochs.")
        rem_seq = get_sequences(np.where(hypno == 5)[0])

        # Filter REM epochs based on a minimum duration
        min_dur = 3
        rem_seq = [(rem_start, rem_end) for rem_start, rem_end in rem_seq if (rem_end - rem_start) >= min_dur]
        if len(rem_seq) == 0:
            logger.debug("Failed min duration criteria. Skipping.")
            logger.debug(f"{rem_seq}")
            continue
        
        # get REM segments
        rem_idx = [(start * targetFs, (end+1) * targetFs) for start, end in rem_seq]
        rem_epochs = get_segments(rem_idx, data_resample)
        logger.debug("FINISHED: Extract REM epochs.")
        logger.debug(f"{rem_seq}")

        # Combine the REM indices with the corresponding downsampled segments
        rem = {seq:seg for seq, seg in zip(rem_seq, rem_epochs)}
        del rem_epochs, rem_seq

        if metadata["trial_num"] == '5':
            for i, partition in enumerate(partition_to_4(rem)):
                metadata["trial_num"] = '5.' + str(i+1)

                logger.debug("Partition: {0}".format(str(partition)))
                # Detect phasic & save phasic/tonic percentage, rem_epoch durations
                analysis(partition, hypno[i*2700:(i+1)*2700], metadata)
        else:
            # Detect phasic & save phasic/tonic percentage, rem_epoch durations
            analysis(rem, hypno, metadata)

  0%|          | 0/539 [00:00<?, ?it/s]

In [12]:
#np.save('phasic_tonic_idx', phasic_tonic_idx)
len(phasic_tonic_idx)

637

In [10]:
nrem_df = pd.DataFrame({key:pd.Series(value) for key, value in nrem_rem_percentage.items()})
ph_df = pd.DataFrame({key:pd.Series(value) for key, value in per_rem_epoch_percentage.items()})

nrem_df["nrem_norm"] = nrem_df["nrem"]/2700.0
nrem_df["rem_norm"] = nrem_df["rem"]/2700.0

ph_df = ph_df[ph_df["rem_epoch_dur"]>0.0]
ph_df["count_norm"] = ph_df["count"]/(45.0)
ph_df["rem_percentage"] = ph_df["duration"]/ph_df["rem_epoch_dur"]

In [15]:
ph_tot_df = pd.DataFrame({key:pd.Series(value) for key, value in per_trial_percentage.items()})
ph_tot_df["phasic_norm"] = ph_tot_df["phasic"]/2700.0
ph_tot_df["tonic_norm"] = ph_tot_df["tonic"]/2700.0
ph_tot_df

Unnamed: 0,rat_id,study_day,condition,treatment,trial_num,phasic,tonic,phasic_norm,tonic_norm
0,5,8,HC,0,3,9.028,191.972,0.003344,0.071101
1,5,8,HC,0,2,16.704,381.296,0.006187,0.141221
2,5,8,HC,0,4,10.616,305.384,0.003932,0.113105
3,5,8,HC,0,5-1,7.952,473.048,0.002945,0.175203
4,5,8,HC,0,5-3,8.700,275.300,0.003222,0.101963
...,...,...,...,...,...,...,...,...,...
632,11,4,OR,4,5-1,19.252,410.748,0.007130,0.152129
633,11,4,OR,4,5-2,25.208,709.792,0.009336,0.262886
634,11,4,OR,4,5-3,12.966,329.034,0.004802,0.121864
635,11,4,OR,4,5-4,1.898,120.102,0.000703,0.044482


In [11]:
ph_df

Unnamed: 0,rat_id,study_day,condition,treatment,trial_num,state,epoch_id,duration,count,rem_epoch_dur,count_norm,rem_percentage
0,5,8,HC,0,3,phasic,0,5.874,3,154,0.066667,0.038143
1,5,8,HC,0,3,tonic,0,148.126,4,154,0.088889,0.961857
2,5,8,HC,0,3,phasic,1,3.154,1,47,0.022222,0.067106
3,5,8,HC,0,3,tonic,1,43.846,2,47,0.044444,0.932894
4,5,8,HC,0,2,phasic,0,4.856,3,85,0.066667,0.057129
...,...,...,...,...,...,...,...,...,...,...,...,...
4275,11,4,OR,4,3,tonic,2,82.590,2,84,0.044444,0.983214
4276,11,4,OR,4,3,phasic,3,8.504,5,137,0.111111,0.062073
4277,11,4,OR,4,3,tonic,3,128.496,6,137,0.133333,0.937927
4278,11,4,OR,4,3,phasic,4,4.926,2,65,0.044444,0.075785


In [12]:
nrem_df

Unnamed: 0,rat_id,study_day,condition,treatment,trial_num,nrem,rem,nrem_norm,rem_norm
0,5,8,HC,0,3,2304,201,0.853333,0.074444
1,5,8,HC,0,2,2090,398,0.774074,0.147407
2,5,8,HC,0,4,1783,316,0.660370,0.117037
3,5,8,HC,0,5-1,1736,481,0.642963,0.178148
4,5,8,HC,0,5-3,1255,285,0.464815,0.105556
...,...,...,...,...,...,...,...,...,...
632,11,4,OR,4,5-1,1494,593,0.553333,0.219630
633,11,4,OR,4,5-2,1720,571,0.637037,0.211481
634,11,4,OR,4,5-3,845,342,0.312963,0.126667
635,11,4,OR,4,5-4,489,122,0.181111,0.045185


In [24]:
# Melt the dataframe to combine columns 'c' and 'd' into a single column 'e'
df_melted = nrem_df.melt(id_vars=['rat_id', 'study_day', 'condition', 'treatment', 'trial_num'],
                          value_vars=['nrem', 'rem'], var_name='sleep', value_name='duration')
df_melted["duration_norm"] = df_melted["duration"]/2700.0
df_melted.to_csv("nrem_rem_percentage0.csv", index=False)

In [25]:
df_melted

Unnamed: 0,rat_id,study_day,condition,treatment,trial_num,sleep,duration,duration_norm
0,5,8,HC,0,3,nrem,2304,0.853333
1,5,8,HC,0,2,nrem,2090,0.774074
2,5,8,HC,0,4,nrem,1783,0.660370
3,5,8,HC,0,5-1,nrem,1736,0.642963
4,5,8,HC,0,5-3,nrem,1255,0.464815
...,...,...,...,...,...,...,...,...
1269,11,4,OR,4,5-1,rem,593,0.219630
1270,11,4,OR,4,5-2,rem,571,0.211481
1271,11,4,OR,4,5-3,rem,342,0.126667
1272,11,4,OR,4,5-4,rem,122,0.045185


In [16]:
ph_df.to_csv("per_rem_epoch_percentage.csv", index=False)
nrem_df.to_csv("nrem_rem_percentage.csv", index=False)
ph_tot_df.to_csv("per_trial_percentage.csv", index=False)