In [1]:
import os

nthreads = "32" # 64 on synapse
os.environ["OMP_NUM_THREADS"] = nthreads
os.environ["OPENBLAS_NUM_THREADS"] = nthreads
os.environ["MKL_NUM_THREADS"] = nthreads
os.environ["VECLIB_MAXIMUM_THREADS"] = nthreads
os.environ["NUMEXPR_NUM_THREADS"] = nthreads
import matplotlib.pyplot as plt
import scipy.stats
import sklearn
from sklearn.ensemble import (HistGradientBoostingRegressor, 
                              RandomForestRegressor)
from sklearn.impute import SimpleImputer
import numpy as np
import sklearn.linear_model
from scipy.interpolate import CubicSpline
import sklearn.model_selection
from typing import List, Dict, Union, Optional
from typing import Any
import pickle
import seaborn as sns
import scipy
from numpy.lib.stride_tricks import sliding_window_view

In [2]:
def parse_filename(filename: str | os.PathLike) -> dict[str,str]:
    """parse filename that are somewhat like BIDS but not rigoursly like it.

    Args:
        filename (str | os.PathLike): The filename to be parsed

    Returns:
        dict[str,str]: The filename parts
    """
    splitted_filename = filename.split('_')
    filename_parts = {}
    for part in splitted_filename:
        splitted_part = part.split('-')
        if splitted_part[0] in ['sub','ses','run','task']:
            label, value = splitted_part
            filename_parts[label] = value
        
    return filename_parts

def combine_data_from_filename(reading_dir: str | os.PathLike,
                               task:str = "checker",
                               run: str = "01"):
    """Combine the data from the files in the reading directory.

    Args:
        reading_dir (str | os.PathLike): The directory where the data is stored.
        task (str, optional): The task to concatenate. Defaults to "checker".
        run (str, optional): Either it's run-01 or run-01BlinksRemoved. 
                             Defaults to "01".

    Returns:
        _type_: _description_
    """
    big_data = dict()
    filename_list = os.listdir(reading_dir)
    for filename in filename_list:
        filename_parts = parse_filename(filename)
        subject = filename_parts["sub"]
        with open(os.path.join(reading_dir,filename), 'rb') as file: 
            data = pickle.load(file)
        if task in filename_parts['task'] and filename_parts['run'] == run:
            wrapped_data = {
                f'ses-{filename_parts["ses"]}':{
                    filename_parts["task"]:{
                        f'run-{filename_parts["run"]}': data
                    }
                }
            }
            if big_data.get(f'sub-{subject}'):
                big_data[f'sub-{subject}'].update(wrapped_data)
            else:
                big_data[f'sub-{subject}'] = wrapped_data


    return big_data

big_d = combine_data_from_filename('/data2/Projects/eeg_fmri_natview/derivatives/multimodal_prediction_models/data_prep/prediction_model_data_eeg_features_v2/group_data_Hz-3.8',
                                    task = 'checker',
                                    run = '01BlinksRemoved')

In [3]:
def filter_data(data: np.ndarray, 
                low_freq_cutoff: float | None = None,
                high_freq_cutoff: float | None = None,
                ):
    """Filter the data using a bandpass filter.

    Args:
        data (np.ndarray): The data to filter
        low_freq_cutoff (float | None): The lower bound to filter. 
                                        Defaults to None.
        high_freq_cutoff (float, optional): The higher bound to filter. 
                                            Defaults to 0.1.

    Returns:
        data: _description_
    """
    if high_freq_cutoff and not low_freq_cutoff:
        filter_type = 'low'
        freq = high_freq_cutoff
    elif low_freq_cutoff and not high_freq_cutoff:
        filter_type = 'high'
        freq = low_freq_cutoff
    elif high_freq_cutoff and low_freq_cutoff:
        filter_type = 'band'
        freq = [low_freq_cutoff, high_freq_cutoff]
    
    filtered_data = scipy.signal.butter(
        4, 
        freq, 
        btype=filter_type,
        output='sos'
        )
    filtered_data = scipy.signal.sosfilt(filtered_data, data, axis=1)
    return filtered_data

def generate_key_list(subjects: list[str] | str,
                      sessions: list[str] | str,
                      task: str,
                      runs: list[str] | str,
                      big_data: dict | None,
                      return_dict = False) -> list[tuple[str, str, str, str]]:
    """Generate a list of keys to access the data in the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        subjects (list[str] | str): The list of subjects to consider
        sessions (list[str] | str): The list of sessions to consider
        tasks (str): The task to consider
        runs (list[str] | str): The list of runs to consider
    
    Returns:
        list[tuple[str]]: The list of keys to access the data
    """
    key_list = list()
    for subject in subjects:
        for session in sessions:
            for run in runs:
                try:
                    big_data[f'sub-{subject}'][f'ses-{session}'][task][f'run-{run}']
                    key_list.append((
                        f'sub-{subject}',
                        f'ses-{session}',
                        task,
                        f'run-{run}'
                        ))
                    
                except:
                    continue
    
    return key_list

def extract_cap_name_list(big_data: dict,
                       keys_list: list[tuple[str, ...]]) -> list[str]:
    """Extract the list of CAP names from the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        keys_list (list): The list of keys to access the data in the dictionary.
    
    Returns:
        list: The list of CAP names
    """
    subject, session, task, run = keys_list[0]
    return big_data[subject][session][task][run]['brainstates']['labels']

def get_real_cap_name(cap_names: str | list[str],
                      cap_list: list[str]) -> list:
    """Get the real CAP name based on a substring from the list of CAP names.
    
    Args:
        cap_name (str): The substring to look for in the list of CAP names
        cap_list (list): The list of CAP names
    
    Returns:
        str: The real CAP name
    """
    real_cap_names = list()
    if isinstance(cap_names,str):
        cap_names = [cap_names]
    for cap_name in cap_names:
        real_cap_names.extend([cap for cap in cap_list if cap_name in cap])
    
    return real_cap_names

def create_big_feature_array(big_data: dict,
                             modality: str,
                             array_name: str,
                             index_to_get: int | None,
                             axis_to_get: int | None,
                             keys_list: list[tuple[str, ...]],
                             axis_to_concatenate: int = 1,
                             subject_agnostic: bool = False
                             ) -> np.ndarray | list:
    """This is to make a big numpy array across subject.

    It concatenates the features of interest along the time axis (2nd dim).

    Args:
        big_data (dict): The dictionary containing all the data
        to_concat (str): The name of the feature to concatenate (EEGBandEnvelope
                            for example). This choose the feature to consider 
                            later as X.
        index_to_get (int): The index (on the third dimension) of the frequency
                                of interest (or the frequency band). 
                                If None, the entire array is considered.
        axis_to_get (int): The axis to get the data from. If None, the entire
                                array is considered.
        keys_list (list): The list of keys to access the data in the dictionary.
    """
    
    concatenation_list = list()
    for keys in keys_list:
        subject, session, task, run = keys
        if isinstance(index_to_get, int) and isinstance(axis_to_get, int):
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality
                              ][array_name].take(
                                index_to_get,
                                axis = axis_to_get
                            )
        else:
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality][array_name]
        
        if extracted_array.ndim < 2:
            extracted_array = np.reshape(extracted_array,(1,extracted_array.shape[0]))
        
        #extracted_array = filter_data(extracted_array,high_freq_cutoff=0.1)
        concatenation_list.append(extracted_array)
    
    if subject_agnostic:
        return np.concatenate(concatenation_list,axis = axis_to_concatenate)
    else:
        return concatenation_list

def _find_item(desired_key: str, obj: Dict[str, Any]) -> Any:
    """Find any item in an encapsulated dictionary."

    Args:
        desired_key (str): They key to look for.
        obj (Dict[str, Any]): the dictionary.

    Returns:
        Any: The returned item found in the encapsulated dictionary.
    """
    if obj.get(desired_key) is not None:
        return obj[desired_key]
    
    for value in obj.values():
        if isinstance(value, dict):
            item = _find_item(desired_key, value)
            if item:
                return item

def get_specific_location(big_data: Dict, 
                          channel_names: Optional[List[str]] = None, 
                          anatomical_location: Optional[List[str]] = None, 
                          laterality: Optional[List[str]] = None) -> Union[np.ndarray, None]:
    """
    Filters the channels based on anatomical location, laterality, and channel names.

    Parameters:
    - big_data (dict): The dictionary containing channel information.
    - channel_names (list[str], optional): List of channel names to filter.
    - anatomical_location (list[str], optional): List of anatomical locations to filter.
    - laterality (list[str], optional): List of lateralities to filter.

    Returns:
    - np.ndarray | None: A boolean array indicating the filtered channels or None if no channel info is found.
    """
    channel_info = _find_item("channels_info", big_data)
    if not channel_info:
        return None

    mask = np.zeros(len(channel_info['channel_name']), dtype=bool)

    if anatomical_location:
        
        anatomy_mask = np.isin(
            channel_info.get('anatomy', []), 
            anatomical_location
            )
        
        mask = np.logical_or(mask, anatomy_mask)

    if laterality:
        
        laterality_mask = np.isin(
            channel_info.get('laterality', []), 
            laterality
            )

        if anatomical_location:
            comparison = getattr(np, 'logical_and')
        else:
            comparison = getattr(np, 'logical_or')
        
        mask = comparison(mask, laterality_mask)

    if channel_names:
        
        channel_mask = np.isin(
            channel_info.get('channel_name', []), 
            channel_names
            )
        
        mask = np.logical_or(mask, channel_mask)

    return mask if mask.any() else None

def combine_masks(big_data:dict,
                  key_list: list,
                  modalities: list = ['EEGbandsEnvelopes','brainstates']
                  ) -> np.ndarray[bool]:
    masks = []
    for modality in modalities:
        if "envelopes" in modality.lower() or "tfr" in modality.lower():
            array_name = 'artifact_mask'
            index_to_get = None
            axis_to_get = None

        else:
            array_name = 'feature'
            index_to_get = -1
            axis_to_get = 0
        
        temp_mask = create_big_feature_array(
            big_data = big_data,
            modality = modality,
            array_name = array_name,
            index_to_get = index_to_get,
            axis_to_get=axis_to_get,
            keys_list=key_list,
            axis_to_concatenate=0
        )
        masks.append(temp_mask.flatten() > 0.5)
    
    masks = np.array(masks)
    overall_mask = np.all(masks, axis = 0)
    print('Overall mask shape:',overall_mask.shape)
    return np.all(masks, axis = 0)
        
def build_windowed_mask(big_data: dict,
                        key_list:list,
                        window_length: int = 45,
                        modalities = ['brainstates']
                        ) -> np.ndarray:
    """Build the mask based on the brainstate and EEG ones.
    
    The mask will be windowed to match the windowed data.

    Args:
        big_data (dict): The dictionary containing all the data
        window_length (int, optional): The length of the sliding window in
                                       samples. Defaults to 45.
        steps (int, optional): The sliding steps in samples. Defaults to 1.

    Returns:
        np.ndarray: The windowed mask
    """

    joined_masks = combine_masks(big_data,
                                 key_list,
                                 modalities = modalities)
    
    windowed_mask = sliding_window_view(joined_masks[:-1], 
                                        window_shape=window_length,
                                        axis = 0)
    print('Windowed mask shape:',windowed_mask.shape)
    if np.ndim(windowed_mask) < 3:
        return windowed_mask
    else:
        # Take the case of EEG channels. If there is one channel not good, reject the entire window.
        return np.all(windowed_mask, axis = 1)

def build_windowed_data(array: np.ndarray,
                        window_length: int = 45,
                        reduction: str = "flatten") -> np.ndarray:
    if array.ndim == 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:-1], 
            window_shape=window_length, 
        )

    elif array.ndim > 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:,:-1,...], 
            window_shape=window_length, 
            axis=1
        )
        new_shape_X = (windowed_data.shape[1], -1) + windowed_data.shape[3:]
        windowed_data = windowed_data.transpose(1, 0, 2, *range(3, array.ndim + 1))
        windowed_data = windowed_data.reshape(new_shape_X)
        
        if array.ndim > 2:
            if reduction == 'flatten':
                windowed_data = windowed_data.reshape(windowed_data.shape[0], -1)
            
            elif reduction == 'gfp':
                windowed_data = np.squeeze(np.var(windowed_data, axis=0))
            
    return windowed_data
    
def create_X_and_Y(big_data: dict,
                   keys_list: list[tuple[str, ...]],
                   X_name: str,
                   cap_name: str,
                   bands_names: str | list | None =  None,
                   chan_select_args: Dict[str,str] | None = None,
                   normalization: str | None = 'zscore',
                   reduction_method: str = 'flatten',
                   window_length: int = 45,
                   integrate_pupil: bool = True,
                  ) -> tuple[Any,Any]:
    
    bands_list = ['delta','theta','alpha','beta','gamma']

    if isinstance(bands_names,list):
        index_band = [bands_list.index(band) for band in bands_names]

    elif isinstance(bands_names, str):
        index_band = bands_list.index(bands_names)
    elif not bands_names:
        pass
    
    if "pupil" in X_name:
        index_to_get = 1 
        axis_to_get = 0
        integrate_pupil = False
    
    elif "envelopes" in X_name.lower() or "tfr" in X_name.lower():
        index_to_get = -1
        axis_to_get = index_band

    big_X_array = create_big_feature_array(
        big_data            = big_data,
        modality            = X_name,
        array_name          = 'feature',
        index_to_get        = index_to_get, #To modify for EEG band. It is now for pupil
        axis_to_get         = axis_to_get,
        keys_list           = keys_list
        )

    if "pupil" in X_name:
        first_derivative = np.diff(big_X_array[0,:], append = 0)
        second_derivative = np.diff(first_derivative, append = 0)
        print('Pupil shape:',big_X_array.shape)
        print('Pupil first derivative shape:',first_derivative.shape)
        print('Pupil second derivative shape:',second_derivative.shape)
        big_X_array = np.stack(
            (big_X_array[0,:],first_derivative,second_derivative),
            axis=0
        )
        
    if chan_select_args:
        channel_mask = get_specific_location(big_data, **chan_select_args)
        big_X_array = big_X_array[channel_mask,...]
    
    cap_names_list = extract_cap_name_list(big_data,keys_list)
    real_cap_name = get_real_cap_name(cap_name,cap_names_list)
    cap_index = [cap_names_list.index(cap) for cap in real_cap_name][0]
    
    if normalization == 'zscore':
        big_X_array = scipy.stats.zscore(big_X_array,axis=1)
    
    windowed_X = build_windowed_data(big_X_array,
                                     window_length,
                                     reduction_method)

    if integrate_pupil:
        pupil_array = create_big_feature_array(
            big_data            = big_data,
            modality            = 'pupil',
            array_name          = 'feature',
            index_to_get        = 1,
            axis_to_get         = 0,
            keys_list           = keys_list
            ) 

        if normalization == 'zscore':
            pupil_array = scipy.stats.zscore(pupil_array,axis=1)
        
        windowed_pupil = build_windowed_data(pupil_array,
                                             window_length,
                                             reduction_method)
    
    
    if integrate_pupil:
        windowed_X = np.concatenate(
            (windowed_X, windowed_pupil),
            axis=1
            )
    
    big_Y_array = create_big_feature_array(
        big_data            = big_data,
        modality            = 'brainstates', 
        array_name          = 'feature',
        index_to_get        = cap_index,
        axis_to_get         = 0,
        keys_list           = keys_list
        )

    if normalization == 'zscore':
        big_Y_array = scipy.stats.zscore(big_Y_array,axis=1)
            
    windowed_Y = np.squeeze(big_Y_array[:,window_length:])

    return windowed_X, windowed_Y

def thresholding_data_rejection(mask: np.ndarray,
                                threshold: int = 20,
                                ) -> np.ndarray[bool]:
    """By studying the mask, reject the window that have too much False.
    
    Based on the windowed mask, it evaluate the amount of data rejected and then
    generate a 1 dimensional boolean mask that will be applied to the
    X and Y data correct the data.

    Args:
        mask (np.ndarray): 2D array of boolean values
        threshold (int, optional): Percentage of data rejected to reject 
        the entire window. Defaults to 20.

    Returns:
        np.ndarray[bool]: A 1 dimensional boolean mask to apply to the data.
    """
    
    valid_data = np.sum(mask, axis = 1)
    percentage = valid_data * 100 / mask.shape[1]

    return np.squeeze(percentage > (100 - threshold))

def interpolate_nan(arr, strategy='imputer'):
    if strategy == 'imputer':
        arr = SimpleImputer(missing_values=np.nan, 
                                     strategy='median', 
                                     copy=False).fit_transform(arr)
    else:
        for i in range(arr.shape[0]):  # Loop through rows
            # Get indices of non-NaN values
            valid_idx = np.nonzero(~np.isnan(arr[i]))[0]
            invalid_idx = np.nonzero(np.isnan(arr[i]))[0]
            
            # If there are enough valid points for cubic spline interpolation
            if len(valid_idx) > 1:
                # Perform cubic spline interpolation
                cs = CubicSpline(valid_idx, arr[i, valid_idx])
                # Replace NaN values with interpolated values
                arr[i, invalid_idx] = cs(invalid_idx)
    
    return arr

