In [1]:
from phasic_tonic.detect_phasic import detect_phasic_v2
from phasic_tonic.helper import get_metadata
from phasic_tonic.runtime_logger import logger_setup
from phasic_tonic.utils import get_sequences, get_segments

import numpy as np
import pandas as pd
import pynapple as nap

from pathlib import Path
from tqdm.auto import tqdm
from scipy.io import loadmat

fs = 500

logger = logger_setup()

CONFIG_DIR = "/home/nero/phasic_tonic/data/dataset_loading.yaml"
DATASET_DIR = "/home/nero/datasets/preprocessed"
OUTPUT_DIR1 = "/home/nero/phasic_tonic/data/analysis_output/whole_posttrial5/"
OUTPUT_DIR2 = "/home/nero/phasic_tonic/data/analysis_output/segmented_posttrial5/"

def str_to_tuple(string):
    string = string.strip("()")
    parts = string.split(",")
    return tuple(map(int, parts))

def load_data(fname):
    loaded_data = np.load(fname)
    loaded_dict = {str_to_tuple(key): loaded_data[key] for key in loaded_data.files}
    return loaded_dict

compressed_datasets = list(Path(DATASET_DIR).glob('*.npz'))
len(compressed_datasets)

699

In [2]:
import emd
from neurodsp.filt import filter_signal

def compute_range(x):
    return x.max() - x.min()

def asc2desc(x):
    """Ascending to Descending ratio ( A / A+D )."""
    pt = emd.cycles.cf_peak_sample(x, interp=True)
    tt = emd.cycles.cf_trough_sample(x, interp=True)
    if (pt is None) or (tt is None):
        return np.nan
    asc = pt + (len(x) - tt)
    #desc = tt - pt
    return asc / len(x)

def peak2trough(x):
    """Peak to trough ratio ( P / P+T )."""
    des = emd.cycles.cf_descending_zero_sample(x, interp=True)
    if des is None:
        return np.nan
    return des / len(x)

def compute_cycles(signal, fs, metadata):
    signal = filter_signal(signal, fs, 'bandpass', (5,12), remove_edges=False)
    # Perform EMD and compute cycle metrics
    IP, IF, IA = emd.spectra.frequency_transform(signal, fs, 'hilbert', smooth_phase=3)
    C = emd.cycles.Cycles(IP.flatten())
    # print("Detected cycles before extraction:")
    # print(C)

    # Compute cycle metrics
    C.compute_cycle_metric('start_sample', np.arange(len(C.cycle_vect)), emd.cycles.cf_start_value)
    C.compute_cycle_metric('stop_sample', signal, emd.cycles.cf_end_value)
    C.compute_cycle_metric('peak_sample', signal, emd.cycles.cf_peak_sample)
    C.compute_cycle_metric('desc_sample', signal, emd.cycles.cf_descending_zero_sample)
    C.compute_cycle_metric('trough_sample', signal, emd.cycles.cf_trough_sample)
    C.compute_cycle_metric('duration_samples', signal, len)
    C.compute_cycle_metric('max_amp', IA, np.max)
    C.compute_cycle_metric('mean_if', IF, np.mean)
    C.compute_cycle_metric('max_if', IF, np.max)
    C.compute_cycle_metric('range_if', IF, compute_range)  # Make sure 'compute_range' is defined
    C.compute_cycle_metric('asc2desc', signal, asc2desc)  # Make sure 'asc2desc' is defined
    C.compute_cycle_metric('peak2trough', signal, peak2trough)  # Make sure 'peak2trough' is defined

    # print('\nFinished computing the cycles metrics\n')

    # Extract a subset of the cycles
    amp_thresh = np.percentile(IA, 25)
    lo_freq_duration = fs / 5
    hi_freq_duration = fs / 12
    conditions = ['is_good==1',
                  f'duration_samples<{lo_freq_duration}',
                  f'duration_samples>{hi_freq_duration}',
                  f'max_amp>{amp_thresh}']

    # print("Cycles after extraction:")
    df_emd = C.get_metric_dataframe(conditions=conditions)
    # print(f'{len(df_emd)}')

    #Add the metadata
    df_emd["rat"]       = metadata["rat_id"]
    df_emd["study_day"] = metadata["study_day"]
    df_emd["condition"] = metadata["condition"]
    df_emd["treatment"] = metadata["treatment"]
    df_emd["trial_num"] = metadata["trial_num"]
    df_emd["state"]     = metadata["state"]
    
    start, end = metadata["interval"]
    df_emd["start"] = start
    df_emd["end"] = end

    return df_emd

