In [34]:
import os, copy, time
from collections import OrderedDict

import numpy as np
import pandas as pd
import joblib
from scipy.signal import sosfiltfilt
from sklearn.pipeline import make_pipeline, clone
from sklearn.metrics import confusion_matrix, balanced_accuracy_score

from brainda.datasets import Nakanishi2015, Wang2016, BETA
from brainda.paradigms import SSVEP
from brainda.algorithms.utils.model_selection import (
    set_random_seeds,
    generate_loo_indices, match_loo_indices)
from brainda.algorithms.decomposition import (
    generate_filterbank, generate_cca_references)
from brainda.algorithms.deep_learning import EEGNet

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import skorch

from skorch.helper import predefined_split

from utils import *
from models import *

In [21]:
def make_file(
    dataset, model_name, channels, srate, duration, events, 
    preprocess=None, 
    n_bands=None,
    augment=False, loo=False, fixed_dtn_template=False):
    file = "{:s}-{:s}-{ch:d}-{srate:d}-{nt:d}-{event:d}".format(
        dataset.dataset_code,
        model_name,
        ch=len(channels), 
        srate=srate, 
        nt=int(duration*srate),
        event=len(events))
    if n_bands is not None:
        file += '-{:d}'.format(n_bands)
    if preprocess is not None:
        file += '-{:s}'.format(preprocess)
    if augment:
        file += '-augment'
    if fixed_dtn_template:
        file += '-fixed'
    if loo:
        file += '-loo'
    file += '.joblib'
    return file

def make_dl_model(
        model_name, n_channels, n_samples, n_classes, 
        Yf=None):
    set_random_seeds(64)
    if model_name == 'eegnet-ssvep':
        model = EEGNet(
            n_channels, n_samples, n_classes,
            time_kernel=(96, (1, n_samples), (1, 1)), 
            D=1,
            separa_kernel=(96, (1, 16), (1, 1)),
            dropout_rate=0.5,
            fc_norm_rate=1)
    elif model_name == 'dtn':
        model = DTN(
            3, 120, n_channels, n_samples, n_classes, 
            band_kernel=9, pooling_kernel=2, 
            dropout=0.95, momentum=0.1)
    elif model_name == 'ftn':
        model = FTN(
            3, 120, n_channels, n_samples, n_classes, Yf, 
            band_kernel=9, pooling_kernel=2,
            dropout=0.95)
    return model

def generate_dl_parameters(dataset, model_name):
    batch_size = 256
    max_epochs = 600
    lr = 1e-3
    if model_name == 'eegnet-ssvep':
        lr = 1e-2
    return lr, batch_size, max_epochs

def generate_dl_ft_parameters(dataset, model_name):
    max_epochs = 100
    lr = 1e-3
    
    batch_sizes = {
        'nakanishi2015': 32,
        'wang2016': 32,
        'beta': 32,
    }
    
    batch_size = batch_sizes[dataset.dataset_code]
    
    if model_name == 'eegnet-ssvep':
        lr = 1e-2
    
    return lr, batch_size, max_epochs

def exchange(X, n_trials, seeds=253, dropout=0):
    np.random.seed(seeds)
    ind1 = np.random.randint(0, high=len(X), size=(n_trials, X.shape[1]))
    ind2 = np.tile(np.arange(X.shape[1]), (n_trials, 1))
    newX = X[ind1, ind2, ...]
    if dropout != 0:
        newX *= np.random.binomial(np.ones(X.shape[1], dtype=np.int), 1-dropout)[np.newaxis, :, np.newaxis, np.newaxis]
    return newX
    
def generate_augment_data(X, y, filterbank, n_trials_per_class=0, seeds=253, dropout=0):
    if n_trials_per_class != 0:
        filterX = np.stack(
            [sosfiltfilt(filterbank[i], np.copy(X), axis=-1) for i in range(len(filterbank))], axis=1)
        resX = X - np.sum(filterX, 1)
        filterX = np.concatenate([filterX, resX[:, np.newaxis, ...]], axis=1)
        labels = np.unique(y)
        augX, augY = [], []
        for label in labels:
            newX = np.sum(exchange(filterX[y==label], n_trials_per_class, seeds=seeds, dropout=dropout), axis=1)
            augX.append(newX)
            augY.append(np.ones(len(newX))*label)
        augX = np.concatenate(augX, axis=0)
        augY = np.concatenate(augY, axis=0)
    else:
        augX = None
        augY = None
    return augX, augY