def create_train_test_data(big_data: dict,
                           train_sessions: list[str],
                           test_subject: str,
                           test_sessions: str | list[str],
                           task: str,
                           runs: list[str],
                           cap_name: str,
                           X_name: str,
                           band_name: str,
                           window_length: int = 45,
                           chan_select_args = None,
                           masking = False,
                           ) -> tuple[Any,Any,Any,Any]:
    """Create the train and test data using leave one out method.

    Args:
        big_data (dict): The dictionary containing all the data
        test_subject (str): The subject to leave out for testing

    Returns:
        tuple[np.ndarray]: the train and test data
    """
    subjects = [sub.split('-')[1] for sub in big_data.keys()]
    train_subjects = [subject for subject in subjects if subject != test_subject]
    
    print(f'Train subjects: {train_subjects}')
    print(f'Test subject: {test_subject}')
    
    print(f'Train sessions: {train_sessions}')
    print(f'Test session: {test_sessions}')
    
    train_keys = generate_key_list(
        big_data = big_data,
        subjects = train_subjects,
        sessions = train_sessions,
        task     = task,
        runs     = runs
        )
    
    test_keys = generate_key_list(
        big_data = big_data,
        subjects = [test_subject],
        sessions = test_sessions,
        task     = task,
        runs     = runs
        )
    
    if test_keys == []:
        raise ValueError(f'No data for:sub-{test_subject}_ses-{test_sessions}')
    
    X_train, Y_train = create_X_and_Y(
        big_data         = big_data,
        keys_list        = train_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_train shape: {X_train.shape}')
    print(f'Y_train shape: {Y_train.shape}')

    
    X_test, Y_test = create_X_and_Y(
        big_data         = big_data,
        keys_list        = test_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_test shape: {X_test.shape}')
    print(f'Y_test shape: {Y_test.shape}')
    
    if masking:
        train_mask = build_windowed_mask(big_data,
                                        key_list = train_keys,
                                        window_length=window_length,
                                        modalities = ['brainstates','pupil']) # !!! TEMP FIX
        
        test_mask = build_windowed_mask(big_data,
                                        key_list=test_keys, 
                                        window_length=window_length,
                                        modalities=['brainstates','pupil']) # !!! TEMP FIX
        
        if "pupil" in X_name:
            train_mask = np.repeat(train_mask,3, axis=1)
            test_mask = np.repeat(test_mask,3, axis=1)
        
        
        window_rejection_mask_train = thresholding_data_rejection(train_mask)
        X_train_mask = train_mask[window_rejection_mask_train,:]
        X_train = X_train[window_rejection_mask_train,:]
        X_train[~X_train_mask] = np.nan
        X_train = interpolate_nan(X_train, strategy='imputer')
        Y_train = Y_train[window_rejection_mask_train]

        print(f'Train mask shape: {train_mask.shape}')
        print(f'Test mask shape: {test_mask.shape}')
        print(f'X_train shape after masking: {X_train.shape}')
        print(f'Y_train shape after masking: {Y_train.shape}')
        
        window_rejection_mask_test = thresholding_data_rejection(test_mask)
        print(f'X_test shape after masking = {X_test.shape}')
        print(f'Y_test shape after masking = {Y_test.shape}')
        X_test_mask = test_mask[window_rejection_mask_test,:]
        X_test = X_test[window_rejection_mask_test,:]
        X_test[~X_test_mask] = np.nan
        X_test = interpolate_nan(X_test, strategy='imputer')
        Y_test = Y_test[window_rejection_mask_test]
    
    return (X_train, 
            Y_train, 
            X_test, 
            Y_test)

def train_model(big_data,
                test_subject,
                train_sessions,
                test_sessions,
                task,
                runs,
                cap_name,
                X_name,
                band_name,
                window_length,
                chan_select_args = None,
                masking = False,
                model_name = 'ridge',
                viz_path = False):
    
    try:
        X_train, Y_train, X_test, Y_test = create_train_test_data(
        big_data         = big_data,
        train_sessions   = train_sessions,
        test_subject     = test_subject,
        test_sessions    = test_sessions,
        task             = task,
        runs             = runs,
        cap_name         = cap_name,
        X_name           = X_name,
        band_name        = band_name,
        window_length    = window_length,
        chan_select_args = chan_select_args,
        masking          = masking
            )
    except Exception as e:
        print(e)
        #raise e
        return None, None, None, None, None

    if 'ridge' in model_name.lower():
        model = sklearn.linear_model.RidgeCV(cv = 5)

    elif model_name.lower() == 'lasso': 

        alphas = np.linspace(1e-7,1e-4,1000)
        model = sklearn.linear_model.LassoCV(max_iter=10000,
                                             alphas = alphas
                                             )
    
    elif model_name.lower() == 'lassolars':
        model = sklearn.linear_model.LassoLarsCV(max_iter=10000,
                                                 max_n_alphas=1000)

    elif 'hist' in model_name.lower():
        model = HistGradientBoostingRegressor(max_iter=1000)

    elif 'forest' in model_name.lower():
        model = RandomForestRegressor(criterion = 'absolute_error', 
                                    max_features = 'log2', 
                                    n_estimators = 800)
        
    elif 'elastic' in model_name.lower():
        model = sklearn.linear_model.ElasticNetCV(max_iter=10000)
        
    model.fit(X_train,Y_train)
    if viz_path:
        plot_path(X_train,Y_train)

    return model, X_train, Y_train, X_test, Y_test

def plot_path(X_train, Y_train):
    alphas = np.linspace(1e-6,1e-3,1000)
    alphas_lasso, coef_lasso, _ = sklearn.linear_model.lasso_path(
        X_train, 
        Y_train, 
        alphas = alphas,
        max_iter = 10000
        )
    
    plt.figure(figsize=(12, 6))
    plt.plot(alphas_lasso,coef_lasso.T)
    plt.xlabel('alpha')
    plt.ylabel('coefficients')
    plt.title('Coefficient path')
    plt.legend(['delta','theta','alpha','beta','gamma'])
    plt.show()

In [4]:
keys_list = generate_key_list(subjects = ['01','02'],
                              sessions = ['01','02'],
                              task = '01BlinksRemoved',
                              big_data = big_d)
array = create_big_feature_array(big_data = big_d,
                                 modality = 'pupil',
                                 array_name = 'feature',
                                 index_to_get = 1,
                                 axis_to_get = 0,
                                 keys_list = keys_list)

TypeError: generate_key_list() missing 1 required positional argument: 'runs'

In [5]:
keys_list = generate_key_list(subjects = ['01','02'],
                              sessions = ['01','02'],
                              task = 'checker',
                              runs = '01BlinksRemoved'
                              big_data = big_d)
array = create_big_feature_array(big_data = big_d,
                                 modality = 'pupil',
                                 array_name = 'feature',
                                 index_to_get = 1,
                                 axis_to_get = 0,
                                 keys_list = keys_list)

SyntaxError: invalid syntax. Perhaps you forgot a comma? (<ipython-input-5-5666ffb7b577>, line 4)

In [6]:
keys_list = generate_key_list(subjects = ['01','02'],
                              sessions = ['01','02'],
                              task = 'checker',
                              runs = '01BlinksRemoved',
                              big_data = big_d)
array = create_big_feature_array(big_data = big_d,
                                 modality = 'pupil',
                                 array_name = 'feature',
                                 index_to_get = 1,
                                 axis_to_get = 0,
                                 keys_list = keys_list)

In [7]:
array

[]

In [8]:
keys_list

[]

In [9]:
keys_list = generate_key_list(subjects = ['01','02'],
                              sessions = ['01','02'],
                              task = 'checker',
                              runs = ['01BlinksRemoved'],
                              big_data = big_d)
array = create_big_feature_array(big_data = big_d,
                                 modality = 'pupil',
                                 array_name = 'feature',
                                 index_to_get = 1,
                                 axis_to_get = 0,
                                 keys_list = keys_list)

In [10]:
array

[array([[ 732.19      ,  721.63075631,  731.01113221,  753.43031701,
          781.9875    ,  809.78187049,  829.91261779,  835.47893119,
          819.58      ,  779.46340235,  728.97027164,  686.09013012,
          668.8125    ,  687.99756699,  725.98817064,  757.99781397,
          759.24      ,  713.63422033,  639.92392081,  565.55853589,
          517.9875    ,  517.05578608,  552.19052111,  605.21437059,
          657.95      ,  694.58314317,  708.75180725,  696.4570677 ,
          653.7       ,  581.37339904,  499.93693739,  434.74200704,
          411.14      ,  445.37544818,  517.26544321,  597.52021663,
          656.85      ,  674.3974645 ,  663.03503979,  644.06759519,
          638.8       ,  660.9353579 ,  693.76971012,  712.99733229,
          694.3125    ,  624.62223671,  535.68455721,  470.4702241 ,
          471.95      ,  566.8488593 ,  716.90862353,  867.625326  ,
          964.495     ,  968.41877139,  901.91813616,  802.91968285,
          709.35      ,  652.13738

In [11]:
array.shape

AttributeError: 'list' object has no attribute 'shape'

In [12]:
a = np.array(array)

In [13]:
a.shape

(2, 1, 756)

Restarted mne (Python 3.12.2)

In [1]:
import os

nthreads = "32" # 64 on synapse
os.environ["OMP_NUM_THREADS"] = nthreads
os.environ["OPENBLAS_NUM_THREADS"] = nthreads
os.environ["MKL_NUM_THREADS"] = nthreads
os.environ["VECLIB_MAXIMUM_THREADS"] = nthreads
os.environ["NUMEXPR_NUM_THREADS"] = nthreads
import matplotlib.pyplot as plt
import scipy.stats
import sklearn
from sklearn.ensemble import (HistGradientBoostingRegressor, 
                              RandomForestRegressor)
from sklearn.impute import SimpleImputer
import numpy as np
import sklearn.linear_model
from scipy.interpolate import CubicSpline
import sklearn.model_selection
from typing import List, Dict, Union, Optional
from typing import Any
import pickle
import seaborn as sns
import scipy
from numpy.lib.stride_tricks import sliding_window_view

In [2]:
def filter_data(data: np.ndarray, 
                low_freq_cutoff: float | None = None,
                high_freq_cutoff: float | None = None,
                ):
    """Filter the data using a bandpass filter.

    Args:
        data (np.ndarray): The data to filter
        low_freq_cutoff (float | None): The lower bound to filter. 
                                        Defaults to None.
        high_freq_cutoff (float, optional): The higher bound to filter. 
                                            Defaults to 0.1.

    Returns:
        data: _description_
    """
    if high_freq_cutoff and not low_freq_cutoff:
        filter_type = 'low'
        freq = high_freq_cutoff
    elif low_freq_cutoff and not high_freq_cutoff:
        filter_type = 'high'
        freq = low_freq_cutoff
    elif high_freq_cutoff and low_freq_cutoff:
        filter_type = 'band'
        freq = [low_freq_cutoff, high_freq_cutoff]
    
    filtered_data = scipy.signal.butter(
        4, 
        freq, 
        btype=filter_type,
        output='sos'
        )
    filtered_data = scipy.signal.sosfilt(filtered_data, data, axis=1)
    return filtered_data

def generate_key_list(subjects: list[str] | str,
                      sessions: list[str] | str,
                      task: str,
                      runs: list[str] | str,
                      big_data: dict | None,
                      return_dict = False) -> list[tuple[str, str, str, str]]:
    """Generate a list of keys to access the data in the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        subjects (list[str] | str): The list of subjects to consider
        sessions (list[str] | str): The list of sessions to consider
        tasks (str): The task to consider
        runs (list[str] | str): The list of runs to consider
    
    Returns:
        list[tuple[str]]: The list of keys to access the data
    """
    key_list = list()
    for subject in subjects:
        for session in sessions:
            for run in runs:
                try:
                    big_data[f'sub-{subject}'][f'ses-{session}'][task][f'run-{run}']
                    key_list.append((
                        f'sub-{subject}',
                        f'ses-{session}',
                        task,
                        f'run-{run}'
                        ))
                    
                except:
                    continue
    
    return key_list

def extract_cap_name_list(big_data: dict,
                       keys_list: list[tuple[str, ...]]) -> list[str]:
    """Extract the list of CAP names from the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        keys_list (list): The list of keys to access the data in the dictionary.
    
    Returns:
        list: The list of CAP names
    """
    subject, session, task, run = keys_list[0]
    return big_data[subject][session][task][run]['brainstates']['labels']

def get_real_cap_name(cap_names: str | list[str],
                      cap_list: list[str]) -> list:
    """Get the real CAP name based on a substring from the list of CAP names.
    
    Args:
        cap_name (str): The substring to look for in the list of CAP names
        cap_list (list): The list of CAP names
    
    Returns:
        str: The real CAP name
    """
    real_cap_names = list()
    if isinstance(cap_names,str):
        cap_names = [cap_names]
    for cap_name in cap_names:
        real_cap_names.extend([cap for cap in cap_list if cap_name in cap])
    
    return real_cap_names

def create_big_feature_array(big_data: dict,
                             modality: str,
                             array_name: str,
                             index_to_get: int | None,
                             axis_to_get: int | None,
                             keys_list: list[tuple[str, ...]],
                             axis_to_concatenate: int = 1,
                             subject_agnostic: bool = False
                             ) -> np.ndarray:
    """This is to make a big numpy array across subject.

    It concatenates the features of interest along the time axis (2nd dim).

    Args:
        big_data (dict): The dictionary containing all the data
        to_concat (str): The name of the feature to concatenate (EEGBandEnvelope
                            for example). This choose the feature to consider 
                            later as X.
        index_to_get (int): The index (on the third dimension) of the frequency
                                of interest (or the frequency band). 
                                If None, the entire array is considered.
        axis_to_get (int): The axis to get the data from. If None, the entire
                                array is considered.
        keys_list (list): The list of keys to access the data in the dictionary.
    """
    
    concatenation_list = list()
    for keys in keys_list:
        subject, session, task, run = keys
        if isinstance(index_to_get, int) and isinstance(axis_to_get, int):
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality
                              ][array_name].take(
                                index_to_get,
                                axis = axis_to_get
                            )
        else:
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality][array_name]
        
        if extracted_array.ndim < 2:
            extracted_array = np.reshape(extracted_array,(1,extracted_array.shape[0]))
        
        #extracted_array = filter_data(extracted_array,high_freq_cutoff=0.1)
        concatenation_list.append(extracted_array)
    
    if subject_agnostic:
        return np.concatenate(concatenation_list,axis = axis_to_concatenate)
    else:
        return np.array(concatenation_list)

def _find_item(desired_key: str, obj: Dict[str, Any]) -> Any:
    """Find any item in an encapsulated dictionary."

    Args:
        desired_key (str): They key to look for.
        obj (Dict[str, Any]): the dictionary.

    Returns:
        Any: The returned item found in the encapsulated dictionary.
    """
    if obj.get(desired_key) is not None:
        return obj[desired_key]
    
    for value in obj.values():
        if isinstance(value, dict):
            item = _find_item(desired_key, value)
            if item:
                return item

def get_specific_location(big_data: Dict, 
                          channel_names: Optional[List[str]] = None, 
                          anatomical_location: Optional[List[str]] = None, 
                          laterality: Optional[List[str]] = None) -> Union[np.ndarray, None]:
    """
    Filters the channels based on anatomical location, laterality, and channel names.

    Parameters:
    - big_data (dict): The dictionary containing channel information.
    - channel_names (list[str], optional): List of channel names to filter.
    - anatomical_location (list[str], optional): List of anatomical locations to filter.
    - laterality (list[str], optional): List of lateralities to filter.

    Returns:
    - np.ndarray | None: A boolean array indicating the filtered channels or None if no channel info is found.
    """
    channel_info = _find_item("channels_info", big_data)
    if not channel_info:
        return None

    mask = np.zeros(len(channel_info['channel_name']), dtype=bool)

    if anatomical_location:
        
        anatomy_mask = np.isin(
            channel_info.get('anatomy', []), 
            anatomical_location
            )
        
        mask = np.logical_or(mask, anatomy_mask)

    if laterality:
        
        laterality_mask = np.isin(
            channel_info.get('laterality', []), 
            laterality
            )

        if anatomical_location:
            comparison = getattr(np, 'logical_and')
        else:
            comparison = getattr(np, 'logical_or')
        
        mask = comparison(mask, laterality_mask)

    if channel_names:
        
        channel_mask = np.isin(
            channel_info.get('channel_name', []), 
            channel_names
            )
        
        mask = np.logical_or(mask, channel_mask)

    return mask if mask.any() else None

def combine_masks(big_data:dict,
                  key_list: list,
                  modalities: list = ['EEGbandsEnvelopes','brainstates']
                  ) -> np.ndarray[bool]:
    masks = []
    for modality in modalities:
        if "envelopes" in modality.lower() or "tfr" in modality.lower():
            array_name = 'artifact_mask'
            index_to_get = None
            axis_to_get = None

        else:
            array_name = 'feature'
            index_to_get = -1
            axis_to_get = 0
        
        temp_mask = create_big_feature_array(
            big_data = big_data,
            modality = modality,
            array_name = array_name,
            index_to_get = index_to_get,
            axis_to_get=axis_to_get,
            keys_list=key_list,
            axis_to_concatenate=0
        )
        masks.append(temp_mask.flatten() > 0.5)
    
    masks = np.array(masks)
    overall_mask = np.all(masks, axis = 0)
    print('Overall mask shape:',overall_mask.shape)
    return np.all(masks, axis = 0)
        
def build_windowed_mask(big_data: dict,
                        key_list:list,
                        window_length: int = 45,
                        modalities = ['brainstates']
                        ) -> np.ndarray:
    """Build the mask based on the brainstate and EEG ones.
    
    The mask will be windowed to match the windowed data.

    Args:
        big_data (dict): The dictionary containing all the data
        window_length (int, optional): The length of the sliding window in
                                       samples. Defaults to 45.
        steps (int, optional): The sliding steps in samples. Defaults to 1.

    Returns:
        np.ndarray: The windowed mask
    """

    joined_masks = combine_masks(big_data,
                                 key_list,
                                 modalities = modalities)
    
    windowed_mask = sliding_window_view(joined_masks[:-1], 
                                        window_shape=window_length,
                                        axis = 0)
    print('Windowed mask shape:',windowed_mask.shape)
    if np.ndim(windowed_mask) < 3:
        return windowed_mask
    else:
        # Take the case of EEG channels. If there is one channel not good, reject the entire window.
        return np.all(windowed_mask, axis = 1)

def build_windowed_data(array: np.ndarray,
                        window_length: int = 45,
                        reduction: str = "flatten") -> np.ndarray:
    if array.ndim == 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:-1], 
            window_shape=window_length, 
        )

    elif array.ndim > 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:,:-1,...], 
            window_shape=window_length, 
            axis=1
        )
        new_shape_X = (windowed_data.shape[1], -1) + windowed_data.shape[3:]
        windowed_data = windowed_data.transpose(1, 0, 2, *range(3, array.ndim + 1))
        windowed_data = windowed_data.reshape(new_shape_X)
        
        if array.ndim > 2:
            if reduction == 'flatten':
                windowed_data = windowed_data.reshape(windowed_data.shape[0], -1)
            
            elif reduction == 'gfp':
                windowed_data = np.squeeze(np.var(windowed_data, axis=0))
            
    return windowed_data
    
def create_X_and_Y(big_data: dict,
                   keys_list: list[tuple[str, ...]],
                   X_name: str,
                   cap_name: str,
                   bands_names: str | list | None =  None,
                   chan_select_args: Dict[str,str] | None = None,
                   normalization: str | None = 'zscore',
                   reduction_method: str = 'flatten',
                   window_length: int = 45,
                   integrate_pupil: bool = True,
                  ) -> tuple[Any,Any]:
    
    bands_list = ['delta','theta','alpha','beta','gamma']

    if isinstance(bands_names,list):
        index_band = [bands_list.index(band) for band in bands_names]

    elif isinstance(bands_names, str):
        index_band = bands_list.index(bands_names)
    elif not bands_names:
        pass
    
    if "pupil" in X_name:
        index_to_get = 1 
        axis_to_get = 0
        integrate_pupil = False
    
    elif "envelopes" in X_name.lower() or "tfr" in X_name.lower():
        index_to_get = -1
        axis_to_get = index_band

    big_X_array = create_big_feature_array(
        big_data            = big_data,
        modality            = X_name,
        array_name          = 'feature',
        index_to_get        = index_to_get, #To modify for EEG band. It is now for pupil
        axis_to_get         = axis_to_get,
        keys_list           = keys_list
        )

    if "pupil" in X_name:
        first_derivative = np.diff(big_X_array[0,:], append = 0)
        second_derivative = np.diff(first_derivative, append = 0)
        print('Pupil shape:',big_X_array.shape)
        print('Pupil first derivative shape:',first_derivative.shape)
        print('Pupil second derivative shape:',second_derivative.shape)
        big_X_array = np.stack(
            (big_X_array[0,:],first_derivative,second_derivative),
            axis=0
        )
        
    if chan_select_args:
        channel_mask = get_specific_location(big_data, **chan_select_args)
        big_X_array = big_X_array[channel_mask,...]
    
    cap_names_list = extract_cap_name_list(big_data,keys_list)
    real_cap_name = get_real_cap_name(cap_name,cap_names_list)
    cap_index = [cap_names_list.index(cap) for cap in real_cap_name][0]
    
    if normalization == 'zscore':
        big_X_array = scipy.stats.zscore(big_X_array,axis=1)
    
    windowed_X = build_windowed_data(big_X_array,
                                     window_length,
                                     reduction_method)

    if integrate_pupil:
        pupil_array = create_big_feature_array(
            big_data            = big_data,
            modality            = 'pupil',
            array_name          = 'feature',
            index_to_get        = 1,
            axis_to_get         = 0,
            keys_list           = keys_list
            ) 

        if normalization == 'zscore':
            pupil_array = scipy.stats.zscore(pupil_array,axis=1)
        
        windowed_pupil = build_windowed_data(pupil_array,
                                             window_length,
                                             reduction_method)
    
    
    if integrate_pupil:
        windowed_X = np.concatenate(
            (windowed_X, windowed_pupil),
            axis=1
            )
    
    big_Y_array = create_big_feature_array(
        big_data            = big_data,
        modality            = 'brainstates', 
        array_name          = 'feature',
        index_to_get        = cap_index,
        axis_to_get         = 0,
        keys_list           = keys_list
        )

    if normalization == 'zscore':
        big_Y_array = scipy.stats.zscore(big_Y_array,axis=1)
            
    windowed_Y = np.squeeze(big_Y_array[:,window_length:])

    return windowed_X, windowed_Y

def thresholding_data_rejection(mask: np.ndarray,
                                threshold: int = 20,
                                ) -> np.ndarray[bool]:
    """By studying the mask, reject the window that have too much False.
    
    Based on the windowed mask, it evaluate the amount of data rejected and then
    generate a 1 dimensional boolean mask that will be applied to the
    X and Y data correct the data.

    Args:
        mask (np.ndarray): 2D array of boolean values
        threshold (int, optional): Percentage of data rejected to reject 
        the entire window. Defaults to 20.

    Returns:
        np.ndarray[bool]: A 1 dimensional boolean mask to apply to the data.
    """
    
    valid_data = np.sum(mask, axis = 1)
    percentage = valid_data * 100 / mask.shape[1]

    return np.squeeze(percentage > (100 - threshold))

def interpolate_nan(arr, strategy='imputer'):
    if strategy == 'imputer':
        arr = SimpleImputer(missing_values=np.nan, 
                                     strategy='median', 
                                     copy=False).fit_transform(arr)
    else:
        for i in range(arr.shape[0]):  # Loop through rows
            # Get indices of non-NaN values
            valid_idx = np.nonzero(~np.isnan(arr[i]))[0]
            invalid_idx = np.nonzero(np.isnan(arr[i]))[0]
            
            # If there are enough valid points for cubic spline interpolation
            if len(valid_idx) > 1:
                # Perform cubic spline interpolation
                cs = CubicSpline(valid_idx, arr[i, valid_idx])
                # Replace NaN values with interpolated values
                arr[i, invalid_idx] = cs(invalid_idx)
    
    return arr

