In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib qt5

In [3]:
import warnings, copy, os
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import mne
import joblib

from sklearn.metrics import confusion_matrix, balanced_accuracy_score

from brainda.datasets import (
    PhysionetMI, BNCI2014001, Weibo2014, Cho2017)

from brainda.paradigms import MotorImagery
from brainda.algorithms.utils.model_selection import (
    set_random_seeds,
    generate_kfold_indices, match_kfold_indices)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import skorch
from skorch.classifier import NeuralNetClassifier
from skorch.helper import predefined_split
from skorch.callbacks import (LRScheduler, EpochScoring, Checkpoint, Callback,
                              TrainEndCheckpoint, LoadInitState, EarlyStopping)

from brainda.algorithms.deep_learning import EEGNet, ShallowNet
from braindecode.models import ShallowFBCSPNet, TIDNet, EEGNetv4
from braindecode import EEGClassifier

from brainda.algorithms.transfer_learning import MEKT
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from utils import *

In [4]:
device_id = 0
device = torch.device("cuda:{:d}".format(device_id) if torch.cuda.is_available() else "cpu")
if device != 'cpu':
    torch.backends.cudnn.benchmark = True
print("Available GPU devices: {}".format(torch.cuda.device_count()))
print("Current pytorch device: {}".format(device))

Available GPU devices: 1
Current pytorch device: cuda:0


In [5]:
from brainda.algorithms.manifold.rpa import get_recenter, recenter

def preprocessing(X, meta, k, indices, y=None, method='cnorm', no_split=False):
    # remove mean
    X = X - np.mean(X, axis=-1, keepdims=True)
    
    if method == 'cnorm':
        X = X / np.std(X, axis=-1, keepdims=True)
    elif method == 'tnorm':
        X = X / np.std(X, axis= (-1, -2), keepdims=True)
    elif method in ['riemann', 'euclid']:
        # subject-level aligning
        subjects = np.unique(meta['subject'])
        for sub_id in subjects:
            sub_meta = meta[meta['subject']==sub_id]
            if no_split:
                ind = sub_meta.index.to_numpy()
                iM12 = get_recenter(
                    X[ind],
                    cov_method='lwf',
                    mean_method=method)
                X[ind] = recenter(X[ind], iM12)
            else:
                train_ind, validate_ind, test_ind = match_kfold_indices(k, sub_meta, indices)
                iM12 = get_recenter(
                    X[np.concatenate((train_ind, validate_ind))],
                    cov_method='lwf', 
                    mean_method=method)
                X[train_ind] = recenter(X[train_ind], iM12)
                X[validate_ind] = recenter(X[validate_ind], iM12)
                X[test_ind] = recenter(X[test_ind], iM12)      
    return X

In [6]:
def raw_hook(raw, caches, verbose=False):
    raw.filter(4, 40, l_trans_bandwidth=2, h_trans_bandwidth=5, phase='zero-double')
    return raw, caches

class NeuralNetClassifierNoLog(NeuralNetClassifier):
    def get_loss(self, y_pred, y_true, *args, **kwargs):
        return super(NeuralNetClassifier, self).get_loss(y_pred, y_true, *args, **kwargs)

