In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mne
from tqdm import tqdm

import torch
from braindecode.models import ShallowFBCSPNet
from sklearn.model_selection import StratifiedGroupKFold
from braindecode.datasets import create_from_X_y
from braindecode import EEGRegressor
from skorch.helper import predefined_split

from dataloader import get_dataloaders, get_datasets
from utils import seed_everything
from train import load_data
from config import Config
from train import train

class CFG(Config):
    data_type = 'eeg'
    spec_trial_selection = 'all'
    eeg_trial_selection = 'first'
    scheduler_step_size = 2
    lr_gamma = 0.1

# Params
sfreq = 200
n_jobs = 32
ch_list = ['Fp1', 'F3', 'C3', 'P3', 'F7', 'T3', 'T5', 'O1', 'Fz', 'Cz', 'Pz', 'Fp2', 'F4', 'C4', 'P4', 'F8', 'T4', 'T6', 'O2', 'EKG']
ch_pairs = [('Fp1', 'F7'), ('F7', 'T3'), ('T3', 'T5'), ('T5', 'O1'), 
            ('Fp2', 'F8'), ('F8', 'T4'), ('T4', 'T6'), ('T6', 'O2'), 
            ('Fp1', 'F3'), ('F3', 'C3'), ('C3', 'P3'), ('P3', 'O1'), 
            ('Fp2', 'F4'), ('F4', 'C4'), ('C4', 'P4'), ('P4', 'O2'), 
            ('Fz', 'Cz'), ('Cz', 'Pz')]
sub_ch_list = ['-'.join(pair) for pair in ch_pairs]

seed_everything(CFG.seed)
df, data = load_data(CFG)
eeg_data = data['eeg_data']
display(df)

In [None]:
mne.set_log_level('warning')
dataset = []
for row in tqdm(df.itertuples(), total=len(df)):
    data = eeg_data[row.eeg_id]
    data = np.nan_to_num(data, nan=0.0)
    
    # Raw channels
    data = data.T
    tf_ch_names = ch_list

    # # Dipoles
    # arr_list = []
    # for (ch1, ch2) in ch_pairs:
    #     arr_list.append(data[:,ch_list.index(ch1)] - data[:,ch_list.index(ch2)])
    # arr_list.append(data[:,ch_list.index('EKG')])
    # data = np.stack(arr_list, axis=0)
    # tf_ch_names = sub_ch_list + ['EKG']

    raw = mne.io.RawArray(data, mne.create_info(ch_names=tf_ch_names, sfreq=sfreq, ch_types=['eeg']*(len(tf_ch_names)-1) + ['ecg']))
    raw = raw.crop(tmin=row.eeg_label_offset_seconds, tmax=row.eeg_label_offset_seconds + 50, include_tmax=False)
    dataset.append(raw.get_data(picks=['eeg']))
    tf_ch_names = tf_ch_names[:-1]
dataset = np.stack(dataset, axis=0)
y = df[CFG.TARGETS].values

In [None]:
skf = StratifiedGroupKFold(n_splits=CFG.cv_fold, random_state=CFG.seed, shuffle=True)
train_index, valid_index = list(skf.split(X=np.zeros(len(df)), y=df[CFG.stratif_vars], groups=df[CFG.grouping_vars]))[0]

In [None]:
train_set = create_from_X_y(dataset[train_index], y[train_index], sfreq=200, drop_last_window=False, ch_names=tf_ch_names)
valid_set = create_from_X_y(dataset[valid_index], y[valid_index], sfreq=200, drop_last_window=False, ch_names=tf_ch_names)
del dataset, y

In [None]:
from numpy import multiply
from braindecode.preprocessing import Filter, exponential_moving_standardize, preprocess, Preprocessor

low_cut_hz = 1.  # low cut frequency for filtering
high_cut_hz = 20.  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
picks = list(range(train_set[0][0].shape[0]))

preprocessors = [
    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz, picks=picks, verbose='ERROR'),
    Preprocessor('crop', tmin=15, tmax=35, include_tmax=False, verbose='ERROR'),
    Preprocessor('resample', sfreq=50, verbose='ERROR'),
    # Preprocessor(exponential_moving_standardize, factor_new=factor_new, init_block_size=init_block_size, picks=picks)
    ]

# Transform the data
preprocess(train_set, preprocessors, n_jobs=1);
preprocess(valid_set, preprocessors, n_jobs=1);

In [None]:
raw_data = train_set[0][0]
display(raw_data.shape)
plt.figure(figsize=(20,10))
for k in range(1, len(raw_data)):
    plt.plot(range(raw_data.shape[1]), raw_data[k]-k*(raw_data[0].max()-raw_data[0].min()))

plt.legend()
plt.yticks([])
plt.show()

In [None]:
from skorch.callbacks import LRScheduler
from braindecode.models import ShallowFBCSPNet, Deep4Net, EEGInceptionERP, EEGConformer, EEGNetv4

# model = ShallowFBCSPNet(
#     n_chans=train_set[0][0].shape[0],
#     n_outputs=6,
#     n_times=train_set[0][0].shape[1],
#     final_conv_length='auto'
#     )
model = Deep4Net(
    n_chans=train_set[0][0].shape[0],
    n_outputs=6,
    n_times=train_set[0][0].shape[1],
    final_conv_length='auto'
    )
# model = EEGConformer(
#     n_chans=train_set[0][0].shape[0],
#     n_outputs=6,
#     n_times=train_set[0][0].shape[1]
#     )

# from braindecode.augmentation import FrequencyShift, AugmentedDataLoader, SignFlip, SmoothTimeMask

# freq_shift = FrequencyShift(
#     probability=.5,
#     sfreq=sfreq,
#     max_delta_freq=2.
# )

# sign_flip = SignFlip(probability=.1)

# smooth_time_mask = SmoothTimeMask(probability=.5, mask_len_samples=100)

# transforms = [
#     freq_shift,
#     sign_flip
# ]

epochs = 8
lr = 0.01
net = EEGRegressor(model.cuda(),
                #    iterator_train=AugmentedDataLoader,
                #    iterator_train__transforms=transforms,
                    criterion=torch.nn.KLDivLoss(reduction='batchmean'),
                    optimizer=torch.optim.AdamW,
                    train_split=predefined_split(valid_set),
                    batch_size=12,
                    lr=lr,
                    device='cuda',
                    callbacks=[("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=epochs))]
                    )
net.fit(train_set, epochs=epochs)