In [22]:
device_id = 0
device = torch.device("cuda:{:d}".format(device_id) if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
print("Total available GPU devices: {}".format(torch.cuda.device_count()))
print("Current pytorch device: {}".format(device))

Total available GPU devices: 1
Current pytorch device: cuda:0


In [23]:
datasets = [
    Nakanishi2015(), 
    Wang2016(), 
    BETA()
]
delays = [
    0.135, 
    0.14, 
    0.13
]
channels = [
    ['PO7', 'PO3', 'POZ', 'PO4', 'PO8', 'O1', 'OZ', 'O2'],
    ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2'],
    ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2']
]

srate = 250
durations = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

def data_hook(X, y, meta, caches):
    filterbank = generate_filterbank([[8, 90]], [[6, 95]], srate, order=4, rp=1)
    X = sosfiltfilt(filterbank[0], X, axis=-1)
    return X, y, meta, caches

In [24]:
x_dtype, y_dtype = torch.float, torch.long

models = ['eegnet-ssvep', 'ftn','dtn']

n_bands = 3
n_harmonics = 5

force_update = False
save_folder = 'neural_networks'

## within-subject classification

In [6]:
for dataset, dataset_channels, delay in zip(datasets, channels, delays):
    os.makedirs(save_folder, exist_ok=True)
    
    dataset_events = sorted(list(dataset.events.keys()))
    freqs = [dataset.get_freq(event) for event in dataset_events]
    phases = [dataset.get_phase(event) for event in dataset_events]
    
    X, y, meta = get_ssvep_data(
        dataset, srate, dataset_channels, 1.1, dataset_events, 
        delay=delay, 
        data_hook=data_hook)
    labels = np.unique(y)
    Yf = generate_cca_references(
        freqs, srate, 1.1, 
        phases=None, 
        n_harmonics=n_harmonics)
    print("Dataset: {} Size: {}".format(dataset.dataset_code, X.shape))
    _, n_channels, n_samples = X.shape
    n_classes = len(labels)
    
    indices = joblib.load(
        "indices/{:s}-loo-{:d}class-indices.joblib".format(
        dataset.dataset_code, n_classes))['indices']
    loo = len(indices[1][dataset_events[0]])
    
    min_f, max_f = np.min(freqs), np.max(freqs)
    wp = [[min_f*i, max_f*i] for i in range(1, n_harmonics+1)]
    ws= [[min_f*i-2, max_f*i+2] for i in range(1, n_harmonics+1)]
    aug_filterbank = generate_filterbank(wp, ws, srate, order=4, rp=1)
    
    for duration in durations:
        for model_name in models:
            file_name = make_file(
                dataset, model_name, dataset_channels, srate, duration, dataset_events, 
                n_bands=n_bands)
            save_file = os.path.join(save_folder, file_name)
            if not force_update and os.path.exists(save_file):
                scores = joblib.load(save_file)
                global_sub_accs = scores['global_sub_accs']
                ft_sub_accs = scores['ft_sub_accs']
                ft_aug_sub_accs = scores['ft_aug_sub_accs']
                print("{:s} acc:{:.4f} ft_acc:{:.4f} ft_aug_acc:{:.4f}".format(
                    save_file, np.mean(global_sub_accs), np.mean(ft_sub_accs), np.mean(ft_aug_sub_accs)))
                continue  
                
            set_random_seeds(42)
            loo_global_accs = []
            loo_global_model_states = []
            loo_fine_tuning_accs = []
            loo_fine_tuning_accs_aug = []

            for k in range(loo):
                filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)
                filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)

                train_ind, validate_ind, test_ind = match_loo_indices(
                    k, meta, indices)
                trainX, trainY, trainMeta = filterX[train_ind], filterY[train_ind], meta.iloc[train_ind]
                validateX, validateY, validateMeta = filterX[validate_ind], filterY[validate_ind], meta.iloc[validate_ind]
                testX, testY, testMeta = filterX[test_ind], filterY[test_ind], meta.iloc[test_ind]

                trainX, validateX, testX = generate_tensors(
                    trainX, validateX, testX, dtype=x_dtype)
                trainY, validateY, testY = generate_tensors(
                    trainY, validateY, testY, dtype=y_dtype) 

                net = make_dl_model(
                    model_name, n_channels, int(srate*duration), n_classes,
                    Yf=Yf[..., :int(srate*duration)])
                
                lr, batch_size, max_epochs = generate_dl_parameters(
                    dataset, model_name)
                net.device = device
                net.lr = lr
                net.max_epochs = max_epochs
                net.batch_size = batch_size
                net.set_params(callbacks__lr_scheduler__T_max=max_epochs-1)
                net.verbose=False

                net.train_split = predefined_split(
                    skorch.dataset.Dataset(
                        {'X': validateX}, validateY))

                if model_name == 'dtn':
                    net = net.fit({'X': trainX, 'y': trainY}, y=trainY)
                else:
                    net = net.fit(
                        {'X': trainX}, 
                        y=trainY)

                loo_global_model_states.append(
                    copy.deepcopy(net.module.state_dict()))
                
                ## testing
                sub_accs = []
                for sub_id in dataset.subjects:
                    sub_test_mask = (testMeta['subject']==sub_id).to_numpy()
                    pred_labels = net.predict({'X': testX[sub_test_mask]})
                    true_labels = testY[sub_test_mask].numpy()
                    sub_acc = balanced_accuracy_score(true_labels, pred_labels)
                    sub_accs.append(sub_acc)
                loo_global_accs.append(sub_accs)

                ## fine-tuning
                sub_accs = []
                sub_accs_aug = []
                for sub_id in dataset.subjects:
                    sub_train_mask = (trainMeta['subject']==sub_id).to_numpy()
                    sub_valid_mask = (validateMeta['subject']==sub_id).to_numpy()
                    sub_test_mask = (testMeta['subject']==sub_id).to_numpy()

                    sub_trainX, sub_trainY = trainX[sub_train_mask], trainY[sub_train_mask]
                    sub_validateX, sub_validateY = validateX[sub_valid_mask], validateY[sub_valid_mask]
                    sub_testX, sub_testY = testX[sub_test_mask], testY[sub_test_mask]    

                    lr, batch_size, max_epochs = generate_dl_ft_parameters(
                        dataset, model_name)

                    net = make_dl_model(
                        model_name, n_channels, int(srate*duration), n_classes,
                        Yf=Yf[..., :int(srate*duration)])

                    net.device = device
                    net.lr = lr
                    net.max_epochs = max_epochs
                    net.batch_size = batch_size
                    net.set_params(callbacks__lr_scheduler__T_max=max_epochs-1)
                    net.verbose=False
                    net.module.load_state_dict(
                        copy.deepcopy(loo_global_model_states[k]))

                    net.train_split = predefined_split(
                        skorch.dataset.Dataset(
                            {'X': sub_validateX}, sub_validateY))

                    if model_name == 'dtn':
                        net = net.fit({'X': sub_trainX, 'y': sub_trainY}, y=sub_trainY)
                    else:
                        net = net.fit(
                            {'X': sub_trainX}, 
                            y=sub_trainY)

                    pred_labels = net.predict({'X': sub_testX})
                    true_labels = sub_testY.numpy()
                    sub_acc = balanced_accuracy_score(true_labels, pred_labels)
                    sub_accs.append(sub_acc)
                    
                    # augment training and validation data
                    # the augmented data used for validation while the original training and validation data are combined for training
                    sub_aug_trainX, sub_aug_trainY = generate_augment_data(
                        np.concatenate([sub_trainX.numpy(), sub_validateX.numpy()], axis=0),
                        np.concatenate([sub_trainY.numpy(), sub_validateY.numpy()], axis=0), aug_filterbank,
                        n_trials_per_class=20)
                    sub_aug_trainX = generate_tensors(sub_aug_trainX, dtype=x_dtype)
                    sub_aug_trainY = generate_tensors(sub_aug_trainY, dtype=y_dtype)
                    sub_trainX = torch.cat([sub_trainX, sub_validateX], 0)
                    sub_trainY = torch.cat([sub_trainY, sub_validateY], 0) 

                    sub_validateX = sub_aug_trainX
                    sub_validateY = sub_aug_trainY 

                    lr, batch_size, max_epochs = generate_dl_ft_parameters(
                        dataset, model_name)

                    net = make_dl_model(
                        model_name, n_channels, int(srate*duration), n_classes,
                        Yf=Yf[..., :int(srate*duration)])

                    net.device = device
                    net.lr = lr
                    net.max_epochs = max_epochs
                    net.batch_size = batch_size
                    net.set_params(callbacks__lr_scheduler__T_max=max_epochs-1)
                    net.verbose=False
                    net.module.load_state_dict(
                        copy.deepcopy(loo_global_model_states[k]))

                    net.train_split = predefined_split(
                        skorch.dataset.Dataset(
                            {'X': sub_validateX}, sub_validateY))

                    if model_name == 'dtn':
                        net = net.fit({'X': sub_trainX, 'y': sub_trainY}, y=sub_trainY)
                    else:
                        net = net.fit(
                            {'X': sub_trainX}, 
                            y=sub_trainY)

                    pred_labels = net.predict({'X': sub_testX})
                    true_labels = sub_testY.numpy()
                    sub_acc = balanced_accuracy_score(true_labels, pred_labels)
                    sub_accs_aug.append(sub_acc)
                loo_fine_tuning_accs.append(sub_accs)
                loo_fine_tuning_accs_aug.append(sub_accs_aug)
            global_sub_accs = np.array(loo_global_accs).T
            ft_sub_accs = np.array(loo_fine_tuning_accs).T
            ft_aug_sub_accs = np.array(loo_fine_tuning_accs_aug).T
            joblib.dump(
                {
                    'global_sub_accs': global_sub_accs, 
                    'ft_sub_accs': ft_sub_accs,
                    'ft_aug_sub_accs': ft_aug_sub_accs
                }, save_file)
            torch.save(loo_global_model_states, save_file.replace('joblib', 'pt'))
            print("{:s} acc:{:.4f} ft_acc:{:.4f} ft_aug_acc:{:.4f}".format(
                save_file, np.mean(global_sub_accs), np.mean(ft_sub_accs), np.mean(ft_aug_sub_accs)))