def make_model(model_name, n_channels, n_samples, n_classes):
    set_random_seeds(64)
    if model_name == 'eegnetv4':
        model = EEGNetv4(n_channels, n_classes, 
                    input_window_samples=n_samples, 
                    F1=8,
                    D=2,
                    F2=16,
                    kernel_length=64,
                    drop_prob=0.5)
    elif model_name == 'shallowfbcspnet':
        model = ShallowFBCSPNet(n_channels, n_classes,
                    input_window_samples=n_samples,
                    n_filters_time=40,
                    filter_time_length=13,
                    n_filters_spat=20,
                    pool_time_length=37,
                    pool_time_stride=7,
                    final_conv_length='auto',
                    drop_prob=0.5)
    elif model_name == 'tidnet':
        model = TIDNet(n_channels, n_classes, 
                    input_window_samples=n_samples,
                    s_growth=24,
                    t_filters=32, 
                    temp_layers=2, 
                    spat_layers=2, 
                    pooling=15, 
                    temp_span=0.05, 
                    bottleneck=3)
    elif model_name == 'eegnet':
        model = EEGNet(n_channels, n_samples, n_classes,
                    time_kernel=(8, (1, 64), (1, 1)),
                    D=2,
                    pool_kernel1=((1, 4), (1, 4)),
                    separa_kernel=(16, (1, 16), (1, 1)),
                    pool_kernel2=((1, 8), (1, 8)),
                    depthwise_norm_rate=1,
                    fc_norm_rate=0.25,
                    dropout_rate=0.5)
    elif model_name == 'shallownet':
        model = ShallowNet(n_channels, n_samples, n_classes,
                    n_time_filters=40,
                    time_kernel=13,
                    n_space_filters=20,
                    pool_kernel=37,
                    pool_stride=7,
                    dropout_rate=0.5)
        
    return model
    
def network_training(model_name, dataset, selected_channels, srate, duration, events,
    kfold=5,
    save_folder='eegnet',
    batch_size=256, lr=1e-2, max_epochs=400, T_max=400,
    criterion=nn.NLLLoss, optimizer=optim.Adam,
    preprocess_methods=['raw'],
    force_update=False,
    verbose=False):
    
    for preprocess_method in preprocess_methods:
        ckp_dirname = 'runs_{}'.format(model_name)
        os.makedirs(save_folder, exist_ok=True)
        save_file = os.path.join(
            save_folder, 
            "{}-{}-{}classes.joblib".format(
                dataset.dataset_code, 
                '{}-{}'.format(model_name, preprocess_method), 
                len(events)))
        if not force_update and os.path.exists(save_file):
            kfold_accs = joblib.load(save_file)['kfold_accs']
            print("Dataset:{} {} Acc: {:.4f}".format(
                dataset.dataset_code,
                '{}-{}'.format(model_name, preprocess_method),
                np.mean(kfold_accs)))
            continue
        
        start_t = dataset.events['left_hand'][1][0]
        if start_t+ duration > dataset.events['left_hand'][1][1]:
            print("Warning: the current dataset avaliable trial duration is not long enough.")
        paradigm = MotorImagery(
            channels=selected_channels, 
            srate=srate, 
            intervals=[(start_t, start_t + duration)], 
            events=events)
        event_id = [dataset.events[e][0] for e in events]
        paradigm.register_raw_hook(raw_hook)
        X, y, meta = paradigm.get_data(
            dataset, 
            subjects=dataset.subjects, 
            return_concat=True, 
            verbose=False)
        y = label_encoder(y, event_id)
        labels = np.unique(y)

        set_random_seeds(38)
        indices = generate_kfold_indices(meta, kfold=kfold)
        
        
        n_trials, n_channels, n_samples = X.shape
        n_classes = len(labels)
        x_dtype, y_dtype = torch.float, torch.long
        model = make_model(model_name, n_channels, n_samples, n_classes)
        initial_state = copy.deepcopy(model.state_dict())
        
        kfold_accs, kfold_cms = [], []
        model_states = []

        set_random_seeds(42)
        for k in range(kfold):
            filterX = np.copy(X)
            filterY = np.copy(y)

            filterX = preprocessing(filterX, meta, k, indices, method=preprocess_method)

            train_ind, validate_ind, test_ind = match_kfold_indices(k, meta, indices)
            trainX, trainy = filterX[train_ind], filterY[train_ind]
            validateX, validatey = filterX[validate_ind], filterY[validate_ind]
            testX, testy = filterX[test_ind], filterY[test_ind]

            trainX, validateX, testX = generate_tensors(trainX, validateX, testX, dtype=x_dtype)
            trainy, validatey, testy = generate_tensors(trainy, validatey, testy, dtype=y_dtype)

            torch.cuda.empty_cache()
            ckp = Checkpoint(dirname=ckp_dirname)
            train_end_ckp = TrainEndCheckpoint(dirname=ckp_dirname)
            estopper = EarlyStopping(patience=50)

            model.load_state_dict(copy.deepcopy(initial_state))
            train_split = predefined_split(
                skorch.dataset.Dataset(
                    validateX, validatey))

            net = NeuralNetClassifierNoLog(model,
                    criterion=criterion,
                    optimizer=optimizer,
                    batch_size=batch_size, 
                    lr=lr, 
                    max_epochs=max_epochs,
                    device=device,
                    train_split=train_split,
                    iterator_train__shuffle=True,
                    callbacks=[
                        ('train_acc', EpochScoring('accuracy', 
                                                   name='train_acc', 
                                                   on_train=True, 
                                                   lower_is_better=False)),
                         ('lr_scheduler', LRScheduler('CosineAnnealingLR', T_max=T_max - 1)),
                        estopper,
                        ckp,
                        train_end_ckp
                    ],
                    verbose=verbose)

            net.fit(trainX, y=trainy)
            net.load_params(checkpoint=ckp)
            model_states.append(copy.deepcopy(net.module_.state_dict()))

            # test set
            pred_labels = net.predict(testX)
            true_labels = testy.numpy()
            cm = confusion_matrix(true_labels, pred_labels, labels=labels, normalize='true')
            kfold_accs.append(balanced_accuracy_score(true_labels, pred_labels))
            kfold_cms.append(cm)

        print("Dataset:{} {} Acc: {:.4f}".format(
            dataset.dataset_code,
            '{}-{}'.format(model_name, preprocess_method),
            np.mean(kfold_accs)))
        joblib.dump({
            'model_states': model_states, 
            'kfold_accs': kfold_accs, 'kfold_cms': kfold_cms}, 
            save_file)