def create_train_test_data(big_data: dict,
                           train_sessions: list[str],
                           test_subject: str,
                           test_sessions: str | list[str],
                           task: str,
                           runs: list[str],
                           cap_name: str,
                           X_name: str,
                           band_name: str,
                           window_length: int = 45,
                           chan_select_args = None,
                           masking = False,
                           ) -> tuple[Any,Any,Any,Any]:
    """Create the train and test data using leave one out method.

    Args:
        big_data (dict): The dictionary containing all the data
        test_subject (str): The subject to leave out for testing

    Returns:
        tuple[np.ndarray]: the train and test data
    """
    subjects = [sub.split('-')[1] for sub in big_data.keys()]
    train_subjects = [subject for subject in subjects if subject != test_subject]
    
    print(f'Train subjects: {train_subjects}')
    print(f'Test subject: {test_subject}')
    
    print(f'Train sessions: {train_sessions}')
    print(f'Test session: {test_sessions}')
    
    train_keys = generate_key_list(
        big_data = big_data,
        subjects = train_subjects,
        sessions = train_sessions,
        task     = task,
        runs     = runs
        )
    
    test_keys = generate_key_list(
        big_data = big_data,
        subjects = [test_subject],
        sessions = test_sessions,
        task     = task,
        runs     = runs
        )
    
    if test_keys == []:
        raise ValueError(f'No data for:sub-{test_subject}_ses-{test_sessions}')
    
    X_train, Y_train = create_X_and_Y(
        big_data         = big_data,
        keys_list        = train_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_train shape: {X_train.shape}')
    print(f'Y_train shape: {Y_train.shape}')

    
    X_test, Y_test = create_X_and_Y(
        big_data         = big_data,
        keys_list        = test_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_test shape: {X_test.shape}')
    print(f'Y_test shape: {Y_test.shape}')
    
    if masking:
        train_mask = build_windowed_mask(big_data,
                                        key_list = train_keys,
                                        window_length=window_length,
                                        modalities = ['brainstates','pupil']) # !!! TEMP FIX
        
        test_mask = build_windowed_mask(big_data,
                                        key_list=test_keys, 
                                        window_length=window_length,
                                        modalities=['brainstates','pupil']) # !!! TEMP FIX
        
        if "pupil" in X_name:
            train_mask = np.repeat(train_mask,3, axis=1)
            test_mask = np.repeat(test_mask,3, axis=1)
        
        
        window_rejection_mask_train = thresholding_data_rejection(train_mask)
        X_train_mask = train_mask[window_rejection_mask_train,:]
        X_train = X_train[window_rejection_mask_train,:]
        X_train[~X_train_mask] = np.nan
        X_train = interpolate_nan(X_train, strategy='imputer')
        Y_train = Y_train[window_rejection_mask_train]

        print(f'Train mask shape: {train_mask.shape}')
        print(f'Test mask shape: {test_mask.shape}')
        print(f'X_train shape after masking: {X_train.shape}')
        print(f'Y_train shape after masking: {Y_train.shape}')
        
        window_rejection_mask_test = thresholding_data_rejection(test_mask)
        print(f'X_test shape after masking = {X_test.shape}')
        print(f'Y_test shape after masking = {Y_test.shape}')
        X_test_mask = test_mask[window_rejection_mask_test,:]
        X_test = X_test[window_rejection_mask_test,:]
        X_test[~X_test_mask] = np.nan
        X_test = interpolate_nan(X_test, strategy='imputer')
        Y_test = Y_test[window_rejection_mask_test]
    
    return (X_train, 
            Y_train, 
            X_test, 
            Y_test)

def train_model(big_data,
                test_subject,
                train_sessions,
                test_sessions,
                task,
                runs,
                cap_name,
                X_name,
                band_name,
                window_length,
                chan_select_args = None,
                masking = False,
                model_name = 'ridge',
                viz_path = False):
    
    try:
        X_train, Y_train, X_test, Y_test = create_train_test_data(
        big_data         = big_data,
        train_sessions   = train_sessions,
        test_subject     = test_subject,
        test_sessions    = test_sessions,
        task             = task,
        runs             = runs,
        cap_name         = cap_name,
        X_name           = X_name,
        band_name        = band_name,
        window_length    = window_length,
        chan_select_args = chan_select_args,
        masking          = masking
            )
    except Exception as e:
        print(e)
        #raise e
        return None, None, None, None, None

    if 'ridge' in model_name.lower():
        model = sklearn.linear_model.RidgeCV(cv = 5)

    elif model_name.lower() == 'lasso': 

        alphas = np.linspace(1e-7,1e-4,1000)
        model = sklearn.linear_model.LassoCV(max_iter=10000,
                                             alphas = alphas
                                             )
    
    elif model_name.lower() == 'lassolars':
        model = sklearn.linear_model.LassoLarsCV(max_iter=10000,
                                                 max_n_alphas=1000)

    elif 'hist' in model_name.lower():
        model = HistGradientBoostingRegressor(max_iter=1000)

    elif 'forest' in model_name.lower():
        model = RandomForestRegressor(criterion = 'absolute_error', 
                                    max_features = 'log2', 
                                    n_estimators = 800)
        
    elif 'elastic' in model_name.lower():
        model = sklearn.linear_model.ElasticNetCV(max_iter=10000)
        
    model.fit(X_train,Y_train)
    if viz_path:
        plot_path(X_train,Y_train)

    return model, X_train, Y_train, X_test, Y_test

def plot_path(X_train, Y_train):
    alphas = np.linspace(1e-6,1e-3,1000)
    alphas_lasso, coef_lasso, _ = sklearn.linear_model.lasso_path(
        X_train, 
        Y_train, 
        alphas = alphas,
        max_iter = 10000
        )
    
    plt.figure(figsize=(12, 6))
    plt.plot(alphas_lasso,coef_lasso.T)
    plt.xlabel('alpha')
    plt.ylabel('coefficients')
    plt.title('Coefficient path')
    plt.legend(['delta','theta','alpha','beta','gamma'])
    plt.show()

In [3]:
keys_list = generate_key_list(subjects = ['01','02'],
                              sessions = ['01','02'],
                              task = 'checker',
                              runs = ['01BlinksRemoved'],
                              big_data = big_d)
array = create_big_feature_array(big_data = big_d,
                                 modality = 'pupil',
                                 array_name = 'feature',
                                 index_to_get = 1,
                                 axis_to_get = 0,
                                 keys_list = keys_list)

NameError: name 'big_d' is not defined

In [4]:
def parse_filename(filename: str | os.PathLike) -> dict[str,str]:
    """parse filename that are somewhat like BIDS but not rigoursly like it.

    Args:
        filename (str | os.PathLike): The filename to be parsed

    Returns:
        dict[str,str]: The filename parts
    """
    splitted_filename = filename.split('_')
    filename_parts = {}
    for part in splitted_filename:
        splitted_part = part.split('-')
        if splitted_part[0] in ['sub','ses','run','task']:
            label, value = splitted_part
            filename_parts[label] = value
        
    return filename_parts

def combine_data_from_filename(reading_dir: str | os.PathLike,
                               task:str = "checker",
                               run: str = "01"):
    """Combine the data from the files in the reading directory.

    Args:
        reading_dir (str | os.PathLike): The directory where the data is stored.
        task (str, optional): The task to concatenate. Defaults to "checker".
        run (str, optional): Either it's run-01 or run-01BlinksRemoved. 
                             Defaults to "01".

    Returns:
        _type_: _description_
    """
    big_data = dict()
    filename_list = os.listdir(reading_dir)
    for filename in filename_list:
        filename_parts = parse_filename(filename)
        subject = filename_parts["sub"]
        with open(os.path.join(reading_dir,filename), 'rb') as file: 
            data = pickle.load(file)
        if task in filename_parts['task'] and filename_parts['run'] == run:
            wrapped_data = {
                f'ses-{filename_parts["ses"]}':{
                    filename_parts["task"]:{
                        f'run-{filename_parts["run"]}': data
                    }
                }
            }
            if big_data.get(f'sub-{subject}'):
                big_data[f'sub-{subject}'].update(wrapped_data)
            else:
                big_data[f'sub-{subject}'] = wrapped_data


    return big_data

big_d = combine_data_from_filename('/data2/Projects/eeg_fmri_natview/derivatives/multimodal_prediction_models/data_prep/prediction_model_data_eeg_features_v2/group_data_Hz-3.8',
                                    task = 'checker',
                                    run = '01BlinksRemoved')

In [5]:
keys_list = generate_key_list(subjects = ['01','02'],
                              sessions = ['01','02'],
                              task = 'checker',
                              runs = ['01BlinksRemoved'],
                              big_data = big_d)
array = create_big_feature_array(big_data = big_d,
                                 modality = 'pupil',
                                 array_name = 'feature',
                                 index_to_get = 1,
                                 axis_to_get = 0,
                                 keys_list = keys_list)

In [6]:
array.shape

(2, 1, 756)

In [7]:
masks = combine_masks(big_data = big_d,
                      key_list = keys_list,
                      modalities = ['brainstates','pupil'])

Overall mask shape: (1512,)


Restarted mne (Python 3.12.2)

In [1]:
import os

nthreads = "32" # 64 on synapse
os.environ["OMP_NUM_THREADS"] = nthreads
os.environ["OPENBLAS_NUM_THREADS"] = nthreads
os.environ["MKL_NUM_THREADS"] = nthreads
os.environ["VECLIB_MAXIMUM_THREADS"] = nthreads
os.environ["NUMEXPR_NUM_THREADS"] = nthreads
import matplotlib.pyplot as plt
import scipy.stats
import sklearn
from sklearn.ensemble import (HistGradientBoostingRegressor, 
                              RandomForestRegressor)
from sklearn.impute import SimpleImputer
import numpy as np
import sklearn.linear_model
from scipy.interpolate import CubicSpline
import sklearn.model_selection
from typing import List, Dict, Union, Optional
from typing import Any
import pickle
import seaborn as sns
import scipy
from numpy.lib.stride_tricks import sliding_window_view

In [2]:
def parse_filename(filename: str | os.PathLike) -> dict[str,str]:
    """parse filename that are somewhat like BIDS but not rigoursly like it.

    Args:
        filename (str | os.PathLike): The filename to be parsed

    Returns:
        dict[str,str]: The filename parts
    """
    splitted_filename = filename.split('_')
    filename_parts = {}
    for part in splitted_filename:
        splitted_part = part.split('-')
        if splitted_part[0] in ['sub','ses','run','task']:
            label, value = splitted_part
            filename_parts[label] = value
        
    return filename_parts

def combine_data_from_filename(reading_dir: str | os.PathLike,
                               task:str = "checker",
                               run: str = "01"):
    """Combine the data from the files in the reading directory.

    Args:
        reading_dir (str | os.PathLike): The directory where the data is stored.
        task (str, optional): The task to concatenate. Defaults to "checker".
        run (str, optional): Either it's run-01 or run-01BlinksRemoved. 
                             Defaults to "01".

    Returns:
        _type_: _description_
    """
    big_data = dict()
    filename_list = os.listdir(reading_dir)
    for filename in filename_list:
        filename_parts = parse_filename(filename)
        subject = filename_parts["sub"]
        with open(os.path.join(reading_dir,filename), 'rb') as file: 
            data = pickle.load(file)
        if task in filename_parts['task'] and filename_parts['run'] == run:
            wrapped_data = {
                f'ses-{filename_parts["ses"]}':{
                    filename_parts["task"]:{
                        f'run-{filename_parts["run"]}': data
                    }
                }
            }
            if big_data.get(f'sub-{subject}'):
                big_data[f'sub-{subject}'].update(wrapped_data)
            else:
                big_data[f'sub-{subject}'] = wrapped_data


    return big_data

big_d = combine_data_from_filename('/data2/Projects/eeg_fmri_natview/derivatives/multimodal_prediction_models/data_prep/prediction_model_data_eeg_features_v2/group_data_Hz-3.8',
                                    task = 'checker',
                                    run = '01BlinksRemoved')

In [3]:
def filter_data(data: np.ndarray, 
                low_freq_cutoff: float | None = None,
                high_freq_cutoff: float | None = None,
                ):
    """Filter the data using a bandpass filter.

    Args:
        data (np.ndarray): The data to filter
        low_freq_cutoff (float | None): The lower bound to filter. 
                                        Defaults to None.
        high_freq_cutoff (float, optional): The higher bound to filter. 
                                            Defaults to 0.1.

    Returns:
        data: _description_
    """
    if high_freq_cutoff and not low_freq_cutoff:
        filter_type = 'low'
        freq = high_freq_cutoff
    elif low_freq_cutoff and not high_freq_cutoff:
        filter_type = 'high'
        freq = low_freq_cutoff
    elif high_freq_cutoff and low_freq_cutoff:
        filter_type = 'band'
        freq = [low_freq_cutoff, high_freq_cutoff]
    
    filtered_data = scipy.signal.butter(
        4, 
        freq, 
        btype=filter_type,
        output='sos'
        )
    filtered_data = scipy.signal.sosfilt(filtered_data, data, axis=1)
    return filtered_data

def generate_key_list(subjects: list[str] | str,
                      sessions: list[str] | str,
                      task: str,
                      runs: list[str] | str,
                      big_data: dict | None,
                      return_dict = False) -> list[tuple[str, str, str, str]]:
    """Generate a list of keys to access the data in the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        subjects (list[str] | str): The list of subjects to consider
        sessions (list[str] | str): The list of sessions to consider
        tasks (str): The task to consider
        runs (list[str] | str): The list of runs to consider
    
    Returns:
        list[tuple[str]]: The list of keys to access the data
    """
    key_list = list()
    for subject in subjects:
        for session in sessions:
            for run in runs:
                try:
                    big_data[f'sub-{subject}'][f'ses-{session}'][task][f'run-{run}']
                    key_list.append((
                        f'sub-{subject}',
                        f'ses-{session}',
                        task,
                        f'run-{run}'
                        ))
                    
                except:
                    continue
    
    return key_list

def extract_cap_name_list(big_data: dict,
                       keys_list: list[tuple[str, ...]]) -> list[str]:
    """Extract the list of CAP names from the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        keys_list (list): The list of keys to access the data in the dictionary.
    
    Returns:
        list: The list of CAP names
    """
    subject, session, task, run = keys_list[0]
    return big_data[subject][session][task][run]['brainstates']['labels']

def get_real_cap_name(cap_names: str | list[str],
                      cap_list: list[str]) -> list:
    """Get the real CAP name based on a substring from the list of CAP names.
    
    Args:
        cap_name (str): The substring to look for in the list of CAP names
        cap_list (list): The list of CAP names
    
    Returns:
        str: The real CAP name
    """
    real_cap_names = list()
    if isinstance(cap_names,str):
        cap_names = [cap_names]
    for cap_name in cap_names:
        real_cap_names.extend([cap for cap in cap_list if cap_name in cap])
    
    return real_cap_names

def create_big_feature_array(big_data: dict,
                             modality: str,
                             array_name: str,
                             index_to_get: int | None,
                             axis_to_get: int | None,
                             keys_list: list[tuple[str, ...]],
                             axis_to_concatenate: int = 1,
                             subject_agnostic: bool = False
                             ) -> np.ndarray:
    """This is to make a big numpy array across subject.

    It concatenates the features of interest along the time axis (2nd dim).

    Args:
        big_data (dict): The dictionary containing all the data
        to_concat (str): The name of the feature to concatenate (EEGBandEnvelope
                            for example). This choose the feature to consider 
                            later as X.
        index_to_get (int): The index (on the third dimension) of the frequency
                                of interest (or the frequency band). 
                                If None, the entire array is considered.
        axis_to_get (int): The axis to get the data from. If None, the entire
                                array is considered.
        keys_list (list): The list of keys to access the data in the dictionary.
    """
    
    concatenation_list = list()
    for keys in keys_list:
        subject, session, task, run = keys
        if isinstance(index_to_get, int) and isinstance(axis_to_get, int):
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality
                              ][array_name].take(
                                index_to_get,
                                axis = axis_to_get
                            )
        else:
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality][array_name]
        
        if extracted_array.ndim < 2:
            extracted_array = np.reshape(extracted_array,(1,extracted_array.shape[0]))
        
        #extracted_array = filter_data(extracted_array,high_freq_cutoff=0.1)
        concatenation_list.append(extracted_array)
    
    if subject_agnostic:
        return np.concatenate(concatenation_list,axis = axis_to_concatenate)
    else:
        return np.array(concatenation_list)

def _find_item(desired_key: str, obj: Dict[str, Any]) -> Any:
    """Find any item in an encapsulated dictionary."

    Args:
        desired_key (str): They key to look for.
        obj (Dict[str, Any]): the dictionary.

    Returns:
        Any: The returned item found in the encapsulated dictionary.
    """
    if obj.get(desired_key) is not None:
        return obj[desired_key]
    
    for value in obj.values():
        if isinstance(value, dict):
            item = _find_item(desired_key, value)
            if item:
                return item

def get_specific_location(big_data: Dict, 
                          channel_names: Optional[List[str]] = None, 
                          anatomical_location: Optional[List[str]] = None, 
                          laterality: Optional[List[str]] = None) -> Union[np.ndarray, None]:
    """
    Filters the channels based on anatomical location, laterality, and channel names.

    Parameters:
    - big_data (dict): The dictionary containing channel information.
    - channel_names (list[str], optional): List of channel names to filter.
    - anatomical_location (list[str], optional): List of anatomical locations to filter.
    - laterality (list[str], optional): List of lateralities to filter.

    Returns:
    - np.ndarray | None: A boolean array indicating the filtered channels or None if no channel info is found.
    """
    channel_info = _find_item("channels_info", big_data)
    if not channel_info:
        return None

    mask = np.zeros(len(channel_info['channel_name']), dtype=bool)

    if anatomical_location:
        
        anatomy_mask = np.isin(
            channel_info.get('anatomy', []), 
            anatomical_location
            )
        
        mask = np.logical_or(mask, anatomy_mask)

    if laterality:
        
        laterality_mask = np.isin(
            channel_info.get('laterality', []), 
            laterality
            )

        if anatomical_location:
            comparison = getattr(np, 'logical_and')
        else:
            comparison = getattr(np, 'logical_or')
        
        mask = comparison(mask, laterality_mask)

    if channel_names:
        
        channel_mask = np.isin(
            channel_info.get('channel_name', []), 
            channel_names
            )
        
        mask = np.logical_or(mask, channel_mask)

    return mask if mask.any() else None

def combine_masks(big_data:dict,
                  key_list: list,
                  modalities: list = ['EEGbandsEnvelopes','brainstates'],
                  subject_agnostic: bool = False
                  ) -> np.ndarray[bool]:
    masks = []
    for modality in modalities:
        if "envelopes" in modality.lower() or "tfr" in modality.lower():
            array_name = 'artifact_mask'
            index_to_get = None
            axis_to_get = None

        else:
            array_name = 'feature'
            index_to_get = -1
            axis_to_get = 0
        
        temp_mask = create_big_feature_array(
            big_data = big_data,
            modality = modality,
            array_name = array_name,
            index_to_get = index_to_get,
            axis_to_get=axis_to_get,
            keys_list=key_list,
            axis_to_concatenate=0
        )
        if subject_agnostic:
            temp_mask = temp_mask.flatten()
            
        masks.append(temp_mask > 0.5)
    
    masks = np.array(masks)
    overall_mask = np.all(masks, axis = 0)
    print('Overall mask shape:',overall_mask.shape)
    return np.all(masks, axis = 0)
        
def build_windowed_mask(big_data: dict,
                        key_list:list,
                        window_length: int = 45,
                        modalities = ['brainstates']
                        ) -> np.ndarray:
    """Build the mask based on the brainstate and EEG ones.
    
    The mask will be windowed to match the windowed data.

    Args:
        big_data (dict): The dictionary containing all the data
        window_length (int, optional): The length of the sliding window in
                                       samples. Defaults to 45.
        steps (int, optional): The sliding steps in samples. Defaults to 1.

    Returns:
        np.ndarray: The windowed mask
    """

    joined_masks = combine_masks(big_data,
                                 key_list,
                                 modalities = modalities)
    
    windowed_mask = sliding_window_view(joined_masks[:-1], 
                                        window_shape=window_length,
                                        axis = 0)
    print('Windowed mask shape:',windowed_mask.shape)
    if np.ndim(windowed_mask) < 3:
        return windowed_mask
    else:
        # Take the case of EEG channels. If there is one channel not good, reject the entire window.
        return np.all(windowed_mask, axis = 1)

def build_windowed_data(array: np.ndarray,
                        window_length: int = 45,
                        reduction: str = "flatten") -> np.ndarray:
    if array.ndim == 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:-1], 
            window_shape=window_length, 
        )

    elif array.ndim > 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:,:-1,...], 
            window_shape=window_length, 
            axis=1
        )
        new_shape_X = (windowed_data.shape[1], -1) + windowed_data.shape[3:]
        windowed_data = windowed_data.transpose(1, 0, 2, *range(3, array.ndim + 1))
        windowed_data = windowed_data.reshape(new_shape_X)
        
        if array.ndim > 2:
            if reduction == 'flatten':
                windowed_data = windowed_data.reshape(windowed_data.shape[0], -1)
            
            elif reduction == 'gfp':
                windowed_data = np.squeeze(np.var(windowed_data, axis=0))
            
    return windowed_data
    
def create_X_and_Y(big_data: dict,
                   keys_list: list[tuple[str, ...]],
                   X_name: str,
                   cap_name: str,
                   bands_names: str | list | None =  None,
                   chan_select_args: Dict[str,str] | None = None,
                   normalization: str | None = 'zscore',
                   reduction_method: str = 'flatten',
                   window_length: int = 45,
                   integrate_pupil: bool = True,
                  ) -> tuple[Any,Any]:
    
    bands_list = ['delta','theta','alpha','beta','gamma']

    if isinstance(bands_names,list):
        index_band = [bands_list.index(band) for band in bands_names]

    elif isinstance(bands_names, str):
        index_band = bands_list.index(bands_names)
    elif not bands_names:
        pass
    
    if "pupil" in X_name:
        index_to_get = 1 
        axis_to_get = 0
        integrate_pupil = False
    
    elif "envelopes" in X_name.lower() or "tfr" in X_name.lower():
        index_to_get = -1
        axis_to_get = index_band

    big_X_array = create_big_feature_array(
        big_data            = big_data,
        modality            = X_name,
        array_name          = 'feature',
        index_to_get        = index_to_get, #To modify for EEG band. It is now for pupil
        axis_to_get         = axis_to_get,
        keys_list           = keys_list
        )

    if "pupil" in X_name:
        first_derivative = np.diff(big_X_array[0,:], append = 0)
        second_derivative = np.diff(first_derivative, append = 0)
        print('Pupil shape:',big_X_array.shape)
        print('Pupil first derivative shape:',first_derivative.shape)
        print('Pupil second derivative shape:',second_derivative.shape)
        big_X_array = np.stack(
            (big_X_array[0,:],first_derivative,second_derivative),
            axis=0
        )
        
    if chan_select_args:
        channel_mask = get_specific_location(big_data, **chan_select_args)
        big_X_array = big_X_array[channel_mask,...]
    
    cap_names_list = extract_cap_name_list(big_data,keys_list)
    real_cap_name = get_real_cap_name(cap_name,cap_names_list)
    cap_index = [cap_names_list.index(cap) for cap in real_cap_name][0]
    
    if normalization == 'zscore':
        big_X_array = scipy.stats.zscore(big_X_array,axis=1)
    
    windowed_X = build_windowed_data(big_X_array,
                                     window_length,
                                     reduction_method)

    if integrate_pupil:
        pupil_array = create_big_feature_array(
            big_data            = big_data,
            modality            = 'pupil',
            array_name          = 'feature',
            index_to_get        = 1,
            axis_to_get         = 0,
            keys_list           = keys_list
            ) 

        if normalization == 'zscore':
            pupil_array = scipy.stats.zscore(pupil_array,axis=1)
        
        windowed_pupil = build_windowed_data(pupil_array,
                                             window_length,
                                             reduction_method)
    
    
    if integrate_pupil:
        windowed_X = np.concatenate(
            (windowed_X, windowed_pupil),
            axis=1
            )
    
    big_Y_array = create_big_feature_array(
        big_data            = big_data,
        modality            = 'brainstates', 
        array_name          = 'feature',
        index_to_get        = cap_index,
        axis_to_get         = 0,
        keys_list           = keys_list
        )

    if normalization == 'zscore':
        big_Y_array = scipy.stats.zscore(big_Y_array,axis=1)
            
    windowed_Y = np.squeeze(big_Y_array[:,window_length:])

    return windowed_X, windowed_Y

def thresholding_data_rejection(mask: np.ndarray,
                                threshold: int = 20,
                                ) -> np.ndarray[bool]:
    """By studying the mask, reject the window that have too much False.
    
    Based on the windowed mask, it evaluate the amount of data rejected and then
    generate a 1 dimensional boolean mask that will be applied to the
    X and Y data correct the data.

    Args:
        mask (np.ndarray): 2D array of boolean values
        threshold (int, optional): Percentage of data rejected to reject 
        the entire window. Defaults to 20.

    Returns:
        np.ndarray[bool]: A 1 dimensional boolean mask to apply to the data.
    """
    
    valid_data = np.sum(mask, axis = 1)
    percentage = valid_data * 100 / mask.shape[1]

    return np.squeeze(percentage > (100 - threshold))

def interpolate_nan(arr, strategy='imputer'):
    if strategy == 'imputer':
        arr = SimpleImputer(missing_values=np.nan, 
                                     strategy='median', 
                                     copy=False).fit_transform(arr)
    else:
        for i in range(arr.shape[0]):  # Loop through rows
            # Get indices of non-NaN values
            valid_idx = np.nonzero(~np.isnan(arr[i]))[0]
            invalid_idx = np.nonzero(np.isnan(arr[i]))[0]
            
            # If there are enough valid points for cubic spline interpolation
            if len(valid_idx) > 1:
                # Perform cubic spline interpolation
                cs = CubicSpline(valid_idx, arr[i, valid_idx])
                # Replace NaN values with interpolated values
                arr[i, invalid_idx] = cs(invalid_idx)
    
    return arr