In [3]:
combined = []

with tqdm(mapped_datasets) as mapped_tqdm:
    for name in mapped_tqdm:
        metadata = get_metadata(name)
        mapped_tqdm.set_postfix_str(name)
        states_fname, hpc_fname, pfc_fname = mapped_datasets[name]
        logger.debug("Loading: {0}".format(name))

        if metadata["treatment"] == 0 or metadata["treatment"] == 1:
            n_down = n_down_cbd
        elif metadata["treatment"] == 2 or metadata["treatment"] == 3:
            n_down = n_down_rgs
        elif metadata["treatment"] == 4:
            n_down = n_down_os
        
        # Load the LFP data
        lfpHPC = loadmat(hpc_fname)['HPC'].flatten()

        # Load the states
        hypno = loadmat(states_fname)['states'].flatten()
        
        # Skip if no REM epoch is detected
        if(not (np.any(hypno == 5))):
            logger.debug("No REM detected. Skipping.")
            continue
        elif(np.sum(np.diff(get_sequences(np.where(hypno == 5)[0]))) < 10):
            logger.debug("No REM longer than 10s. Skipping.")
            continue
        
        # Detect phasic intervals
        lfpHPC_down = preprocess(lfpHPC, n_down)
        phrem = detect_phasic(lfpHPC_down, hypno, targetFs)
        break
        t = np.arange(0, len(lfpHPC_down)/targetFs, 1/targetFs)
        lfp = nap.Tsd(t=t, d=lfpHPC_down)
        
        start, end = [], []
        rem_start, rem_end = [], []
        for rem_idx in phrem:
            rem_start.append(rem_idx[0])
            rem_end.append(rem_idx[1])

            for s, e in phrem[rem_idx]:
                start.append(s / targetFs)
                end.append(e / targetFs)
        
        rem_interval = nap.IntervalSet(rem_start, rem_end)
        phasic_interval = nap.IntervalSet(start, end)
        tonic_interval = rem_interval.set_diff(phasic_interval)

        phasic_interval = phasic_interval.drop_short_intervals(0.6)
        tonic_interval = tonic_interval.drop_short_intervals(0.6)
        
        #Compute waveform dynamics for each intervals
        metadata['state'] = 'phasic'
        for i in range(len(phasic_interval)):
            metadata['interval'] = (phasic_interval[i]['start'].item(), phasic_interval[i]['end'].item())
            df_emd = compute_cycles(lfp.restrict(phasic_interval[i]).to_numpy(), lfp.rate, metadata)
            combined.append(df_emd)
            
        metadata['state'] = 'tonic'
        for i in range(len(tonic_interval)):
            metadata['interval'] = (tonic_interval[i]['start'].item(), tonic_interval[i]['end'].item())
            df_emd = compute_cycles(lfp.restrict(tonic_interval[i]).to_numpy(), lfp.rate, metadata)
            combined.append(df_emd)

NameError: name 'mapped_datasets' is not defined

In [5]:
combined = []