### Network Training

In [24]:
srate = 128
datasets = [BNCI2014001(), PhysionetMI(), Weibo2014(), Cho2017()]
selected_channels = ['FZ', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'P1', 'PZ', 'P2', 'POZ']
duration = 3 # seconds
events = ['left_hand', 'right_hand']
preprocess_methods = ['raw', 'cnorm', 'tnorm', 'euclid', 'riemann']
kfold = 5

model_names = ['shallownet', 'eegnet', 'shallowfbcspnet', 'eegnetv4', 'tident']
lrs = [0.0625 * 0.01, 1e-2, 0.0625 * 0.01, 1e-2, 1e-3]

for i, model_name in enumerate(model_names):
    if model_name in ['shallownet', 'eegnet']:
        criterion = nn.CrossEntropyLoss
    else:
        criterion = nn.NLLLoss
        
    for dataset in datasets:
        network_training(model_name, dataset, selected_channels, srate, duration, events,
            kfold=kfold, 
            save_folder=model_name, 
            batch_size=256, lr=lrs[i], max_epochs=200, T_max=200,
            criterion=criterion,
            preprocess_methods=preprocess_methods,
            force_update=False,
            verbose=False)

Dataset:bnci2014001 tidnet-raw Acc: 0.6131
Dataset:bnci2014001 tidnet-cnorm Acc: 0.5657
Dataset:bnci2014001 tidnet-tnorm Acc: 0.6107
Dataset:bnci2014001 tidnet-euclid Acc: 0.7872
Dataset:bnci2014001 tidnet-riemann Acc: 0.8007
Dataset:eegbci tidnet-raw Acc: 0.5478
Dataset:eegbci tidnet-cnorm Acc: 0.5313
Dataset:eegbci tidnet-tnorm Acc: 0.5629
Dataset:eegbci tidnet-euclid Acc: 0.6278
Dataset:eegbci tidnet-riemann Acc: 0.6358
Dataset:weibo2014 tidnet-raw Acc: 0.6285
Dataset:weibo2014 tidnet-cnorm Acc: 0.5519
Dataset:weibo2014 tidnet-tnorm Acc: 0.6127
Dataset:weibo2014 tidnet-euclid Acc: 0.6987
Dataset:weibo2014 tidnet-riemann Acc: 0.7038
Dataset:cho2017 tidnet-raw Acc: 0.5438
Dataset:cho2017 tidnet-cnorm Acc: 0.5351
Dataset:cho2017 tidnet-tnorm Acc: 0.5449
Dataset:cho2017 tidnet-euclid Acc: 0.6338
Dataset:cho2017 tidnet-riemann Acc: 0.6294


