In [1]:
import mne
import pickle
import pyriemann
import numpy as np
from pyriemann.tangentspace import TangentSpace
from pyriemann.estimation import Shrinkage, Covariances
from pyriemann.preprocessing import Whitening
import s00_helper_functions_data_bva_edf3 as helper
import os
import pandas as pd
import envelope_and_TF_pipeline as etf

In [2]:
eeg_foldername = '/projects/EEG_FMRI/bids_eeg/BIDS/NEW/PREP_BV_EDF'
fmri_foldername = '/data2/Projects/eeg_fmri_natview/data/fmriprep/derivatives/cap_ts'
eeg_filename = 'sub-02_ses-01_task-checker_run-01_eeg.edf'
raw = mne.io.read_raw(os.path.join(eeg_foldername, eeg_filename), preload=True)
raw = etf.extract_eeg_only(raw)
raw = etf.specific_crop(raw, 1, False)
blinks_remover = etf.BlinkRemover(raw).remove_blinks()
blinks_removed_raw = blinks_remover.blink_removed_raw
muscle_annotations, _ = mne.preprocessing.annotate_muscle_zscore(blinks_removed_raw, 
                                            threshold       = 3.5, 
                                            ch_type         = None, 
                                            min_length_good = 0.1, 
                                            filter_freq     = (30, None), 
                                            )
blinks_removed_raw.set_annotations(blinks_removed_raw.annotations + muscle_annotations)
events = mne.events_from_annotations(blinks_removed_raw)
trigger_name = 'R128'
filtered_events_mask = np.where(events[0][:, 2] == events[1][trigger_name])
filtered_events = events[0][filtered_events_mask]

brainstate_exists, brainstate = helper.get_brainstate_data(
    brainstate_dir = fmri_foldername,
    sub='01', 
    ses='01', 
    task='checker')

epochs = mne.Epochs(blinks_removed_raw, 
                    filtered_events, 
                    tmin=-4, 
                    tmax=-1/raw.info['sfreq'], 
                    baseline=None, 
                    preload=True,
                    metadata=brainstate,
                    reject_by_annotation=True
                    )
epochs.drop(epochs.metadata[epochs.metadata['Mask'] == 0].index)


Extracting EDF parameters from /projects/EEG_FMRI/bids_eeg/BIDS/NEW/PREP_BV_EDF/sub-02_ses-01_task-checker_run-01_eeg.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


Reading 0 ... 66499  =      0.000 ...   265.996 secs...
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['New Segment', 'R', 'R128', 'S  1', 'S 10', 'S 11', 'S 12', 'S 25', 'S 26', 'S 27', 'S 99', 'Sync On', 'TEND', 'TPEAK', 'TSTART', 'Time 0', 'Userdefined Artifact']
Used Annotations descriptions: ['New Segment', 'R', 'R128', 'S  1', 'S 10', 'S 11', 'S 12', 'S 25', 'S 26', 'S 27', 'S 99', 'Sync On', 'TEND', 'TPEAK', 'TSTART', 'Time 0', 'Userdefined Artifact']
cropping from 36.568 to 244.368
Running EOG SSP computation
Using EOG channels: Fp1, Fp2
EOG channel index for this subject is: [0 1]
Filtering the data to remove DC offset to help distinguish blinks from saccades
Selecting channel Fp1 for blink detection
Setting up band-pass filter from 1 - 10 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed frequency-domain design (firwin2) meth

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


Setting up low-pass filter at 4 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal lowpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Upper passband edge: 4.00 Hz
- Upper transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 5.00 Hz)
- Filter length: 413 samples (1.652 s)

Used Annotations descriptions: ['R', 'R128', 'S  1', 'S 10', 'S 11', 'S 12', 'S 25', 'S 26', 'S 27', 'S 99', 'Sync On', 'TEND', 'TPEAK', 'TSTART']
/data2/Projects/eeg_fmri_natview/data/fmriprep/derivatives/cap_ts/sub-01_ses-01_task-checker.txt
Adding metadata with 9 columns
98 matching events found
No baseline correction applied
Created an SSP operator (subspace dimension = 1)
1 projection items activated
Using data from preloaded Raw for 98 events and 1000 original time points ...
5 bad epochs dropped
Dropped 0 epochs: 


Unnamed: 0,General,General.1
,MNE object type,Epochs
,Measurement date,2000-01-01 at 00:00:00 UTC
,Participant,X
,Experimenter,Unknown
,Acquisition,Acquisition
,Total number of events,93
,Events counts,2: 93
,Time range,-4.000 – -0.004 s
,Baseline,off
,Sampling frequency,250.00 Hz


