In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib qt5

In [2]:
import warnings, copy, os
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import mne
import joblib

from sklearn.metrics import confusion_matrix, balanced_accuracy_score

from brainda.datasets import (
    PhysionetMI, BNCI2014001, Weibo2014, Cho2017)

from brainda.paradigms import MotorImagery
from brainda.algorithms.utils.model_selection import (
    set_random_seeds,
    generate_kfold_indices, match_kfold_indices)

from brainda.algorithms.transfer_learning import MEKT, choose_multiple_subjects
from brainda.algorithms.manifold import tangent_space, mean_riemann
from brainda.algorithms.utils.covariance import Covariance, invsqrtm
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from utils import *

In [3]:
def raw_hook(raw, caches, verbose=False):
    raw.filter(4, 40, l_trans_bandwidth=2, h_trans_bandwidth=5, phase='zero-double')
    return raw, caches

def mekt_transform_and_predict(
    source_dataset, target_dataset, selected_channels, srate, duration, events,
    kfold=5, 
    save_folder='mekt',
    force_update=False):
    os.makedirs(save_folder, exist_ok=True)
    save_file = "{}->{}-mekt-{}classes.joblib".format(
        source_dataset.dataset_code, 
        target_dataset.dataset_code, 
        len(events))
    
    if not force_update and os.path.exists(os.path.join(save_folder, save_file)):
        kfold_accs = joblib.load(os.path.join(save_folder, save_file))['kfold_accs']
        print("Source: {} Target: {} {} Acc: {:.4f}".format(
            source_dataset.dataset_code,
            target_dataset.dataset_code,
            'mekt',
            np.mean(kfold_accs)))
        return
    
    start_t = source_dataset.events['left_hand'][1][0]
    if start_t+ duration > source_dataset.events['left_hand'][1][1]:
        print("Warning: the current dataset avaliable trial duration is not long enough.")
    events = ['left_hand', 'right_hand']
    paradigm = MotorImagery(
        channels=selected_channels, 
        srate=srate, 
        intervals=[(start_t, start_t + duration)], 
        events=events)

    event_id = [source_dataset.events[e][0] for e in events]
    paradigm.register_raw_hook(raw_hook)
    Xs, ys, metas = paradigm.get_data(
        source_dataset, 
        subjects=source_dataset.subjects, 
        return_concat=True, 
        verbose=False)
    ys = label_encoder(ys, event_id)
    
    set_random_seeds(38)
    indices = generate_kfold_indices(metas, kfold=kfold)
    
    kfold_accs, kfold_cms = [], []
    set_random_seeds(42)

    for k in range(kfold):
        featureXs, featureYs, featureSubs = [], [], []    
        filterX, filterY = np.copy(Xs), np.copy(ys)
        filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)
        covest = Covariance(estimator='lwf')
        filterX = covest.transform(filterX)
        subjects = metas['subject'].to_numpy()
        for sub_id in source_dataset.subjects:
            sub_meta = metas[metas['subject']==sub_id]
            train_ind, validate_ind, test_ind = match_kfold_indices(k, sub_meta, indices)
            train_ind = np.concatenate((train_ind, validate_ind))
            M = mean_riemann(filterX[train_ind])
            iM12 = invsqrtm(M)
            Cs = iM12@filterX[train_ind]@iM12.T
            featureXs.append(tangent_space(Cs, np.eye(M.shape[0])))
            featureYs.append(filterY[train_ind])
            featureSubs.append(subjects[train_ind])
        featureXs = np.concatenate(featureXs, axis=0)
        featureYs = np.concatenate(featureYs, axis=0)
        featureSubs = np.concatenate(featureSubs, axis=0)

        sub_accs, sub_cms = [], []
        if source_dataset.dataset_code != target_dataset.dataset_code:
            for sub_id in target_dataset.subjects:
                start_t = target_dataset.events['left_hand'][1][0]
                if start_t+ duration > target_dataset.events['left_hand'][1][1]:
                    print("Warning: the current dataset avaliable trial duration is not long enough.")
                paradigm = MotorImagery(
                    channels=selected_channels, 
                    srate=srate, 
                    intervals=[(start_t, start_t + duration)],
                    events=events)
                event_id = [target_dataset.events[e][0] for e in events]

                paradigm.register_raw_hook(raw_hook)
                Xt, yt, metat = paradigm.get_data(target_dataset, subjects=[sub_id], return_concat=True, verbose=False)
                yt = label_encoder(yt, event_id)

                Xt = covest.transform(Xt)
                M = mean_riemann(Xt)
                iM12 = invsqrtm(M)
                Ct = iM12@Xt@iM12.T
                featureXt = tangent_space(Ct, np.eye(M.shape[0]))
                mekt = MEKT()
                source_features, target_features = mekt.fit_transform(featureXs, featureYs, featureXt)
                clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')

                pred_labels = clf.fit(source_features, featureYs).predict(target_features)
                true_labels = yt
                sub_accs.append(balanced_accuracy_score(true_labels, pred_labels))
                cm = confusion_matrix(true_labels, pred_labels, labels=np.unique(true_labels), normalize='true')
                sub_cms.append(cm)
        else:
            # leave-one-out scheme
            for sub_id in target_dataset.subjects:
                start_t = target_dataset.events['left_hand'][1][0]
                if start_t+ duration > target_dataset.events['left_hand'][1][1]:
                    print("Warning: the current dataset avaliable trial duration is not long enough.")
                paradigm = MotorImagery(
                    channels=selected_channels, 
                    srate=srate, 
                    intervals=[(start_t, start_t + duration)],
                    events=events)
                event_id = [target_dataset.events[e][0] for e in events]

                paradigm.register_raw_hook(raw_hook)
                Xt, yt, metat = paradigm.get_data(target_dataset, subjects=[sub_id], return_concat=True, verbose=False)
                yt = label_encoder(yt, event_id)

                Xt = covest.transform(Xt)
                M = mean_riemann(Xt)
                iM12 = invsqrtm(M)
                Ct = iM12@Xt@iM12.T
                featureXt = tangent_space(Ct, np.eye(M.shape[0]))
                mekt = MEKT()
                
                source_features, target_features = mekt.fit_transform(
                    featureXs[featureSubs!=sub_id], featureYs[featureSubs!=sub_id], featureXt)
                clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')

                pred_labels = clf.fit(source_features, featureYs[featureSubs!=sub_id]).predict(target_features)
                true_labels = yt
                sub_accs.append(balanced_accuracy_score(true_labels, pred_labels))
                cm = confusion_matrix(true_labels, pred_labels, labels=np.unique(true_labels), normalize='true')
                sub_cms.append(cm)
            
        kfold_accs.append(sub_accs)
        kfold_cms.append(sub_cms)
    print("Source: {} Target: {} {} Acc: {:.4f}".format(
        source_dataset.dataset_code,
        target_dataset.dataset_code,
        'mekt',
        np.mean(kfold_accs)))
    joblib.dump(
        {'kfold_accs': kfold_accs, "kfold_cms": kfold_cms}, 
        os.path.join(save_folder, save_file))