Dataset: nakanishi2015 Size: (1800, 8, 274)
neural_networks/nakanishi2015-eegnet-ssvep-8-250-50-12-3.joblib acc:0.4844 ft_acc:0.6428 ft_aug_acc:0.6589
neural_networks/nakanishi2015-ftn-8-250-50-12-3.joblib acc:0.4894 ft_acc:0.7233 ft_aug_acc:0.7333
neural_networks/nakanishi2015-dtn-8-250-50-12-3.joblib acc:0.4822 ft_acc:0.7128 ft_aug_acc:0.7311
neural_networks/nakanishi2015-eegnet-ssvep-8-250-75-12-3.joblib acc:0.5761 ft_acc:0.7294 ft_aug_acc:0.7428
neural_networks/nakanishi2015-ftn-8-250-75-12-3.joblib acc:0.6183 ft_acc:0.7989 ft_aug_acc:0.8144
neural_networks/nakanishi2015-dtn-8-250-75-12-3.joblib acc:0.5778 ft_acc:0.8011 ft_aug_acc:0.8200
neural_networks/nakanishi2015-eegnet-ssvep-8-250-100-12-3.joblib acc:0.6511 ft_acc:0.7828 ft_aug_acc:0.8133
neural_networks/nakanishi2015-ftn-8-250-100-12-3.joblib acc:0.6861 ft_acc:0.8394 ft_aug_acc:0.8528
neural_networks/nakanishi2015-dtn-8-250-100-12-3.joblib acc:0.7000 ft_acc:0.8772 ft_aug_acc:0.8811
neural_networks/nakanishi2015-eegnet-ssvep-8