with tqdm(compressed_datasets) as datasets:
    for fname in datasets:
        metaname = str(fname.stem)

        datasets.set_postfix_str(metaname)
        metadata = get_metadata(metaname)

        rem_epochs = load_data(fname)

        if not rem_epochs:
            continue
        
        phrem = detect_phasic_v2(rem_epochs, fs)
        
        start, end = [], []
        rem_start, rem_end = [], []
        for rem_idx in phrem:
            rem_start.append(rem_idx[0])
            rem_end.append(rem_idx[1])

            for s, e in phrem[rem_idx]:
                start.append(s / fs)
                end.append(e / fs)
        
        rem_interval = nap.IntervalSet(rem_start, rem_end)
        phasic_interval = nap.IntervalSet(start, end)
        tonic_interval = rem_interval.set_diff(phasic_interval)

        break
        t = np.arange(0, len(lfpHPC_down)/targetFs, 1/targetFs)
        lfp = nap.Tsd(t=t, d=lfpHPC_down)
        
        start, end = [], []
        rem_start, rem_end = [], []
        for rem_idx in phrem:
            rem_start.append(rem_idx[0])
            rem_end.append(rem_idx[1])

            for s, e in phrem[rem_idx]:
                start.append(s / targetFs)
                end.append(e / targetFs)
        
        rem_interval = nap.IntervalSet(rem_start, rem_end)
        phasic_interval = nap.IntervalSet(start, end)
        tonic_interval = rem_interval.set_diff(phasic_interval)

        phasic_interval = phasic_interval.drop_short_intervals(0.6)
        tonic_interval = tonic_interval.drop_short_intervals(0.6)
        
        #Compute waveform dynamics for each intervals
        metadata['state'] = 'phasic'
        for i in range(len(phasic_interval)):
            metadata['interval'] = (phasic_interval[i]['start'].item(), phasic_interval[i]['end'].item())
            df_emd = compute_cycles(lfp.restrict(phasic_interval[i]).to_numpy(), lfp.rate, metadata)
            combined.append(df_emd)
            
        metadata['state'] = 'tonic'
        for i in range(len(tonic_interval)):
            metadata['interval'] = (tonic_interval[i]['start'].item(), tonic_interval[i]['end'].item())
            df_emd = compute_cycles(lfp.restrict(tonic_interval[i]).to_numpy(), lfp.rate, metadata)
            combined.append(df_emd)

        if metadata['trial_num'] in ['5-0', '5-1', '5-2', '5-3']:
            a, b = metadata['trial_num'].split('-')
            metadata['trial_num'] = a + '.' + str(int(b)+1)
        
        # Save duration bouts
        for state, interval in [("phasic", phasic_interval), ("tonic", tonic_interval)]:
            for condition in metadata.keys():
                per_trial_stats[condition].append(metadata[condition])
            per_trial_stats['state'].append(state)
            per_trial_stats['total_duration'].append(interval.tot_length())
            per_trial_stats['num_bouts'].append(len(interval))

# df_trial = pd.DataFrame(per_trial_stats)

  0%|          | 0/699 [00:00<?, ?it/s]

In [6]:
combined = []

with tqdm(compressed_datasets) as datasets:
    for fname in datasets:
        metaname = str(fname.stem)

        datasets.set_postfix_str(metaname)
        metadata = get_metadata(metaname)

        rem_epochs = load_data(fname)

        if not rem_epochs:
            continue
        
        phrem = detect_phasic_v2(rem_epochs, fs)

        for rem_idx in rem_epochs:
            lfpREM = rem_epochs[rem_idx]
            phasic_intervals = phrem[rem_idx]
            phasic, tonic = [], []
            if phasic_intervals:
                for start, end in phasic_intervals:
                    phasic.append(lfp[start:end])
            else:
                tonic.append(lfp)
            


{(9121,
  9168): array([-521.40680664, -425.96849279, -371.52721119, ...,  -47.62286642,
         -41.44032593,  -24.79185434]),
 (9430,
  9468): array([-352.73359408, -372.59991162, -369.94954227, ..., -238.26998387,
        -208.68342674, -167.58037353]),
 (9738,
  9820): array([-111.45551596, -177.1756736 , -245.16593697, ...,   18.80917785,
         -23.71035767,  -96.97847569])}

In [13]:
nap.IntervalSet(start=start, end=end)

            start      end
       0  9158     9158.99
       1  9788.73  9790.73
       2  9808.62  9810.67
       3  9811.39  9812.98
shape: (4, 2), time unit: sec.

In [21]:
for i, interval in enumerate(rem_interval):
    print(i, interval)

0             start    end
       0     9121   9168
shape: (1, 2), time unit: sec.
1             start    end
       0     9430   9468
shape: (1, 2), time unit: sec.
2             start    end
       0     9738   9820
shape: (1, 2), time unit: sec.


In [19]:
4579495/500

9158.99

In [14]:
rem_interval

            start    end
       0     9121   9168
       1     9430   9468
       2     9738   9820
shape: (3, 2), time unit: sec.

In [15]:
tonic_interval

            start      end
       0  9121     9158
       1  9158.99  9168
       2  9430     9468
       3  9738     9788.73
       4  9790.73  9808.62
       5  9810.67  9811.39
       6  9812.98  9820
shape: (7, 2), time unit: sec.