def create_train_test_data(big_data: dict,
                           train_sessions: list[str],
                           test_subject: str,
                           test_sessions: str | list[str],
                           task: str,
                           runs: list[str],
                           cap_name: str,
                           X_name: str,
                           band_name: str,
                           window_length: int = 45,
                           chan_select_args = None,
                           masking = False,
                           ) -> tuple[Any,Any,Any,Any]:
    """Create the train and test data using leave one out method.

    Args:
        big_data (dict): The dictionary containing all the data
        test_subject (str): The subject to leave out for testing

    Returns:
        tuple[np.ndarray]: the train and test data
    """
    subjects = [sub.split('-')[1] for sub in big_data.keys()]
    train_subjects = [subject for subject in subjects if subject != test_subject]
    
    print(f'Train subjects: {train_subjects}')
    print(f'Test subject: {test_subject}')
    
    print(f'Train sessions: {train_sessions}')
    print(f'Test session: {test_sessions}')
    
    train_keys = generate_key_list(
        big_data = big_data,
        subjects = train_subjects,
        sessions = train_sessions,
        task     = task,
        runs     = runs
        )
    
    test_keys = generate_key_list(
        big_data = big_data,
        subjects = [test_subject],
        sessions = test_sessions,
        task     = task,
        runs     = runs
        )
    
    if test_keys == []:
        raise ValueError(f'No data for:sub-{test_subject}_ses-{test_sessions}')
    
    X_train, Y_train = create_X_and_Y(
        big_data         = big_data,
        keys_list        = train_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_train shape: {X_train.shape}')
    print(f'Y_train shape: {Y_train.shape}')

    
    X_test, Y_test = create_X_and_Y(
        big_data         = big_data,
        keys_list        = test_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_test shape: {X_test.shape}')
    print(f'Y_test shape: {Y_test.shape}')
    
    if masking:
        train_mask = build_windowed_mask(big_data,
                                        key_list = train_keys,
                                        window_length=window_length,
                                        modalities = ['brainstates','pupil']) # !!! TEMP FIX
        
        test_mask = build_windowed_mask(big_data,
                                        key_list=test_keys, 
                                        window_length=window_length,
                                        modalities=['brainstates','pupil']) # !!! TEMP FIX
        
        if "pupil" in X_name:
            train_mask = np.repeat(train_mask,3, axis=1)
            test_mask = np.repeat(test_mask,3, axis=1)
        
        
        window_rejection_mask_train = thresholding_data_rejection(train_mask)
        X_train_mask = train_mask[window_rejection_mask_train,:]
        X_train = X_train[window_rejection_mask_train,:]
        X_train[~X_train_mask] = np.nan
        X_train = interpolate_nan(X_train, strategy='imputer')
        Y_train = Y_train[window_rejection_mask_train]

        print(f'Train mask shape: {train_mask.shape}')
        print(f'Test mask shape: {test_mask.shape}')
        print(f'X_train shape after masking: {X_train.shape}')
        print(f'Y_train shape after masking: {Y_train.shape}')
        
        window_rejection_mask_test = thresholding_data_rejection(test_mask)
        print(f'X_test shape after masking = {X_test.shape}')
        print(f'Y_test shape after masking = {Y_test.shape}')
        X_test_mask = test_mask[window_rejection_mask_test,:]
        X_test = X_test[window_rejection_mask_test,:]
        X_test[~X_test_mask] = np.nan
        X_test = interpolate_nan(X_test, strategy='imputer')
        Y_test = Y_test[window_rejection_mask_test]
    
    return (X_train, 
            Y_train, 
            X_test, 
            Y_test)

def train_model(big_data,
                test_subject,
                train_sessions,
                test_sessions,
                task,
                runs,
                cap_name,
                X_name,
                band_name,
                window_length,
                chan_select_args = None,
                masking = False,
                model_name = 'ridge',
                viz_path = False):
    
    try:
        X_train, Y_train, X_test, Y_test = create_train_test_data(
        big_data         = big_data,
        train_sessions   = train_sessions,
        test_subject     = test_subject,
        test_sessions    = test_sessions,
        task             = task,
        runs             = runs,
        cap_name         = cap_name,
        X_name           = X_name,
        band_name        = band_name,
        window_length    = window_length,
        chan_select_args = chan_select_args,
        masking          = masking
            )
    except Exception as e:
        print(e)
        #raise e
        return None, None, None, None, None

    if 'ridge' in model_name.lower():
        model = sklearn.linear_model.RidgeCV(cv = 5)

    elif model_name.lower() == 'lasso': 

        alphas = np.linspace(1e-7,1e-4,1000)
        model = sklearn.linear_model.LassoCV(max_iter=10000,
                                             alphas = alphas
                                             )
    
    elif model_name.lower() == 'lassolars':
        model = sklearn.linear_model.LassoLarsCV(max_iter=10000,
                                                 max_n_alphas=1000)

    elif 'hist' in model_name.lower():
        model = HistGradientBoostingRegressor(max_iter=1000)

    elif 'forest' in model_name.lower():
        model = RandomForestRegressor(criterion = 'absolute_error', 
                                    max_features = 'log2', 
                                    n_estimators = 800)
        
    elif 'elastic' in model_name.lower():
        model = sklearn.linear_model.ElasticNetCV(max_iter=10000)
        
    model.fit(X_train,Y_train)
    if viz_path:
        plot_path(X_train,Y_train)

    return model, X_train, Y_train, X_test, Y_test

def plot_path(X_train, Y_train):
    alphas = np.linspace(1e-6,1e-3,1000)
    alphas_lasso, coef_lasso, _ = sklearn.linear_model.lasso_path(
        X_train, 
        Y_train, 
        alphas = alphas,
        max_iter = 10000
        )
    
    plt.figure(figsize=(12, 6))
    plt.plot(alphas_lasso,coef_lasso.T)
    plt.xlabel('alpha')
    plt.ylabel('coefficients')
    plt.title('Coefficient path')
    plt.legend(['delta','theta','alpha','beta','gamma'])
    plt.show()

In [4]:
keys_list = generate_key_list(subjects = ['01','02'],
                              sessions = ['01','02'],
                              task = 'checker',
                              runs = ['01BlinksRemoved'],
                              big_data = big_d)
array = create_big_feature_array(big_data = big_d,
                                 modality = 'pupil',
                                 array_name = 'feature',
                                 index_to_get = 1,
                                 axis_to_get = 0,
                                 keys_list = keys_list)

In [5]:
masks = combine_masks(big_data = big_d,
                      key_list = keys_list,
                      modalities = ['brainstates','pupil'])

Overall mask shape: (2, 1, 756)


Connected to mne (Python 3.12.2)

In [1]:
masks

NameError: name 'masks' is not defined

In [2]:
import os

nthreads = "32" # 64 on synapse
os.environ["OMP_NUM_THREADS"] = nthreads
os.environ["OPENBLAS_NUM_THREADS"] = nthreads
os.environ["MKL_NUM_THREADS"] = nthreads
os.environ["VECLIB_MAXIMUM_THREADS"] = nthreads
os.environ["NUMEXPR_NUM_THREADS"] = nthreads
import matplotlib.pyplot as plt
import scipy.stats
import sklearn
from sklearn.ensemble import (HistGradientBoostingRegressor, 
                              RandomForestRegressor)
from sklearn.impute import SimpleImputer
import numpy as np
import sklearn.linear_model
from scipy.interpolate import CubicSpline
import sklearn.model_selection
from typing import List, Dict, Union, Optional
from typing import Any
import pickle
import seaborn as sns
import scipy
from numpy.lib.stride_tricks import sliding_window_view

In [3]:
def parse_filename(filename: str | os.PathLike) -> dict[str,str]:
    """parse filename that are somewhat like BIDS but not rigoursly like it.

    Args:
        filename (str | os.PathLike): The filename to be parsed

    Returns:
        dict[str,str]: The filename parts
    """
    splitted_filename = filename.split('_')
    filename_parts = {}
    for part in splitted_filename:
        splitted_part = part.split('-')
        if splitted_part[0] in ['sub','ses','run','task']:
            label, value = splitted_part
            filename_parts[label] = value
        
    return filename_parts

def combine_data_from_filename(reading_dir: str | os.PathLike,
                               task:str = "checker",
                               run: str = "01"):
    """Combine the data from the files in the reading directory.

    Args:
        reading_dir (str | os.PathLike): The directory where the data is stored.
        task (str, optional): The task to concatenate. Defaults to "checker".
        run (str, optional): Either it's run-01 or run-01BlinksRemoved. 
                             Defaults to "01".

    Returns:
        _type_: _description_
    """
    big_data = dict()
    filename_list = os.listdir(reading_dir)
    for filename in filename_list:
        filename_parts = parse_filename(filename)
        subject = filename_parts["sub"]
        with open(os.path.join(reading_dir,filename), 'rb') as file: 
            data = pickle.load(file)
        if task in filename_parts['task'] and filename_parts['run'] == run:
            wrapped_data = {
                f'ses-{filename_parts["ses"]}':{
                    filename_parts["task"]:{
                        f'run-{filename_parts["run"]}': data
                    }
                }
            }
            if big_data.get(f'sub-{subject}'):
                big_data[f'sub-{subject}'].update(wrapped_data)
            else:
                big_data[f'sub-{subject}'] = wrapped_data


    return big_data

big_d = combine_data_from_filename('/data2/Projects/eeg_fmri_natview/derivatives/multimodal_prediction_models/data_prep/prediction_model_data_eeg_features_v2/group_data_Hz-3.8',
                                    task = 'checker',
                                    run = '01BlinksRemoved')

In [4]:
def filter_data(data: np.ndarray, 
                low_freq_cutoff: float | None = None,
                high_freq_cutoff: float | None = None,
                ):
    """Filter the data using a bandpass filter.

    Args:
        data (np.ndarray): The data to filter
        low_freq_cutoff (float | None): The lower bound to filter. 
                                        Defaults to None.
        high_freq_cutoff (float, optional): The higher bound to filter. 
                                            Defaults to 0.1.

    Returns:
        data: _description_
    """
    if high_freq_cutoff and not low_freq_cutoff:
        filter_type = 'low'
        freq = high_freq_cutoff
    elif low_freq_cutoff and not high_freq_cutoff:
        filter_type = 'high'
        freq = low_freq_cutoff
    elif high_freq_cutoff and low_freq_cutoff:
        filter_type = 'band'
        freq = [low_freq_cutoff, high_freq_cutoff]
    
    filtered_data = scipy.signal.butter(
        4, 
        freq, 
        btype=filter_type,
        output='sos'
        )
    filtered_data = scipy.signal.sosfilt(filtered_data, data, axis=1)
    return filtered_data

def generate_key_list(subjects: list[str] | str,
                      sessions: list[str] | str,
                      task: str,
                      runs: list[str] | str,
                      big_data: dict | None,
                      return_dict = False) -> list[tuple[str, str, str, str]]:
    """Generate a list of keys to access the data in the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        subjects (list[str] | str): The list of subjects to consider
        sessions (list[str] | str): The list of sessions to consider
        tasks (str): The task to consider
        runs (list[str] | str): The list of runs to consider
    
    Returns:
        list[tuple[str]]: The list of keys to access the data
    """
    key_list = list()
    for subject in subjects:
        for session in sessions:
            for run in runs:
                try:
                    big_data[f'sub-{subject}'][f'ses-{session}'][task][f'run-{run}']
                    key_list.append((
                        f'sub-{subject}',
                        f'ses-{session}',
                        task,
                        f'run-{run}'
                        ))
                    
                except:
                    continue
    
    return key_list

def extract_cap_name_list(big_data: dict,
                       keys_list: list[tuple[str, ...]]) -> list[str]:
    """Extract the list of CAP names from the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        keys_list (list): The list of keys to access the data in the dictionary.
    
    Returns:
        list: The list of CAP names
    """
    subject, session, task, run = keys_list[0]
    return big_data[subject][session][task][run]['brainstates']['labels']

def get_real_cap_name(cap_names: str | list[str],
                      cap_list: list[str]) -> list:
    """Get the real CAP name based on a substring from the list of CAP names.
    
    Args:
        cap_name (str): The substring to look for in the list of CAP names
        cap_list (list): The list of CAP names
    
    Returns:
        str: The real CAP name
    """
    real_cap_names = list()
    if isinstance(cap_names,str):
        cap_names = [cap_names]
    for cap_name in cap_names:
        real_cap_names.extend([cap for cap in cap_list if cap_name in cap])
    
    return real_cap_names

def create_big_feature_array(big_data: dict,
                             modality: str,
                             array_name: str,
                             index_to_get: int | None,
                             axis_to_get: int | None,
                             keys_list: list[tuple[str, ...]],
                             axis_to_concatenate: int = 1,
                             subject_agnostic: bool = False
                             ) -> np.ndarray:
    """This is to make a big numpy array across subject.

    It concatenates the features of interest along the time axis (2nd dim).

    Args:
        big_data (dict): The dictionary containing all the data
        to_concat (str): The name of the feature to concatenate (EEGBandEnvelope
                            for example). This choose the feature to consider 
                            later as X.
        index_to_get (int): The index (on the third dimension) of the frequency
                                of interest (or the frequency band). 
                                If None, the entire array is considered.
        axis_to_get (int): The axis to get the data from. If None, the entire
                                array is considered.
        keys_list (list): The list of keys to access the data in the dictionary.
    """
    
    concatenation_list = list()
    for keys in keys_list:
        subject, session, task, run = keys
        if isinstance(index_to_get, int) and isinstance(axis_to_get, int):
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality
                              ][array_name].take(
                                index_to_get,
                                axis = axis_to_get
                            )
        else:
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality][array_name]
        
        if extracted_array.ndim < 2:
            extracted_array = np.reshape(extracted_array,(1,extracted_array.shape[0]))
        
        #extracted_array = filter_data(extracted_array,high_freq_cutoff=0.1)
        concatenation_list.append(extracted_array)
    
    if subject_agnostic:
        return np.concatenate(concatenation_list,axis = axis_to_concatenate)
    else:
        return np.array(concatenation_list)

def _find_item(desired_key: str, obj: Dict[str, Any]) -> Any:
    """Find any item in an encapsulated dictionary."

    Args:
        desired_key (str): They key to look for.
        obj (Dict[str, Any]): the dictionary.

    Returns:
        Any: The returned item found in the encapsulated dictionary.
    """
    if obj.get(desired_key) is not None:
        return obj[desired_key]
    
    for value in obj.values():
        if isinstance(value, dict):
            item = _find_item(desired_key, value)
            if item:
                return item

def get_specific_location(big_data: Dict, 
                          channel_names: Optional[List[str]] = None, 
                          anatomical_location: Optional[List[str]] = None, 
                          laterality: Optional[List[str]] = None) -> Union[np.ndarray, None]:
    """
    Filters the channels based on anatomical location, laterality, and channel names.

    Parameters:
    - big_data (dict): The dictionary containing channel information.
    - channel_names (list[str], optional): List of channel names to filter.
    - anatomical_location (list[str], optional): List of anatomical locations to filter.
    - laterality (list[str], optional): List of lateralities to filter.

    Returns:
    - np.ndarray | None: A boolean array indicating the filtered channels or None if no channel info is found.
    """
    channel_info = _find_item("channels_info", big_data)
    if not channel_info:
        return None

    mask = np.zeros(len(channel_info['channel_name']), dtype=bool)

    if anatomical_location:
        
        anatomy_mask = np.isin(
            channel_info.get('anatomy', []), 
            anatomical_location
            )
        
        mask = np.logical_or(mask, anatomy_mask)

    if laterality:
        
        laterality_mask = np.isin(
            channel_info.get('laterality', []), 
            laterality
            )

        if anatomical_location:
            comparison = getattr(np, 'logical_and')
        else:
            comparison = getattr(np, 'logical_or')
        
        mask = comparison(mask, laterality_mask)

    if channel_names:
        
        channel_mask = np.isin(
            channel_info.get('channel_name', []), 
            channel_names
            )
        
        mask = np.logical_or(mask, channel_mask)

    return mask if mask.any() else None

def combine_masks(big_data:dict,
                  key_list: list,
                  modalities: list = ['EEGbandsEnvelopes','brainstates'],
                  subject_agnostic: bool = False
                  ) -> np.ndarray[bool]:
    masks = []
    for modality in modalities:
        if "envelopes" in modality.lower() or "tfr" in modality.lower():
            array_name = 'artifact_mask'
            index_to_get = None
            axis_to_get = None

        else:
            array_name = 'feature'
            index_to_get = -1
            axis_to_get = 0
        
        temp_mask = create_big_feature_array(
            big_data = big_data,
            modality = modality,
            array_name = array_name,
            index_to_get = index_to_get,
            axis_to_get=axis_to_get,
            keys_list=key_list,
            axis_to_concatenate=0
        )
        if subject_agnostic:
            temp_mask = temp_mask.flatten()
            
        masks.append(temp_mask > 0.5)
    
    masks = np.array(masks)
    overall_mask = np.all(masks, axis = 0)
    print('Overall mask shape:',overall_mask.shape)
    return np.all(masks, axis = 0)
        
def build_windowed_mask(big_data: dict,
                        key_list:list,
                        window_length: int = 45,
                        modalities = ['brainstates'],
                        subject_agnostic: bool = False
                        ) -> np.ndarray:
    """Build the mask based on the brainstate and EEG ones.
    
    The mask will be windowed to match the windowed data.

    Args:
        big_data (dict): The dictionary containing all the data
        window_length (int, optional): The length of the sliding window in
                                       samples. Defaults to 45.
        steps (int, optional): The sliding steps in samples. Defaults to 1.

    Returns:
        np.ndarray: The windowed mask
    """

    joined_masks = combine_masks(big_data,
                                 key_list,
                                 modalities = modalities,
                                 subject_agnostic = subject_agnostic)
    
    windowed_mask = sliding_window_view(joined_masks[:-1], 
                                        window_shape=window_length,
                                        axis = 0)
    print('Windowed mask shape:',windowed_mask.shape)
    if np.ndim(windowed_mask) < 3:
        return windowed_mask
    else:
        # Take the case of EEG channels. If there is one channel not good, reject the entire window.
        return np.all(windowed_mask, axis = 1)

def build_windowed_data(array: np.ndarray,
                        window_length: int = 45,
                        reduction: str = "flatten") -> np.ndarray:
    if array.ndim == 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:-1], 
            window_shape=window_length, 
        )

    elif array.ndim > 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:,:-1,...], 
            window_shape=window_length, 
            axis=1
        )
        new_shape_X = (windowed_data.shape[1], -1) + windowed_data.shape[3:]
        windowed_data = windowed_data.transpose(1, 0, 2, *range(3, array.ndim + 1))
        windowed_data = windowed_data.reshape(new_shape_X)
        
        if array.ndim > 2:
            if reduction == 'flatten':
                windowed_data = windowed_data.reshape(windowed_data.shape[0], -1)
            
            elif reduction == 'gfp':
                windowed_data = np.squeeze(np.var(windowed_data, axis=0))
            
    return windowed_data
    
def create_X_and_Y(big_data: dict,
                   keys_list: list[tuple[str, ...]],
                   X_name: str,
                   cap_name: str,
                   bands_names: str | list | None =  None,
                   chan_select_args: Dict[str,str] | None = None,
                   normalization: str | None = 'zscore',
                   reduction_method: str = 'flatten',
                   window_length: int = 45,
                   integrate_pupil: bool = True,
                  ) -> tuple[Any,Any]:
    
    bands_list = ['delta','theta','alpha','beta','gamma']

    if isinstance(bands_names,list):
        index_band = [bands_list.index(band) for band in bands_names]

    elif isinstance(bands_names, str):
        index_band = bands_list.index(bands_names)
    elif not bands_names:
        pass
    
    if "pupil" in X_name:
        index_to_get = 1 
        axis_to_get = 0
        integrate_pupil = False
    
    elif "envelopes" in X_name.lower() or "tfr" in X_name.lower():
        index_to_get = -1
        axis_to_get = index_band

    big_X_array = create_big_feature_array(
        big_data            = big_data,
        modality            = X_name,
        array_name          = 'feature',
        index_to_get        = index_to_get, #To modify for EEG band. It is now for pupil
        axis_to_get         = axis_to_get,
        keys_list           = keys_list
        )

    if "pupil" in X_name:
        first_derivative = np.diff(big_X_array[0,:], append = 0)
        second_derivative = np.diff(first_derivative, append = 0)
        print('Pupil shape:',big_X_array.shape)
        print('Pupil first derivative shape:',first_derivative.shape)
        print('Pupil second derivative shape:',second_derivative.shape)
        big_X_array = np.stack(
            (big_X_array[0,:],first_derivative,second_derivative),
            axis=0
        )
        
    if chan_select_args:
        channel_mask = get_specific_location(big_data, **chan_select_args)
        big_X_array = big_X_array[channel_mask,...]
    
    cap_names_list = extract_cap_name_list(big_data,keys_list)
    real_cap_name = get_real_cap_name(cap_name,cap_names_list)
    cap_index = [cap_names_list.index(cap) for cap in real_cap_name][0]
    
    if normalization == 'zscore':
        big_X_array = scipy.stats.zscore(big_X_array,axis=1)
    
    windowed_X = build_windowed_data(big_X_array,
                                     window_length,
                                     reduction_method)

    if integrate_pupil:
        pupil_array = create_big_feature_array(
            big_data            = big_data,
            modality            = 'pupil',
            array_name          = 'feature',
            index_to_get        = 1,
            axis_to_get         = 0,
            keys_list           = keys_list
            ) 

        if normalization == 'zscore':
            pupil_array = scipy.stats.zscore(pupil_array,axis=1)
        
        windowed_pupil = build_windowed_data(pupil_array,
                                             window_length,
                                             reduction_method)
    
    
    if integrate_pupil:
        windowed_X = np.concatenate(
            (windowed_X, windowed_pupil),
            axis=1
            )
    
    big_Y_array = create_big_feature_array(
        big_data            = big_data,
        modality            = 'brainstates', 
        array_name          = 'feature',
        index_to_get        = cap_index,
        axis_to_get         = 0,
        keys_list           = keys_list
        )

    if normalization == 'zscore':
        big_Y_array = scipy.stats.zscore(big_Y_array,axis=1)
            
    windowed_Y = np.squeeze(big_Y_array[:,window_length:])

    return windowed_X, windowed_Y

def thresholding_data_rejection(mask: np.ndarray,
                                threshold: int = 20,
                                ) -> np.ndarray[bool]:
    """By studying the mask, reject the window that have too much False.
    
    Based on the windowed mask, it evaluate the amount of data rejected and then
    generate a 1 dimensional boolean mask that will be applied to the
    X and Y data correct the data.

    Args:
        mask (np.ndarray): 2D array of boolean values
        threshold (int, optional): Percentage of data rejected to reject 
        the entire window. Defaults to 20.

    Returns:
        np.ndarray[bool]: A 1 dimensional boolean mask to apply to the data.
    """
    
    valid_data = np.sum(mask, axis = 1)
    percentage = valid_data * 100 / mask.shape[1]

    return np.squeeze(percentage > (100 - threshold))

def interpolate_nan(arr, strategy='imputer'):
    if strategy == 'imputer':
        arr = SimpleImputer(missing_values=np.nan, 
                                     strategy='median', 
                                     copy=False).fit_transform(arr)
    else:
        for i in range(arr.shape[0]):  # Loop through rows
            # Get indices of non-NaN values
            valid_idx = np.nonzero(~np.isnan(arr[i]))[0]
            invalid_idx = np.nonzero(np.isnan(arr[i]))[0]
            
            # If there are enough valid points for cubic spline interpolation
            if len(valid_idx) > 1:
                # Perform cubic spline interpolation
                cs = CubicSpline(valid_idx, arr[i, valid_idx])
                # Replace NaN values with interpolated values
                arr[i, invalid_idx] = cs(invalid_idx)
    
    return arr

def create_train_test_data(big_data: dict,
                           train_sessions: list[str],
                           test_subject: str,
                           test_sessions: str | list[str],
                           task: str,
                           runs: list[str],
                           cap_name: str,
                           X_name: str,
                           band_name: str,
                           window_length: int = 45,
                           chan_select_args = None,
                           masking = False,
                           ) -> tuple[Any,Any,Any,Any]:
    """Create the train and test data using leave one out method.

    Args:
        big_data (dict): The dictionary containing all the data
        test_subject (str): The subject to leave out for testing

    Returns:
        tuple[np.ndarray]: the train and test data
    """
    subjects = [sub.split('-')[1] for sub in big_data.keys()]
    train_subjects = [subject for subject in subjects if subject != test_subject]
    
    print(f'Train subjects: {train_subjects}')
    print(f'Test subject: {test_subject}')
    
    print(f'Train sessions: {train_sessions}')
    print(f'Test session: {test_sessions}')
    
    train_keys = generate_key_list(
        big_data = big_data,
        subjects = train_subjects,
        sessions = train_sessions,
        task     = task,
        runs     = runs
        )
    
    test_keys = generate_key_list(
        big_data = big_data,
        subjects = [test_subject],
        sessions = test_sessions,
        task     = task,
        runs     = runs
        )
    
    if test_keys == []:
        raise ValueError(f'No data for:sub-{test_subject}_ses-{test_sessions}')
    
    X_train, Y_train = create_X_and_Y(
        big_data         = big_data,
        keys_list        = train_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_train shape: {X_train.shape}')
    print(f'Y_train shape: {Y_train.shape}')

    
    X_test, Y_test = create_X_and_Y(
        big_data         = big_data,
        keys_list        = test_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_test shape: {X_test.shape}')
    print(f'Y_test shape: {Y_test.shape}')
    
    if masking:
        train_mask = build_windowed_mask(big_data,
                                        key_list = train_keys,
                                        window_length=window_length,
                                        modalities = ['brainstates','pupil']) # !!! TEMP FIX
        
        test_mask = build_windowed_mask(big_data,
                                        key_list=test_keys, 
                                        window_length=window_length,
                                        modalities=['brainstates','pupil']) # !!! TEMP FIX
        
        if "pupil" in X_name:
            train_mask = np.repeat(train_mask,3, axis=1)
            test_mask = np.repeat(test_mask,3, axis=1)
        
        
        window_rejection_mask_train = thresholding_data_rejection(train_mask)
        X_train_mask = train_mask[window_rejection_mask_train,:]
        X_train = X_train[window_rejection_mask_train,:]
        X_train[~X_train_mask] = np.nan
        X_train = interpolate_nan(X_train, strategy='imputer')
        Y_train = Y_train[window_rejection_mask_train]

        print(f'Train mask shape: {train_mask.shape}')
        print(f'Test mask shape: {test_mask.shape}')
        print(f'X_train shape after masking: {X_train.shape}')
        print(f'Y_train shape after masking: {Y_train.shape}')
        
        window_rejection_mask_test = thresholding_data_rejection(test_mask)
        print(f'X_test shape after masking = {X_test.shape}')
        print(f'Y_test shape after masking = {Y_test.shape}')
        X_test_mask = test_mask[window_rejection_mask_test,:]
        X_test = X_test[window_rejection_mask_test,:]
        X_test[~X_test_mask] = np.nan
        X_test = interpolate_nan(X_test, strategy='imputer')
        Y_test = Y_test[window_rejection_mask_test]
    
    return (X_train, 
            Y_train, 
            X_test, 
            Y_test)

def train_model(big_data,
                test_subject,
                train_sessions,
                test_sessions,
                task,
                runs,
                cap_name,
                X_name,
                band_name,
                window_length,
                chan_select_args = None,
                masking = False,
                model_name = 'ridge',
                viz_path = False):
    
    try:
        X_train, Y_train, X_test, Y_test = create_train_test_data(
        big_data         = big_data,
        train_sessions   = train_sessions,
        test_subject     = test_subject,
        test_sessions    = test_sessions,
        task             = task,
        runs             = runs,
        cap_name         = cap_name,
        X_name           = X_name,
        band_name        = band_name,
        window_length    = window_length,
        chan_select_args = chan_select_args,
        masking          = masking
            )
    except Exception as e:
        print(e)
        #raise e
        return None, None, None, None, None

    if 'ridge' in model_name.lower():
        model = sklearn.linear_model.RidgeCV(cv = 5)

    elif model_name.lower() == 'lasso': 

        alphas = np.linspace(1e-7,1e-4,1000)
        model = sklearn.linear_model.LassoCV(max_iter=10000,
                                             alphas = alphas
                                             )
    
    elif model_name.lower() == 'lassolars':
        model = sklearn.linear_model.LassoLarsCV(max_iter=10000,
                                                 max_n_alphas=1000)

    elif 'hist' in model_name.lower():
        model = HistGradientBoostingRegressor(max_iter=1000)

    elif 'forest' in model_name.lower():
        model = RandomForestRegressor(criterion = 'absolute_error', 
                                    max_features = 'log2', 
                                    n_estimators = 800)
        
    elif 'elastic' in model_name.lower():
        model = sklearn.linear_model.ElasticNetCV(max_iter=10000)
        
    model.fit(X_train,Y_train)
    if viz_path:
        plot_path(X_train,Y_train)

    return model, X_train, Y_train, X_test, Y_test

def plot_path(X_train, Y_train):
    alphas = np.linspace(1e-6,1e-3,1000)
    alphas_lasso, coef_lasso, _ = sklearn.linear_model.lasso_path(
        X_train, 
        Y_train, 
        alphas = alphas,
        max_iter = 10000
        )
    
    plt.figure(figsize=(12, 6))
    plt.plot(alphas_lasso,coef_lasso.T)
    plt.xlabel('alpha')
    plt.ylabel('coefficients')
    plt.title('Coefficient path')
    plt.legend(['delta','theta','alpha','beta','gamma'])
    plt.show()

In [5]:
keys_list = generate_key_list(subjects = ['01','02'],
                              sessions = ['01','02'],
                              task = 'checker',
                              runs = ['01BlinksRemoved'],
                              big_data = big_d)
array = create_big_feature_array(big_data = big_d,
                                 modality = 'pupil',
                                 array_name = 'feature',
                                 index_to_get = 1,
                                 axis_to_get = 0,
                                 keys_list = keys_list)

In [6]:
masks = combine_masks(big_data = big_d,
                      key_list = keys_list,
                      modalities = ['brainstates','pupil'])

Overall mask shape: (2, 1, 756)


In [7]:
masks

array([[[False, False, False, ...,  True,  True,  True]],

       [[False, False, False, ...,  True,  True,  True]]])

In [8]:
windowed_mask = build_windowed_mask(big_data=big_d,
                                    key_list=keys_list,
                                    modalities=['pupil', 'brainstates'])

Overall mask shape: (2, 1, 756)


ValueError: window shape cannot be larger than input array shape

In [9]:
def filter_data(data: np.ndarray, 
                low_freq_cutoff: float | None = None,
                high_freq_cutoff: float | None = None,
                ):
    """Filter the data using a bandpass filter.

    Args:
        data (np.ndarray): The data to filter
        low_freq_cutoff (float | None): The lower bound to filter. 
                                        Defaults to None.
        high_freq_cutoff (float, optional): The higher bound to filter. 
                                            Defaults to 0.1.

    Returns:
        data: _description_
    """
    if high_freq_cutoff and not low_freq_cutoff:
        filter_type = 'low'
        freq = high_freq_cutoff
    elif low_freq_cutoff and not high_freq_cutoff:
        filter_type = 'high'
        freq = low_freq_cutoff
    elif high_freq_cutoff and low_freq_cutoff:
        filter_type = 'band'
        freq = [low_freq_cutoff, high_freq_cutoff]
    
    filtered_data = scipy.signal.butter(
        4, 
        freq, 
        btype=filter_type,
        output='sos'
        )
    filtered_data = scipy.signal.sosfilt(filtered_data, data, axis=1)
    return filtered_data

def generate_key_list(subjects: list[str] | str,
                      sessions: list[str] | str,
                      task: str,
                      runs: list[str] | str,
                      big_data: dict | None,
                      return_dict = False) -> list[tuple[str, str, str, str]]:
    """Generate a list of keys to access the data in the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        subjects (list[str] | str): The list of subjects to consider
        sessions (list[str] | str): The list of sessions to consider
        tasks (str): The task to consider
        runs (list[str] | str): The list of runs to consider
    
    Returns:
        list[tuple[str]]: The list of keys to access the data
    """
    key_list = list()
    for subject in subjects:
        for session in sessions:
            for run in runs:
                try:
                    big_data[f'sub-{subject}'][f'ses-{session}'][task][f'run-{run}']
                    key_list.append((
                        f'sub-{subject}',
                        f'ses-{session}',
                        task,
                        f'run-{run}'
                        ))
                    
                except:
                    continue
    
    return key_list

def extract_cap_name_list(big_data: dict,
                       keys_list: list[tuple[str, ...]]) -> list[str]:
    """Extract the list of CAP names from the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        keys_list (list): The list of keys to access the data in the dictionary.
    
    Returns:
        list: The list of CAP names
    """
    subject, session, task, run = keys_list[0]
    return big_data[subject][session][task][run]['brainstates']['labels']

def get_real_cap_name(cap_names: str | list[str],
                      cap_list: list[str]) -> list:
    """Get the real CAP name based on a substring from the list of CAP names.
    
    Args:
        cap_name (str): The substring to look for in the list of CAP names
        cap_list (list): The list of CAP names
    
    Returns:
        str: The real CAP name
    """
    real_cap_names = list()
    if isinstance(cap_names,str):
        cap_names = [cap_names]
    for cap_name in cap_names:
        real_cap_names.extend([cap for cap in cap_list if cap_name in cap])
    
    return real_cap_names

def create_big_feature_array(big_data: dict,
                             modality: str,
                             array_name: str,
                             index_to_get: int | None,
                             axis_to_get: int | None,
                             keys_list: list[tuple[str, ...]],
                             axis_to_concatenate: int = 1,
                             subject_agnostic: bool = False
                             ) -> np.ndarray:
    """This is to make a big numpy array across subject.

    It concatenates the features of interest along the time axis (2nd dim).

    Args:
        big_data (dict): The dictionary containing all the data
        to_concat (str): The name of the feature to concatenate (EEGBandEnvelope
                            for example). This choose the feature to consider 
                            later as X.
        index_to_get (int): The index (on the third dimension) of the frequency
                                of interest (or the frequency band). 
                                If None, the entire array is considered.
        axis_to_get (int): The axis to get the data from. If None, the entire
                                array is considered.
        keys_list (list): The list of keys to access the data in the dictionary.
    """
    
    concatenation_list = list()
    for keys in keys_list:
        subject, session, task, run = keys
        if isinstance(index_to_get, int) and isinstance(axis_to_get, int):
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality
                              ][array_name].take(
                                index_to_get,
                                axis = axis_to_get
                            )
        else:
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality][array_name]
        
        if extracted_array.ndim < 2:
            extracted_array = np.reshape(extracted_array,(1,extracted_array.shape[0]))
        
        #extracted_array = filter_data(extracted_array,high_freq_cutoff=0.1)
        concatenation_list.append(extracted_array)
    
    if subject_agnostic:
        return np.concatenate(concatenation_list,axis = axis_to_concatenate)
    else:
        return np.array(concatenation_list)