In [2]:
eeg_foldername = '/projects/EEG_FMRI/bids_eeg/BIDS/NEW/PREP_BV_EDF'
fmri_foldername = '/data2/Projects/eeg_fmri_natview/data/fmriprep/derivatives/cap_ts'
eeg_filename = 'sub-14_ses-01_task-checker_run-01_eeg.edf'

def prepare_raw(eeg_foldername, eeg_filename):
    raw = mne.io.read_raw(os.path.join(eeg_foldername, eeg_filename), preload=True)
    raw = etf.extract_eeg_only(raw)
    raw = etf.specific_crop(raw, 1, False)
    blinks_remover = etf.BlinkRemover(raw).remove_blinks()
    blinks_removed_raw = blinks_remover.blink_removed_raw
    muscle_annotations, _ = mne.preprocessing.annotate_muscle_zscore(blinks_removed_raw, 
                                             threshold=3.5, 
                                             ch_type=None, 
                                             min_length_good=0.1, 
                                             filter_freq=(30, None), 
                                             )
    blinks_removed_raw.set_annotations(blinks_removed_raw.annotations + muscle_annotations)
    return blinks_removed_raw
def create_epochs(raw: mne.io.Raw,
                  brainstate: pd.DataFrame):
    events = mne.events_from_annotations(raw)
    trigger_name = 'R128'
    filtered_events_mask = np.where(events[0][:, 2] == events[1][trigger_name])
    filtered_events = events[0][filtered_events_mask]


    epochs = mne.Epochs(raw,
                        filtered_events, 
                        tmin=-4, 
                        tmax=-1/raw.info['sfreq'], 
                        baseline=None, 
                        preload=True,
                        metadata=brainstate,
                        reject_by_annotation=True
                        )
    epochs.drop(np.logical_not(epochs.metadata['Mask'].astype(bool)))
    return epochs

def calculate_tangent_space(epochs: mne.Epochs):
    covariance = Covariances(estimator='lwf').transform(epochs.get_data())
    whitener = Whitening(metric = 'riemann')
    whitened_cov_matrices = whitener.fit_transform(covariance)
    tangent_space = TangentSpace()
    X_tangent = tangent_space.fit_transform(whitened_cov_matrices)
    return X_tangent


In [3]:
%%capture
subjects = [ '01', '02', '03', '05', '06', '07', '08', '09', '10',
    '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22']
sessions = [ '01', '02' ] # 
tasks = [ 'checker', 'rest' ]
runs = [ '01', '02' ]
concatenated_data = {}
for subject in subjects:
    for session in sessions:
        for task in tasks:
            for run in runs:
                eeg_filename = f'sub-{subject}_ses-{session}_task-{task}_run-{run}_eeg.edf'
                brainstate_exists, brainstate = helper.get_brainstate_data(
                    brainstate_dir=fmri_foldername,
                    sub=subject,
                    ses=session,
                    task=task,
                )
                both_file_exist = (
                    os.path.exists(os.path.join(eeg_foldername, eeg_filename)) and
                    brainstate_exists
                )
                if both_file_exist:
                    try:
                        raw = prepare_raw(eeg_foldername, eeg_filename)
                        raw.filter(0.1, 40)
                        epochs = create_epochs(raw, brainstate)
                        tangent_array = calculate_tangent_space(epochs)
                        caps = {}
                        for cap_name in brainstate.columns:
                            if cap_name != 'Mask':
                                caps.update({cap_name: epochs.metadata[cap_name].values})
                        if task == 'checker':
                            concatenated_data.update({f'sub-{subject}': 
                                {f'ses-{session}': 
                                    {f'task-{task}': 
                                        {f'run-{run}': 
                                            {'X': tangent_array, 'Y': caps}
                                            }
                                        }
                                    }
                                }
                                                    )
                        elif task == 'rest':
                            concatenated_data[
                                f'sub-{subject}'][
                                    f'ses-{session}'].update({f'task-{task}':
                                {run: {'X': tangent_array,
                                        'Y': caps},
                                }
                                        }
                            )
                    except Exception as e:
                        print(f'Error in {eeg_filename}: {e}')
                        continue

