# Load the datasets

The path to dataset directory and patterns to search in those directories for the HPC, PFC recordings are loaded from the config file.

In [1]:
from phasic_tonic.detect_phasic import detect_phasic
from phasic_tonic.DatasetLoader import DatasetLoader
from phasic_tonic.helper import get_metadata
from phasic_tonic.runtime_logger import logger_setup
from phasic_tonic.utils import get_sequences, get_segments

import numpy as np
import pandas as pd
import pynapple as nap
import yasa

from tqdm.auto import tqdm
from scipy.io import loadmat
from mne.filter import resample

fs_cbd = 2500
fs_os = 2500
fs_rgs = 1000

targetFs = 500
n_down_cbd = fs_cbd/targetFs
n_down_rgs = fs_rgs/targetFs
n_down_os = fs_os/targetFs

logger = logger_setup()

CONFIG_DIR = "/home/nero/phasic_tonic/data/dataset_loading.yaml"
OUTPUT_DIR = "/home/nero/datasets/preprocessed/"

Datasets = DatasetLoader(CONFIG_DIR)
mapped_datasets = Datasets.load_datasets()

def preprocess(signal: np.ndarray, n_down: int, target_fs=500) -> np.ndarray:
    """Downsample and remove artifacts."""
    
    logger.debug("STARTED: Resampling to 500 Hz.")
    # Downsample to 500 Hz
    data = resample(signal, down=n_down, method='fft', npad='auto')
    logger.debug("FINISHED: Resampling to 500 Hz.")
    logger.debug("Resampled: {0} -> {1}.".format(str(signal.shape), str(data.shape)))
    
    logger.debug("STARTED: Remove artifacts.")
    # Remove artifacts
    art_std, _ = yasa.art_detect(data, target_fs , window=1, method='std', threshold=4)
    art_up = yasa.hypno_upsample_to_data(art_std, 1, data, target_fs)
    data[art_up] = 0
    logger.debug("FINISHED: Remove artifacts.")
        
    data -= data.mean()
    return data

def partition_to_4(rem_dict):
    # rem_dict: dictionary with keys as tuples and values as numpy arrays
    keys = sorted(rem_dict.keys())
    partitions = [{} for _ in range(4)]  # Create a list of 4 empty dictionaries

    for rem_idx in keys:
        _, end = rem_idx
        if end < 2700:  # First region
            partitions[0][rem_idx] = rem_dict[rem_idx]
        elif end < 5400:  # Second region
            partitions[1][rem_idx] = rem_dict[rem_idx]
        elif end < 8100:  # Third region
            partitions[2][rem_idx] = rem_dict[rem_idx]
        else:  # Fourth region
            partitions[3][rem_idx] = rem_dict[rem_idx]

    return partitions

Check the number of loaded recordings

In [2]:
cbd_cnt = 0
rgs_cnt = 0
os_cnt = 0

# Count recordings belonging to CBD dataset
for name in mapped_datasets:
    metadata = get_metadata(name)
    if metadata['treatment'] == 0 or metadata['treatment'] == 1:
        cbd_cnt += 1
    elif metadata['treatment'] == 2 or metadata['treatment'] == 3:
        rgs_cnt += 1
    elif metadata['treatment'] == 4:
        os_cnt += 1

assert cbd_cnt == 170
assert rgs_cnt == 159
assert os_cnt == 210

# Loop through the dataset

In [3]:
with tqdm(mapped_datasets) as mapped_tqdm:
    for name in mapped_tqdm:
        metadata = get_metadata(name)
        mapped_tqdm.set_postfix_str(name)
        
        logger.debug("Loading: {0}".format(name))
        states_fname, hpc_fname, pfc_fname = mapped_datasets[name]
        logger.debug(f"Sleep States file: {states_fname}")
        logger.debug(f"HPC file: {hpc_fname}")
        logger.debug(f"PFC file: {pfc_fname}")
        

        if metadata["treatment"] == 0 or metadata["treatment"] == 1:
            n_down = n_down_cbd
        elif metadata["treatment"] == 2 or metadata["treatment"] == 3:
            n_down = n_down_rgs
        elif metadata["treatment"] == 4:
            n_down = n_down_os
        
        # Load the LFP data
        lfpHPC = loadmat(hpc_fname)['HPC'].flatten()

        # Load the states
        hypno = loadmat(states_fname)['states'].flatten()
    
        logger.debug("STARTED: Extract REM epochs.")    
        # Skip if no REM epoch is detected
        if(not (np.any(hypno == 5))):
            logger.debug("No REM detected. Skipping.")
            continue
        
        rem_idx = []
        min_duration = 3
        for start, end in get_sequences(np.where(hypno == 5)[0]):
            if (end-start) > min_duration:
                rem_idx.append((start, end))

        if len(rem_idx) == 0:
            logger.debug("No REM epochs greater than min_dur.")
            logger.debug(f"{rem_idx}")
            continue

        rem_seq = [(start*targetFs, (end+1)*targetFs) for start, end in rem_idx]
        logger.debug("FINISHED: Extract REM epochs.")
        logger.debug(f"REM Epochs: {rem_seq}")

        lfpHPC_down = preprocess(lfpHPC, n_down)
        rem_seg = get_segments(rem_seq, lfpHPC_down)

        rem = {idx:array for idx, array in zip(rem_idx, rem_seg)}
        del rem_seg, rem_seq
        
        if metadata["trial_num"] == '5':
            for i, partition in enumerate(partition_to_4(rem)):
                metadata["trial_num"] = '5-' + str(i+1)
                logger.debug("Partition: {0}".format(str(partition.keys())))
                
                fname = f"{name}-{i}.npz"
                logger.debug(f"Saving to {fname}.")
                
                # Convert keywords to strings to save as npz
                np.savez(OUTPUT_DIR+fname, **{str(key): value for key, value in partition.items()})
        else:
            fname = f"{name}.npz"
            logger.debug(f"Saving to {fname}.")
            
            # Convert keywords to strings to save as npz
            np.savez(OUTPUT_DIR+fname, **{str(key): value for key, value in rem.items()})

  0%|          | 0/539 [00:00<?, ?it/s]