### Cross-dataset Performance

In [8]:
def batchnorm_pre_forward_hook(self, input):
    old_training_state = self.training
    self.eval()
    # global AdaBN
    with torch.no_grad():
        self.running_mean.data.zero_()
        self.running_mean.data.add_(torch.mean(input[0], dim=(0, 2, 3)))
        self.running_var.data.zero_()
        self.running_var.data.add_(torch.var(input[0], dim=(0, 2, 3)))
    if old_training_state:
        self.train()

def cross_dataset_network_predict(
    model_name, source_dataset, target_dataset, selected_channels, srate, duration, events,
    kfold=5,
    save_folder='eegnet',
    preprocess_methods=['raw'],
    use_adabn=False,
    force_update=False):
    
    for preprocess_method in preprocess_methods:
        model_name_str = '{}-{}'.format(model_name, preprocess_method)
        
        model_file = os.path.join(
            save_folder, 
            "{}-{}-{}classes.joblib".format(
                source_dataset.dataset_code, 
                model_name_str, 
                len(events)))
        
        if use_adabn:
            save_file = os.path.join(
                save_folder,
                "{}->{}-{}-{}classes.joblib".format(
                    source_dataset.dataset_code,
                    target_dataset.dataset_code,
                    model_name_str+'-adabn',
                    len(events)))
        else:
            save_file = os.path.join(
                save_folder,
                "{}->{}-{}-{}classes.joblib".format(
                    source_dataset.dataset_code,
                    target_dataset.dataset_code,
                    model_name_str,
                    len(events)))
        
        if not force_update and os.path.exists(save_file):
            kfold_accs = joblib.load(save_file)['kfold_accs']
            print("Source: {} Target: {} {} Acc: {:.4f}".format(
                source_dataset.dataset_code,
                target_dataset.dataset_code,
                model_name_str+'-adabn' if use_adabn else model_name_str,
                np.mean(kfold_accs)))
            continue
            
        start_t = target_dataset.events['left_hand'][1][0]
        if start_t+ duration > target_dataset.events['left_hand'][1][1]:
            print("Warning: the current dataset avaliable trial duration is not long enough.")
        paradigm = MotorImagery(
            channels=selected_channels, 
            srate=srate, 
            intervals=[(start_t, start_t + duration)], 
            events=events)
        event_id = [target_dataset.events[e][0] for e in events]
        paradigm.register_raw_hook(raw_hook)

        X, y, meta = paradigm.get_data(
            target_dataset, 
            subjects=target_dataset.subjects, 
            return_concat=True, 
            verbose=False)
        y = label_encoder(y, event_id)
        labels = np.unique(y)

        set_random_seeds(38)
        indices = generate_kfold_indices(meta, kfold=kfold)
            
        
        model_states = joblib.load(model_file)['model_states']
        
        n_trials, n_channels, n_samples = X.shape
        n_classes = len(labels)
        x_dtype, y_dtype = torch.float, torch.long
        model = make_model(model_name, n_channels, n_samples, n_classes)
        if use_adabn:
            handles = []
            for module in model.modules():
                if 'BatchNorm' in module.__class__.__name__:
                    handles.append(
                        module.register_forward_pre_hook(batchnorm_pre_forward_hook))

        kfold_accs, kfold_cms = [], []
 
        set_random_seeds(42)
        for k in range(kfold):
            model.load_state_dict(copy.deepcopy(model_states[k]))
            model.eval()
            
            filterX = np.copy(X)
            filterY = np.copy(y)
            
            if source_dataset.dataset_code != target_dataset.dataset_code:
                # make k and indices useless
                filterX = preprocessing(filterX, meta, None, None,
                            method=preprocess_method,
                            no_split=True)
            else:
                filterX = preprocessing(filterX, meta, k, indices,
                            method=preprocess_method,
                            no_split=False)
            
            sub_accs, sub_cms = [], []
            for sub_id in target_dataset.subjects:
                if source_dataset.dataset_code != target_dataset.dataset_code:
                    test_ind = meta[meta['subject']==sub_id].index.to_numpy()
                else:
                    _, _, test_ind = match_kfold_indices(k, meta[meta['subject']==sub_id], indices)
                testX, testy = filterX[test_ind], filterY[test_ind]

                testX = generate_tensors(testX, dtype=x_dtype)
                testy = generate_tensors(testy, dtype=y_dtype)
                
                torch.cuda.empty_cache()
                
                output = model(testX)
                pred_labels = torch.argmax(output, dim=1).detach().numpy()
                true_labels = testy.numpy()
                cm = confusion_matrix(true_labels, pred_labels, labels=labels, normalize='true')
                sub_accs.append(balanced_accuracy_score(true_labels, pred_labels))
                sub_cms.append(cm)
            kfold_accs.append(sub_accs)
            kfold_cms.append(sub_cms)

        print("Source: {} Target: {} {} Acc: {:.4f}".format(
            source_dataset.dataset_code,
            target_dataset.dataset_code,
            model_name_str+'-adabn' if use_adabn else model_name_str,
            np.mean(kfold_accs)))
        joblib.dump({
            'kfold_accs': kfold_accs, 'kfold_cms': kfold_cms}, 
            save_file)