replace dtn templates with grand average templates

In [25]:
datasets = [
    Nakanishi2015(), 
    Wang2016(), 
    BETA()
]
delays = [
    0.135, 
    0.14, 
    0.13
]
channels = [
    ['PO7', 'PO3', 'POZ', 'PO4', 'PO8', 'O1', 'OZ', 'O2'],
    ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2'],
    ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2']
]

srate = 250
durations = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

def data_hook(X, y, meta, caches):
    filterbank = generate_filterbank([[8, 90]], [[6, 95]], srate, order=4, rp=1)
    X = sosfiltfilt(filterbank[0], X, axis=-1)
    return X, y, meta, caches

In [26]:
x_dtype, y_dtype = torch.float, torch.long

models = ['dtn']

n_bands = 3
n_harmonics = 5

force_update = False
save_folder = 'neural_networks'

In [None]:
for dataset, dataset_channels, delay in zip(datasets, channels, delays):
    os.makedirs(save_folder, exist_ok=True)
    
    dataset_events = sorted(list(dataset.events.keys()))
    freqs = [dataset.get_freq(event) for event in dataset_events]
    phases = [dataset.get_phase(event) for event in dataset_events]
    
    X, y, meta = get_ssvep_data(
        dataset, srate, dataset_channels, 1.1, dataset_events, 
        delay=delay, 
        data_hook=data_hook)
    labels = np.unique(y)
    Yf = generate_cca_references(
        freqs, srate, 1.1, 
        phases=None, 
        n_harmonics=n_harmonics)
    print("Dataset: {} Size: {}".format(dataset.dataset_code, X.shape))
    _, n_channels, n_samples = X.shape
    n_classes = len(labels)
    
    indices = joblib.load(
        "indices/{:s}-loo-{:d}class-indices.joblib".format(
        dataset.dataset_code, n_classes))['indices']
    loo = len(indices[1][dataset_events[0]])
    
    min_f, max_f = np.min(freqs), np.max(freqs)
    wp = [[min_f*i, max_f*i] for i in range(1, n_harmonics+1)]
    ws= [[min_f*i-2, max_f*i+2] for i in range(1, n_harmonics+1)]
    aug_filterbank = generate_filterbank(wp, ws, srate, order=4, rp=1)
    
    for duration in durations:
        for model_name in models:
            model_file = make_file(
                dataset, model_name, dataset_channels, srate, duration, dataset_events, 
                n_bands=n_bands)
            model_file = os.path.join(save_folder, model_file)
            loo_global_model_states = torch.load(
                model_file.replace('.joblib', '.pt'), map_location='cpu')
            
            file_name = make_file(
                dataset, model_name, dataset_channels, srate, duration, dataset_events, 
                n_bands=n_bands, fixed_dtn_template=True)
            save_file = os.path.join(save_folder, file_name)
            if not force_update and os.path.exists(save_file):
                scores = joblib.load(save_file)
                global_sub_accs_fixed = scores['global_sub_accs_fixed']
                ft_aug_sub_accs_fixed = scores['ft_aug_sub_accs_fixed']
                print("{:s} acc:{:.4f} ft_aug_acc:{:.4f}".format(
                    save_file, np.mean(global_sub_accs_fixed), np.mean(ft_aug_sub_accs_fixed)))
                continue  

            set_random_seeds(42)
            loo_global_accs_fixed = []
            loo_fine_tuning_accs_aug_fixed = []

            for k in range(loo):
                filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)
                filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)

                train_ind, validate_ind, test_ind = match_loo_indices(
                    k, meta, indices)
                trainX, trainY, trainMeta = filterX[train_ind], filterY[train_ind], meta.iloc[train_ind]
                validateX, validateY, validateMeta = filterX[validate_ind], filterY[validate_ind], meta.iloc[validate_ind]
                testX, testY, testMeta = filterX[test_ind], filterY[test_ind], meta.iloc[test_ind]

                trainX, validateX, testX = generate_tensors(
                    trainX, validateX, testX, dtype=x_dtype)
                trainY, validateY, testY = generate_tensors(
                    trainY, validateY, testY, dtype=y_dtype) 

                net = make_dl_model(
                    model_name, n_channels, int(srate*duration), n_classes,
                    Yf=Yf[..., :int(srate*duration)])
                
                lr, batch_size, max_epochs = generate_dl_parameters(
                    dataset, model_name)
                net.device = 'cpu'
                net.lr = lr
                net.max_epochs = max_epochs
                net.batch_size = batch_size
                net.set_params(callbacks__lr_scheduler__T_max=max_epochs-1)
                net.verbose=False
                net.initialize()

                net.module.load_state_dict(
                    copy.deepcopy(loo_global_model_states[k]))
                
                ## testing
                sub_accs = []
                for sub_id in dataset.subjects:
                    # replace dtn templates with subject's grand average templates
                    sub_train_mask = (trainMeta['subject']==sub_id).to_numpy()
                    sub_trainX = (trainX[sub_train_mask]).numpy()
                    sub_trainY = (trainY[sub_train_mask]).numpy()
                    
                    templates = np.stack(
                        [np.mean(sub_trainX[sub_trainY==label], axis=0) for label in labels], 0)
                    templates = torch.tensor(templates[:, np.newaxis, ...], dtype=torch.float)
                    with torch.no_grad():
                        net.module.eval()
                        out = net.module.instance_norm(templates)
                        out = net.module.feature_extractor(out)
                        net.module.update_running_templates(out)
                    sub_test_mask = (testMeta['subject']==sub_id).to_numpy()
                    pred_labels = net.predict({'X': testX[sub_test_mask]})
                    true_labels = testY[sub_test_mask].numpy()
                    sub_acc = balanced_accuracy_score(true_labels, pred_labels)
                    sub_accs.append(sub_acc)
                loo_global_accs_fixed.append(sub_accs)

                ## fine-tuning
                sub_accs_aug = []
                for sub_id in dataset.subjects:
                    sub_train_mask = (trainMeta['subject']==sub_id).to_numpy()
                    sub_valid_mask = (validateMeta['subject']==sub_id).to_numpy()
                    sub_test_mask = (testMeta['subject']==sub_id).to_numpy()

                    sub_trainX, sub_trainY = trainX[sub_train_mask], trainY[sub_train_mask]
                    sub_validateX, sub_validateY = validateX[sub_valid_mask], validateY[sub_valid_mask]
                    sub_testX, sub_testY = testX[sub_test_mask], testY[sub_test_mask]    
                    
                    # augment training and validation data
                    # the augmented data used for validation while the original training and validation data are combined for training
                    sub_aug_trainX, sub_aug_trainY = generate_augment_data(
                        np.concatenate([sub_trainX.numpy(), sub_validateX.numpy()], axis=0),
                        np.concatenate([sub_trainY.numpy(), sub_validateY.numpy()], axis=0), aug_filterbank,
                        n_trials_per_class=20)
                    sub_aug_trainX = generate_tensors(sub_aug_trainX, dtype=x_dtype)
                    sub_aug_trainY = generate_tensors(sub_aug_trainY, dtype=y_dtype)
                    sub_trainX = torch.cat([sub_trainX, sub_validateX], 0)
                    sub_trainY = torch.cat([sub_trainY, sub_validateY], 0) 

                    sub_validateX = sub_aug_trainX
                    sub_validateY = sub_aug_trainY 

                    lr, batch_size, max_epochs = generate_dl_ft_parameters(
                        dataset, model_name)

                    net = make_dl_model(
                        model_name, n_channels, int(srate*duration), n_classes,
                        Yf=Yf[..., :int(srate*duration)])

                    net.device = device
                    net.lr = lr
                    net.max_epochs = max_epochs
                    net.batch_size = batch_size
                    net.set_params(callbacks__lr_scheduler__T_max=max_epochs-1)
                    net.verbose=False
                    net.module.load_state_dict(
                        copy.deepcopy(loo_global_model_states[k]))

                    net.train_split = predefined_split(
                        skorch.dataset.Dataset(
                            {'X': sub_validateX}, sub_validateY))

                    if model_name == 'dtn':
                        net = net.fit({'X': sub_trainX, 'y': sub_trainY}, y=sub_trainY)
                    else:
                        net = net.fit(
                            {'X': sub_trainX}, 
                            y=sub_trainY)
                    
                    net.device = 'cpu'
                    # replace dtn templates with subject's grand average templates
                    sub_trainX = sub_trainX.numpy()
                    sub_trainY = sub_trainY.numpy()
                    templates = np.stack(
                        [np.mean(sub_trainX[sub_trainY==label], axis=0) for label in labels], 0)
                    templates = torch.tensor(templates[:, np.newaxis, ...], dtype=torch.float)
                    with torch.no_grad():
                        net.module.eval()
                        net.module = net.module.to('cpu')
                        out = net.module.instance_norm(templates)
                        out = net.module.feature_extractor(out)
                        net.module.update_running_templates(out)

                    pred_labels = net.predict({'X': sub_testX})
                    true_labels = sub_testY.numpy()
                    sub_acc = balanced_accuracy_score(true_labels, pred_labels)
                    sub_accs_aug.append(sub_acc)
                loo_fine_tuning_accs_aug_fixed.append(sub_accs_aug)
            global_sub_accs_fixed = np.array(loo_global_accs_fixed).T
            ft_aug_sub_accs_fixed = np.array(loo_fine_tuning_accs_aug_fixed).T
            joblib.dump(
                {
                    'global_sub_accs_fixed': global_sub_accs_fixed, 
                    'ft_aug_sub_accs_fixed': ft_aug_sub_accs_fixed
                }, save_file)
            print("{:s} acc:{:.4f} ft_aug_acc:{:.4f}".format(
                save_file, np.mean(global_sub_accs_fixed), np.mean(ft_aug_sub_accs_fixed)))

