In [2]:
import os, time
from collections import OrderedDict

import numpy as np
import joblib
from scipy.signal import sosfiltfilt
from sklearn.pipeline import make_pipeline, clone
from sklearn.metrics import confusion_matrix, balanced_accuracy_score

from brainda.datasets import Nakanishi2015, Wang2016, BETA
from brainda.paradigms import SSVEP
from brainda.algorithms.utils.model_selection import (
    set_random_seeds,
    generate_loo_indices, match_loo_indices)
from brainda.algorithms.decomposition import (
    SCCA, FBSCCA, 
    ItCCA, FBItCCA, 
    ECCA, FBECCA, 
    TtCCA, FBTtCCA, 
    MsetCCA, FBMsetCCA,
    MsCCA, FBMsCCA,
    MsetCCAR, FBMsetCCAR,
    TRCA, TRCAR, 
    FBTRCA, FBTRCAR,
    DSP, FBDSP,
    TDCA, FBTDCA,
    generate_filterbank, generate_cca_references)
from brainda.algorithms.deep_learning import EEGNet

import torch, skorch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from utils import *

In [3]:
def make_file(
    dataset, model_name, channels, srate, duration, events, 
    preprocess=None, 
    n_bands=None,
    augment=False):
    file = "{:s}-{:s}-{ch:d}-{srate:d}-{nt:d}-{event:d}".format(
        dataset.dataset_code,
        model_name,
        ch=len(channels), 
        srate=srate, 
        nt=int(duration*srate),
        event=len(events))
    if n_bands is not None:
        file += '-{:d}'.format(n_bands)
    if preprocess is not None:
        file += '-{:s}'.format(preprocess)
    if augment:
        file += '-augment'
    file += '.joblib'
    return file

In [4]:
datasets = [
    Nakanishi2015(), 
    Wang2016(), 
    BETA()
]
delays = [
    0.135, 
    0.14, 
    0.13
]
channels = [
    ['PO7', 'PO3', 'POZ', 'PO4', 'PO8', 'O1', 'OZ', 'O2'],
    ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2'],
    ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2']
]

srate = 250
durations = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

def data_hook(X, y, meta, caches):
    filterbank = generate_filterbank([[8, 90]], [[6, 95]], srate, order=4, rp=1)
    X = sosfiltfilt(filterbank[0], X, axis=-1)
    return X, y, meta, caches

In [5]:
n_bands = 3
n_harmonics = 5
l = 5

wp = [[8*i, 90] for i in range(1, n_bands+1)]
ws = [[8*i-2, 95] for i in range(1, n_bands+1)]
filterbank = generate_filterbank(wp, ws, srate, order=4, rp=1)
filterweights = np.arange(1, len(filterbank)+1)**(-1.25) + 0.25

set_random_seeds(64)
models = OrderedDict([
    ('fbdsp', FBDSP(
            filterbank, filterweights=filterweights)),
    ('fbtrca', FBTRCA(
            filterbank, filterweights=filterweights)),
    ('fbtrcar', FBTRCAR(
            filterbank, filterweights=filterweights)),
    ('fbtdca', FBTDCA(
            filterbank, l, n_components=8, 
            filterweights=filterweights)),
    ('fbscca', FBSCCA(
            filterbank, filterweights=filterweights)) # for cross-subject validation
])

force_update = False
save_folder = 'matrix_decomposition'

In [5]:
for dataset, dataset_channels, delay in zip(datasets, channels, delays):
    os.makedirs(save_folder, exist_ok=True)
    
    dataset_events = sorted(list(dataset.events.keys()))
    freqs = [dataset.get_freq(event) for event in dataset_events]
    phases = [dataset.get_phase(event) for event in dataset_events]
    
    X, y, meta = get_ssvep_data(
        dataset, srate, dataset_channels, 1.1, dataset_events, 
        delay=delay, 
        data_hook=data_hook)
    labels = np.unique(y)
    Yf = generate_cca_references(
        freqs, srate, 1.1, 
        phases=None, 
        n_harmonics=n_harmonics)
    _, n_channels, _ = X.shape
    n_classes = len(labels)

    indices = joblib.load(
        "indices/{:s}-loo-{:d}class-indices.joblib".format(
        dataset.dataset_code, n_classes))['indices']
    
    for duration in durations:
        for model_name in models:
            save_file = make_file(
                dataset, model_name, dataset_channels, srate, duration, dataset_events,
                n_bands=n_bands)
            save_path = os.path.join(
                save_folder, save_file)
            if not force_update and os.path.exists(save_path):
                sub_accs = joblib.load(save_path)['sub_accs']
                print("{:s} Acc:{:.2f}".format(save_file, np.mean(sub_accs)))
                continue

            set_random_seeds(42)
            loo = len(indices[1][dataset_events[0]])
            loo_accs = []
            for k in range(loo):
                if model_name == 'fbtdca':
                    filterX, filterY = np.copy(X[..., :int(srate*duration)+l]), np.copy(y)
                else:
                    filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)
                    
                filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)

                sub_accs = []
                for sub_id in dataset.subjects:
                    sub_meta = meta[meta['subject']==sub_id]
                    train_ind, validate_ind, test_ind = match_loo_indices(
                        k, sub_meta, indices)
                    trainX, trainY = filterX[train_ind], filterY[train_ind]
                    validateX, validateY = filterX[validate_ind], filterY[validate_ind]
                    testX, testY = filterX[test_ind], filterY[test_ind]

                    trainX = np.concatenate([trainX, validateX], axis=0)
                    trainY = np.concatenate([trainY, validateY], axis=0)

                    model = clone(models[model_name]).fit(
                        trainX, trainY, 
                        Yf=Yf[..., :int(srate*duration)])
                    pred_labels = model.predict(testX)
                    true_labels = testY
                    sub_accs.append(
                        balanced_accuracy_score(true_labels, pred_labels))
                loo_accs.append(sub_accs)
            sub_accs = np.array(loo_accs).T
            joblib.dump({'sub_accs': sub_accs}, save_path)
            print("{:s} Acc:{:.2f}".format(save_file, np.mean(sub_accs)))