In [10]:
srate = 128
datasets = [BNCI2014001(), PhysionetMI(), Weibo2014(), Cho2017()]
selected_channels = ['FZ', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'P1', 'PZ', 'P2', 'POZ']
duration = 3 # seconds
events = ['left_hand', 'right_hand']
preprocess_methods = ['raw', 'cnorm', 'tnorm', 'euclid', 'riemann']
kfold = 5

model_names = ['shallownet', 'eegnet', 'shallowfbcspnet', 'eegnetv4', 'tidnet']
for model_name in model_names:
    for source_dataset in datasets:
        for target_dataset in datasets:
            cross_dataset_network_predict(
                model_name, source_dataset, target_dataset, selected_channels, srate, duration, events,
                kfold=kfold, 
                save_folder=model_name,
                preprocess_methods=preprocess_methods,
                use_adabn=False,
                force_update=False)

Source: bnci2014001 Target: bnci2014001 shallownet-raw Acc: 0.8165
Source: bnci2014001 Target: bnci2014001 shallownet-cnorm Acc: 0.8081
Source: bnci2014001 Target: bnci2014001 shallownet-tnorm Acc: 0.7987
Source: bnci2014001 Target: bnci2014001 shallownet-euclid Acc: 0.8634
Source: bnci2014001 Target: bnci2014001 shallownet-riemann Acc: 0.8611
Source: bnci2014001 Target: eegbci shallownet-raw Acc: 0.6094
Source: bnci2014001 Target: eegbci shallownet-cnorm Acc: 0.6132
Source: bnci2014001 Target: eegbci shallownet-tnorm Acc: 0.6172
Source: bnci2014001 Target: eegbci shallownet-euclid Acc: 0.6384
Source: bnci2014001 Target: eegbci shallownet-riemann Acc: 0.6285
Source: bnci2014001 Target: weibo2014 shallownet-raw Acc: 0.6279
Source: bnci2014001 Target: weibo2014 shallownet-cnorm Acc: 0.6328
Source: bnci2014001 Target: weibo2014 shallownet-tnorm Acc: 0.6381
Source: bnci2014001 Target: weibo2014 shallownet-euclid Acc: 0.7003
Source: bnci2014001 Target: weibo2014 shallownet-riemann Acc: 0.68