Dataset: nakanishi2015 Size: (1800, 8, 274)
neural_networks/nakanishi2015-dtn-8-250-50-12-3-fixed.joblib acc:0.5811 ft_aug_acc:0.7272
neural_networks/nakanishi2015-dtn-8-250-75-12-3-fixed.joblib acc:0.6961 ft_aug_acc:0.8161
neural_networks/nakanishi2015-dtn-8-250-100-12-3-fixed.joblib acc:0.8011 ft_aug_acc:0.8800


 case study

In [11]:
dataset = BETA()
delay = 0.13
dataset_channels = ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2']

srate = 250
duration = 0.5
model_name = 'dtn'

def data_hook(X, y, meta, caches):
    filterbank = generate_filterbank([[8, 90]], [[6, 95]], srate, order=4, rp=1)
    X = sosfiltfilt(filterbank[0], X, axis=-1)
    return X, y, meta, caches

In [12]:
better = [8, 31, 43, 59]
worse = [14, 25, 40, 56]
subjects = better + worse

In [13]:
dataset_events = sorted(list(dataset.events.keys()))
freqs = [dataset.get_freq(event) for event in dataset_events]
phases = [dataset.get_phase(event) for event in dataset_events]

X, y, meta = get_ssvep_data(
    dataset, srate, dataset_channels, 1.1, dataset_events, 
    delay=delay, 
    data_hook=data_hook)