nakanishi2015-fbdsp-8-250-50-12-3.joblib Acc:0.71
nakanishi2015-fbtrca-8-250-50-12-3.joblib Acc:0.74
nakanishi2015-fbtrcar-8-250-50-12-3.joblib Acc:0.74
nakanishi2015-fbtdca-8-250-50-12-3.joblib Acc:0.74
nakanishi2015-fbscca-8-250-50-12-3.joblib Acc:0.09
nakanishi2015-fbdsp-8-250-75-12-3.joblib Acc:0.81
nakanishi2015-fbtrca-8-250-75-12-3.joblib Acc:0.83
nakanishi2015-fbtrcar-8-250-75-12-3.joblib Acc:0.83
nakanishi2015-fbtdca-8-250-75-12-3.joblib Acc:0.83
nakanishi2015-fbscca-8-250-75-12-3.joblib Acc:0.12
nakanishi2015-fbdsp-8-250-100-12-3.joblib Acc:0.87
nakanishi2015-fbtrca-8-250-100-12-3.joblib Acc:0.89
nakanishi2015-fbtrcar-8-250-100-12-3.joblib Acc:0.90
nakanishi2015-fbtdca-8-250-100-12-3.joblib Acc:0.90
nakanishi2015-fbscca-8-250-100-12-3.joblib Acc:0.15
nakanishi2015-fbdsp-8-250-125-12-3.joblib Acc:0.90
nakanishi2015-fbtrca-8-250-125-12-3.joblib Acc:0.92
nakanishi2015-fbtrcar-8-250-125-12-3.joblib Acc:0.93
nakanishi2015-fbtdca-8-250-125-12-3.joblib Acc:0.92
nakanishi2015-fbscca-8

computation time

In [7]:
dataset, dataset_channels, delay = datasets[2], channels[2], delays[2]

dataset_events = sorted(list(dataset.events.keys()))
freqs = [dataset.get_freq(event) for event in dataset_events]
phases = [dataset.get_phase(event) for event in dataset_events]

X, y, meta = get_ssvep_data(
    dataset, srate, dataset_channels, 1.1, dataset_events, 
    delay=delay, 
    data_hook=data_hook)
labels = np.unique(y)
Yf = generate_cca_references(
    freqs, srate, 1.1, 
    phases=None, 
    n_harmonics=n_harmonics)
_, n_channels, _ = X.shape
n_classes = len(labels)

indices = joblib.load(
    "indices/{:s}-loo-{:d}class-indices.joblib".format(
    dataset.dataset_code, n_classes))['indices']

In [14]:
duration = 0.5
model_name = 'fbtdca'

In [16]:
set_random_seeds(42)
loo = len(indices[1][dataset_events[0]])
loo_accs = []

training_time = 0
inference_time = 0
for k in range(loo):
    if model_name == 'fbtdca':
        filterX, filterY = np.copy(X[..., :int(srate*duration)+l]), np.copy(y)
    else:
        filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)

    filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)
    
    sub_accs = []
    for sub_id in [1]:
        sub_meta = meta[meta['subject']==sub_id]
        train_ind, validate_ind, test_ind = match_loo_indices(
            k, sub_meta, indices)
        trainX, trainY = filterX[train_ind], filterY[train_ind]
        validateX, validateY = filterX[validate_ind], filterY[validate_ind]
        testX, testY = filterX[test_ind], filterY[test_ind]

        trainX = np.concatenate([trainX, validateX], axis=0)
        trainY = np.concatenate([trainY, validateY], axis=0)
        start_t = time.time()
        model = clone(models[model_name]).fit(
            trainX, trainY, 
            Yf=Yf[..., :int(srate*duration)])
        end_t = time.time()
        training_time += (end_t - start_t)
        
        tmp = 0
        for i in range(len(testX)):
            start_t = time.time()
            pred_labels = model.predict(testX[i][np.newaxis, ...])
            end_t = time.time()
            tmp += (end_t - start_t)
        tmp /= len(testX)
        inference_time += tmp
training_time /= loo
inference_time /= loo

In [18]:
print("average training time: {:.4f}s".format(training_time))
print("average inference time: {:.4f}s".format(inference_time))

average training time: 1.1393s
average inference time: 0.2236s