def _find_item(desired_key: str, obj: Dict[str, Any]) -> Any:
    """Find any item in an encapsulated dictionary."

    Args:
        desired_key (str): They key to look for.
        obj (Dict[str, Any]): the dictionary.

    Returns:
        Any: The returned item found in the encapsulated dictionary.
    """
    if obj.get(desired_key) is not None:
        return obj[desired_key]
    
    for value in obj.values():
        if isinstance(value, dict):
            item = _find_item(desired_key, value)
            if item:
                return item

def get_specific_location(big_data: Dict, 
                          channel_names: Optional[List[str]] = None, 
                          anatomical_location: Optional[List[str]] = None, 
                          laterality: Optional[List[str]] = None) -> Union[np.ndarray, None]:
    """
    Filters the channels based on anatomical location, laterality, and channel names.

    Parameters:
    - big_data (dict): The dictionary containing channel information.
    - channel_names (list[str], optional): List of channel names to filter.
    - anatomical_location (list[str], optional): List of anatomical locations to filter.
    - laterality (list[str], optional): List of lateralities to filter.

    Returns:
    - np.ndarray | None: A boolean array indicating the filtered channels or None if no channel info is found.
    """
    channel_info = _find_item("channels_info", big_data)
    if not channel_info:
        return None

    mask = np.zeros(len(channel_info['channel_name']), dtype=bool)

    if anatomical_location:
        
        anatomy_mask = np.isin(
            channel_info.get('anatomy', []), 
            anatomical_location
            )
        
        mask = np.logical_or(mask, anatomy_mask)

    if laterality:
        
        laterality_mask = np.isin(
            channel_info.get('laterality', []), 
            laterality
            )

        if anatomical_location:
            comparison = getattr(np, 'logical_and')
        else:
            comparison = getattr(np, 'logical_or')
        
        mask = comparison(mask, laterality_mask)

    if channel_names:
        
        channel_mask = np.isin(
            channel_info.get('channel_name', []), 
            channel_names
            )
        
        mask = np.logical_or(mask, channel_mask)

    return mask if mask.any() else None

def combine_masks(big_data:dict,
                  key_list: list,
                  modalities: list = ['EEGbandsEnvelopes','brainstates'],
                  subject_agnostic: bool = False
                  ) -> np.ndarray[bool]:
    masks = []
    for modality in modalities:
        if "envelopes" in modality.lower() or "tfr" in modality.lower():
            array_name = 'artifact_mask'
            index_to_get = None
            axis_to_get = None

        else:
            array_name = 'feature'
            index_to_get = -1
            axis_to_get = 0
        
        temp_mask = create_big_feature_array(
            big_data = big_data,
            modality = modality,
            array_name = array_name,
            index_to_get = index_to_get,
            axis_to_get=axis_to_get,
            keys_list=key_list,
            axis_to_concatenate=0
        )
        if subject_agnostic:
            temp_mask = temp_mask.flatten()
            
        masks.append(temp_mask > 0.5)
    
    masks = np.array(masks)
    overall_mask = np.all(masks, axis = 0)
    print('Overall mask shape:',overall_mask.shape)
    return np.all(masks, axis = 0)
        
def build_windowed_mask(big_data: dict,
                        key_list:list,
                        window_length: int = 45,
                        modalities = ['brainstates'],
                        subject_agnostic: bool = False
                        ) -> np.ndarray:
    """Build the mask based on the brainstate and EEG ones.
    
    The mask will be windowed to match the windowed data.

    Args:
        big_data (dict): The dictionary containing all the data
        window_length (int, optional): The length of the sliding window in
                                       samples. Defaults to 45.
        steps (int, optional): The sliding steps in samples. Defaults to 1.

    Returns:
        np.ndarray: The windowed mask
    """

    joined_masks = combine_masks(big_data,
                                 key_list,
                                 modalities = modalities,
                                 subject_agnostic = subject_agnostic)
    
    windowed_mask = sliding_window_view(joined_masks[:-1], 
                                        window_shape=window_length,
                                        axis = 0)
    print('Windowed mask shape:',windowed_mask.shape)
    if subject_agnostic:
        max_dim = 4
        axis = 2
    else:
        max_dim = 3
        axis = 1
        
    if np.ndim(windowed_mask) < max_dim:
        return windowed_mask
    else:
        # Take the case of EEG channels. If there is one channel not good, reject the entire window.
        return np.all(windowed_mask, axis = axis)

def build_windowed_data(array: np.ndarray,
                        window_length: int = 45,
                        reduction: str = "flatten") -> np.ndarray:
    if array.ndim == 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:-1], 
            window_shape=window_length, 
        )

    elif array.ndim > 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:,:-1,...], 
            window_shape=window_length, 
            axis=1
        )
        new_shape_X = (windowed_data.shape[1], -1) + windowed_data.shape[3:]
        windowed_data = windowed_data.transpose(1, 0, 2, *range(3, array.ndim + 1))
        windowed_data = windowed_data.reshape(new_shape_X)
        
        if array.ndim > 2:
            if reduction == 'flatten':
                windowed_data = windowed_data.reshape(windowed_data.shape[0], -1)
            
            elif reduction == 'gfp':
                windowed_data = np.squeeze(np.var(windowed_data, axis=0))
            
    return windowed_data
    
def create_X_and_Y(big_data: dict,
                   keys_list: list[tuple[str, ...]],
                   X_name: str,
                   cap_name: str,
                   bands_names: str | list | None =  None,
                   chan_select_args: Dict[str,str] | None = None,
                   normalization: str | None = 'zscore',
                   reduction_method: str = 'flatten',
                   window_length: int = 45,
                   integrate_pupil: bool = True,
                  ) -> tuple[Any,Any]:
    
    bands_list = ['delta','theta','alpha','beta','gamma']

    if isinstance(bands_names,list):
        index_band = [bands_list.index(band) for band in bands_names]

    elif isinstance(bands_names, str):
        index_band = bands_list.index(bands_names)
    elif not bands_names:
        pass
    
    if "pupil" in X_name:
        index_to_get = 1 
        axis_to_get = 0
        integrate_pupil = False
    
    elif "envelopes" in X_name.lower() or "tfr" in X_name.lower():
        index_to_get = -1
        axis_to_get = index_band

    big_X_array = create_big_feature_array(
        big_data            = big_data,
        modality            = X_name,
        array_name          = 'feature',
        index_to_get        = index_to_get, #To modify for EEG band. It is now for pupil
        axis_to_get         = axis_to_get,
        keys_list           = keys_list
        )

    if "pupil" in X_name:
        first_derivative = np.diff(big_X_array[0,:], append = 0)
        second_derivative = np.diff(first_derivative, append = 0)
        print('Pupil shape:',big_X_array.shape)
        print('Pupil first derivative shape:',first_derivative.shape)
        print('Pupil second derivative shape:',second_derivative.shape)
        big_X_array = np.stack(
            (big_X_array[0,:],first_derivative,second_derivative),
            axis=0
        )
        
    if chan_select_args:
        channel_mask = get_specific_location(big_data, **chan_select_args)
        big_X_array = big_X_array[channel_mask,...]
    
    cap_names_list = extract_cap_name_list(big_data,keys_list)
    real_cap_name = get_real_cap_name(cap_name,cap_names_list)
    cap_index = [cap_names_list.index(cap) for cap in real_cap_name][0]
    
    if normalization == 'zscore':
        big_X_array = scipy.stats.zscore(big_X_array,axis=1)
    
    windowed_X = build_windowed_data(big_X_array,
                                     window_length,
                                     reduction_method)

    if integrate_pupil:
        pupil_array = create_big_feature_array(
            big_data            = big_data,
            modality            = 'pupil',
            array_name          = 'feature',
            index_to_get        = 1,
            axis_to_get         = 0,
            keys_list           = keys_list
            ) 

        if normalization == 'zscore':
            pupil_array = scipy.stats.zscore(pupil_array,axis=1)
        
        windowed_pupil = build_windowed_data(pupil_array,
                                             window_length,
                                             reduction_method)
    
    
    if integrate_pupil:
        windowed_X = np.concatenate(
            (windowed_X, windowed_pupil),
            axis=1
            )
    
    big_Y_array = create_big_feature_array(
        big_data            = big_data,
        modality            = 'brainstates', 
        array_name          = 'feature',
        index_to_get        = cap_index,
        axis_to_get         = 0,
        keys_list           = keys_list
        )

    if normalization == 'zscore':
        big_Y_array = scipy.stats.zscore(big_Y_array,axis=1)
            
    windowed_Y = np.squeeze(big_Y_array[:,window_length:])

    return windowed_X, windowed_Y

def thresholding_data_rejection(mask: np.ndarray,
                                threshold: int = 20,
                                ) -> np.ndarray[bool]:
    """By studying the mask, reject the window that have too much False.
    
    Based on the windowed mask, it evaluate the amount of data rejected and then
    generate a 1 dimensional boolean mask that will be applied to the
    X and Y data correct the data.

    Args:
        mask (np.ndarray): 2D array of boolean values
        threshold (int, optional): Percentage of data rejected to reject 
        the entire window. Defaults to 20.

    Returns:
        np.ndarray[bool]: A 1 dimensional boolean mask to apply to the data.
    """
    
    valid_data = np.sum(mask, axis = 1)
    percentage = valid_data * 100 / mask.shape[1]

    return np.squeeze(percentage > (100 - threshold))

def interpolate_nan(arr, strategy='imputer'):
    if strategy == 'imputer':
        arr = SimpleImputer(missing_values=np.nan, 
                                     strategy='median', 
                                     copy=False).fit_transform(arr)
    else:
        for i in range(arr.shape[0]):  # Loop through rows
            # Get indices of non-NaN values
            valid_idx = np.nonzero(~np.isnan(arr[i]))[0]
            invalid_idx = np.nonzero(np.isnan(arr[i]))[0]
            
            # If there are enough valid points for cubic spline interpolation
            if len(valid_idx) > 1:
                # Perform cubic spline interpolation
                cs = CubicSpline(valid_idx, arr[i, valid_idx])
                # Replace NaN values with interpolated values
                arr[i, invalid_idx] = cs(invalid_idx)
    
    return arr

def create_train_test_data(big_data: dict,
                           train_sessions: list[str],
                           test_subject: str,
                           test_sessions: str | list[str],
                           task: str,
                           runs: list[str],
                           cap_name: str,
                           X_name: str,
                           band_name: str,
                           window_length: int = 45,
                           chan_select_args = None,
                           masking = False,
                           ) -> tuple[Any,Any,Any,Any]:
    """Create the train and test data using leave one out method.

    Args:
        big_data (dict): The dictionary containing all the data
        test_subject (str): The subject to leave out for testing

    Returns:
        tuple[np.ndarray]: the train and test data
    """
    subjects = [sub.split('-')[1] for sub in big_data.keys()]
    train_subjects = [subject for subject in subjects if subject != test_subject]
    
    print(f'Train subjects: {train_subjects}')
    print(f'Test subject: {test_subject}')
    
    print(f'Train sessions: {train_sessions}')
    print(f'Test session: {test_sessions}')
    
    train_keys = generate_key_list(
        big_data = big_data,
        subjects = train_subjects,
        sessions = train_sessions,
        task     = task,
        runs     = runs
        )
    
    test_keys = generate_key_list(
        big_data = big_data,
        subjects = [test_subject],
        sessions = test_sessions,
        task     = task,
        runs     = runs
        )
    
    if test_keys == []:
        raise ValueError(f'No data for:sub-{test_subject}_ses-{test_sessions}')
    
    X_train, Y_train = create_X_and_Y(
        big_data         = big_data,
        keys_list        = train_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_train shape: {X_train.shape}')
    print(f'Y_train shape: {Y_train.shape}')

    
    X_test, Y_test = create_X_and_Y(
        big_data         = big_data,
        keys_list        = test_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_test shape: {X_test.shape}')
    print(f'Y_test shape: {Y_test.shape}')
    
    if masking:
        train_mask = build_windowed_mask(big_data,
                                        key_list = train_keys,
                                        window_length=window_length,
                                        modalities = ['brainstates','pupil']) # !!! TEMP FIX
        
        test_mask = build_windowed_mask(big_data,
                                        key_list=test_keys, 
                                        window_length=window_length,
                                        modalities=['brainstates','pupil']) # !!! TEMP FIX
        
        if "pupil" in X_name:
            train_mask = np.repeat(train_mask,3, axis=1)
            test_mask = np.repeat(test_mask,3, axis=1)
        
        
        window_rejection_mask_train = thresholding_data_rejection(train_mask)
        X_train_mask = train_mask[window_rejection_mask_train,:]
        X_train = X_train[window_rejection_mask_train,:]
        X_train[~X_train_mask] = np.nan
        X_train = interpolate_nan(X_train, strategy='imputer')
        Y_train = Y_train[window_rejection_mask_train]

        print(f'Train mask shape: {train_mask.shape}')
        print(f'Test mask shape: {test_mask.shape}')
        print(f'X_train shape after masking: {X_train.shape}')
        print(f'Y_train shape after masking: {Y_train.shape}')
        
        window_rejection_mask_test = thresholding_data_rejection(test_mask)
        print(f'X_test shape after masking = {X_test.shape}')
        print(f'Y_test shape after masking = {Y_test.shape}')
        X_test_mask = test_mask[window_rejection_mask_test,:]
        X_test = X_test[window_rejection_mask_test,:]
        X_test[~X_test_mask] = np.nan
        X_test = interpolate_nan(X_test, strategy='imputer')
        Y_test = Y_test[window_rejection_mask_test]
    
    return (X_train, 
            Y_train, 
            X_test, 
            Y_test)

def train_model(big_data,
                test_subject,
                train_sessions,
                test_sessions,
                task,
                runs,
                cap_name,
                X_name,
                band_name,
                window_length,
                chan_select_args = None,
                masking = False,
                model_name = 'ridge',
                viz_path = False):
    
    try:
        X_train, Y_train, X_test, Y_test = create_train_test_data(
        big_data         = big_data,
        train_sessions   = train_sessions,
        test_subject     = test_subject,
        test_sessions    = test_sessions,
        task             = task,
        runs             = runs,
        cap_name         = cap_name,
        X_name           = X_name,
        band_name        = band_name,
        window_length    = window_length,
        chan_select_args = chan_select_args,
        masking          = masking
            )
    except Exception as e:
        print(e)
        #raise e
        return None, None, None, None, None

    if 'ridge' in model_name.lower():
        model = sklearn.linear_model.RidgeCV(cv = 5)

    elif model_name.lower() == 'lasso': 

        alphas = np.linspace(1e-7,1e-4,1000)
        model = sklearn.linear_model.LassoCV(max_iter=10000,
                                             alphas = alphas
                                             )
    
    elif model_name.lower() == 'lassolars':
        model = sklearn.linear_model.LassoLarsCV(max_iter=10000,
                                                 max_n_alphas=1000)

    elif 'hist' in model_name.lower():
        model = HistGradientBoostingRegressor(max_iter=1000)

    elif 'forest' in model_name.lower():
        model = RandomForestRegressor(criterion = 'absolute_error', 
                                    max_features = 'log2', 
                                    n_estimators = 800)
        
    elif 'elastic' in model_name.lower():
        model = sklearn.linear_model.ElasticNetCV(max_iter=10000)
        
    model.fit(X_train,Y_train)
    if viz_path:
        plot_path(X_train,Y_train)

    return model, X_train, Y_train, X_test, Y_test

def plot_path(X_train, Y_train):
    alphas = np.linspace(1e-6,1e-3,1000)
    alphas_lasso, coef_lasso, _ = sklearn.linear_model.lasso_path(
        X_train, 
        Y_train, 
        alphas = alphas,
        max_iter = 10000
        )
    
    plt.figure(figsize=(12, 6))
    plt.plot(alphas_lasso,coef_lasso.T)
    plt.xlabel('alpha')
    plt.ylabel('coefficients')
    plt.title('Coefficient path')
    plt.legend(['delta','theta','alpha','beta','gamma'])
    plt.show()

In [10]:
windowed_mask = build_windowed_mask(big_data=big_d,
                                    key_list=keys_list,
                                    modalities=['pupil', 'brainstates'])

Overall mask shape: (2, 1, 756)


ValueError: window shape cannot be larger than input array shape

In [11]:
masks

array([[[False, False, False, ...,  True,  True,  True]],

       [[False, False, False, ...,  True,  True,  True]]])

In [12]:
masks.shape

(2, 1, 756)

In [13]:
def filter_data(data: np.ndarray, 
                low_freq_cutoff: float | None = None,
                high_freq_cutoff: float | None = None,
                ):
    """Filter the data using a bandpass filter.

    Args:
        data (np.ndarray): The data to filter
        low_freq_cutoff (float | None): The lower bound to filter. 
                                        Defaults to None.
        high_freq_cutoff (float, optional): The higher bound to filter. 
                                            Defaults to 0.1.

    Returns:
        data: _description_
    """
    if high_freq_cutoff and not low_freq_cutoff:
        filter_type = 'low'
        freq = high_freq_cutoff
    elif low_freq_cutoff and not high_freq_cutoff:
        filter_type = 'high'
        freq = low_freq_cutoff
    elif high_freq_cutoff and low_freq_cutoff:
        filter_type = 'band'
        freq = [low_freq_cutoff, high_freq_cutoff]
    
    filtered_data = scipy.signal.butter(
        4, 
        freq, 
        btype=filter_type,
        output='sos'
        )
    filtered_data = scipy.signal.sosfilt(filtered_data, data, axis=1)
    return filtered_data

def generate_key_list(subjects: list[str] | str,
                      sessions: list[str] | str,
                      task: str,
                      runs: list[str] | str,
                      big_data: dict | None,
                      return_dict = False) -> list[tuple[str, str, str, str]]:
    """Generate a list of keys to access the data in the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        subjects (list[str] | str): The list of subjects to consider
        sessions (list[str] | str): The list of sessions to consider
        tasks (str): The task to consider
        runs (list[str] | str): The list of runs to consider
    
    Returns:
        list[tuple[str]]: The list of keys to access the data
    """
    key_list = list()
    for subject in subjects:
        for session in sessions:
            for run in runs:
                try:
                    big_data[f'sub-{subject}'][f'ses-{session}'][task][f'run-{run}']
                    key_list.append((
                        f'sub-{subject}',
                        f'ses-{session}',
                        task,
                        f'run-{run}'
                        ))
                    
                except:
                    continue
    
    return key_list

def extract_cap_name_list(big_data: dict,
                       keys_list: list[tuple[str, ...]]) -> list[str]:
    """Extract the list of CAP names from the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        keys_list (list): The list of keys to access the data in the dictionary.
    
    Returns:
        list: The list of CAP names
    """
    subject, session, task, run = keys_list[0]
    return big_data[subject][session][task][run]['brainstates']['labels']

def get_real_cap_name(cap_names: str | list[str],
                      cap_list: list[str]) -> list:
    """Get the real CAP name based on a substring from the list of CAP names.
    
    Args:
        cap_name (str): The substring to look for in the list of CAP names
        cap_list (list): The list of CAP names
    
    Returns:
        str: The real CAP name
    """
    real_cap_names = list()
    if isinstance(cap_names,str):
        cap_names = [cap_names]
    for cap_name in cap_names:
        real_cap_names.extend([cap for cap in cap_list if cap_name in cap])
    
    return real_cap_names

def create_big_feature_array(big_data: dict,
                             modality: str,
                             array_name: str,
                             index_to_get: int | None,
                             axis_to_get: int | None,
                             keys_list: list[tuple[str, ...]],
                             axis_to_concatenate: int = 1,
                             subject_agnostic: bool = False
                             ) -> np.ndarray:
    """This is to make a big numpy array across subject.

    It concatenates the features of interest along the time axis (2nd dim).

    Args:
        big_data (dict): The dictionary containing all the data
        to_concat (str): The name of the feature to concatenate (EEGBandEnvelope
                            for example). This choose the feature to consider 
                            later as X.
        index_to_get (int): The index (on the third dimension) of the frequency
                                of interest (or the frequency band). 
                                If None, the entire array is considered.
        axis_to_get (int): The axis to get the data from. If None, the entire
                                array is considered.
        keys_list (list): The list of keys to access the data in the dictionary.
    """
    
    concatenation_list = list()
    for keys in keys_list:
        subject, session, task, run = keys
        if isinstance(index_to_get, int) and isinstance(axis_to_get, int):
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality
                              ][array_name].take(
                                index_to_get,
                                axis = axis_to_get
                            )
        else:
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality][array_name]
        
        if extracted_array.ndim < 2:
            extracted_array = np.reshape(extracted_array,(1,extracted_array.shape[0]))
        
        #extracted_array = filter_data(extracted_array,high_freq_cutoff=0.1)
        concatenation_list.append(extracted_array)
    
    if subject_agnostic:
        return np.concatenate(concatenation_list,axis = axis_to_concatenate)
    else:
        return np.array(concatenation_list)

