# Load the datasets

The path to dataset directory and patterns to search in those directories for the HPC, PFC recordings are in loaded from the config file.

In [24]:
from phasic_tonic.detect_phasic import detect_phasic
from phasic_tonic.DatasetLoader import DatasetLoader
from phasic_tonic.helper import get_metadata
from phasic_tonic.runtime_logger import logger_setup
from phasic_tonic.utils import get_sequences

import numpy as np
import pandas as pd
import pynapple as nap
import yasa
import os

from tqdm import tqdm
from scipy.io import loadmat
from mne.filter import resample

fs_cbd = 2500
fs_os = 2500
fs_rgs = 1000

targetFs = 500
n_down_cbd = fs_cbd/targetFs
n_down_rgs = fs_rgs/targetFs
n_down_os = fs_os/targetFs

logger = logger_setup()

CONFIG_DIR = "/home/nero/phasic_tonic/data/dataset_loading.yaml"
OUTPUT_DIR = "/home/nero/datasets/preprocessed"

Datasets = DatasetLoader(CONFIG_DIR)
mapped_datasets = Datasets.load_datasets()

def preprocess(signal: np.ndarray, n_down: int, target_fs=500) -> np.ndarray:
    """Downsample and remove artifacts."""
    
    logger.debug("STARTED: Resampling to 500 Hz.")
    # Downsample to 500 Hz
    data = resample(signal, down=n_down, method='fft', npad='auto')
    logger.debug("FINISHED: Resampling to 500 Hz.")
    logger.debug("Resampled: {0} -> {1}.".format(str(signal.shape), str(data.shape)))
    
    logger.debug("STARTED: Remove artifacts.")
    # Remove artifacts
    art_std, _ = yasa.art_detect(data, target_fs , window=1, method='std', threshold=4)
    art_up = yasa.hypno_upsample_to_data(art_std, 1, data, target_fs)
    data[art_up] = 0
    logger.debug("FINISHED: Remove artifacts.")
        
    data -= data.mean()
    return data

def get_start_end(hypno: np.ndarray, sleep_state_id: int):
    """Convert sleep states into lists of start and end time indices."""
    seq = get_sequences(np.where(hypno == sleep_state_id)[0])
    start = []
    end = []
    for s, e in seq:
        start.append(s)
        end.append(e)
    return (start, end)

def _detect_troughs(signal, thr):
    lidx  = np.where(signal[0:-2] > signal[1:-1])[0]
    ridx  = np.where(signal[1:-1] <= signal[2:])[0]
    thidx = np.where(signal[1:-1] < thr)[0]
    sidx = np.intersect1d(lidx, np.intersect1d(ridx, thidx))+1
    return sidx

Check the number of loaded recordings

In [25]:
cbd_cnt = 0
rgs_cnt = 0
os_cnt = 0

# Count recordings belonging to CBD dataset
for name in mapped_datasets:
    metadata = get_metadata(name)
    if metadata['treatment'] == 0 or metadata['treatment'] == 1:
        cbd_cnt += 1
    elif metadata['treatment'] == 2 or metadata['treatment'] == 3:
        rgs_cnt += 1
    elif metadata['treatment'] == 4:
        os_cnt += 1

assert cbd_cnt == 170
assert rgs_cnt == 159
assert os_cnt == 210

# Loop through the dataset

In [33]:
for name in mapped_datasets:
    metadata = get_metadata(name)
    states_fname, hpc_fname, pfc_fname = mapped_datasets[name]
    logger.debug("Loading: {0}".format(name))

    if metadata["treatment"] == 0 or metadata["treatment"] == 1:
        n_down = n_down_cbd
    elif metadata["treatment"] == 2 or metadata["treatment"] == 3:
        n_down = n_down_rgs
    elif metadata["treatment"] == 4:
        n_down = n_down_os
    
    # Load the LFP data
    lfpHPC = loadmat(hpc_fname)['HPC'].flatten()

    # Load the states
    hypno = loadmat(states_fname)['states'].flatten()
    
    # Skip if no REM epoch is detected
    if(not (np.any(hypno == 5))):
        logger.debug("No REM detected. Skipping.")
        continue
    
    # Detect phasic intervals
    lfpHPC_down = preprocess(lfpHPC, n_down)
    fname = os.path.join(OUTPUT_DIR, name)
    print(f"Saving to: {fname}")
    np.save(fname, lfpHPC_down) 