labels = np.unique(y)
Yf = generate_cca_references(
    freqs, srate, 1.1, 
    phases=None, 
    n_harmonics=n_harmonics)
print("Dataset: {} Size: {}".format(dataset.dataset_code, X.shape))
_, n_channels, n_samples = X.shape
n_classes = len(labels)

indices = joblib.load(
    "indices/{:s}-loo-{:d}class-indices.joblib".format(
    dataset.dataset_code, n_classes))['indices']
loo = len(indices[1][dataset_events[0]])

min_f, max_f = np.min(freqs), np.max(freqs)
wp = [[min_f*i, max_f*i] for i in range(1, n_harmonics+1)]
ws= [[min_f*i-2, max_f*i+2] for i in range(1, n_harmonics+1)]
aug_filterbank = generate_filterbank(wp, ws, srate, order=4, rp=1)

file_name = make_file(
    dataset, model_name, dataset_channels, srate, duration, dataset_events, 
    n_bands=n_bands)
save_file = os.path.join(save_folder, file_name)

set_random_seeds(42)

loo_global_model_states = torch.load(
    save_file.replace('.joblib', '.pt'), map_location='cpu')

Dataset: beta Size: (11200, 9, 275)


In [13]:
if not os.path.exists('sub_case_model_states.pt'):
    loo_sub_model_states = []
    for k in range(loo):
        filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)
        filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)

        train_ind, validate_ind, test_ind = match_loo_indices(
            k, meta, indices)
        trainX, trainY, trainMeta = filterX[train_ind], filterY[train_ind], meta.iloc[train_ind]
        validateX, validateY, validateMeta = filterX[validate_ind], filterY[validate_ind], meta.iloc[validate_ind]
        testX, testY, testMeta = filterX[test_ind], filterY[test_ind], meta.iloc[test_ind]

        trainX, validateX, testX = generate_tensors(
            trainX, validateX, testX, dtype=x_dtype)
        trainY, validateY, testY = generate_tensors(
            trainY, validateY, testY, dtype=y_dtype) 

        ## fine-tuning
        sub_model_states = []
        for sub_id in subjects:
            sub_train_mask = (trainMeta['subject']==sub_id).to_numpy()
            sub_valid_mask = (validateMeta['subject']==sub_id).to_numpy()
            sub_test_mask = (testMeta['subject']==sub_id).to_numpy()

            sub_trainX, sub_trainY = trainX[sub_train_mask], trainY[sub_train_mask]
            sub_validateX, sub_validateY = validateX[sub_valid_mask], validateY[sub_valid_mask]
            sub_testX, sub_testY = testX[sub_test_mask], testY[sub_test_mask]    

            # augment training and validation data
            # the augmented data used for validation while the original training and validation data are combined for training
            sub_aug_trainX, sub_aug_trainY = generate_augment_data(
                np.concatenate([sub_trainX.numpy(), sub_validateX.numpy()], axis=0),
                np.concatenate([sub_trainY.numpy(), sub_validateY.numpy()], axis=0), aug_filterbank,
                n_trials_per_class=20)
            sub_aug_trainX = generate_tensors(sub_aug_trainX, dtype=x_dtype)
            sub_aug_trainY = generate_tensors(sub_aug_trainY, dtype=y_dtype)
            sub_trainX = torch.cat([sub_trainX, sub_validateX], 0)
            sub_trainY = torch.cat([sub_trainY, sub_validateY], 0) 

            sub_validateX = sub_aug_trainX
            sub_validateY = sub_aug_trainY 

            lr, batch_size, max_epochs = generate_dl_ft_parameters(
                dataset, model_name)

            net = make_dl_model(
                model_name, n_channels, int(srate*duration), n_classes,
                Yf=Yf[..., :int(srate*duration)])

            net.device = device
            net.lr = lr
            net.max_epochs = max_epochs
            net.batch_size = batch_size
            net.set_params(callbacks__lr_scheduler__T_max=max_epochs-1)
            net.verbose=False
            net.module.load_state_dict(
                copy.deepcopy(loo_global_model_states[k]))

            net.train_split = predefined_split(
                skorch.dataset.Dataset(
                    {'X': sub_validateX}, sub_validateY))

            if model_name == 'dtn':
                net = net.fit({'X': sub_trainX, 'y': sub_trainY}, y=sub_trainY)
            else:
                net = net.fit(
                    {'X': sub_trainX}, 
                    y=sub_trainY)

            sub_model_states.append(
                copy.deepcopy(net.module.state_dict()))
        loo_sub_model_states.append(sub_model_states)
    loo_sub_model_states = [x for x in zip(*loo_sub_model_states)]
    torch.save(loo_sub_model_states, 'sub_case_model_states.pt')