In [None]:
srate = 128
datasets = [BNCI2014001(), PhysionetMI(), Weibo2014(), Cho2017()]

selected_channels = [
    'FZ', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 
    'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 
    'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 
    'P1', 'PZ', 'P2', 'POZ']
duration = 3
events = ['left_hand', 'right_hand']
kfold = 5
save_folder = 'mekt'

for source_dataset in datasets:
    for target_dataset in datasets:
        mekt_transform_and_predict(
            source_dataset, target_dataset, selected_channels, srate, duration, events,
            save_folder=save_folder,
            force_update=True)

Source: bnci2014001 Target: bnci2014001 mekt Acc: 0.6979
Source: bnci2014001 Target: eegbci mekt Acc: 0.6358
Source: bnci2014001 Target: weibo2014 mekt Acc: 0.6863
Source: bnci2014001 Target: cho2017 mekt Acc: 0.5885
Source: eegbci Target: bnci2014001 mekt Acc: 0.7063
Source: eegbci Target: eegbci mekt Acc: 0.6640
Source: eegbci Target: weibo2014 mekt Acc: 0.6988
Source: eegbci Target: cho2017 mekt Acc: 0.6030
Source: weibo2014 Target: bnci2014001 mekt Acc: 0.6843
Source: weibo2014 Target: eegbci mekt Acc: 0.6372
Source: weibo2014 Target: weibo2014 mekt Acc: 0.6997
