In [1]:
from braindecode.datasets import MOABBDataset
from numpy import multiply, clip

import numpy as np
from scipy import linalg

from braindecode.preprocessing import (
    Preprocessor,
    exponential_moving_standardize,
    preprocess,
)

dataset = MOABBDataset(dataset_name="Schirrmeister2017", subject_ids=[1,2,3])


def ZCA_whitening(data: np.ndarray):
    '''
    Applies zero component analysis whitening to the input X
   
    Args:
        data: np.ndarray (n_channels, n_times)
    
    Returns: whitened data: np.ndarray (n_channels, n_times)
    '''
    
    # Zero center data
    xc = data - np.mean(data, axis=0)
    xcov = np.cov(xc, rowvar=True, bias=True)

    # Calculate Eigenvalues and Eigenvectors
    w, v = linalg.eig(xcov) 
    
    # Create a diagonal matrix
    diagw = np.diag(1/(w**0.5)) 
    diagw = diagw.real.round(4) 

    # Whitening transform using ZCA (Zero Component Analysis)
    return np.dot(np.dot(np.dot(v, diagw), v.T), xc)

2024-03-12 22:05:33.140213: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-12 22:05:33.160795: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-12 22:05:33.160828: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-12 22:05:33.161425: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-12 22:05:33.165232: I tensorflow/core/platform/cpu_feature_guar

Extracting EDF parameters from /home/samuelboehm/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/1.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7

In [2]:
CHANNELS = ['Fp1','Fp2','F7','F3','Fz','F4','F8',
        'T7','C3','Cz','C4','T8','P7','P3',
        'Pz','P4','P8','O1','O2','M1','M2']

low_cut_hz = 4.  # low cut frequency for filtering
high_cut_hz = 50  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6

preprocessors = [
    Preprocessor('pick_channels',ch_names=CHANNELS),
    Preprocessor('set_eeg_reference', ref_channels='average'),
    Preprocessor(lambda data: multiply(data, factor)),  # Convert from V to uV
    Preprocessor(lambda data: clip(data, -800, 800)), 
    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
    Preprocessor('resample', sfreq=100),
    Preprocessor(lambda data : ZCA_whitening(data)),
    Preprocessor(exponential_moving_standardize,
                 # Exponential moving standardization
                 factor_new=factor_new, init_block_size=init_block_size)
]

# Transform the data
dataset = preprocess(dataset, preprocessors, n_jobs=-1)


  warn('Preprocessing choices with lambda functions cannot be saved.')


2024-03-12 22:05:47.944704: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-12 22:05:47.964969: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-12 22:05:47.964992: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-12 22:05:47.965551: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-12 22:05:47.968972: I tensorflow/core/platform/cpu_feature_guar

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.


