In [6]:
!pip install meegkit

Collecting meegkit
  Downloading meegkit-0.1.7-py3-none-any.whl.metadata (8.7 kB)
Collecting scikit-learn (from meegkit)
  Downloading scikit_learn-1.5.1-cp311-cp311-macosx_10_9_x86_64.whl.metadata (12 kB)
Collecting joblib (from meegkit)
  Downloading joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting statsmodels (from meegkit)
  Downloading statsmodels-0.14.2-cp311-cp311-macosx_10_9_x86_64.whl.metadata (9.2 kB)
Collecting pyriemann>=0.2.7 (from meegkit)
  Downloading pyriemann-0.6-py2.py3-none-any.whl.metadata (8.3 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn->meegkit)
  Downloading threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB)
Collecting patsy>=0.5.6 (from statsmodels->meegkit)
  Using cached patsy-0.5.6-py2.py3-none-any.whl.metadata (3.5 kB)
Downloading meegkit-0.1.7-py3-none-any.whl (78 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.2/78.2 kB[0m [31m701.0 kB/s[0m eta [36m0:00:00[0m[36m0:00:01[0m
[?25hDownloading pyriemann-0.6

In [8]:
import os
import numpy as np
import mne
import torch

from utils import get_args, get_mne_info

import mne
import logging
from pathlib import Path
import logging
from preprocessing.utils import split_raw, get_unannotated_raw, split_raw_annotations
from preprocessing.methods import PreprocessMethods

# Import typing
from typing import Tuple, List, Optional

from preprocessing.pipeline import BasePipeline

In [9]:
args = get_args()

In [12]:
def get_raw(args):
    data = torch.load(os.path.join(args.dataset_dir, "EEG-ImageNet_1.pth"))

    eeg_data_list = []
    labels = []

    descriptions = data['labels']

    for event in data['dataset']:
        eeg_data_list.append(event['eeg_data'].numpy())  # Convert tensors to numpy arrays
        labels.append(event['label']) # data['labels'].index(event['label']))  # Convert labels to ints and store them in a list

    eeg_data = np.concatenate(eeg_data_list, axis=1)

    info = get_mne_info()

    # Create RawArray object
    raw = mne.io.RawArray(eeg_data, info)

    annotations = []
    onset = 0
    for label, event in zip(labels, eeg_data_list):
        duration = event.shape[1] / info.get('sfreq')  # Calculate the duration of each event
        annotations.append([onset, duration, str(label)])  # Onset, Duration, Label
        onset += duration

    # Create MNE Annotations object
    annotations = mne.Annotations(onset=[ann[0] for ann in annotations],
                                duration=[ann[1] for ann in annotations],
                                description=[ann[2] for ann in annotations])

    # Set annotations to the raw object
    raw.set_annotations(annotations)

    return raw, descriptions

In [15]:


class DownstreamPipeline(BasePipeline):
    def __init__(self, args, descriptions: List[str], tmin: float = -0.5, tlen: float = 5.0, **kwargs):
        super().__init__(**kwargs)
        self.descriptions = descriptions
        self.description_map = {label: i for i, label in enumerate(descriptions)}
        self.tmin, self.tlen = tmin, tlen

        self.args = args
        
    def __call__(self, src_paths: List[str]) -> Tuple[List[mne.io.Raw], List[Tuple[float, float]], List[int]]:
        return self.run(src_paths)
    
    def run(self):
        logging.debug("Loading EDF files...")
        # src_paths = [Path(src_path) for src_path in src_paths]

        raws = []
        
        logging.debug("Splitting raws...")   
        for i in range(0,8):
            try:
                raw_orig = get_raw(get_args(subject = i)) # mne.io.read_raw_edf(src_path, preload=True, verbose=False)
                
                # Rename channels with channel_rename
                if self.channels_rename is not None:
                    raw_orig.rename_channels(self.channels_rename)
                    #logging.info(f"File: {src_paths[i].stem}.\tRenamed channels: {self.channels_rename}.")
                           
                self._to_standard_names(raw_orig)
                drop_chs = self._set_montage(raw_orig)
                #logging.info(f"File: {src_paths[i].stem}.\tDropped {len(drop_chs)} channels when setting montage: {drop_chs}.")
            except Exception as e:
                #logging.error(f"Dropping file: {src_path.stem}.\tError: {e}")
                continue

            raws.append(raw_orig)
            
        total_windows = len(raws)
        logging.debug(f"Total files: {total_windows}")
                       
        for i, raw in enumerate(raws):
            filename = f"subject {i}" # src_paths[i]
            
            try:
                raws[i] = self.run_single(raw, filename)
            except Exception as e:
                logging.error(f"File: {filename}.\tError: {e}")
                raws[i] = None
                
            # Log progress every N windows
            if (i + 1) % 10 == 0:
                logging.debug(f"Processed {i + 1}/{total_windows} files.")
     
        raw_windows = []
        times = []
        indices = []
        descriptions = []
        
        for i, raw in enumerate(raws):
            windows, time_slices, descri = split_raw_annotations(raw, labels = self.descriptions, tmin=self.tmin,
                                                                 tlen=self.tlen, verbose=False)
            raw_windows.extend(windows)
            times.extend(time_slices)
            descriptions.extend(descri)
            indices.extend([i] * len(windows))
            
        labels = [self.description_map[description] for description in descriptions]
        
        assert len(raw_windows) == len(times) == len(indices) == len(labels)
        
        return raw_windows, times, indices, labels
    
    def run_single(self, raw, filename) -> Optional[mne.io.Raw]:
        window_info_str = f"File: {filename}."
        
        self._remove_line_noise(raw)
        
        raw_unannotated = raw #get_unannotated_raw(raw, resting_state=['T0'])
        bad_chs = self._find_bad_channels(raw_unannotated)
        logging.info(f"{window_info_str}\tFound {len(bad_chs)} bad channels: {bad_chs}.")
        
        # Drop bad channels
        raw.drop_channels(bad_chs)
        
        self._filter(raw)    
        self._average_reference(raw)
                
        if self.do_ica:
            excluded_idxs, labels, y_proba = self._ica_clean(raw)
            logging.info(f"{window_info_str}\tExcluding {len(excluded_idxs)} components: {excluded_idxs}.")
            logging.info(f"{window_info_str}\tLabels: {labels}.")
            logging.info(f"{window_info_str}\tProbabilities: {[round(prob, 2) for prob in y_proba]}.")
            
            raw_unannotated = raw #get_unannotated_raw(raw, resting_state=['T0'])
            bad_chs = self._find_bad_channels(raw_unannotated)
            logging.info(f"{window_info_str}\tFound {len(bad_chs)} bad channels: {bad_chs}.")
            
            # Drop bad channels
            raw.drop_channels(bad_chs)

        missing_chs = self._interpolate_missing(raw)
        logging.info(f"{window_info_str}\tIntepolating {len(missing_chs)} channels: {missing_chs}.")
        
        extra_chs = self._drop_extra_and_reorder(raw)
        logging.info(f"{window_info_str}\tRemoving {len(extra_chs)} extra channels: {extra_chs}.")
        
        self._interpolate_nearest(raw)        
        return raw
    
    def _find_bad_channels(self, raw: mne.io.Raw, drop=True):       
        return PreprocessMethods.find_bad_channels(raw, ransac = self.ransac, drop = False)
    
class DownstreamPipelineBENDR(DownstreamPipeline):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def _zero_missing(self, raw: mne.io.Raw):
        return PreprocessMethods.zero_missing(raw, self.chs, self.montage)
    
    def run_single(self, raw, filename) -> Optional[mne.io.Raw]:
        window_info_str = f"File: {filename}."
        
        missing_chs = self._zero_missing(raw)
        logging.info(f"{window_info_str}\Zeroing {len(missing_chs)} channels: {missing_chs}.")
        
        extra_chs = self._drop_extra_and_reorder(raw)
        logging.info(f"{window_info_str}\tRemoving {len(extra_chs)} extra channels: {extra_chs}.")
        
        self._interpolate_nearest(raw)        
        return raw
    
    def _find_bad_channels(self, raw: mne.io.Raw, drop=True):       
        return PreprocessMethods.find_bad_channels(raw, ransac = self.ransac, drop = False)


In [14]:
raw, descriptions = get_raw(args)

Creating RawArray with float64 data, n_channels=62, n_times=16006950
    Range : 0 ... 16006949 =      0.000 ... 16006.949 secs
Ready.


In [16]:
DownstreamPipeline(args, descriptions).run()