def _find_item(desired_key: str, obj: Dict[str, Any]) -> Any:
    """Find any item in an encapsulated dictionary."

    Args:
        desired_key (str): They key to look for.
        obj (Dict[str, Any]): the dictionary.

    Returns:
        Any: The returned item found in the encapsulated dictionary.
    """
    if obj.get(desired_key) is not None:
        return obj[desired_key]
    
    for value in obj.values():
        if isinstance(value, dict):
            item = _find_item(desired_key, value)
            if item:
                return item

def get_specific_location(big_data: Dict, 
                          channel_names: Optional[List[str]] = None, 
                          anatomical_location: Optional[List[str]] = None, 
                          laterality: Optional[List[str]] = None) -> Union[np.ndarray, None]:
    """
    Filters the channels based on anatomical location, laterality, and channel names.

    Parameters:
    - big_data (dict): The dictionary containing channel information.
    - channel_names (list[str], optional): List of channel names to filter.
    - anatomical_location (list[str], optional): List of anatomical locations to filter.
    - laterality (list[str], optional): List of lateralities to filter.

    Returns:
    - np.ndarray | None: A boolean array indicating the filtered channels or None if no channel info is found.
    """
    channel_info = _find_item("channels_info", big_data)
    if not channel_info:
        return None

    mask = np.zeros(len(channel_info['channel_name']), dtype=bool)

    if anatomical_location:
        
        anatomy_mask = np.isin(
            channel_info.get('anatomy', []), 
            anatomical_location
            )
        
        mask = np.logical_or(mask, anatomy_mask)

    if laterality:
        
        laterality_mask = np.isin(
            channel_info.get('laterality', []), 
            laterality
            )

        if anatomical_location:
            comparison = getattr(np, 'logical_and')
        else:
            comparison = getattr(np, 'logical_or')
        
        mask = comparison(mask, laterality_mask)

    if channel_names:
        
        channel_mask = np.isin(
            channel_info.get('channel_name', []), 
            channel_names
            )
        
        mask = np.logical_or(mask, channel_mask)

    return mask if mask.any() else None

def combine_masks(big_data:dict,
                  key_list: list,
                  modalities: list = ['EEGbandsEnvelopes','brainstates'],
                  subject_agnostic: bool = False
                  ) -> np.ndarray[bool]:
    masks = []
    for modality in modalities:
        if "envelopes" in modality.lower() or "tfr" in modality.lower():
            array_name = 'artifact_mask'
            index_to_get = None
            axis_to_get = None

        else:
            array_name = 'feature'
            index_to_get = -1
            axis_to_get = 0
        
        temp_mask = create_big_feature_array(
            big_data = big_data,
            modality = modality,
            array_name = array_name,
            index_to_get = index_to_get,
            axis_to_get=axis_to_get,
            keys_list=key_list,
            axis_to_concatenate=0
        )
        if subject_agnostic:
            temp_mask = temp_mask.flatten()
            
        masks.append(temp_mask > 0.5)
    
    masks = np.array(masks)
    overall_mask = np.all(masks, axis = 0)
    print('Overall mask shape:',overall_mask.shape)
    return np.all(masks, axis = 0)
        
def build_windowed_mask(big_data: dict,
                        key_list:list,
                        window_length: int = 45,
                        modalities = ['brainstates'],
                        subject_agnostic: bool = False
                        ) -> np.ndarray:
    """Build the mask based on the brainstate and EEG ones.
    
    The mask will be windowed to match the windowed data.

    Args:
        big_data (dict): The dictionary containing all the data
        window_length (int, optional): The length of the sliding window in
                                       samples. Defaults to 45.
        steps (int, optional): The sliding steps in samples. Defaults to 1.

    Returns:
        np.ndarray: The windowed mask
    """

    joined_masks = combine_masks(big_data,
                                 key_list,
                                 modalities = modalities,
                                 subject_agnostic = subject_agnostic)
    
    windowed_mask = sliding_window_view(joined_masks[:,:,:-1], 
                                        window_shape=window_length,
                                        axis = 0)
    print('Windowed mask shape:',windowed_mask.shape)
    if subject_agnostic:
        max_dim = 4
        axis = 2
    else:
        max_dim = 3
        axis = 1
        
    if np.ndim(windowed_mask) < max_dim:
        return windowed_mask
    else:
        # Take the case of EEG channels. If there is one channel not good, reject the entire window.
        return np.all(windowed_mask, axis = axis)

def build_windowed_data(array: np.ndarray,
                        window_length: int = 45,
                        reduction: str = "flatten") -> np.ndarray:
    if array.ndim == 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:-1], 
            window_shape=window_length, 
        )

    elif array.ndim > 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:,:-1,...], 
            window_shape=window_length, 
            axis=1
        )
        new_shape_X = (windowed_data.shape[1], -1) + windowed_data.shape[3:]
        windowed_data = windowed_data.transpose(1, 0, 2, *range(3, array.ndim + 1))
        windowed_data = windowed_data.reshape(new_shape_X)
        
        if array.ndim > 2:
            if reduction == 'flatten':
                windowed_data = windowed_data.reshape(windowed_data.shape[0], -1)
            
            elif reduction == 'gfp':
                windowed_data = np.squeeze(np.var(windowed_data, axis=0))
            
    return windowed_data
    
def create_X_and_Y(big_data: dict,
                   keys_list: list[tuple[str, ...]],
                   X_name: str,
                   cap_name: str,
                   bands_names: str | list | None =  None,
                   chan_select_args: Dict[str,str] | None = None,
                   normalization: str | None = 'zscore',
                   reduction_method: str = 'flatten',
                   window_length: int = 45,
                   integrate_pupil: bool = True,
                  ) -> tuple[Any,Any]:
    
    bands_list = ['delta','theta','alpha','beta','gamma']

    if isinstance(bands_names,list):
        index_band = [bands_list.index(band) for band in bands_names]

    elif isinstance(bands_names, str):
        index_band = bands_list.index(bands_names)
    elif not bands_names:
        pass
    
    if "pupil" in X_name:
        index_to_get = 1 
        axis_to_get = 0
        integrate_pupil = False
    
    elif "envelopes" in X_name.lower() or "tfr" in X_name.lower():
        index_to_get = -1
        axis_to_get = index_band

    big_X_array = create_big_feature_array(
        big_data            = big_data,
        modality            = X_name,
        array_name          = 'feature',
        index_to_get        = index_to_get, #To modify for EEG band. It is now for pupil
        axis_to_get         = axis_to_get,
        keys_list           = keys_list
        )

    if "pupil" in X_name:
        first_derivative = np.diff(big_X_array[0,:], append = 0)
        second_derivative = np.diff(first_derivative, append = 0)
        print('Pupil shape:',big_X_array.shape)
        print('Pupil first derivative shape:',first_derivative.shape)
        print('Pupil second derivative shape:',second_derivative.shape)
        big_X_array = np.stack(
            (big_X_array[0,:],first_derivative,second_derivative),
            axis=0
        )
        
    if chan_select_args:
        channel_mask = get_specific_location(big_data, **chan_select_args)
        big_X_array = big_X_array[channel_mask,...]
    
    cap_names_list = extract_cap_name_list(big_data,keys_list)
    real_cap_name = get_real_cap_name(cap_name,cap_names_list)
    cap_index = [cap_names_list.index(cap) for cap in real_cap_name][0]
    
    if normalization == 'zscore':
        big_X_array = scipy.stats.zscore(big_X_array,axis=1)
    
    windowed_X = build_windowed_data(big_X_array,
                                     window_length,
                                     reduction_method)

    if integrate_pupil:
        pupil_array = create_big_feature_array(
            big_data            = big_data,
            modality            = 'pupil',
            array_name          = 'feature',
            index_to_get        = 1,
            axis_to_get         = 0,
            keys_list           = keys_list
            ) 

        if normalization == 'zscore':
            pupil_array = scipy.stats.zscore(pupil_array,axis=1)
        
        windowed_pupil = build_windowed_data(pupil_array,
                                             window_length,
                                             reduction_method)
    
    
    if integrate_pupil:
        windowed_X = np.concatenate(
            (windowed_X, windowed_pupil),
            axis=1
            )
    
    big_Y_array = create_big_feature_array(
        big_data            = big_data,
        modality            = 'brainstates', 
        array_name          = 'feature',
        index_to_get        = cap_index,
        axis_to_get         = 0,
        keys_list           = keys_list
        )

    if normalization == 'zscore':
        big_Y_array = scipy.stats.zscore(big_Y_array,axis=1)
            
    windowed_Y = np.squeeze(big_Y_array[:,window_length:])

    return windowed_X, windowed_Y

def thresholding_data_rejection(mask: np.ndarray,
                                threshold: int = 20,
                                ) -> np.ndarray[bool]:
    """By studying the mask, reject the window that have too much False.
    
    Based on the windowed mask, it evaluate the amount of data rejected and then
    generate a 1 dimensional boolean mask that will be applied to the
    X and Y data correct the data.

    Args:
        mask (np.ndarray): 2D array of boolean values
        threshold (int, optional): Percentage of data rejected to reject 
        the entire window. Defaults to 20.

    Returns:
        np.ndarray[bool]: A 1 dimensional boolean mask to apply to the data.
    """
    
    valid_data = np.sum(mask, axis = 1)
    percentage = valid_data * 100 / mask.shape[1]

    return np.squeeze(percentage > (100 - threshold))

def interpolate_nan(arr, strategy='imputer'):
    if strategy == 'imputer':
        arr = SimpleImputer(missing_values=np.nan, 
                                     strategy='median', 
                                     copy=False).fit_transform(arr)
    else:
        for i in range(arr.shape[0]):  # Loop through rows
            # Get indices of non-NaN values
            valid_idx = np.nonzero(~np.isnan(arr[i]))[0]
            invalid_idx = np.nonzero(np.isnan(arr[i]))[0]
            
            # If there are enough valid points for cubic spline interpolation
            if len(valid_idx) > 1:
                # Perform cubic spline interpolation
                cs = CubicSpline(valid_idx, arr[i, valid_idx])
                # Replace NaN values with interpolated values
                arr[i, invalid_idx] = cs(invalid_idx)
    
    return arr

def create_train_test_data(big_data: dict,
                           train_sessions: list[str],
                           test_subject: str,
                           test_sessions: str | list[str],
                           task: str,
                           runs: list[str],
                           cap_name: str,
                           X_name: str,
                           band_name: str,
                           window_length: int = 45,
                           chan_select_args = None,
                           masking = False,
                           ) -> tuple[Any,Any,Any,Any]:
    """Create the train and test data using leave one out method.

    Args:
        big_data (dict): The dictionary containing all the data
        test_subject (str): The subject to leave out for testing

    Returns:
        tuple[np.ndarray]: the train and test data
    """
    subjects = [sub.split('-')[1] for sub in big_data.keys()]
    train_subjects = [subject for subject in subjects if subject != test_subject]
    
    print(f'Train subjects: {train_subjects}')
    print(f'Test subject: {test_subject}')
    
    print(f'Train sessions: {train_sessions}')
    print(f'Test session: {test_sessions}')
    
    train_keys = generate_key_list(
        big_data = big_data,
        subjects = train_subjects,
        sessions = train_sessions,
        task     = task,
        runs     = runs
        )
    
    test_keys = generate_key_list(
        big_data = big_data,
        subjects = [test_subject],
        sessions = test_sessions,
        task     = task,
        runs     = runs
        )
    
    if test_keys == []:
        raise ValueError(f'No data for:sub-{test_subject}_ses-{test_sessions}')
    
    X_train, Y_train = create_X_and_Y(
        big_data         = big_data,
        keys_list        = train_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_train shape: {X_train.shape}')
    print(f'Y_train shape: {Y_train.shape}')

    
    X_test, Y_test = create_X_and_Y(
        big_data         = big_data,
        keys_list        = test_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_test shape: {X_test.shape}')
    print(f'Y_test shape: {Y_test.shape}')
    
    if masking:
        train_mask = build_windowed_mask(big_data,
                                        key_list = train_keys,
                                        window_length=window_length,
                                        modalities = ['brainstates','pupil']) # !!! TEMP FIX
        
        test_mask = build_windowed_mask(big_data,
                                        key_list=test_keys, 
                                        window_length=window_length,
                                        modalities=['brainstates','pupil']) # !!! TEMP FIX
        
        if "pupil" in X_name:
            train_mask = np.repeat(train_mask,3, axis=1)
            test_mask = np.repeat(test_mask,3, axis=1)
        
        
        window_rejection_mask_train = thresholding_data_rejection(train_mask)
        X_train_mask = train_mask[window_rejection_mask_train,:]
        X_train = X_train[window_rejection_mask_train,:]
        X_train[~X_train_mask] = np.nan
        X_train = interpolate_nan(X_train, strategy='imputer')
        Y_train = Y_train[window_rejection_mask_train]

        print(f'Train mask shape: {train_mask.shape}')
        print(f'Test mask shape: {test_mask.shape}')
        print(f'X_train shape after masking: {X_train.shape}')
        print(f'Y_train shape after masking: {Y_train.shape}')
        
        window_rejection_mask_test = thresholding_data_rejection(test_mask)
        print(f'X_test shape after masking = {X_test.shape}')
        print(f'Y_test shape after masking = {Y_test.shape}')
        X_test_mask = test_mask[window_rejection_mask_test,:]
        X_test = X_test[window_rejection_mask_test,:]
        X_test[~X_test_mask] = np.nan
        X_test = interpolate_nan(X_test, strategy='imputer')
        Y_test = Y_test[window_rejection_mask_test]
    
    return (X_train, 
            Y_train, 
            X_test, 
            Y_test)

def train_model(big_data,
                test_subject,
                train_sessions,
                test_sessions,
                task,
                runs,
                cap_name,
                X_name,
                band_name,
                window_length,
                chan_select_args = None,
                masking = False,
                model_name = 'ridge',
                viz_path = False):
    
    try:
        X_train, Y_train, X_test, Y_test = create_train_test_data(
        big_data         = big_data,
        train_sessions   = train_sessions,
        test_subject     = test_subject,
        test_sessions    = test_sessions,
        task             = task,
        runs             = runs,
        cap_name         = cap_name,
        X_name           = X_name,
        band_name        = band_name,
        window_length    = window_length,
        chan_select_args = chan_select_args,
        masking          = masking
            )
    except Exception as e:
        print(e)
        #raise e
        return None, None, None, None, None

    if 'ridge' in model_name.lower():
        model = sklearn.linear_model.RidgeCV(cv = 5)

    elif model_name.lower() == 'lasso': 

        alphas = np.linspace(1e-7,1e-4,1000)
        model = sklearn.linear_model.LassoCV(max_iter=10000,
                                             alphas = alphas
                                             )
    
    elif model_name.lower() == 'lassolars':
        model = sklearn.linear_model.LassoLarsCV(max_iter=10000,
                                                 max_n_alphas=1000)

    elif 'hist' in model_name.lower():
        model = HistGradientBoostingRegressor(max_iter=1000)

    elif 'forest' in model_name.lower():
        model = RandomForestRegressor(criterion = 'absolute_error', 
                                    max_features = 'log2', 
                                    n_estimators = 800)
        
    elif 'elastic' in model_name.lower():
        model = sklearn.linear_model.ElasticNetCV(max_iter=10000)
        
    model.fit(X_train,Y_train)
    if viz_path:
        plot_path(X_train,Y_train)

    return model, X_train, Y_train, X_test, Y_test

def plot_path(X_train, Y_train):
    alphas = np.linspace(1e-6,1e-3,1000)
    alphas_lasso, coef_lasso, _ = sklearn.linear_model.lasso_path(
        X_train, 
        Y_train, 
        alphas = alphas,
        max_iter = 10000
        )
    
    plt.figure(figsize=(12, 6))
    plt.plot(alphas_lasso,coef_lasso.T)
    plt.xlabel('alpha')
    plt.ylabel('coefficients')
    plt.title('Coefficient path')
    plt.legend(['delta','theta','alpha','beta','gamma'])
    plt.show()

In [14]:
windowed_mask = build_windowed_mask(big_data=big_d,
                                    key_list=keys_list,
                                    modalities=['pupil', 'brainstates'])

Overall mask shape: (2, 1, 756)


ValueError: window shape cannot be larger than input array shape

In [15]:
def filter_data(data: np.ndarray, 
                low_freq_cutoff: float | None = None,
                high_freq_cutoff: float | None = None,
                ):
    """Filter the data using a bandpass filter.

    Args:
        data (np.ndarray): The data to filter
        low_freq_cutoff (float | None): The lower bound to filter. 
                                        Defaults to None.
        high_freq_cutoff (float, optional): The higher bound to filter. 
                                            Defaults to 0.1.

    Returns:
        data: _description_
    """
    if high_freq_cutoff and not low_freq_cutoff:
        filter_type = 'low'
        freq = high_freq_cutoff
    elif low_freq_cutoff and not high_freq_cutoff:
        filter_type = 'high'
        freq = low_freq_cutoff
    elif high_freq_cutoff and low_freq_cutoff:
        filter_type = 'band'
        freq = [low_freq_cutoff, high_freq_cutoff]
    
    filtered_data = scipy.signal.butter(
        4, 
        freq, 
        btype=filter_type,
        output='sos'
        )
    filtered_data = scipy.signal.sosfilt(filtered_data, data, axis=1)
    return filtered_data

def generate_key_list(subjects: list[str] | str,
                      sessions: list[str] | str,
                      task: str,
                      runs: list[str] | str,
                      big_data: dict | None,
                      return_dict = False) -> list[tuple[str, str, str, str]]:
    """Generate a list of keys to access the data in the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        subjects (list[str] | str): The list of subjects to consider
        sessions (list[str] | str): The list of sessions to consider
        tasks (str): The task to consider
        runs (list[str] | str): The list of runs to consider
    
    Returns:
        list[tuple[str]]: The list of keys to access the data
    """
    key_list = list()
    for subject in subjects:
        for session in sessions:
            for run in runs:
                try:
                    big_data[f'sub-{subject}'][f'ses-{session}'][task][f'run-{run}']
                    key_list.append((
                        f'sub-{subject}',
                        f'ses-{session}',
                        task,
                        f'run-{run}'
                        ))
                    
                except:
                    continue
    
    return key_list

def extract_cap_name_list(big_data: dict,
                       keys_list: list[tuple[str, ...]]) -> list[str]:
    """Extract the list of CAP names from the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        keys_list (list): The list of keys to access the data in the dictionary.
    
    Returns:
        list: The list of CAP names
    """
    subject, session, task, run = keys_list[0]
    return big_data[subject][session][task][run]['brainstates']['labels']

def get_real_cap_name(cap_names: str | list[str],
                      cap_list: list[str]) -> list:
    """Get the real CAP name based on a substring from the list of CAP names.
    
    Args:
        cap_name (str): The substring to look for in the list of CAP names
        cap_list (list): The list of CAP names
    
    Returns:
        str: The real CAP name
    """
    real_cap_names = list()
    if isinstance(cap_names,str):
        cap_names = [cap_names]
    for cap_name in cap_names:
        real_cap_names.extend([cap for cap in cap_list if cap_name in cap])
    
    return real_cap_names

def create_big_feature_array(big_data: dict,
                             modality: str,
                             array_name: str,
                             index_to_get: int | None,
                             axis_to_get: int | None,
                             keys_list: list[tuple[str, ...]],
                             axis_to_concatenate: int = 1,
                             subject_agnostic: bool = False
                             ) -> np.ndarray:
    """This is to make a big numpy array across subject.

    It concatenates the features of interest along the time axis (2nd dim).

    Args:
        big_data (dict): The dictionary containing all the data
        to_concat (str): The name of the feature to concatenate (EEGBandEnvelope
                            for example). This choose the feature to consider 
                            later as X.
        index_to_get (int): The index (on the third dimension) of the frequency
                                of interest (or the frequency band). 
                                If None, the entire array is considered.
        axis_to_get (int): The axis to get the data from. If None, the entire
                                array is considered.
        keys_list (list): The list of keys to access the data in the dictionary.
    """
    
    concatenation_list = list()
    for keys in keys_list:
        subject, session, task, run = keys
        if isinstance(index_to_get, int) and isinstance(axis_to_get, int):
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality
                              ][array_name].take(
                                index_to_get,
                                axis = axis_to_get
                            )
        else:
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality][array_name]
        
        if extracted_array.ndim < 2:
            extracted_array = np.reshape(extracted_array,(1,extracted_array.shape[0]))
        
        #extracted_array = filter_data(extracted_array,high_freq_cutoff=0.1)
        concatenation_list.append(extracted_array)
    
    if subject_agnostic:
        return np.concatenate(concatenation_list,axis = axis_to_concatenate)
    else:
        return np.array(concatenation_list)

def _find_item(desired_key: str, obj: Dict[str, Any]) -> Any:
    """Find any item in an encapsulated dictionary."

    Args:
        desired_key (str): They key to look for.
        obj (Dict[str, Any]): the dictionary.

    Returns:
        Any: The returned item found in the encapsulated dictionary.
    """
    if obj.get(desired_key) is not None:
        return obj[desired_key]
    
    for value in obj.values():
        if isinstance(value, dict):
            item = _find_item(desired_key, value)
            if item:
                return item

def get_specific_location(big_data: Dict, 
                          channel_names: Optional[List[str]] = None, 
                          anatomical_location: Optional[List[str]] = None, 
                          laterality: Optional[List[str]] = None) -> Union[np.ndarray, None]:
    """
    Filters the channels based on anatomical location, laterality, and channel names.

    Parameters:
    - big_data (dict): The dictionary containing channel information.
    - channel_names (list[str], optional): List of channel names to filter.
    - anatomical_location (list[str], optional): List of anatomical locations to filter.
    - laterality (list[str], optional): List of lateralities to filter.

    Returns:
    - np.ndarray | None: A boolean array indicating the filtered channels or None if no channel info is found.
    """
    channel_info = _find_item("channels_info", big_data)
    if not channel_info:
        return None

    mask = np.zeros(len(channel_info['channel_name']), dtype=bool)

    if anatomical_location:
        
        anatomy_mask = np.isin(
            channel_info.get('anatomy', []), 
            anatomical_location
            )
        
        mask = np.logical_or(mask, anatomy_mask)

    if laterality:
        
        laterality_mask = np.isin(
            channel_info.get('laterality', []), 
            laterality
            )

        if anatomical_location:
            comparison = getattr(np, 'logical_and')
        else:
            comparison = getattr(np, 'logical_or')
        
        mask = comparison(mask, laterality_mask)

    if channel_names:
        
        channel_mask = np.isin(
            channel_info.get('channel_name', []), 
            channel_names
            )
        
        mask = np.logical_or(mask, channel_mask)

    return mask if mask.any() else None

def combine_masks(big_data:dict,
                  key_list: list,
                  modalities: list = ['EEGbandsEnvelopes','brainstates'],
                  subject_agnostic: bool = False
                  ) -> np.ndarray[bool]:
    masks = []
    for modality in modalities:
        if "envelopes" in modality.lower() or "tfr" in modality.lower():
            array_name = 'artifact_mask'
            index_to_get = None
            axis_to_get = None

        else:
            array_name = 'feature'
            index_to_get = -1
            axis_to_get = 0
        
        temp_mask = create_big_feature_array(
            big_data = big_data,
            modality = modality,
            array_name = array_name,
            index_to_get = index_to_get,
            axis_to_get=axis_to_get,
            keys_list=key_list,
            axis_to_concatenate=0
        )
        if subject_agnostic:
            temp_mask = temp_mask.flatten()
            
        masks.append(temp_mask > 0.5)
    
    masks = np.array(masks)
    overall_mask = np.all(masks, axis = 0)
    print('Overall mask shape:',overall_mask.shape)
    return np.all(masks, axis = 0)
        
def build_windowed_mask(big_data: dict,
                        key_list:list,
                        window_length: int = 45,
                        modalities = ['brainstates'],
                        subject_agnostic: bool = False
                        ) -> np.ndarray:
    """Build the mask based on the brainstate and EEG ones.
    
    The mask will be windowed to match the windowed data.

    Args:
        big_data (dict): The dictionary containing all the data
        window_length (int, optional): The length of the sliding window in
                                       samples. Defaults to 45.
        steps (int, optional): The sliding steps in samples. Defaults to 1.

    Returns:
        np.ndarray: The windowed mask
    """

    joined_masks = combine_masks(big_data,
                                 key_list,
                                 modalities = modalities,
                                 subject_agnostic = subject_agnostic)
    
    windowed_mask = sliding_window_view(joined_masks[:,:,:-1], 
                                        window_shape=window_length,
                                        axis = 2)
    print('Windowed mask shape:',windowed_mask.shape)
    if subject_agnostic:
        max_dim = 4
        axis = 2
    else:
        max_dim = 3
        axis = 1
        
    if np.ndim(windowed_mask) < max_dim:
        return windowed_mask
    else:
        # Take the case of EEG channels. If there is one channel not good, reject the entire window.
        return np.all(windowed_mask, axis = axis)