In [19]:
loo_sub_model_states = torch.load(
    'sub_case_model_states.pt', map_location='cpu')

net = make_dl_model(
    'dtn', n_channels, int(srate*duration), n_classes,
    Yf=Yf[..., :int(srate*duration)])

dtn_sub_templates1 = []
for i_sub in range(len(subjects)):
    template = 0 
    for k in range(loo):
        net.module.load_state_dict(
            copy.deepcopy(loo_sub_model_states[i_sub][k]))
        template += net.module.running_template.detach().numpy()
    template /= loo
    dtn_sub_templates1.append(np.squeeze(template))

filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)
filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)

dtn_sub_templates2 = []
for i_sub in range(len(subjects)):
    subX, subY = filterX[meta['subject']==subjects[i_sub]], filterY[meta['subject']==subjects[i_sub]]
    templates = []
    for label in labels:
        templates.append(np.mean(subX[subY==label], axis=0))
    templates = np.stack(templates, 0)
    templates = torch.tensor(templates[:, np.newaxis, ...], dtype=torch.float)
    
    template = 0 
    for k in range(loo):
        net.module.load_state_dict(
            copy.deepcopy(loo_sub_model_states[i_sub][k]))
        net.module.eval()
        out = net.module.instance_norm(templates)
        out = net.module.feature_extractor(out).detach().numpy()
        template += out
    template /= loo
    dtn_sub_templates2.append(np.squeeze(template))

joblib.dump(
    {'dtn_sub_templates1': dtn_sub_templates1, 'dtn_sub_templates2': dtn_sub_templates2},
    'sub_templates.joblib')

computation time

In [41]:
dataset = BETA()
delay = 0.13
dataset_channels = ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2']

srate = 250
duration = 0.5
model_name = 'ftn'

def data_hook(X, y, meta, caches):
    filterbank = generate_filterbank([[8, 90]], [[6, 95]], srate, order=4, rp=1)
    X = sosfiltfilt(filterbank[0], X, axis=-1)
    return X, y, meta, caches

In [42]:
dataset_events = sorted(list(dataset.events.keys()))
freqs = [dataset.get_freq(event) for event in dataset_events]
phases = [dataset.get_phase(event) for event in dataset_events]

X, y, meta = get_ssvep_data(
    dataset, srate, dataset_channels, 1.1, dataset_events, 
    delay=delay, 
    data_hook=data_hook)
