In [34]:
import os
import numpy as np
import json

from pyddeeg import METRIC_NAME_TO_INDEX

ROOT = "/home/mario/Python/Datasets/EEG/timeseries/processed/rqa_windows/"
ELECTRODE_INDEXED = os.path.join(ROOT, "electrode_indexed")
DATASET_INDEXES = json.load(
    open(os.path.join(ELECTRODE_INDEXED, "dataset_index.json"), "r")
)

METRIC_NAME_TO_INDEX

{'RR': 0,
 'DET': 1,
 'L_max': 2,
 'L_mean': 3,
 'ENT': 4,
 'LAM': 5,
 'TT': 6,
 'V_max': 7,
 'V_mean': 8,
 'V_ENT': 9,
 'W_max': 10,
 'W_mean': 11,
 'W_ENT': 12,
 'CLEAR': 13,
 'PERM_ENT': 14}

In [41]:
import os
import numpy as np
import mne
from pyddeeg import METRIC_NAME_TO_INDEX

def get_patient_data(window_size, direction, group, metric_name, patient_index, sfreq=250.0):
    """
    Extract patient data for a specific metric across all electrodes and convert to MNE object.
    
    Parameters:
    -----------
    window_size (str): Size of the window (e.g., 'window_50')
    direction (str): Direction of stimulus ('up' or 'down')
    group (str): Patient group ('CT' or 'DD')
    metric_name (str): Name of the RQA metric
    patient_index (int): Index of the patient
    sfreq (float, optional): Sampling frequency in Hz, defaults to 250.0
    
    Returns:
    --------
    mne.io.RawArray: MNE object containing the data
    """
    # Validate inputs
    if window_size not in DATASET_INDEXES:
        raise ValueError(f"Invalid window_size. Available options: {list(DATASET_INDEXES.keys())}")
    
    if direction not in DATASET_INDEXES[window_size]:
        raise ValueError(f"Invalid direction. Available options: {list(DATASET_INDEXES[window_size].keys())}")
    
    if metric_name not in METRIC_NAME_TO_INDEX:
        raise ValueError(f"Invalid metric_name. Available options: {list(METRIC_NAME_TO_INDEX.keys())}")
    
    # Get the metric index
    metric_index = METRIC_NAME_TO_INDEX[metric_name]
    
    # Get all electrodes for the specified window size and direction
    electrodes = list(DATASET_INDEXES[window_size][direction].keys())
    
    # Initialize a list to store data for each electrode
    electrode_data = []
    electrode_names = []
    
    # For each electrode, get the data for the specified group, metric, and patient
    for electrode in electrodes:
        # Find the file path for the specified group
        file_paths = DATASET_INDEXES[window_size][direction][electrode]
        file_path = None
        
        for path in file_paths:
            filename = os.path.basename(path)
            if f"{electrode}_{group}_{direction.upper()}" in filename:
                file_path = path
                break
        
        if file_path is None:
            print(f"Warning: No data found for electrode {electrode} in group {group}")
            continue
        
        # Load the npz file
        try:
            npz_file = np.load(file_path)
            
            # Extract the metrics data
            metrics_data = npz_file['metrics']
            
            # Check if patient index is valid
            if patient_index >= metrics_data.shape[0]:
                print(f"Warning: Patient index {patient_index} out of range for {electrode} in group {group}")
                continue
            
            # Extract data for the specified patient and metric
            data = metrics_data[patient_index, metric_index, :]
            
            # Add data to the lists
            electrode_data.append(data)
            electrode_names.append(electrode)
        
        except Exception as e:
            print(f"Error loading data for electrode {electrode}: {e}")
    
    if not electrode_data:
        raise ValueError("No valid data found for the specified parameters")
    
    # Convert list to numpy array
    data_tensor = np.array(electrode_data)
    
    # Create info dictionary for MNE
    ch_names = electrode_names
    ch_types = ['eeg'] * len(electrode_names)
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
    
    # Create RawArray
    raw = mne.io.RawArray(data_tensor, info)
    
    return raw

In [42]:
# Example: Get data for window size 50, upward direction, control group, 
# RR metric, patient index 0
raw = get_patient_data(
    window_size="window_50",
    direction="up",
    group="CT",
    metric_name="RR",
    patient_index=0
)

# Plot the data
raw.plot()

# Access the underlying data tensor
data_tensor = raw.get_data()
print(f"Shape of data tensor: {data_tensor.shape}")

ImportError: cannot import name 'broadcast_to' from 'numpy.lib.stride_tricks' (/home/mario/miniconda3/envs/pyddeeg/lib/python3.10/site-packages/numpy/lib/stride_tricks.py)

In [43]:
import numpy as np
import numpy.lib.stride_tricks as st

# Add broadcast_to to stride_tricks if it's missing
if not hasattr(st, 'broadcast_to'):
    st.broadcast_to = np.broadcast_to