def build_windowed_data(array: np.ndarray,
                        window_length: int = 45,
                        reduction: str = "flatten") -> np.ndarray:
    if array.ndim == 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:-1], 
            window_shape=window_length, 
        )

    elif array.ndim > 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:,:-1,...], 
            window_shape=window_length, 
            axis=1
        )
        new_shape_X = (windowed_data.shape[1], -1) + windowed_data.shape[3:]
        windowed_data = windowed_data.transpose(1, 0, 2, *range(3, array.ndim + 1))
        windowed_data = windowed_data.reshape(new_shape_X)
        
        if array.ndim > 2:
            if reduction == 'flatten':
                windowed_data = windowed_data.reshape(windowed_data.shape[0], -1)
            
            elif reduction == 'gfp':
                windowed_data = np.squeeze(np.var(windowed_data, axis=0))
            
    return windowed_data
    
def create_X_and_Y(big_data: dict,
                   keys_list: list[tuple[str, ...]],
                   X_name: str,
                   cap_name: str,
                   bands_names: str | list | None =  None,
                   chan_select_args: Dict[str,str] | None = None,
                   normalization: str | None = 'zscore',
                   reduction_method: str = 'flatten',
                   window_length: int = 45,
                   integrate_pupil: bool = True,
                  ) -> tuple[Any,Any]:
    
    bands_list = ['delta','theta','alpha','beta','gamma']

    if isinstance(bands_names,list):
        index_band = [bands_list.index(band) for band in bands_names]

    elif isinstance(bands_names, str):
        index_band = bands_list.index(bands_names)
    elif not bands_names:
        pass
    
    if "pupil" in X_name:
        index_to_get = 1 
        axis_to_get = 0
        integrate_pupil = False
    
    elif "envelopes" in X_name.lower() or "tfr" in X_name.lower():
        index_to_get = -1
        axis_to_get = index_band

    big_X_array = create_big_feature_array(
        big_data            = big_data,
        modality            = X_name,
        array_name          = 'feature',
        index_to_get        = index_to_get, #To modify for EEG band. It is now for pupil
        axis_to_get         = axis_to_get,
        keys_list           = keys_list
        )

    if "pupil" in X_name:
        first_derivative = np.diff(big_X_array[0,:], append = 0)
        second_derivative = np.diff(first_derivative, append = 0)
        print('Pupil shape:',big_X_array.shape)
        print('Pupil first derivative shape:',first_derivative.shape)
        print('Pupil second derivative shape:',second_derivative.shape)
        big_X_array = np.stack(
            (big_X_array[0,:],first_derivative,second_derivative),
            axis=0
        )
        
    if chan_select_args:
        channel_mask = get_specific_location(big_data, **chan_select_args)
        big_X_array = big_X_array[channel_mask,...]
    
    cap_names_list = extract_cap_name_list(big_data,keys_list)
    real_cap_name = get_real_cap_name(cap_name,cap_names_list)
    cap_index = [cap_names_list.index(cap) for cap in real_cap_name][0]
    
    if normalization == 'zscore':
        big_X_array = scipy.stats.zscore(big_X_array,axis=1)
    
    windowed_X = build_windowed_data(big_X_array,
                                     window_length,
                                     reduction_method)

    if integrate_pupil:
        pupil_array = create_big_feature_array(
            big_data            = big_data,
            modality            = 'pupil',
            array_name          = 'feature',
            index_to_get        = 1,
            axis_to_get         = 0,
            keys_list           = keys_list
            ) 

        if normalization == 'zscore':
            pupil_array = scipy.stats.zscore(pupil_array,axis=1)
        
        windowed_pupil = build_windowed_data(pupil_array,
                                             window_length,
                                             reduction_method)
    
    
    if integrate_pupil:
        windowed_X = np.concatenate(
            (windowed_X, windowed_pupil),
            axis=1
            )
    
    big_Y_array = create_big_feature_array(
        big_data            = big_data,
        modality            = 'brainstates', 
        array_name          = 'feature',
        index_to_get        = cap_index,
        axis_to_get         = 0,
        keys_list           = keys_list
        )

    if normalization == 'zscore':
        big_Y_array = scipy.stats.zscore(big_Y_array,axis=1)
            
    windowed_Y = np.squeeze(big_Y_array[:,window_length:])

    return windowed_X, windowed_Y

def thresholding_data_rejection(mask: np.ndarray,
                                threshold: int = 20,
                                ) -> np.ndarray[bool]:
    """By studying the mask, reject the window that have too much False.
    
    Based on the windowed mask, it evaluate the amount of data rejected and then
    generate a 1 dimensional boolean mask that will be applied to the
    X and Y data correct the data.

    Args:
        mask (np.ndarray): 2D array of boolean values
        threshold (int, optional): Percentage of data rejected to reject 
        the entire window. Defaults to 20.

    Returns:
        np.ndarray[bool]: A 1 dimensional boolean mask to apply to the data.
    """
    
    valid_data = np.sum(mask, axis = 1)
    percentage = valid_data * 100 / mask.shape[1]

    return np.squeeze(percentage > (100 - threshold))

def interpolate_nan(arr, strategy='imputer'):
    if strategy == 'imputer':
        arr = SimpleImputer(missing_values=np.nan, 
                                     strategy='median', 
                                     copy=False).fit_transform(arr)
    else:
        for i in range(arr.shape[0]):  # Loop through rows
            # Get indices of non-NaN values
            valid_idx = np.nonzero(~np.isnan(arr[i]))[0]
            invalid_idx = np.nonzero(np.isnan(arr[i]))[0]
            
            # If there are enough valid points for cubic spline interpolation
            if len(valid_idx) > 1:
                # Perform cubic spline interpolation
                cs = CubicSpline(valid_idx, arr[i, valid_idx])
                # Replace NaN values with interpolated values
                arr[i, invalid_idx] = cs(invalid_idx)
    
    return arr

def create_train_test_data(big_data: dict,
                           train_sessions: list[str],
                           test_subject: str,
                           test_sessions: str | list[str],
                           task: str,
                           runs: list[str],
                           cap_name: str,
                           X_name: str,
                           band_name: str,
                           window_length: int = 45,
                           chan_select_args = None,
                           masking = False,
                           ) -> tuple[Any,Any,Any,Any]:
    """Create the train and test data using leave one out method.

    Args:
        big_data (dict): The dictionary containing all the data
        test_subject (str): The subject to leave out for testing

    Returns:
        tuple[np.ndarray]: the train and test data
    """
    subjects = [sub.split('-')[1] for sub in big_data.keys()]
    train_subjects = [subject for subject in subjects if subject != test_subject]
    
    print(f'Train subjects: {train_subjects}')
    print(f'Test subject: {test_subject}')
    
    print(f'Train sessions: {train_sessions}')
    print(f'Test session: {test_sessions}')
    
    train_keys = generate_key_list(
        big_data = big_data,
        subjects = train_subjects,
        sessions = train_sessions,
        task     = task,
        runs     = runs
        )
    
    test_keys = generate_key_list(
        big_data = big_data,
        subjects = [test_subject],
        sessions = test_sessions,
        task     = task,
        runs     = runs
        )
    
    if test_keys == []:
        raise ValueError(f'No data for:sub-{test_subject}_ses-{test_sessions}')
    
    X_train, Y_train = create_X_and_Y(
        big_data         = big_data,
        keys_list        = train_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_train shape: {X_train.shape}')
    print(f'Y_train shape: {Y_train.shape}')

    
    X_test, Y_test = create_X_and_Y(
        big_data         = big_data,
        keys_list        = test_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_test shape: {X_test.shape}')
    print(f'Y_test shape: {Y_test.shape}')
    
    if masking:
        train_mask = build_windowed_mask(big_data,
                                        key_list = train_keys,
                                        window_length=window_length,
                                        modalities = ['brainstates','pupil']) # !!! TEMP FIX
        
        test_mask = build_windowed_mask(big_data,
                                        key_list=test_keys, 
                                        window_length=window_length,
                                        modalities=['brainstates','pupil']) # !!! TEMP FIX
        
        if "pupil" in X_name:
            train_mask = np.repeat(train_mask,3, axis=1)
            test_mask = np.repeat(test_mask,3, axis=1)
        
        
        window_rejection_mask_train = thresholding_data_rejection(train_mask)
        X_train_mask = train_mask[window_rejection_mask_train,:]
        X_train = X_train[window_rejection_mask_train,:]
        X_train[~X_train_mask] = np.nan
        X_train = interpolate_nan(X_train, strategy='imputer')
        Y_train = Y_train[window_rejection_mask_train]

        print(f'Train mask shape: {train_mask.shape}')
        print(f'Test mask shape: {test_mask.shape}')
        print(f'X_train shape after masking: {X_train.shape}')
        print(f'Y_train shape after masking: {Y_train.shape}')
        
        window_rejection_mask_test = thresholding_data_rejection(test_mask)
        print(f'X_test shape after masking = {X_test.shape}')
        print(f'Y_test shape after masking = {Y_test.shape}')
        X_test_mask = test_mask[window_rejection_mask_test,:]
        X_test = X_test[window_rejection_mask_test,:]
        X_test[~X_test_mask] = np.nan
        X_test = interpolate_nan(X_test, strategy='imputer')
        Y_test = Y_test[window_rejection_mask_test]
    
    return (X_train, 
            Y_train, 
            X_test, 
            Y_test)

def train_model(big_data,
                test_subject,
                train_sessions,
                test_sessions,
                task,
                runs,
                cap_name,
                X_name,
                band_name,
                window_length,
                chan_select_args = None,
                masking = False,
                model_name = 'ridge',
                viz_path = False):
    
    try:
        X_train, Y_train, X_test, Y_test = create_train_test_data(
        big_data         = big_data,
        train_sessions   = train_sessions,
        test_subject     = test_subject,
        test_sessions    = test_sessions,
        task             = task,
        runs             = runs,
        cap_name         = cap_name,
        X_name           = X_name,
        band_name        = band_name,
        window_length    = window_length,
        chan_select_args = chan_select_args,
        masking          = masking
            )
    except Exception as e:
        print(e)
        #raise e
        return None, None, None, None, None

    if 'ridge' in model_name.lower():
        model = sklearn.linear_model.RidgeCV(cv = 5)

    elif model_name.lower() == 'lasso': 

        alphas = np.linspace(1e-7,1e-4,1000)
        model = sklearn.linear_model.LassoCV(max_iter=10000,
                                             alphas = alphas
                                             )
    
    elif model_name.lower() == 'lassolars':
        model = sklearn.linear_model.LassoLarsCV(max_iter=10000,
                                                 max_n_alphas=1000)

    elif 'hist' in model_name.lower():
        model = HistGradientBoostingRegressor(max_iter=1000)

    elif 'forest' in model_name.lower():
        model = RandomForestRegressor(criterion = 'absolute_error', 
                                    max_features = 'log2', 
                                    n_estimators = 800)
        
    elif 'elastic' in model_name.lower():
        model = sklearn.linear_model.ElasticNetCV(max_iter=10000)
        
    model.fit(X_train,Y_train)
    if viz_path:
        plot_path(X_train,Y_train)

    return model, X_train, Y_train, X_test, Y_test

def plot_path(X_train, Y_train):
    alphas = np.linspace(1e-6,1e-3,1000)
    alphas_lasso, coef_lasso, _ = sklearn.linear_model.lasso_path(
        X_train, 
        Y_train, 
        alphas = alphas,
        max_iter = 10000
        )
    
    plt.figure(figsize=(12, 6))
    plt.plot(alphas_lasso,coef_lasso.T)
    plt.xlabel('alpha')
    plt.ylabel('coefficients')
    plt.title('Coefficient path')
    plt.legend(['delta','theta','alpha','beta','gamma'])
    plt.show()

In [16]:
windowed_mask = build_windowed_mask(big_data=big_d,
                                    key_list=keys_list,
                                    modalities=['pupil', 'brainstates'])

Overall mask shape: (2, 1, 756)
Windowed mask shape: (2, 1, 711, 45)


In [17]:
data = create_big_feature_array(big_data = big_d,
                                modality = 'pupil',
                                array_name = 'feature',
                                index_to_get = 1,
                                axis_to_get = 0,
                                keys_list = keys_list)

In [18]:
data.shape


(2, 1, 756)

In [19]:
data = create_big_feature_array(big_data = big_d,
                                modality = 'brainstate',
                                array_name = 'feature',
                                index_to_get = 1,
                                axis_to_get = 0,
                                keys_list = keys_list)

KeyError: 'brainstate'

In [20]:
data = create_big_feature_array(big_data = big_d,
                                modality = 'brainstates',
                                array_name = 'feature',
                                index_to_get = 1,
                                axis_to_get = 0,
                                keys_list = keys_list)

In [21]:
data.shape

(2, 1, 756)

In [22]:
build_windowed_data(array=data)

ValueError: window shape cannot be larger than input array shape

In [23]:
def filter_data(data: np.ndarray, 
                low_freq_cutoff: float | None = None,
                high_freq_cutoff: float | None = None,
                ):
    """Filter the data using a bandpass filter.

    Args:
        data (np.ndarray): The data to filter
        low_freq_cutoff (float | None): The lower bound to filter. 
                                        Defaults to None.
        high_freq_cutoff (float, optional): The higher bound to filter. 
                                            Defaults to 0.1.

    Returns:
        data: _description_
    """
    if high_freq_cutoff and not low_freq_cutoff:
        filter_type = 'low'
        freq = high_freq_cutoff
    elif low_freq_cutoff and not high_freq_cutoff:
        filter_type = 'high'
        freq = low_freq_cutoff
    elif high_freq_cutoff and low_freq_cutoff:
        filter_type = 'band'
        freq = [low_freq_cutoff, high_freq_cutoff]
    
    filtered_data = scipy.signal.butter(
        4, 
        freq, 
        btype=filter_type,
        output='sos'
        )
    filtered_data = scipy.signal.sosfilt(filtered_data, data, axis=1)
    return filtered_data

def generate_key_list(subjects: list[str] | str,
                      sessions: list[str] | str,
                      task: str,
                      runs: list[str] | str,
                      big_data: dict | None,
                      return_dict = False) -> list[tuple[str, str, str, str]]:
    """Generate a list of keys to access the data in the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        subjects (list[str] | str): The list of subjects to consider
        sessions (list[str] | str): The list of sessions to consider
        tasks (str): The task to consider
        runs (list[str] | str): The list of runs to consider
    
    Returns:
        list[tuple[str]]: The list of keys to access the data
    """
    key_list = list()
    for subject in subjects:
        for session in sessions:
            for run in runs:
                try:
                    big_data[f'sub-{subject}'][f'ses-{session}'][task][f'run-{run}']
                    key_list.append((
                        f'sub-{subject}',
                        f'ses-{session}',
                        task,
                        f'run-{run}'
                        ))
                    
                except:
                    continue
    
    return key_list

def extract_cap_name_list(big_data: dict,
                       keys_list: list[tuple[str, ...]]) -> list[str]:
    """Extract the list of CAP names from the big dictionary.
    
    Args:
        big_data (dict): The big dictionary containing all the data
        keys_list (list): The list of keys to access the data in the dictionary.
    
    Returns:
        list: The list of CAP names
    """
    subject, session, task, run = keys_list[0]
    return big_data[subject][session][task][run]['brainstates']['labels']

def get_real_cap_name(cap_names: str | list[str],
                      cap_list: list[str]) -> list:
    """Get the real CAP name based on a substring from the list of CAP names.
    
    Args:
        cap_name (str): The substring to look for in the list of CAP names
        cap_list (list): The list of CAP names
    
    Returns:
        str: The real CAP name
    """
    real_cap_names = list()
    if isinstance(cap_names,str):
        cap_names = [cap_names]
    for cap_name in cap_names:
        real_cap_names.extend([cap for cap in cap_list if cap_name in cap])
    
    return real_cap_names

def create_big_feature_array(big_data: dict,
                             modality: str,
                             array_name: str,
                             index_to_get: int | None,
                             axis_to_get: int | None,
                             keys_list: list[tuple[str, ...]],
                             axis_to_concatenate: int = 1,
                             subject_agnostic: bool = False
                             ) -> np.ndarray:
    """This is to make a big numpy array across subject.

    It concatenates the features of interest along the time axis (2nd dim).

    Args:
        big_data (dict): The dictionary containing all the data
        to_concat (str): The name of the feature to concatenate (EEGBandEnvelope
                            for example). This choose the feature to consider 
                            later as X.
        index_to_get (int): The index (on the third dimension) of the frequency
                                of interest (or the frequency band). 
                                If None, the entire array is considered.
        axis_to_get (int): The axis to get the data from. If None, the entire
                                array is considered.
        keys_list (list): The list of keys to access the data in the dictionary.
    """
    
    concatenation_list = list()
    for keys in keys_list:
        subject, session, task, run = keys
        if isinstance(index_to_get, int) and isinstance(axis_to_get, int):
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality
                              ][array_name].take(
                                index_to_get,
                                axis = axis_to_get
                            )
        else:
            extracted_array = big_data[subject
                ][session
                    ][task
                        ][run
                            ][modality][array_name]
        
        if extracted_array.ndim < 2:
            extracted_array = np.reshape(extracted_array,(1,extracted_array.shape[0]))
        
        #extracted_array = filter_data(extracted_array,high_freq_cutoff=0.1)
        concatenation_list.append(extracted_array)
    
    if subject_agnostic:
        return np.concatenate(concatenation_list,axis = axis_to_concatenate)
    else:
        return np.array(concatenation_list)

def _find_item(desired_key: str, obj: Dict[str, Any]) -> Any:
    """Find any item in an encapsulated dictionary."

    Args:
        desired_key (str): They key to look for.
        obj (Dict[str, Any]): the dictionary.

    Returns:
        Any: The returned item found in the encapsulated dictionary.
    """
    if obj.get(desired_key) is not None:
        return obj[desired_key]
    
    for value in obj.values():
        if isinstance(value, dict):
            item = _find_item(desired_key, value)
            if item:
                return item

def get_specific_location(big_data: Dict, 
                          channel_names: Optional[List[str]] = None, 
                          anatomical_location: Optional[List[str]] = None, 
                          laterality: Optional[List[str]] = None) -> Union[np.ndarray, None]:
    """
    Filters the channels based on anatomical location, laterality, and channel names.

    Parameters:
    - big_data (dict): The dictionary containing channel information.
    - channel_names (list[str], optional): List of channel names to filter.
    - anatomical_location (list[str], optional): List of anatomical locations to filter.
    - laterality (list[str], optional): List of lateralities to filter.

    Returns:
    - np.ndarray | None: A boolean array indicating the filtered channels or None if no channel info is found.
    """
    channel_info = _find_item("channels_info", big_data)
    if not channel_info:
        return None

    mask = np.zeros(len(channel_info['channel_name']), dtype=bool)

    if anatomical_location:
        
        anatomy_mask = np.isin(
            channel_info.get('anatomy', []), 
            anatomical_location
            )
        
        mask = np.logical_or(mask, anatomy_mask)

    if laterality:
        
        laterality_mask = np.isin(
            channel_info.get('laterality', []), 
            laterality
            )

        if anatomical_location:
            comparison = getattr(np, 'logical_and')
        else:
            comparison = getattr(np, 'logical_or')
        
        mask = comparison(mask, laterality_mask)

    if channel_names:
        
        channel_mask = np.isin(
            channel_info.get('channel_name', []), 
            channel_names
            )
        
        mask = np.logical_or(mask, channel_mask)

    return mask if mask.any() else None

def combine_masks(big_data:dict,
                  key_list: list,
                  modalities: list = ['EEGbandsEnvelopes','brainstates'],
                  subject_agnostic: bool = False
                  ) -> np.ndarray[bool]:
    masks = []
    for modality in modalities:
        if "envelopes" in modality.lower() or "tfr" in modality.lower():
            array_name = 'artifact_mask'
            index_to_get = None
            axis_to_get = None

        else:
            array_name = 'feature'
            index_to_get = -1
            axis_to_get = 0
        
        temp_mask = create_big_feature_array(
            big_data = big_data,
            modality = modality,
            array_name = array_name,
            index_to_get = index_to_get,
            axis_to_get=axis_to_get,
            keys_list=key_list,
            axis_to_concatenate=0
        )
        if subject_agnostic:
            temp_mask = temp_mask.flatten()
            
        masks.append(temp_mask > 0.5)
    
    masks = np.array(masks)
    overall_mask = np.all(masks, axis = 0)
    print('Overall mask shape:',overall_mask.shape)
    return np.all(masks, axis = 0)
        
def build_windowed_mask(big_data: dict,
                        key_list:list,
                        window_length: int = 45,
                        modalities = ['brainstates'],
                        subject_agnostic: bool = False
                        ) -> np.ndarray:
    """Build the mask based on the brainstate and EEG ones.
    
    The mask will be windowed to match the windowed data.

    Args:
        big_data (dict): The dictionary containing all the data
        window_length (int, optional): The length of the sliding window in
                                       samples. Defaults to 45.
        steps (int, optional): The sliding steps in samples. Defaults to 1.

    Returns:
        np.ndarray: The windowed mask
    """

    joined_masks = combine_masks(big_data,
                                 key_list,
                                 modalities = modalities,
                                 subject_agnostic = subject_agnostic)
    
    windowed_mask = sliding_window_view(joined_masks[:,:,:-1], 
                                        window_shape=window_length,
                                        axis = 2)
    print('Windowed mask shape:',windowed_mask.shape)
    if subject_agnostic:
        max_dim = 4
        axis = 2
    else:
        max_dim = 3
        axis = 1
        
    if np.ndim(windowed_mask) < max_dim:
        return windowed_mask
    else:
        # Take the case of EEG channels. If there is one channel not good, reject the entire window.
        return np.all(windowed_mask, axis = axis)

def build_windowed_data(array: np.ndarray,
                        window_length: int = 45,
                        reduction: str = "flatten",
                        subject_agnostic = False) -> np.ndarray:
    if array.ndim == 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:-1], 
            window_shape=window_length, 
        )

    elif array.ndim > 1:
        windowed_data = np.lib.stride_tricks.sliding_window_view(
            array[:,:,:-1,...], 
            window_shape=window_length, 
            axis=2
        )
        new_shape_X = (windowed_data.shape[1], -1) + windowed_data.shape[3:]
        windowed_data = windowed_data.transpose(1, 0, 2, *range(3, array.ndim + 1))
        windowed_data = windowed_data.reshape(new_shape_X)
        
        if array.ndim > 2:
            if reduction == 'flatten':
                windowed_data = windowed_data.reshape(windowed_data.shape[0], -1)
            
            elif reduction == 'gfp':
                windowed_data = np.squeeze(np.var(windowed_data, axis=0))
            
    return windowed_data
    