In [12]:
srate = 128
datasets = [BNCI2014001(), PhysionetMI(), Weibo2014(), Cho2017()]
selected_channels = ['FZ', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'P1', 'PZ', 'P2', 'POZ']
duration = 3 # seconds
events = ['left_hand', 'right_hand']
preprocess_methods = ['raw', 'cnorm', 'tnorm', 'euclid', 'riemann']
kfold = 5

model_names = ['shallownet', 'eegnet', 'shallowfbcspnet', 'eegnetv4', 'tidnet']
for model_name in model_names:
    for source_dataset in datasets:
        for target_dataset in datasets:
            cross_dataset_network_predict(
                model_name, source_dataset, target_dataset, selected_channels, srate, duration, events,
                kfold=kfold, 
                save_folder=model_name,
                preprocess_methods=preprocess_methods,
                use_adabn=True,
                force_update=False)

Source: bnci2014001 Target: bnci2014001 shallownet-raw-adabn Acc: 0.8232
Source: bnci2014001 Target: bnci2014001 shallownet-cnorm-adabn Acc: 0.8197
Source: bnci2014001 Target: bnci2014001 shallownet-tnorm-adabn Acc: 0.8047
Source: bnci2014001 Target: bnci2014001 shallownet-euclid-adabn Acc: 0.8645
Source: bnci2014001 Target: bnci2014001 shallownet-riemann-adabn Acc: 0.8637
Source: bnci2014001 Target: eegbci shallownet-raw-adabn Acc: 0.6374
Source: bnci2014001 Target: eegbci shallownet-cnorm-adabn Acc: 0.6251
Source: bnci2014001 Target: eegbci shallownet-tnorm-adabn Acc: 0.6394
Source: bnci2014001 Target: eegbci shallownet-euclid-adabn Acc: 0.6579
Source: bnci2014001 Target: eegbci shallownet-riemann-adabn Acc: 0.6481
Source: bnci2014001 Target: weibo2014 shallownet-raw-adabn Acc: 0.6419
Source: bnci2014001 Target: weibo2014 shallownet-cnorm-adabn Acc: 0.6514
Source: bnci2014001 Target: weibo2014 shallownet-tnorm-adabn Acc: 0.6542
Source: bnci2014001 Target: weibo2014 shallownet-euclid-

### Fusing MEKT, alignment strategies, and AdaBN

In [9]:
class Expand4D(nn.Module):
    def __init__(self):
        super(Expand4D, self).__init__()
    
    def forward(self, X):
        return X.unsqueeze(1)

def make_feature_extractor(model_name, model):
    if model_name == 'eegnet':
        fe = nn.Sequential(Expand4D(), *list(model.model.children())[:-1])
    elif model_name == 'shallownet':
        fe = nn.Sequential(Expand4D(), *list(model.model.children())[:-1])
    elif model_name == 'shallowfbcspnet':
        fe = nn.Sequential(*list(model.children())[:-3], nn.Flatten())
    elif model_name == 'eegnetv4':
        fe = nn.Sequential(*list(model.children())[:-4], nn.Flatten())
    elif model_name == 'tidnet':
        fe = nn.Sequential(model.dscnn)
    return fe