labels = np.unique(y)
Yf = generate_cca_references(
    freqs, srate, 1.1, 
    phases=None, 
    n_harmonics=n_harmonics)
print("Dataset: {} Size: {}".format(dataset.dataset_code, X.shape))
_, n_channels, n_samples = X.shape
n_classes = len(labels)

indices = joblib.load(
    "indices/{:s}-loo-{:d}class-indices.joblib".format(
    dataset.dataset_code, n_classes))['indices']
loo = len(indices[1][dataset_events[0]])

min_f, max_f = np.min(freqs), np.max(freqs)
wp = [[min_f*i, max_f*i] for i in range(1, n_harmonics+1)]
ws= [[min_f*i-2, max_f*i+2] for i in range(1, n_harmonics+1)]
aug_filterbank = generate_filterbank(wp, ws, srate, order=4, rp=1)

file_name = make_file(
    dataset, model_name, dataset_channels, srate, duration, dataset_events, 
    n_bands=n_bands)
save_file = os.path.join(save_folder, file_name)

set_random_seeds(42)

loo_global_model_states = torch.load(
    save_file.replace('.joblib', '.pt'), map_location='cpu')

Dataset: beta Size: (11200, 9, 275)


In [43]:
training_time = 0 #gpu
inference_time = 0 #cpu
for k in range(loo):
    filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)
    filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)

    train_ind, validate_ind, test_ind = match_loo_indices(
        k, meta, indices)
    trainX, trainY, trainMeta = filterX[train_ind], filterY[train_ind], meta.iloc[train_ind]
    validateX, validateY, validateMeta = filterX[validate_ind], filterY[validate_ind], meta.iloc[validate_ind]
    testX, testY, testMeta = filterX[test_ind], filterY[test_ind], meta.iloc[test_ind]

    trainX, validateX, testX = generate_tensors(
        trainX, validateX, testX, dtype=x_dtype)
    trainY, validateY, testY = generate_tensors(
        trainY, validateY, testY, dtype=y_dtype) 

    ## fine-tuning
    for sub_id in [1]:
        sub_train_mask = (trainMeta['subject']==sub_id).to_numpy()
        sub_valid_mask = (validateMeta['subject']==sub_id).to_numpy()
        sub_test_mask = (testMeta['subject']==sub_id).to_numpy()

        sub_trainX, sub_trainY = trainX[sub_train_mask], trainY[sub_train_mask]
        sub_validateX, sub_validateY = validateX[sub_valid_mask], validateY[sub_valid_mask]
        sub_testX, sub_testY = testX[sub_test_mask], testY[sub_test_mask]    

        # augment training and validation data
        # the augmented data used for validation while the original training and validation data are combined for training
        sub_aug_trainX, sub_aug_trainY = generate_augment_data(
            np.concatenate([sub_trainX.numpy(), sub_validateX.numpy()], axis=0),
            np.concatenate([sub_trainY.numpy(), sub_validateY.numpy()], axis=0), aug_filterbank,
            n_trials_per_class=20)
        sub_aug_trainX = generate_tensors(sub_aug_trainX, dtype=x_dtype)
        sub_aug_trainY = generate_tensors(sub_aug_trainY, dtype=y_dtype)
        sub_trainX = torch.cat([sub_trainX, sub_validateX], 0)
        sub_trainY = torch.cat([sub_trainY, sub_validateY], 0) 

        sub_validateX = sub_aug_trainX
        sub_validateY = sub_aug_trainY 

        lr, batch_size, max_epochs = generate_dl_ft_parameters(
            dataset, model_name)

        net = make_dl_model(
            model_name, n_channels, int(srate*duration), n_classes,
            Yf=Yf[..., :int(srate*duration)])

        net.device = device
        net.lr = lr
        net.max_epochs = max_epochs
        net.batch_size = batch_size
        net.set_params(callbacks__lr_scheduler__T_max=max_epochs-1)
        net.verbose=False
        net.module.load_state_dict(
            copy.deepcopy(loo_global_model_states[k]))

        net.train_split = predefined_split(
            skorch.dataset.Dataset(
                {'X': sub_validateX}, sub_validateY))
        
        start_t = time.time()
        if model_name == 'dtn':
            net = net.fit({'X': sub_trainX, 'y': sub_trainY}, y=sub_trainY)
        else:
            net = net.fit(
                {'X': sub_trainX}, 
                y=sub_trainY)
        end_t = time.time()
        training_time += (end_t-start_t)
        
        module = net.module
        module.eval()
        module = module.to('cpu')
        
        tmp = 0
        for i in range(len(sub_testX)):
            start_t = time.time()
            pred_label = torch.argmax(module(sub_testX[i][np.newaxis, ...]), -1)
            end_t = time.time()
            tmp += (end_t -start_t)
        tmp /= len(sub_testX)
        inference_time += tmp

In [44]:
training_time /= loo
inference_time /= loo
print("average training time: {:.4f}s".format(training_time))
print("average inference time: {:.4f}s".format(inference_time))

average training time: 22.9188s
average inference time: 0.0107s
