In [None]:
import os

import numpy as np
import joblib
from scipy.signal import sosfiltfilt
from sklearn.pipeline import make_pipeline, clone
from sklearn.metrics import confusion_matrix, balanced_accuracy_score

from brainda.datasets import Nakanishi2015, Wang2016, BETA
from brainda.paradigms import SSVEP
from brainda.algorithms.utils.model_selection import (
    set_random_seeds,
    generate_loo_indices, match_loo_indices)
from brainda.algorithms.decomposition import (
    SCCA, FBSCCA, 
    ItCCA, FBItCCA, 
    ECCA, FBECCA, 
    TtCCA, FBTtCCA, 
    MsetCCA, FBMsetCCA,
    MsCCA, FBMsCCA,
    MsetCCAR, FBMsetCCAR,
    TRCA, TRCAR, 
    FBTRCA, FBTRCAR,
    DSP, FBDSP,
    TDCA, FBTDCA,
    generate_filterbank, generate_cca_references)
from brainda.algorithms.deep_learning import EEGNet

import torch, skorch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from utils import *

In [None]:
datasets = [
    Nakanishi2015(),
    Wang2016(),
    BETA()
]
srate = 100
channels = ['OZ']
duration = 0.2 # seconds
force_update = False

for dataset in datasets:
    os.makedirs('indices', exist_ok=True)
    events = list(dataset.events.keys())
    
    save_file = "{:s}-loo-{:d}class-indices.joblib".format(
        dataset.dataset_code, len(events))
    save_file = os.path.join('indices', save_file)
    if not force_update and os.path.exists(save_file):
        continue
        
    X, y, meta = get_ssvep_data(
        dataset, srate, channels, duration, events)
    
    set_random_seeds(38)
    indices = generate_loo_indices(meta)
    joblib.dump(
        {'indices': indices}, 
        save_file)
    print("{:s} loo indices generated.".format(
        dataset.dataset_code))
    del X, y, meta