In [10]:
def cross_dataset_fusing_predict(
    model_name, source_dataset, target_dataset, selected_channels, srate, duration, events,
    kfold=5,
    save_folder='eegnet',
    preprocess_methods=['raw'],
    use_adabn=False,
    force_update=False):
    
    for preprocess_method in preprocess_methods:
        model_name_str = '{}-{}'.format(model_name, preprocess_method)
        
        model_file = os.path.join(
            save_folder, 
            "{}-{}-{}classes.joblib".format(
                source_dataset.dataset_code, 
                model_name_str, 
                len(events)))
        
        if use_adabn:
            save_file = os.path.join(
                save_folder,
                "{}->{}-{}-{}classes.joblib".format(
                    source_dataset.dataset_code,
                    target_dataset.dataset_code,
                    model_name_str+'-adabn-mekt',
                    len(events)))
        else:
            save_file = os.path.join(
                save_folder,
                "{}->{}-{}-{}classes.joblib".format(
                    source_dataset.dataset_code,
                    target_dataset.dataset_code,
                    model_name_str+'-mekt',
                    len(events)))
        
        if not force_update and os.path.exists(save_file):
            kfold_accs = joblib.load(save_file)['kfold_accs']
            print("Source: {} Target: {} {} Acc: {:.4f}".format(
                source_dataset.dataset_code,
                target_dataset.dataset_code,
                model_name_str+'-adabn-mekt' if use_adabn else model_name_str+'-mekt',
                np.mean(kfold_accs)))
            continue
            
        start_t = source_dataset.events['left_hand'][1][0]
        if start_t+ duration > source_dataset.events['left_hand'][1][1]:
            print("Warning: the current dataset avaliable trial duration is not long enough.")
        paradigm = MotorImagery(
            channels=selected_channels, 
            srate=srate, 
            intervals=[(start_t, start_t + duration)], 
            events=events)
        event_id = [source_dataset.events[e][0] for e in events]
        paradigm.register_raw_hook(raw_hook)

        Xs, ys, metas = paradigm.get_data(
            source_dataset, 
            subjects=source_dataset.subjects, 
            return_concat=True, 
            verbose=False)
        ys = label_encoder(ys, event_id)
        labels = np.unique(ys)

        set_random_seeds(38)
        indices = generate_kfold_indices(metas, kfold=kfold)
            
        model_states = joblib.load(model_file)['model_states']
        
        n_trials, n_channels, n_samples = Xs.shape
        n_classes = len(labels)
        x_dtype, y_dtype = torch.float, torch.long
        
        kfold_accs, kfold_cms = [], []

        set_random_seeds(42)
        for k in range(kfold):
            model = make_model(model_name, n_channels, n_samples, n_classes)
            model.load_state_dict(copy.deepcopy(model_states[k]))
            fe = make_feature_extractor(model_name, model)
            fe.eval()
            
            filterX = np.copy(Xs)
            filterY = np.copy(ys)
            filterX = preprocessing(filterX, metas, k, indices,
                        method=preprocess_method,
                        no_split=False)
            train_ind, valid_ind, test_ind = match_kfold_indices(k, metas, indices)
            train_ind = np.concatenate((train_ind, valid_ind))
            trainX, trainy = filterX[train_ind], filterY[train_ind]
            trainX = generate_tensors(trainX, dtype=x_dtype)
            source_features = fe(trainX).detach().numpy()
            
            if use_adabn:
                handles = []
                for module in fe.modules():
                    if 'BatchNorm' in module.__class__.__name__:
                        handles.append(
                            module.register_forward_pre_hook(batchnorm_pre_forward_hook))
            
            sub_accs, sub_cms = [], []
            for sub_id in target_dataset.subjects:
                start_t = target_dataset.events['left_hand'][1][0]
                if start_t+ duration > target_dataset.events['left_hand'][1][1]:
                    print("Warning: the current dataset avaliable trial duration is not long enough.")
                paradigm = MotorImagery(
                    channels=selected_channels, 
                    srate=srate, 
                    intervals=[(start_t, start_t + duration)], 
                    events=events)
                event_id = [target_dataset.events[e][0] for e in events]
                paradigm.register_raw_hook(raw_hook)

                Xt, yt, metat = paradigm.get_data(
                    target_dataset, 
                    subjects=[sub_id], 
                    return_concat=True, 
                    verbose=False)
                yt = label_encoder(yt, event_id)
                
                Xt = preprocessing(Xt, metat, None, None,
                        method=preprocess_method, 
                        no_split=True)
                
                torch.cuda.empty_cache()
                testX = generate_tensors(Xt, dtype=x_dtype)
                target_features = fe(testX).detach().numpy()
                
                mekt = MEKT(rho=5)
                sf, tf = mekt.fit_transform(source_features, trainy, target_features)
                clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')
                pred_labels = clf.fit(sf, trainy).predict(tf)
                true_labels = yt
                
                cm = confusion_matrix(true_labels, pred_labels, labels=labels, normalize='true')
                sub_accs.append(balanced_accuracy_score(true_labels, pred_labels))
                sub_cms.append(cm)
            kfold_accs.append(sub_accs)
            kfold_cms.append(sub_cms)

        print("Source: {} Target: {} {} Acc: {:.4f}".format(
            source_dataset.dataset_code,
            target_dataset.dataset_code,
            model_name_str+'-adabn-mekt' if use_adabn else model_name_str+'-mekt',
            np.mean(kfold_accs)))
        joblib.dump({
            'kfold_accs': kfold_accs, 'kfold_cms': kfold_cms}, 
            save_file)