In [4]:
def concatenate_data(data: dict,
                     subjects: list,
                     sessions: list,
                     task: str,
                     runs: list,
                     cap_names: list):
    X = []
    Y = []
    for subject in subjects:
      data = data.get(subject, False)
      if not data:
        continue
      for session in sessions:
        data = data.get(session, False)
        if not data:
          continue
        data = data.get(f'task-{task}', False)
        if not data:
          continue
        for run in runs:
          data = data.get(run, False)
          if not data:
            continue
          for cap_name in cap_names:
              X.append(data['X'])
              Y.append(data['Y'][cap_name])
      
    return np.concatenate(X), np.concatenate(Y)

subjects = concatenated_data.keys()
sessions = ['ses-01', 'ses-02']
runs = ['run-01', 'run-02']
X_checker, Y_checker = concatenate_data(concatenated_data, 
                                        subjects, 
                                        sessions, 
                                        'checker', 
                                        runs, 
                                        ['CAP1']
                                        )


AttributeError: 'bool' object has no attribute 'get'

In [5]:
test_dict = {'bonjour':'test',
             'truc': 'machin'}
copy_test_dict = test_dict.copy()

In [8]:
concatenated_data

{'sub-01': {'ses-01': {'task-checker': {'run-01': {'X': array([[ 0.28028944,  0.48658775,  0.20501813, ...,  0.33168602,
              0.02073354,  0.36349279],
            [ 0.12508083,  0.34849428,  0.21465141, ...,  0.30605396,
              0.0387871 ,  0.22099769],
            [ 0.1436615 ,  0.21891172, -0.05938623, ...,  0.07312293,
              0.08790445,  0.60235921],
            ...,
            [ 0.14888362,  0.14487367,  0.01040696, ...,  0.26122529,
             -0.0540205 ,  0.07886934],
            [ 0.32800827, -0.03912195, -0.22735189, ...,  0.20379479,
             -0.08281528,  0.21255274],
            [ 0.30157772, -0.12379771, -0.2464769 , ...,  0.16245569,
              0.10578993, -0.03862156]]),
     'Y': {'CAP1': array([ 0.14915197,  0.27593621,  0.28534789,  0.14248913, -0.05807165,
             -0.15029446, -0.12175829,  0.04829685,  0.18616599,  0.18754464,
              0.13205279, -0.01225724, -0.17374914, -0.23729988, -0.21881648,
             -0.1841921

In [11]:
from typing import List, Dict, Union, Optional, Any
def _find_item(desired_key: str, 
               obj: Dict[str, Any],
               exceptions: list) -> Any:
    """Find any item in an encapsulated dictionary."

    Args:
        desired_key (str): They key to look for.
        obj (Dict[str, Any]): the dictionary.

    Returns:
        Any: The returned item found in the encapsulated dictionary.
    """
    if obj.get(desired_key) is not None:
        return obj[desired_key]
    
    for key, value in obj.items():
        print(key)
        if isinstance(value, dict) and key not in exceptions:
            item = _find_item(desired_key, value, exceptions)
            if item:
                return item
test = _find_item('X',concatenated_data,exceptions=['sub-01'])

sub-01
sub-02
ses-01
task-checker
run-01


ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [14]:
class KeyHandler:
    def __init__(self) -> None:
        self.method_flag = None
        self.subjects = list()
        self.sessions = list()
        self.runs = list()
        self.tasks = list()
        self.descriptions = list()
    def extract_key_value_pairs(self,
                                key_value_string: str):
        
        key_value_pair_dict = dict()
        possible_keys = [
            'subject',
            'session',
            'task',
            'run',
            'description',
        ]
        if isinstance(key_value_string,str):
            separated_pairs = key_value_string.split('_')
        
        for pair in separated_pairs:
            key, value = pair.split('-')
            for possible_key in possible_keys:
                if key in possible_key:
                    key_value_pair_dict[possible_key] = value

        return key_value_pair_dict
        
    
    def read_keys_list(self,key_list: list):
        if self.method_flag is not None:
            raise Exception("You already called a method on this object. To\
                prevent overwritting paramters re-initiate another object")
        self.method_flag = 'read'
        for keys in key_list:
            key_value_pair_dict = self._set_key_value_from_string(keys)
            for key, value in key_value_pair_dict.items():
                attribute_name = key+'s'
                getattr(self,attribute_name).append(value)
    
    def select_from_list(self,
                         key_list: list,
                         subjects: list | str | None,
                         sessions: list | str | None,
                         tasks: list | str | None,
                         descriptions: list | str | None,
                         runs: list | str | None,
                         how: str = 'include'):
        """Select a subset of data name

        Args:
            key_list (list): The list to select from.
            subjects (list | str | None): The list of subjects to select from. 
                                          Defaults to None 
                                          (select all the subjects).
            sessions (list | str | None): The list of sessions to select from. 
                                          Defaults to None 
                                          (select all the sessions).
            tasks (list | str | None): The list of tasks to select from. 
                                          Defaults to None 
                                          (select all the tasks).
            descriptions (list | str | None): The list of descriptions to select from. 
                                          Defaults to None 
                                          (select all the descriptions).
            runs (list | str | None): The list of runs to select from. 
                                          Defaults to None 
                                          (select all the runs).
            how (str, optional): How to select the subset. Either to include 
                                 the desired list or exlude it. 
                                 Defaults to 'include'.
        """
        for keys in key_list:
            


    
    def generate_one_key_list(self,
                          key_name: str,
                          prefix:str | None,
                          start: int | None,
                          stop: int | None,
                          nb_digit: int):
        if self.method_flag is not None:
            raise Exception("You already called a method on this object. To\
                prevent overwritting paramters re-initiate another object")
        self.method_flag = 'write'
        
        for i in range(start,stop):
            string_index = str(i)
            key_string = prefix + f'{string_index.rjust(nb_digit,'0')}'
            getattr(self,key_name).append(key_string)

        
        return self
    def generate_all_keys(self,
                          subject_parameters: dict,
                          session_parameters: dict,
                          tasks_paramters: dict,
                          ):
        pass

        


In [1]:
class KeyHandler:
    def __init__(self) -> None:
        self.method_flag = None
        self.subjects = list()
        self.sessions = list()
        self.runs = list()
        self.tasks = list()
        self.descriptions = list()

    def transform_format(self):
        pass
    
    def read_value_format(self,value: str):
        prefix = ''.join(filter(lambda x: x.isalpha(), value))
        pass
    
    def extract_key_value_pairs(self, key_value_string: str):
        key_value_pair_dict = dict()
        possible_keys = [
            'subject',
            'session',
            'task',
            'run',
            'description',
        ]
        if isinstance(key_value_string, str):
            separated_pairs = key_value_string.split('_')
        
        for pair in separated_pairs:
            key, value = pair.split('-')
            for possible_key in possible_keys:
                if key in possible_key:
                    key_value_pair_dict[possible_key] = value

        return key_value_pair_dict

    def read_keys_list(self, key_list: list):
        if self.method_flag is not None:
            raise Exception("You already called a method on this object. To prevent overwriting parameters, re-initiate another object.")
        self.method_flag = 'read'
        for keys in key_list:
            key_value_pair_dict = self.extract_key_value_pairs(keys)
            for key, value in key_value_pair_dict.items():
                attribute_name = key + 's'
                getattr(self, attribute_name).append(value)

    def select_from_list(self, key_list: list,
                         subjects: list | str | None = None,
                         sessions: list | str | None = None,
                         tasks: list | str | None = None,
                         descriptions: list | str | None = None,
                         runs: list | str | None = None,
                         how: str = 'include'):
        """Select a subset of data names based on given criteria."""
        
        def match_criteria(keys, criteria, key_type):
            if criteria is None:
                return True
            key_value = self.extract_key_value_pairs(keys).get(key_type)
            if isinstance(criteria, str):
                criteria = [criteria]
            if how == 'include':
                return key_value in criteria
            elif how == 'exclude':
                return key_value not in criteria
            return False

        selected_keys = []
        for keys in key_list:
            if (match_criteria(keys, subjects, 'subject') and
                match_criteria(keys, sessions, 'session') and
                match_criteria(keys, tasks, 'task') and
                match_criteria(keys, descriptions, 'description') and
                match_criteria(keys, runs, 'run')):
                selected_keys.append(keys)

        return selected_keys

    def generate_one_key_list(self, 
                              key_name: str, 
                              prefix: str | None, 
                              start: int | None, 
                              stop: int | None,
                              nb_digit: int
                              ):
        if self.method_flag is not None:
            raise Exception("You already called a method on this object. To prevent overwriting parameters, re-initiate another object.")
        self.method_flag = 'write'
        
        for i in range(start, stop):
            string_index = str(i).rjust(nb_digit, '0')
            key_string = prefix + string_index
            getattr(self, key_name).append(key_string)

        return self

    def generate_all_keys(self, 
                          subject_parameters: dict, 
                          session_parameters: dict, 
                          tasks_parameters: dict):
        if self.method_flag is not None:
            raise Exception("You already called a method on this object. To prevent overwriting parameters, re-initiate another object.")
        self.method_flag = 'write'

        subjects = [f'sub-{str(i).rjust(subject_parameters["nb_digit"], "0")}'
                    for i in range(subject_parameters["start"], subject_parameters["stop"])]
        sessions = [f'ses-{str(i).rjust(session_parameters["nb_digit"], "0")}'
                    for i in range(session_parameters["start"], session_parameters["stop"])]
        tasks = [f'task-{task}' for task in tasks_parameters["names"]]

        key_list = []
        for subject in subjects:
            for session in sessions:
                for task in tasks:
                    for run in range(1, tasks_parameters["runs"] + 1):
                        key_list.append(f'{subject}_{session}_{task}_run-{str(run).rjust(2, "0")}')
        
        return key_list


In [10]:
# Instantiate the object
handler = KeyHandler()

# Generate keys
key_list = handler.generate_all_keys(
    subject_parameters={'start': 20, 'stop': 25, 'nb_digit': 3},
    session_parameters={'start': 1, 'stop': 3, 'nb_digit': 2},
    tasks_parameters={'names': ['rest', 'taskA'], 'runs': 1}
)

# Select a subset
selected_keys = handler.select_from_list(
    key_list, 
    subjects=['020','021'],
    how = 'exclude'
    #sessions=['01', '02']
)

print(selected_keys)

['sub-022_ses-01_task-rest_run-01', 'sub-022_ses-01_task-taskA_run-01', 'sub-022_ses-02_task-rest_run-01', 'sub-022_ses-02_task-taskA_run-01', 'sub-023_ses-01_task-rest_run-01', 'sub-023_ses-01_task-taskA_run-01', 'sub-023_ses-02_task-rest_run-01', 'sub-023_ses-02_task-taskA_run-01', 'sub-024_ses-01_task-rest_run-01', 'sub-024_ses-01_task-taskA_run-01', 'sub-024_ses-02_task-rest_run-01', 'sub-024_ses-02_task-taskA_run-01']


In [13]:
st = 'ab012'

False

In [18]:
test = ''.join(filter(lambda x: x.isalpha(), st))


In [21]:
class BidsComponentObject:
    def __init__(self, )

'ab'

In [27]:
strings = ['sub-020_ses-05_task-tamere_run-01',
           'sub-022_ses-01_task-tamere_run-03', 
           'sub-023_ses-05_task-tamere_run-04'
]
handler = KeyHandler()
handler.read_keys_list(strings)

In [18]:
handler = KeyHandler()
handler.generate_one_key_list(key_name='subjects',
                           prefix='',
                           start=1,
                           stop=10,
                           nb_digit=4)
handler.subjects

['0001', '0002', '0003', '0004', '0005', '0006', '0007', '0008', '0009']

In [17]:
helper.get_brainstate_data(
    brainstate_dir = fmri_foldername,
    sub='01', 
    ses='01', 
    task='checker')

/data2/Projects/eeg_fmri_natview/data/fmriprep/derivatives/cap_ts/sub-01_ses-01_task-checker.txt


(True,
         CAP1      CAP2      CAP3      CAP4      CAP5      CAP6      CAP7  \
 0  -0.076605  0.085739 -0.039984  0.026376 -0.145276  0.143430 -0.054188   
 1  -0.047615  0.086727  0.284220 -0.282032 -0.243064  0.225278 -0.011694   
 2   0.014826  0.033645  0.348907 -0.349491 -0.125130  0.105316 -0.070708   
 3   0.149152 -0.084671  0.469230 -0.474412  0.042073 -0.054206 -0.188274   
 4   0.275936 -0.245874  0.327646 -0.320053  0.189417 -0.166823 -0.136113   
 ..       ...       ...       ...       ...       ...       ...       ...   
 93 -0.057774  0.117872  0.305637 -0.312967 -0.322685  0.318181  0.370831   
 94  0.222922 -0.167606  0.256536 -0.251746 -0.450964  0.492126  0.334022   
 95  0.329290 -0.310865  0.131008 -0.108967 -0.410154  0.477974  0.086643   
 96  0.267027 -0.261492  0.125305 -0.101495 -0.407285  0.456321 -0.105831   
 97  0.012580  0.004559  0.149337 -0.150767 -0.273292  0.251552 -0.143269   
 
         CAP8  Mask  
 0   0.058562     0  
 1  -0.012179     0  
 