Saving to: /home/nero/datasets/preprocessed/Rat5_SD8_HC_0_posttrial3
Saving to: /home/nero/datasets/preprocessed/Rat5_SD8_HC_0_posttrial2


  rem_interval = nap.IntervalSet(start=start, end=end)


Saving to: /home/nero/datasets/preprocessed/Rat5_SD8_HC_0_posttrial4
Saving to: /home/nero/datasets/preprocessed/Rat5_SD8_HC_0_posttrial5


KeyboardInterrupt: 

In [27]:
t

array([0.0000000e+00, 4.0000000e-04, 8.0000000e-04, ..., 2.7011404e+03,
       2.7011408e+03, 2.7011412e+03])

In [23]:
tqdm(mapped_datasets)

  0%|          | 0/539 [00:00<?, ?it/s]

<tqdm.auto.tqdm at 0x7fec5a21b350>

## Access the HPC and PFC signals during phasic REM

In [32]:
lfp.restrict(phasic_interval)

AttributeError: 'numpy.ndarray' object has no attribute 'restrict'

In [20]:
phrem_hpc = [lfp["HPC"].restrict(phasic_interval[i]).to_numpy() for i in range(len(phasic_interval))]
tonic_hpc = [lfp["HPC"].restrict(tonic_interval[i]).to_numpy() for i in range(len(tonic_interval))]

In [24]:
phrem_hpc

[array([-226.95125111, -217.58033104, -198.25506026, ...,  -94.37276814,
        -216.52844067, -190.42849516]),
 array([-357.54982828, -394.76025735, -339.96985809, ..., -200.31561612,
        -230.31751253, -281.52736263]),
 array([ -39.15300095,  -44.30848611, -141.90788946, ..., -287.74691549,
        -263.49530349, -289.73187494]),
 array([-331.68989302, -295.6511893 , -314.86898616, ..., -283.82479763,
        -203.91310409, -315.64129598])]

In [21]:
name

'Rat5_SD8_HC_0_posttrial3'

In [25]:
np.savez(name+'_phasic', *phrem_hpc)
np.savez(name+'_tonic', *tonic_hpc)

In [31]:
data = np.load('Rat5_SD8_HC_0_posttrial3_phasic.npz')
phrem = [data[key] for key in data]
phrem

[array([-226.95125111, -217.58033104, -198.25506026, ...,  -94.37276814,
        -216.52844067, -190.42849516]),
 array([-357.54982828, -394.76025735, -339.96985809, ..., -200.31561612,
        -230.31751253, -281.52736263]),
 array([ -39.15300095,  -44.30848611, -141.90788946, ..., -287.74691549,
        -263.49530349, -289.73187494]),
 array([-331.68989302, -295.6511893 , -314.86898616, ..., -283.82479763,
        -203.91310409, -315.64129598])]

## Access HPC and PFC signals during tonic REM

In [38]:
lfp.restrict(tonic_interval)

Time (s)           HPC       PFC
----------  ----------  --------
1331.0      -180.893    -93.2364
1331.0004   -150.785    -69.5545
1331.0008   -132.412    -70.9175
1331.0012    -85.3764   -50.1542
1331.0016    -79.2403   -50.5931
1331.002     -75.6617   -49.3135
1331.0024    -57.6323   -34.4635
...
2474.9976   -125.165     76.2955
2474.998    -118.796     64.7644
2474.9984     -4.49443  173.913
2474.9988    -37.3338   133.221
2474.9992   -137.777     25.3803
2474.9996   -103.864     80.9616
2475.0       -36.7785   134.176
dtype: float64, shape: (479936, 2)