def create_X_and_Y(big_data: dict,
                   keys_list: list[tuple[str, ...]],
                   X_name: str,
                   cap_name: str,
                   bands_names: str | list | None =  None,
                   chan_select_args: Dict[str,str] | None = None,
                   normalization: str | None = 'zscore',
                   reduction_method: str = 'flatten',
                   window_length: int = 45,
                   integrate_pupil: bool = True,
                  ) -> tuple[Any,Any]:
    
    bands_list = ['delta','theta','alpha','beta','gamma']

    if isinstance(bands_names,list):
        index_band = [bands_list.index(band) for band in bands_names]

    elif isinstance(bands_names, str):
        index_band = bands_list.index(bands_names)
    elif not bands_names:
        pass
    
    if "pupil" in X_name:
        index_to_get = 1 
        axis_to_get = 0
        integrate_pupil = False
    
    elif "envelopes" in X_name.lower() or "tfr" in X_name.lower():
        index_to_get = -1
        axis_to_get = index_band

    big_X_array = create_big_feature_array(
        big_data            = big_data,
        modality            = X_name,
        array_name          = 'feature',
        index_to_get        = index_to_get, #To modify for EEG band. It is now for pupil
        axis_to_get         = axis_to_get,
        keys_list           = keys_list
        )

    if "pupil" in X_name:
        first_derivative = np.diff(big_X_array[0,:], append = 0)
        second_derivative = np.diff(first_derivative, append = 0)
        print('Pupil shape:',big_X_array.shape)
        print('Pupil first derivative shape:',first_derivative.shape)
        print('Pupil second derivative shape:',second_derivative.shape)
        big_X_array = np.stack(
            (big_X_array[0,:],first_derivative,second_derivative),
            axis=0
        )
        
    if chan_select_args:
        channel_mask = get_specific_location(big_data, **chan_select_args)
        big_X_array = big_X_array[channel_mask,...]
    
    cap_names_list = extract_cap_name_list(big_data,keys_list)
    real_cap_name = get_real_cap_name(cap_name,cap_names_list)
    cap_index = [cap_names_list.index(cap) for cap in real_cap_name][0]
    
    if normalization == 'zscore':
        big_X_array = scipy.stats.zscore(big_X_array,axis=1)
    
    windowed_X = build_windowed_data(big_X_array,
                                     window_length,
                                     reduction_method)

    if integrate_pupil:
        pupil_array = create_big_feature_array(
            big_data            = big_data,
            modality            = 'pupil',
            array_name          = 'feature',
            index_to_get        = 1,
            axis_to_get         = 0,
            keys_list           = keys_list
            ) 

        if normalization == 'zscore':
            pupil_array = scipy.stats.zscore(pupil_array,axis=1)
        
        windowed_pupil = build_windowed_data(pupil_array,
                                             window_length,
                                             reduction_method)
    
    
    if integrate_pupil:
        windowed_X = np.concatenate(
            (windowed_X, windowed_pupil),
            axis=1
            )
    
    big_Y_array = create_big_feature_array(
        big_data            = big_data,
        modality            = 'brainstates', 
        array_name          = 'feature',
        index_to_get        = cap_index,
        axis_to_get         = 0,
        keys_list           = keys_list
        )

    if normalization == 'zscore':
        big_Y_array = scipy.stats.zscore(big_Y_array,axis=1)
            
    windowed_Y = np.squeeze(big_Y_array[:,window_length:])

    return windowed_X, windowed_Y

def thresholding_data_rejection(mask: np.ndarray,
                                threshold: int = 20,
                                ) -> np.ndarray[bool]:
    """By studying the mask, reject the window that have too much False.
    
    Based on the windowed mask, it evaluate the amount of data rejected and then
    generate a 1 dimensional boolean mask that will be applied to the
    X and Y data correct the data.

    Args:
        mask (np.ndarray): 2D array of boolean values
        threshold (int, optional): Percentage of data rejected to reject 
        the entire window. Defaults to 20.

    Returns:
        np.ndarray[bool]: A 1 dimensional boolean mask to apply to the data.
    """
    
    valid_data = np.sum(mask, axis = 1)
    percentage = valid_data * 100 / mask.shape[1]

    return np.squeeze(percentage > (100 - threshold))

def interpolate_nan(arr, strategy='imputer'):
    if strategy == 'imputer':
        arr = SimpleImputer(missing_values=np.nan, 
                                     strategy='median', 
                                     copy=False).fit_transform(arr)
    else:
        for i in range(arr.shape[0]):  # Loop through rows
            # Get indices of non-NaN values
            valid_idx = np.nonzero(~np.isnan(arr[i]))[0]
            invalid_idx = np.nonzero(np.isnan(arr[i]))[0]
            
            # If there are enough valid points for cubic spline interpolation
            if len(valid_idx) > 1:
                # Perform cubic spline interpolation
                cs = CubicSpline(valid_idx, arr[i, valid_idx])
                # Replace NaN values with interpolated values
                arr[i, invalid_idx] = cs(invalid_idx)
    
    return arr

def create_train_test_data(big_data: dict,
                           train_sessions: list[str],
                           test_subject: str,
                           test_sessions: str | list[str],
                           task: str,
                           runs: list[str],
                           cap_name: str,
                           X_name: str,
                           band_name: str,
                           window_length: int = 45,
                           chan_select_args = None,
                           masking = False,
                           ) -> tuple[Any,Any,Any,Any]:
    """Create the train and test data using leave one out method.

    Args:
        big_data (dict): The dictionary containing all the data
        test_subject (str): The subject to leave out for testing

    Returns:
        tuple[np.ndarray]: the train and test data
    """
    subjects = [sub.split('-')[1] for sub in big_data.keys()]
    train_subjects = [subject for subject in subjects if subject != test_subject]
    
    print(f'Train subjects: {train_subjects}')
    print(f'Test subject: {test_subject}')
    
    print(f'Train sessions: {train_sessions}')
    print(f'Test session: {test_sessions}')
    
    train_keys = generate_key_list(
        big_data = big_data,
        subjects = train_subjects,
        sessions = train_sessions,
        task     = task,
        runs     = runs
        )
    
    test_keys = generate_key_list(
        big_data = big_data,
        subjects = [test_subject],
        sessions = test_sessions,
        task     = task,
        runs     = runs
        )
    
    if test_keys == []:
        raise ValueError(f'No data for:sub-{test_subject}_ses-{test_sessions}')
    
    X_train, Y_train = create_X_and_Y(
        big_data         = big_data,
        keys_list        = train_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_train shape: {X_train.shape}')
    print(f'Y_train shape: {Y_train.shape}')

    
    X_test, Y_test = create_X_and_Y(
        big_data         = big_data,
        keys_list        = test_keys,
        X_name           = X_name,
        bands_names      = band_name,
        cap_name         = cap_name,
        normalization    = 'zscore',
        chan_select_args = chan_select_args,
        window_length    = window_length,
        reduction_method = 'flatten',
        )

    
    print(f'X_test shape: {X_test.shape}')
    print(f'Y_test shape: {Y_test.shape}')
    
    if masking:
        train_mask = build_windowed_mask(big_data,
                                        key_list = train_keys,
                                        window_length=window_length,
                                        modalities = ['brainstates','pupil']) # !!! TEMP FIX
        
        test_mask = build_windowed_mask(big_data,
                                        key_list=test_keys, 
                                        window_length=window_length,
                                        modalities=['brainstates','pupil']) # !!! TEMP FIX
        
        if "pupil" in X_name:
            train_mask = np.repeat(train_mask,3, axis=1)
            test_mask = np.repeat(test_mask,3, axis=1)
        
        
        window_rejection_mask_train = thresholding_data_rejection(train_mask)
        X_train_mask = train_mask[window_rejection_mask_train,:]
        X_train = X_train[window_rejection_mask_train,:]
        X_train[~X_train_mask] = np.nan
        X_train = interpolate_nan(X_train, strategy='imputer')
        Y_train = Y_train[window_rejection_mask_train]

        print(f'Train mask shape: {train_mask.shape}')
        print(f'Test mask shape: {test_mask.shape}')
        print(f'X_train shape after masking: {X_train.shape}')
        print(f'Y_train shape after masking: {Y_train.shape}')
        
        window_rejection_mask_test = thresholding_data_rejection(test_mask)
        print(f'X_test shape after masking = {X_test.shape}')
        print(f'Y_test shape after masking = {Y_test.shape}')
        X_test_mask = test_mask[window_rejection_mask_test,:]
        X_test = X_test[window_rejection_mask_test,:]
        X_test[~X_test_mask] = np.nan
        X_test = interpolate_nan(X_test, strategy='imputer')
        Y_test = Y_test[window_rejection_mask_test]
    
    return (X_train, 
            Y_train, 
            X_test, 
            Y_test)

def train_model(big_data,
                test_subject,
                train_sessions,
                test_sessions,
                task,
                runs,
                cap_name,
                X_name,
                band_name,
                window_length,
                chan_select_args = None,
                masking = False,
                model_name = 'ridge',
                viz_path = False):
    
    try:
        X_train, Y_train, X_test, Y_test = create_train_test_data(
        big_data         = big_data,
        train_sessions   = train_sessions,
        test_subject     = test_subject,
        test_sessions    = test_sessions,
        task             = task,
        runs             = runs,
        cap_name         = cap_name,
        X_name           = X_name,
        band_name        = band_name,
        window_length    = window_length,
        chan_select_args = chan_select_args,
        masking          = masking
            )
    except Exception as e:
        print(e)
        #raise e
        return None, None, None, None, None

    if 'ridge' in model_name.lower():
        model = sklearn.linear_model.RidgeCV(cv = 5)

    elif model_name.lower() == 'lasso': 

        alphas = np.linspace(1e-7,1e-4,1000)
        model = sklearn.linear_model.LassoCV(max_iter=10000,
                                             alphas = alphas
                                             )
    
    elif model_name.lower() == 'lassolars':
        model = sklearn.linear_model.LassoLarsCV(max_iter=10000,
                                                 max_n_alphas=1000)

    elif 'hist' in model_name.lower():
        model = HistGradientBoostingRegressor(max_iter=1000)

    elif 'forest' in model_name.lower():
        model = RandomForestRegressor(criterion = 'absolute_error', 
                                    max_features = 'log2', 
                                    n_estimators = 800)
        
    elif 'elastic' in model_name.lower():
        model = sklearn.linear_model.ElasticNetCV(max_iter=10000)
        
    model.fit(X_train,Y_train)
    if viz_path:
        plot_path(X_train,Y_train)

    return model, X_train, Y_train, X_test, Y_test

def plot_path(X_train, Y_train):
    alphas = np.linspace(1e-6,1e-3,1000)
    alphas_lasso, coef_lasso, _ = sklearn.linear_model.lasso_path(
        X_train, 
        Y_train, 
        alphas = alphas,
        max_iter = 10000
        )
    
    plt.figure(figsize=(12, 6))
    plt.plot(alphas_lasso,coef_lasso.T)
    plt.xlabel('alpha')
    plt.ylabel('coefficients')
    plt.title('Coefficient path')
    plt.legend(['delta','theta','alpha','beta','gamma'])
    plt.show()

In [24]:
build_windowed_data(array=data)

array([[ 0.08573921,  0.08846208,  0.09048314, ..., -0.17956672,
        -0.15462392, -0.12533475]])

In [25]:
a = build_windowed_data(data)

In [26]:
a.shape

(1, 63990)

In [27]:
data

array([[[ 0.08573921,  0.08846208,  0.09048314, ..., -0.19578882,
         -0.22103515, -0.24334531]],

       [[-0.07714075, -0.03197379,  0.00772897, ..., -0.15462392,
         -0.12533475, -0.09301757]]])

In [28]:
data.shape

(2, 1, 756)

In [29]:
windowed_data = np.lib.stride_tricks.sliding_window_view(
            data[:,:,:-1,...], 
            window_shape=45, 
            axis=2
        )

In [30]:
windowed_data

array([[[[ 0.08573921,  0.08846208,  0.09048314, ..., -0.29755922,
          -0.28849872, -0.27688954],
         [ 0.08846208,  0.09048314,  0.09178194, ..., -0.28849872,
          -0.27688954, -0.26281565],
         [ 0.09048314,  0.09178194,  0.09233806, ..., -0.27688954,
          -0.26281565, -0.24636105],
         ...,
         [-0.08941769, -0.05839202, -0.02681673, ..., -0.10314436,
          -0.13657576, -0.16760603],
         [-0.05839202, -0.02681673,  0.00485622, ..., -0.13657576,
          -0.16760603, -0.19578882],
         [-0.02681673,  0.00485622,  0.03617487, ..., -0.16760603,
          -0.19578882, -0.22103515]]],


       [[[-0.07714075, -0.03197379,  0.00772897, ..., -0.04586015,
          -0.04909177, -0.05167804],
         [-0.03197379,  0.00772897,  0.04230715, ..., -0.04909177,
          -0.05167804, -0.05366401],
         [ 0.00772897,  0.04230715,  0.07210036, ..., -0.05167804,
          -0.05366401, -0.05509477],
         ...,
         [ 0.22419547,  0.235200

In [31]:
windowed_data.shape

(2, 1, 711, 45)

In [32]:
masks = combine_masks(big_data = big_d,
                      key_list = keys_list,
                      modalities = ['brainstates','pupil'])

Overall mask shape: (2, 1, 756)


In [33]:
windowed_mask = build_windowed_mask(big_data=big_d,
                                    key_list=keys_list,
                                    modalities=['pupil', 'brainstates'])

Overall mask shape: (2, 1, 756)
Windowed mask shape: (2, 1, 711, 45)


In [34]:
np.diff(data[:,0,:], axis = 2, append = 0)

AxisError: axis 2 is out of bounds for array of dimension 2

In [35]:
data

array([[[ 0.08573921,  0.08846208,  0.09048314, ..., -0.19578882,
         -0.22103515, -0.24334531]],

       [[-0.07714075, -0.03197379,  0.00772897, ..., -0.15462392,
         -0.12533475, -0.09301757]]])

In [36]:
data.shape

(2, 1, 756)

In [37]:
np.diff(data, axis = 2, append = 0)

array([[[ 0.00272287,  0.00202105,  0.0012988 , ..., -0.02524632,
         -0.02231017,  0.24334531]],

       [[ 0.04516696,  0.03970276,  0.03457818, ...,  0.02928917,
          0.03231717,  0.09301757]]])

In [38]:
a = np.diff(data, axis = 2, append = 0)

In [39]:
a.shape

(2, 1, 756)

In [40]:
 first_derivative = np.diff(data, axis = 2, append = 0)
second_derivative = np.diff(first_derivative, axis = 2, append = 0)

In [41]:
big_X_array = np.stack(
            (data,first_derivative,second_derivative),
            axis=1
        )

In [42]:
big_X_array

array([[[[ 0.08573921,  0.08846208,  0.09048314, ..., -0.19578882,
          -0.22103515, -0.24334531]],

        [[ 0.00272287,  0.00202105,  0.0012988 , ..., -0.02524632,
          -0.02231017,  0.24334531]],

        [[-0.00070182, -0.00072225, -0.00074268, ...,  0.00293615,
           0.26565548, -0.24334531]]],


       [[[-0.07714075, -0.03197379,  0.00772897, ..., -0.15462392,
          -0.12533475, -0.09301757]],

        [[ 0.04516696,  0.03970276,  0.03457818, ...,  0.02928917,
           0.03231717,  0.09301757]],

        [[-0.0054642 , -0.00512458, -0.00478496, ...,  0.003028  ,
           0.0607004 , -0.09301757]]]])

In [43]:
big_X_array.shape

(2, 3, 1, 756)

In [44]:
big_X_array = np.concatenate(
            (data,first_derivative,second_derivative),
            axis=1
        )

In [45]:
big_X_array.shape

(2, 3, 756)

In [46]:
big_d

{'sub-01': {'ses-01': {'checker': {'run-01BlinksRemoved': {'brainstates': {'time': array([  1.05  ,   1.3125,   1.575 ,   1.8375,   2.1   ,   2.3625,
               2.625 ,   2.8875,   3.15  ,   3.4125,   3.675 ,   3.9375,
               4.2   ,   4.4625,   4.725 ,   4.9875,   5.25  ,   5.5125,
               5.775 ,   6.0375,   6.3   ,   6.5625,   6.825 ,   7.0875,
               7.35  ,   7.6125,   7.875 ,   8.1375,   8.4   ,   8.6625,
               8.925 ,   9.1875,   9.45  ,   9.7125,   9.975 ,  10.2375,
              10.5   ,  10.7625,  11.025 ,  11.2875,  11.55  ,  11.8125,
              12.075 ,  12.3375,  12.6   ,  12.8625,  13.125 ,  13.3875,
              13.65  ,  13.9125,  14.175 ,  14.4375,  14.7   ,  14.9625,
              15.225 ,  15.4875,  15.75  ,  16.0125,  16.275 ,  16.5375,
              16.8   ,  17.0625,  17.325 ,  17.5875,  17.85  ,  18.1125,
              18.375 ,  18.6375,  18.9   ,  19.1625,  19.425 ,  19.6875,
              19.95  ,  20.2125,  20.475 ,  20.

In [47]:
big_d.keys()

dict_keys(['sub-01', 'sub-02', 'sub-03', 'sub-05', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-13', 'sub-14', 'sub-15', 'sub-16', 'sub-18', 'sub-20', 'sub-21', 'sub-22'])

In [48]:
big_d['sub-01'].keys()

dict_keys(['ses-01'])

In [49]:
big_d['sub-01']['ses-01'].keys()

dict_keys(['checker'])

In [50]:
big_d['sub-01']['ses-01']['checker'].keys()

dict_keys(['run-01BlinksRemoved'])

In [51]:
big_d['sub-01']['ses-01']['checker']['run-01BlinksRemoved'].keys()

dict_keys(['brainstates', 'pupil', 'respiration', 'EEGbandsEnvelopes', 'CustomEnvelopes', 'MorletTFR'])

In [52]:
big_d['sub-01']['ses-01']['checker']['run-01BlinksRemoved']['pupil'].keys()

dict_keys(['time', 'labels', 'feature', 'feature_info'])

In [53]:
big_d['sub-01']['ses-01']['checker']['run-01BlinksRemoved']['pupil']['labels'].keys()

AttributeError: 'list' object has no attribute 'keys'

In [54]:
big_d['sub-01']['ses-01']['checker']['run-01BlinksRemoved']['pupil']['labels']

['pupil_size', 'X_position', 'Y_position', 'tmask']

In [55]:
big_d['sub-01']['ses-01']['checker']['run-01BlinksRemoved']['pupil']['feature']

array([[2.20605000e+03, 2.31969392e+03, 2.38213323e+03, ...,
        8.42903625e+02, 8.47308300e+02, 7.89433825e+02],
       [7.32190000e+02, 7.21630756e+02, 7.31011132e+02, ...,
        5.59520226e+02, 5.58074861e+02, 5.43317066e+02],
       [3.56600000e+02, 1.88609946e+02, 1.25748331e+02, ...,
        3.78516542e+02, 3.82081468e+02, 3.95230659e+02],
       [1.00000000e+00, 1.00000000e+00, 1.00000000e+00, ...,
        1.00000056e+00, 1.00000090e+00, 1.00000079e+00]])

In [56]:
big_d['sub-01']['ses-01']['checker']['run-01BlinksRemoved']['pupil']['feature'].shape

(4, 756)

In [57]:
big_d['sub-01']['ses-01']['checker']['run-01BlinksRemoved']['pupil']['feature'][0,:]

array([2206.05      , 2319.69392148, 2382.13323169, 2405.58092605,
       2402.25      , 2384.35344895, 2364.10426831, 2353.71545352,
       2365.4       , 2406.32197024, 2463.44969506, 2518.70257235,
       2554.        , 2556.68015447, 2535.75632645, 2505.66058521,
       2480.825     , 2472.33241188, 2477.86874912, 2491.77071181,
       2508.375     , 2523.10347926, 2535.71867705, 2547.06828632,
       2558.        , 2568.69742108, 2576.68779268, 2578.83426793,
       2572.        , 2554.45488329, 2530.09577724, 2504.22628258,
       2482.15      , 2468.28031139, 2463.46972335, 2467.68052363,
       2480.875     , 2502.05473051, 2526.37845437, 2548.04420103,
       2561.25      , 2562.11678217, 2554.45708419, 2544.00634411,
       2536.5       , 2536.35314079, 2542.69945887, 2553.35229752,
       2566.125     , 2578.68862343, 2588.14508032, 2591.45399705,
       2585.575     , 2569.28611551, 2548.63896984, 2531.50358926,
       2525.75      , 2536.90324267, 2561.1084153 , 2592.16563

In [58]:
keys_list = generate_key_list(subjects=['01','02'],)
create_big_feature_array(big_data: dict = big_d,
                             modality: str = 'pupil',
                             array_name: str = 'feature',
                             index_to_get: int | None = 1,
                             axis_to_get: int | None = 0,
                             keys_list: list[tuple[str, ...]],
                             axis_to_concatenate: int = 1,
                             subject_agnostic: bool = False
                             )

SyntaxError: invalid syntax (<ipython-input-58-c4a2c54248d0>, line 2)

In [59]:
keys_list = generate_key_list(subjects=['01','02'],
                              sessions=['01','02'],
                              task = 'checker',
                              runs = ['01BlinksRemoved'],
                              big_data = big_d)
create_big_feature_array(big_data: dict = big_d,
                             modality: str = 'pupil',
                             array_name: str = 'feature',
                             index_to_get: int | None = 1,
                             axis_to_get: int | None = 0,
                             keys_list: list[tuple[str, ...]],
                             axis_to_concatenate: int = 1,
                             subject_agnostic: bool = False
                             )

SyntaxError: invalid syntax (<ipython-input-59-c14f35c0c252>, line 6)

In [60]:
keys_list = generate_key_list(subjects=['01','02'],
                              sessions=['01','02'],
                              task = 'checker',
                              runs = ['01BlinksRemoved'],
                              big_data = big_d)
create_big_feature_array(big_data= big_d,
                             modality = 'pupil',
                             array_name = 'feature',
                             index_to_get = 1,
                             axis_to_get = 0,
                             keys_list = keys_list,
                             )

array([[[732.19      , 721.63075631, 731.01113221, ..., 559.52022583,
         558.07486132, 543.31706616]],

       [[565.8       , 570.06061691, 573.14070504, ..., 482.97845266,
         480.70802425, 484.23358372]]])

In [61]:
keys_list = generate_key_list(subjects=['01','02'],
                              sessions=['01','02'],
                              task = 'checker',
                              runs = ['01BlinksRemoved'],
                              big_data = big_d)
a = create_big_feature_array(big_data= big_d,
                             modality = 'pupil',
                             array_name = 'feature',
                             index_to_get = 1,
                             axis_to_get = 0,
                             keys_list = keys_list,
                             )

In [62]:
a.shape

(2, 1, 756)

In [63]:
masks = combine_masks(big_d,key_list=keys_list,modalities=['pupil','brainstate'])

KeyError: 'brainstate'

In [64]:
masks = combine_masks(big_d,key_list=keys_list,modalities=['pupil','brainstates'])

Overall mask shape: (2, 1, 756)


In [65]:
windowed_mask = build_windowed_data(big_d,keys_list,window_length = 45,modalities = ['pupil','brainstates'])

TypeError: build_windowed_data() got multiple values for argument 'window_length'

In [66]:
windowed_mask = build_windowed_data(big_d,key_list = keys_list,window_length = 45,modalities = ['pupil','brainstates'])

TypeError: build_windowed_data() got an unexpected keyword argument 'key_list'

In [67]:
windowed_mask = build_windowed_mask(big_d,key_list = keys_list,window_length = 45,modalities = ['pupil','brainstates'])

Overall mask shape: (2, 1, 756)
Windowed mask shape: (2, 1, 711, 45)


In [68]:
data.shape

(2, 1, 756)

In [69]:
windowed_data

array([[[[ 0.08573921,  0.08846208,  0.09048314, ..., -0.29755922,
          -0.28849872, -0.27688954],
         [ 0.08846208,  0.09048314,  0.09178194, ..., -0.28849872,
          -0.27688954, -0.26281565],
         [ 0.09048314,  0.09178194,  0.09233806, ..., -0.27688954,
          -0.26281565, -0.24636105],
         ...,
         [-0.08941769, -0.05839202, -0.02681673, ..., -0.10314436,
          -0.13657576, -0.16760603],
         [-0.05839202, -0.02681673,  0.00485622, ..., -0.13657576,
          -0.16760603, -0.19578882],
         [-0.02681673,  0.00485622,  0.03617487, ..., -0.16760603,
          -0.19578882, -0.22103515]]],


       [[[-0.07714075, -0.03197379,  0.00772897, ..., -0.04586015,
          -0.04909177, -0.05167804],
         [-0.03197379,  0.00772897,  0.04230715, ..., -0.04909177,
          -0.05167804, -0.05366401],
         [ 0.00772897,  0.04230715,  0.07210036, ..., -0.05167804,
          -0.05366401, -0.05509477],
         ...,
         [ 0.22419547,  0.235200

In [70]:
windowed_data.shape

(2, 1, 711, 45)

In [71]:
X, Y = create_X_and_Y(big_data=big_d,
                      keys_list=keys_list,
                      X_name = 'pupil,
                      cap_name='CAPS1',
                      )

SyntaxError: unterminated string literal (detected at line 3) (<ipython-input-71-2db8246a3a09>, line 3)

In [72]:
X, Y = create_X_and_Y(big_data=big_d,
                      keys_list=keys_list,
                      X_name = 'pupil',
                      cap_name='CAPS1',
                      )

Pupil shape: (2, 1, 756)
Pupil first derivative shape: (1, 756)
Pupil second derivative shape: (1, 756)


IndexError: list index out of range

In [73]:
X, Y = create_X_and_Y(big_data=big_d,
                      keys_list=keys_list,
                      X_name = 'pupil',
                      cap_name='tsCAP1',
                      )

Pupil shape: (2, 1, 756)
Pupil first derivative shape: (1, 756)
Pupil second derivative shape: (1, 756)


In [74]:
X.shape

(1, 95985)

In [75]:
Y.shape

(2, 0, 756)