In [13]:
srate = 128
datasets = [BNCI2014001(), PhysionetMI(), Weibo2014(), Cho2017()]
selected_channels = ['FZ', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'P1', 'PZ', 'P2', 'POZ']
duration = 3 # seconds
events = ['left_hand', 'right_hand']
preprocess_methods = ['euclid', 'riemann']
kfold = 5

model_names = ['eegnet', 'eegnetv4']
for model_name in model_names:
    for source_dataset in datasets:
        for target_dataset in datasets:
            if source_dataset.dataset_code == target_dataset.dataset_code:
                continue
            cross_dataset_fusing_predict(
                model_name, source_dataset, target_dataset, selected_channels, srate, duration, events,
                kfold=kfold, 
                save_folder=model_name,
                preprocess_methods=preprocess_methods,
                use_adabn=True,
                force_update=False)

Source: bnci2014001 Target: eegbci eegnet-euclid-adabn-mekt Acc: 0.6663
Source: bnci2014001 Target: eegbci eegnet-riemann-adabn-mekt Acc: 0.6680
Source: bnci2014001 Target: weibo2014 eegnet-euclid-adabn-mekt Acc: 0.7159
Source: bnci2014001 Target: weibo2014 eegnet-riemann-adabn-mekt Acc: 0.7142
Source: bnci2014001 Target: cho2017 eegnet-euclid-adabn-mekt Acc: 0.6077
Source: bnci2014001 Target: cho2017 eegnet-riemann-adabn-mekt Acc: 0.6051
Source: eegbci Target: bnci2014001 eegnet-euclid-adabn-mekt Acc: 0.7765
Source: eegbci Target: bnci2014001 eegnet-riemann-adabn-mekt Acc: 0.7764
Source: eegbci Target: weibo2014 eegnet-euclid-adabn-mekt Acc: 0.7247
Source: eegbci Target: weibo2014 eegnet-riemann-adabn-mekt Acc: 0.7372
Source: eegbci Target: cho2017 eegnet-euclid-adabn-mekt Acc: 0.6476
Source: eegbci Target: cho2017 eegnet-riemann-adabn-mekt Acc: 0.6437
Source: weibo2014 Target: bnci2014001 eegnet-euclid-adabn-mekt Acc: 0.7030
Source: weibo2014 Target: bnci2014001 eegnet-riemann-adabn-