/home/samuelboehm/miniconda3/envs/eeg-gan/lib/python3.10/site-packages/moabb/pipelines/__init__.py:26: ModuleNotFoundError: Tensorflow is not installed. You won't be able to use these MOABB pipelines if you attempt to do so.
  warn(


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 50 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 50.00 Hz
- Upper transition bandwidth: 12.50 Hz (-6 dB cutoff frequency: 56.25 Hz)
- Filter length: 825 samples (1.650 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s
/home/samuelboehm/miniconda3/envs/eeg-gan/lib/python3.10/site-packages/moabb/pipelines/__init__.py:26: ModuleNotFoundError: Tensorflow is not installed. You won't be able to use these MOABB pipelines if you attempt to do so.
  warn(
2024-03-12 22:05:49.958025: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-12 22:05:49.981241: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-12 22:05:49.981268: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been re

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 50 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 50.00 Hz
- Upper transition bandwidth: 12.50 Hz (-6 dB cutoff frequency: 56.25 Hz)
- Filter length: 825 samples (1.650 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.6s


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 50 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 50.00 Hz
- Upper transition bandwidth: 12.50 Hz (-6 dB cutoff frequency: 56.25 Hz)
- Filter length: 825 samples (1.650 s)



/home/samuelboehm/miniconda3/envs/eeg-gan/lib/python3.10/site-packages/moabb/pipelines/__init__.py:26: ModuleNotFoundError: Tensorflow is not installed. You won't be able to use these MOABB pipelines if you attempt to do so.
  warn(
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s
/home/samuelboehm/miniconda3/envs/eeg-gan/lib/python3.10/site-packages/moabb/pipelines/__init__.py:26: ModuleNotFoundError: Tensorflow is not installed. You won't be able to use these MOABB pipelines if you attempt to do so.
  warn(
  getattr(raw_or_epochs, self.fn)(**self.kwargs)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.


  getattr(raw_or_epochs, self.fn)(**self.kwargs)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 50 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 50.00 Hz
- Upper transition bandwidth: 12.50 Hz (-6 dB cutoff frequency: 56.25 Hz)
- Filter length: 825 samples (1.650 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 50 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 50.00 Hz
- Upper transition bandwidth: 12.50 Hz (-6 dB cutoff frequency: 56.25 Hz)
- Filter length: 825 samples (1.650 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  getattr(raw_or_epochs, self.fn)(**self.kwargs)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.8s


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 50 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 50.00 Hz
- Upper transition bandwidth: 12.50 Hz (-6 dB cutoff frequency: 56.25 Hz)
- Filter length: 825 samples (1.650 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.8s


In [3]:
from braindecode.preprocessing import create_windows_from_events

MAPPING = {'right_hand': 0, 'rest': 1}

trial_start_offset_seconds = -0.5
tiral_stop_offset_seconds = -1.5 # Total duration is 4. seconds so we stop 1.5 seconds before to get 3 seconds total
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info['sfreq']
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])

# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)
trial_stop_offset_samples = int(tiral_stop_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=trial_stop_offset_samples,
    preload=True,
    mapping=MAPPING, 
    drop_bad_windows=True,
)


Used Annotations descriptions: ['rest', 'right_hand']
Used Annotations descriptions: ['rest', 'right_hand']
Used Annotations descriptions: ['rest', 'right_hand']
Used Annotations descriptions: ['rest', 'right_hand']
Used Annotations descriptions: ['rest', 'right_hand']
Used Annotations descriptions: ['rest', 'right_hand']




In [56]:
splitted = windows_dataset.split('run')
train_set = splitted['0train']  # Session train
valid_set = splitted['1test']  # Session evaluation

In [114]:
for X,y, _ in windows_dataset:
    print(X.shape)
    print(y)
    print(_)
    break

(21, 300)
1
[0, 950, 1250]


In [64]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_set, batch_size=64, shuffle=True)

In [73]:
X = next(iter(train_dataloader))

3

In [62]:
from lightning import LightningDataModule
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

class HighGammaModule(LightningDataModule):
    '''
    
    '''
    def __init__(self, data, n_stages:int, batch_size:int, num_workers:int) -> None:
        self.data = data
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.n_stages = n_stages
        super().__init__()

    def train_dataloader(self):
        return DataLoader(self.data,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=True)
    
    def set_stage(self, stage: int):
        sfreq = 
        self.data = preprocess(windows_dataset, [Preprocessor('resample', sfreq=sfreq)], n_jobs=-1)


<braindecode.datasets.base.BaseConcatDataset at 0x7fe9a6184b10>

2024-03-12 22:06:45.952674: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different 2024-03-12 22:06:45.952675: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different 2024-03-12 22:06:45.952673: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different 2024-03-12 22:06:45.952673: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environmennumerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environmennumerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
t variable `TF_ENABLE_ONEDNN_OPTS=0`.
t variable `TF_ENABLE_ONEDNN

In [6]:
for X,y, _ in res:
    print(X.shape)
    print(y)
    print(_)
    break

(21, 96)
1
[0, 950, 1250]


In [7]:
res.datasets

[<braindecode.datasets.base.WindowsDataset at 0x7f09d058f1f0>,
 <braindecode.datasets.base.WindowsDataset at 0x7f09d058f340>,
 <braindecode.datasets.base.WindowsDataset at 0x7f09d058db70>,
 <braindecode.datasets.base.WindowsDataset at 0x7f09d058e2c0>,
 <braindecode.datasets.base.WindowsDataset at 0x7f09d058ded0>,
 <braindecode.datasets.base.WindowsDataset at 0x7f09d058fb50>]

In [8]:
windows_dataset.datasets

[<braindecode.datasets.base.WindowsDataset at 0x7f09d058f1f0>,
 <braindecode.datasets.base.WindowsDataset at 0x7f09d058f340>,
 <braindecode.datasets.base.WindowsDataset at 0x7f09d058db70>,
 <braindecode.datasets.base.WindowsDataset at 0x7f09d058e2c0>,
 <braindecode.datasets.base.WindowsDataset at 0x7f09d058ded0>,
 <braindecode.datasets.base.WindowsDataset at 0x7f09d058fb50>]