In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib qt5

In [2]:
import warnings, copy, os
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import mne
import joblib

import matplotlib
matplotlib.use("Qt5Agg")
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid", context="paper")


from brainda.datasets import (
    PhysionetMI, BNCI2014001, Weibo2014, Cho2017)

In [3]:
import itertools
import scipy

def compute_pvals_wilcoxon(scores, order=None):
    '''Returns kxk matrix of p-values computed via the Wilcoxon rank-sum test,
    order defines the order of rows and columns

    df: DataFrame, samples are index, columns are pipelines, and values are
    scores

    order: list of length (num algorithms) with names corresponding to columns
    of df

    '''
    n_algo = scores.shape[1]
    out = np.zeros((n_algo, n_algo))
    for i in range(n_algo):
        for j in range(n_algo):
            if i != j:
                p = scipy.stats.wilcoxon(scores[:, i], scores[:, j])[1]
                p /= 2
                # we want the one-tailed p-value
                diff = (scores[:, i]-scores[:, j]).mean()
                if diff < 0:
                    p = 1 - p  # was in the other side of the distribution
                out[i, j] = p
    return out


def _pairedttest_exhaustive(data):
    '''Returns p-values for exhaustive ttest that runs through all possible
    permutations of the first dimension. Very bad idea for size greater than 12

    data is a (subj, alg, alg) matrix of differences between scores for each
    pair of algorithms per subject

    '''
    out = np.ones((data.shape[1], data.shape[1]))
    true = data.sum(axis=0)
    nperms = 2**data.shape[0]
    for perm in itertools.product([-1, 1], repeat=data.shape[0]):
        # turn into numpy array
        perm = np.array(perm)
        # multiply permutation by subject dimension and sum over subjects
        randperm = (data * perm[:, None, None]).sum(axis=0)
        # compare to true difference (numpy autocasts bool to 0/1)
        out += (randperm > true)
    out = out / nperms
    # control for cases where pval is 1
    out[out == 1] = 1 - (1 / nperms)
    return out


def _pairedttest_random(data, nperms):
    '''Returns p-values based on nperms permutations of a paired ttest

    data is a (subj, alg, alg) matrix of differences between scores for each
    pair of algorithms per subject
    '''
    out = np.ones((data.shape[1], data.shape[1]))
    true = data.sum(axis=0)
    for i in range(nperms):
        perm = np.random.randint(2, size=(data.shape[0],))
        perm[perm == 0] = -1
        # multiply permutation by subject dimension and sum over subjects
        randperm = (data * perm[:, None, None]).sum(axis=0)
        # compare to true difference (numpy autocasts bool to 0/1)
        out += (randperm > true)
    out[out == nperms] = nperms - 1
    return out / nperms


def compute_pvals_perm(scores):
    '''Returns kxk matrix of p-values computed via permutation test,
    order defines the order of rows and columns

    df: DataFrame, samples are index, columns are pipelines, and values are
    scores

    order: list of length (num algorithms) with names corresponding to columns
    of df

    '''
    # reshape df into matrix (sub, k, k) of differences
    n_sub, n_algo = scores.shape[0], scores.shape[1]
    data = np.zeros((n_sub, n_algo, n_algo))
    for i in range(n_algo):
        for j in range(i + 1, n_algo):
            data[:, i, j] = scores[:, i] - scores[:, j]
            data[:, j, i] = scores[:, j] - scores[:, i]
    if n_sub > 13:
        p = _pairedttest_random(data, 10000)
    else:
        p = _pairedttest_exhaustive(data)
    return p

def compute_pvals(scores, perm_cutoff=20):
    if len(scores) < perm_cutoff:
        p = compute_pvals_perm(scores)
    else:
        p = compute_pvals_wilcoxon(scores)
    return p

## Within-subject Results

In [4]:
kfold = 5
events = ['left_hand', 'right_hand']
datasets = [BNCI2014001(), PhysionetMI(), Weibo2014(), Cho2017()]
# model_names = ['shallownet', 'shallowfbcspnet', 'eegnet', 'eegnetv4', 'tidnet']
model_names = ['shallownet', 'shallowfbcspnet', 'eegnet', 'eegnetv4']
preprocess_methods = ['raw', 'cnorm', 'tnorm', 'euclid', 'riemann']
pval = 1e-3

In [5]:
results = pd.DataFrame(
    columns=['dataset', 'subject', 'accuracy', 'model', 'method'])

In [6]:
use_adabn = False
for model_name in model_names:
    for preprocess_method in preprocess_methods:
        for dataset in datasets:
            source_dataset = dataset
            target_dataset = dataset
            save_folder = model_name
            model_name_str = '{}-{}'.format(model_name, preprocess_method)

            if use_adabn:
                save_file = os.path.join(
                    save_folder,
                    "{}->{}-{}-{}classes.joblib".format(
                        source_dataset.dataset_code,
                        target_dataset.dataset_code,
                        model_name_str+'-adabn',
                        len(events)))
            else:
                save_file = os.path.join(
                    save_folder,
                    "{}->{}-{}-{}classes.joblib".format(
                        source_dataset.dataset_code,
                        target_dataset.dataset_code,
                        model_name_str,
                        len(events)))    
            kfold_accs = joblib.load(save_file)['kfold_accs']
            sub_accs = np.mean(kfold_accs, 0)  
            
            for i, sub_id in enumerate(dataset.subjects):
                results = results.append({
                    'dataset': dataset.dataset_code,
                    'subject': sub_id,
                    'accuracy': sub_accs[i],
                    'model': model_name,
                    'method': preprocess_method
                }, ignore_index=True)

In [7]:
use_adabn = True
for model_name in model_names:
    for preprocess_method in preprocess_methods:
        for dataset in datasets:
            source_dataset = dataset
            target_dataset = dataset
            save_folder = model_name
            model_name_str = '{}-{}'.format(model_name, preprocess_method)

            if use_adabn:
                save_file = os.path.join(
                    save_folder,
                    "{}->{}-{}-{}classes.joblib".format(
                        source_dataset.dataset_code,
                        target_dataset.dataset_code,
                        model_name_str+'-adabn',
                        len(events)))
            else:
                save_file = os.path.join(
                    save_folder,
                    "{}->{}-{}-{}classes.joblib".format(
                        source_dataset.dataset_code,
                        target_dataset.dataset_code,
                        model_name_str,
                        len(events)))    
            kfold_accs = joblib.load(save_file)['kfold_accs']
            sub_accs = np.mean(kfold_accs, 0)  
            
            for i, sub_id in enumerate(dataset.subjects):
                results = results.append({
                    'dataset': dataset.dataset_code,
                    'subject': sub_id,
                    'accuracy': sub_accs[i],
                    'model': model_name,
                    'method': preprocess_method+'-adabn'
                }, ignore_index=True)

In [8]:
dataset_name = 'bnci2014001'
results[results['dataset']==dataset_name].groupby(['model', 'method']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,accuracy
model,method,Unnamed: 2_level_1
eegnet,cnorm,0.828927
eegnet,cnorm-adabn,0.816981
eegnet,euclid,0.892105
eegnet,euclid-adabn,0.885454
eegnet,raw,0.827518
eegnet,raw-adabn,0.819116
eegnet,riemann,0.893952
eegnet,riemann-adabn,0.887411
eegnet,tnorm,0.824056
eegnet,tnorm-adabn,0.822085


In [9]:
model_order = model_names
hue_order = ['raw', 'raw-adabn', 
             'cnorm', 'cnorm-adabn', 
             'tnorm', 'tnorm-adabn', 
             'euclid', 'euclid-adabn', 
             'riemann', 'riemann-adabn']
# names = ['ShallowNet', 'ShallowFBCSPNet', 'EEGNet', 'EEGNetv4', 'TIDNet']
names = ['ShallowNet', 'ShallowFBCSPNet', 'EEGNet', 'EEGNetv4']
os.makedirs('images', exist_ok=True)

for dataset in datasets:
    with sns.plotting_context('paper', font_scale=2):
        with sns.axes_style('whitegrid'):
            g = sns.catplot(
                x='model', 
                y='accuracy', 
                hue='method', 
                order=['shallownet', 'shallowfbcspnet', 'eegnet', 'eegnetv4'],
                hue_order=hue_order,
                data=results[results['dataset']==dataset.dataset_code], 
                kind='bar', 
                height=5, aspect=2, ci='sd', 
                palette=sns.color_palette('Paired', n_colors=10), errcolor='#858585',
                edgecolor=".2", lw=1,capsize=0.02, 
                zorder=5)
            g.despine(left=True)
            g.set_axis_labels(y_var='accuracy')
            g.set(ylim=(0.5, 1))
            g.set(yticks=np.arange(0.5, 1.05, 0.1))
            g._legend.set_title("")
            plt.setp(g._legend.get_texts(), fontsize='12')
            ax = plt.gca()
            xticks = ax.get_xticks()
            plt.xticks(xticks, names)
            g.savefig("images/within-subject-{}.jpg".format(dataset.dataset_code), format='jpg', dpi=300)

significance analysis

In [10]:
def show_p(all_scores, pval=1e-3):
    Ps, n_subs = [], []
    for scores in all_scores:
        scores = np.array(scores).T
        p = compute_pvals(scores)
        Ps.append(p)
        n_sub = len(scores)
        n_subs.append(n_sub)

    weights = np.sqrt(np.array(n_subs))

    Ps = np.array(Ps)

    P = np.zeros((Ps.shape[1], Ps.shape[1]))
    for i in range(Ps.shape[1]):
        for j in range(Ps.shape[1]):
            P[i, j] = scipy.stats.combine_pvalues(Ps[:, i, j], weights=weights, method='stouffer')[1]
    ind = np.diag_indices(Ps.shape[1], Ps.shape[1])
    P[ind[0], ind[1]] = np.NaN
    P[P>pval] = np.NaN
    return P

In [11]:
methods = preprocess_methods + [preprocess_method+'-adabn' for preprocess_method in preprocess_methods]
scores = []
for dataset in datasets:
    for model_name in model_names:
        inner_scores = []
        for method in methods:
            score = results[
                (results['dataset']==dataset.dataset_code) 
                & (results['model']==model_name) 
                & (results['method']==method)].accuracy.to_numpy()
            inner_scores.append(score)
        scores.append(inner_scores)

P = show_p(scores, pval=pval)

In [12]:
with sns.plotting_context('paper', font_scale=2):
    with sns.axes_style('darkgrid'):
        f, ax = plt.subplots(figsize=(9, 9))
        sns.heatmap(P, vmin=0, vmax=1, 
            annot=False, fmt="1.0e", linewidths=0.1, ax=ax, linecolor='k',
            cmap=sns.color_palette('Reds_r'), cbar=False, square=True)

        ax.xaxis.tick_top()
        xticks = ax.get_xticks()
        plt.xticks(xticks, methods, rotation=45, ha='left')
        plt.yticks(ax.get_yticks(), methods, rotation=0)
        ax.tick_params(axis='both', which='both', length=0)

        b, t = plt.ylim()
        b += 0.5
        t -= 0.5
        plt.ylim(b, t)
        ax.xaxis.set_label_position('top')
        plt.tight_layout()
        plt.savefig("images/within_subject_pvalues_single_tail.jpg", format='jpg', dpi=300)

In [13]:
scores = []
for method in ['raw', 'euclid', 'riemann']:
    score = results[
        (results['dataset']=='eegbci') 
        & (results['model']=='eegnet') 
        & (results['method']==method)].accuracy.to_numpy()
    scores.append(score)

In [14]:
a, b, c = scores[0], scores[1], scores[2]
ind = ((a>b) | (a>c)) & (a>0.5)
outliers = np.nonzero(ind)[0]+1
print("Number of outliers: {}".format(len(outliers)))
data = results[(results['dataset']=='eegbci') 
               & (results['model']=='eegnet') 
               & ((results['method']=='raw') | (results['method']=='riemann') | (results['method']=='euclid'))]

with sns.plotting_context('paper', font_scale=1.5):
    with sns.axes_style('whitegrid'):
        with sns.color_palette('Paired'):
            g = sns.catplot(x='subject', y='accuracy', hue='method', data=data.loc[data['subject'].isin(outliers)], kind='bar', 
                height=5, aspect=2.3, ci='sd', palette=sns.color_palette("tab10"))
            g.despine(left=True)
            g.set_axis_labels(y_var='accuracy')
            g.set(ylim=(0, 1))
            g.set(yticks=np.arange(0, 1.1, 0.1))
#             g.set(xticklabels=dataset_names)
            g._legend.set_title("")
            g.savefig("images/phsionet_within_subject_outliers1.jpg", format='jpg', dpi=300)

Number of outliers: 35


In [15]:
ind = a<0.5

outliers = np.nonzero(ind)[0]+1
print("Number of outliers: {}".format(len(outliers)))
data = results[(results['dataset']=='eegbci') 
               & (results['model']=='eegnet') 
               & ((results['method']=='raw') | (results['method']=='riemann') | (results['method']=='euclid'))]

with sns.plotting_context('paper', font_scale=1.5):
    with sns.axes_style('whitegrid'):
        with sns.color_palette('Paired'):
            g = sns.catplot(x='subject', y='accuracy', hue='method', data=data.loc[data['subject'].isin(outliers)], kind='bar', 
                height=5, aspect=0.8, ci='sd', legend=False, palette=sns.color_palette("tab10"))
            g.despine(left=True)
            g.set_axis_labels(y_var='accuracy')
            g.set(ylim=(0, 1))
            g.set(yticks=np.arange(0, 1.1, 0.1))
#             g.set(xticklabels=dataset_names)
#             g._legend.set_title("")
            g.savefig("images/phsionet_within_subject_outliers2.jpg", format='jpg', dpi=300)

Number of outliers: 5


## Cross-subject Results

In [16]:
results = pd.DataFrame(
    columns=['source_dataset', 'target_dataset', 'subject', 'accuracy', 'model', 'method'])
for model_name in model_names:
    for method in methods:
        for source_dataset in datasets:
            for target_dataset in datasets:
                if source_dataset.dataset_code == target_dataset.dataset_code:
                    continue
                save_folder = model_name
                model_name_str = '{}-{}'.format(model_name, method)

                save_file = os.path.join(
                    save_folder,
                    "{}->{}-{}-{}classes.joblib".format(
                        source_dataset.dataset_code,
                        target_dataset.dataset_code,
                        model_name_str,
                        len(events)))    
                kfold_accs = joblib.load(save_file)['kfold_accs']
                sub_accs = np.mean(kfold_accs, axis=0)  
                
                results = results.append(pd.DataFrame.from_dict({
                    'source_dataset': source_dataset.dataset_code,
                    'target_dataset': target_dataset.dataset_code,
                    'subject': target_dataset.subjects,
                    'accuracy': sub_accs,
                    'model': model_name,
                    'method': method,
                }), ignore_index=True)

In [17]:
model_name = 'eegnetv4'
source_dataset = 'eegbci'

tmp = results[(results['model']==model_name)&(results['source_dataset']==source_dataset)]
tmp.groupby(['target_dataset', 'method']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,accuracy
target_dataset,method,Unnamed: 2_level_1
bnci2014001,cnorm,0.746373
bnci2014001,cnorm-adabn,0.752623
bnci2014001,euclid,0.784568
bnci2014001,euclid-adabn,0.788272
bnci2014001,raw,0.746065
bnci2014001,raw-adabn,0.768904
bnci2014001,riemann,0.779938
bnci2014001,riemann-adabn,0.786651
bnci2014001,tnorm,0.739043
bnci2014001,tnorm-adabn,0.753858


significance analysis

In [18]:
scores = []
for dataset in datasets:
    for model_name in model_names:
        inner_scores = []
        for method in methods:
            score = results[
                (results['source_dataset']==dataset.dataset_code) 
                & (results['model']==model_name) 
                & (results['method']==method)].accuracy.to_numpy()
            inner_scores.append(score)
        scores.append(inner_scores)

P = show_p(scores, pval=pval)

In [19]:
with sns.plotting_context('paper', font_scale=2):
    with sns.axes_style('darkgrid'):
        f, ax = plt.subplots(figsize=(9, 9))
        sns.heatmap(P, vmin=0, vmax=1, 
            annot=False, fmt="1.0e", linewidths=0.1, ax=ax, linecolor='k',
            cmap=sns.color_palette('Reds_r'), cbar=False, square=True)

        ax.xaxis.tick_top()
        xticks = ax.get_xticks()
#         xticks += 0.1
        plt.xticks(xticks, methods, rotation=45, ha='left')
        plt.yticks(ax.get_yticks(), methods, rotation=0)
        ax.tick_params(axis='both', which='both', length=0)

        b, t = plt.ylim()
        b += 0.5
        t -= 0.5
        plt.ylim(b, t)
        ax.xaxis.set_label_position('top')
        plt.tight_layout()
        plt.savefig("images/cross_dataset_pvalues_single_tail.jpg", format='jpg', dpi=300)

model analyses

In [20]:
from brainda.algorithms.deep_learning import EEGNet

model = EEGNet(22, 384, 2,
            time_kernel=(8, (1, 64), (1, 1)),
            D=2,
            pool_kernel1=((1, 4), (1, 4)),
            separa_kernel=(16, (1, 16), (1, 1)),
            pool_kernel2=((1, 8), (1, 8)),
            depthwise_norm_rate=1,
            fc_norm_rate=0.25,
            dropout_rate=0.5)

models = [copy.deepcopy(model) for _ in range(len(datasets))]

save_folder = 'eegnet'
model_name = 'eegnet-{}'.format('raw')
for i, source in enumerate(datasets):
    file_name = "{}-{}-2classes.joblib".format(source.dataset_code, model_name)
    model_file = os.path.join(save_folder, file_name)
    model_states = joblib.load(model_file)['model_states']
    models[i].load_state_dict(copy.deepcopy(model_states[0]))

weights = np.stack(
    [np.squeeze(models[i].step1.time_conv.weight.detach().numpy()) for i in range(len(datasets))
    ])

Rs = []
for i in range(len(weights)):
    R = []
    for j in range(len(weights)):
        R.append(np.corrcoef(weights[i], weights[j])[:len(weights[i]), len(weights[i]):len(weights[i])+len(weights[j])])
    R = np.concatenate(R, axis=-1)
    Rs.append(R)

Rs = np.concatenate(Rs, axis=0)

level = 0.5
Rs[np.logical_and(Rs<level, Rs>-level)] = np.NaN
ix = np.diag_indices(len(Rs))
Rs[ix[0], ix[1]] = np.NaN

In [22]:
import matplotlib.patches as patches

with sns.plotting_context('paper', font_scale=2):
    with sns.axes_style(None):
        f, ax = plt.subplots(figsize=(10, 8))
        sns.heatmap(Rs, vmin=-1, vmax=1, 
            annot=False, fmt=".2f", linewidths=0.5, linecolor='#dbdbdb', ax=ax, square=True,
            cmap=sns.color_palette('RdBu_r', n_colors=20), cbar=True)

        ax.xaxis.tick_top()
        labels = ['' for _ in range(32)]
        labels[4] = 'BNCI2014001'
        labels[12] = 'PhysionetMI'
        labels[20] = 'Weibo2014'
        labels[28] = 'Cho2017'
        plt.xticks(ticks=np.arange(0, 32), labels=labels, rotation=0)
        plt.yticks(ticks=np.arange(0, 32), labels=labels, rotation=0)

        ax.tick_params(axis='both', which='both', length=0)
        ax.plot([8, 8], [0, 32], color='k')
        ax.plot([16, 16], [0, 32], color='k')
        ax.plot([24, 24], [0, 32], color='k')

        ax.plot([0, 32], [8, 8], color='k')
        ax.plot([0, 32], [16, 16], color='k')
        ax.plot([0, 32], [24, 24], color='k')
        b, t = plt.ylim()
        b += 1
        t -= 0.2
        plt.ylim(b, t)
        plt.tight_layout()
    
        rect = patches.Rectangle((7.8, 1), 24.2, 1, linewidth=1.5, edgecolor='none', facecolor='r', alpha=0.2)
        ax.add_patch(rect)
        rect = patches.Rectangle((7.8, 4), 24.2, 1, linewidth=1.5, edgecolor='none', facecolor='r', alpha=0.2)
        ax.add_patch(rect)
        rect = patches.Rectangle((7.8, 6), 24.2, 1, linewidth=1.5, edgecolor='none', facecolor='r', alpha=0.2)
        ax.add_patch(rect)       
    
        plt.savefig("images/eegnet_raw_temporal_weight_correlation.jpg", format='jpg', dpi=300)

In [23]:
from scipy.signal import freqz

inds = [
    [1, 2, 0, 0],
    [4, 6, 0, 2],
    [6, 5, 2, 0]]

srate = 128
selected_channels = ['FZ', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'P1', 'PZ', 'P2', 'POZ']
names = ['BNCI2014001', 'PhysioneMI', 'Weibo2014', 'Cho2017']

fig, axes = plt.subplots(len(inds), 1,figsize=(12, 9))
for iind in range(len(inds)):
    for i, ix in enumerate(inds[iind]):
        if ix is None:
            continue
        w, h = freqz(weights[i][ix], fs=srate)
        ia, ib = np.argmin(np.abs(w-4)), np.argmin(np.abs(w-40))
        axes[iind].plot(w[ia:ib], 20 * np.log10(abs(h[ia:ib])), label='{}-{}'.format(names[i], ix), alpha=0.8, linewidth=2)
    axes[iind].set_ylabel('Amplitude [dB]')
    axes[iind].set_xlabel('Frequency [Hz]')
    for item in ([axes[iind].title, axes[iind].xaxis.label, axes[iind].yaxis.label] +
                 axes[iind].get_xticklabels() + axes[iind].get_yticklabels()):
        item.set_fontsize(15)
    axes[iind].legend(fontsize=12, loc='lower left')
    axes[iind].set_xlim((4, 40))
    axes[iind].set_ylim((-60, 30))
    plt.tight_layout()
plt.savefig("images/eegnet_raw_freqz.jpg", format='jpg', dpi=300)

In [24]:
for iind in range(len(inds)):
    montage = mne.channels.make_standard_montage('standard_1005')
    info = mne.create_info(ch_names=selected_channels, sfreq=250, ch_types='eeg')
    info.set_montage(montage, match_case=False)

    plt.rcParams.update({'font.size': 10})

    fig, axes = plt.subplots(2, 4, figsize=(12, 4))

    for i, ix in enumerate(inds[iind]):
        if ix is None:
            continue
        spatial_weights = np.squeeze(models[i].step2.depthwise_conv.weight.detach().numpy())   
        axes[0, i].set_title("{}-{}".format(names[i], ix), fontsize=16)
        mne.viz.plot_topomap(
            spatial_weights[2*ix], info, show=True, names=selected_channels, show_names=True, 
            contours=False, extrapolate='head', axes=axes[0, i])
        mne.viz.plot_topomap(
            spatial_weights[2*ix+1], info, show=True, names=selected_channels, show_names=True, 
            contours=False, extrapolate='head', axes=axes[1, i])
        plt.tight_layout()
        plt.savefig("images/eegnet_raw_topomap_{}.jpg".format(iind), format='jpg', dpi=300)

In [25]:
from brainda.algorithms.deep_learning import EEGNet

model = EEGNet(22, 384, 2,
            time_kernel=(8, (1, 64), (1, 1)),
            D=2,
            pool_kernel1=((1, 4), (1, 4)),
            separa_kernel=(16, (1, 16), (1, 1)),
            pool_kernel2=((1, 8), (1, 8)),
            depthwise_norm_rate=1,
            fc_norm_rate=0.25,
            dropout_rate=0.5)

models = [copy.deepcopy(model) for _ in range(len(datasets))]

save_folder = 'eegnet'
model_name = 'eegnet-{}'.format('euclid')
for i, source in enumerate(datasets):
    file_name = "{}-{}-2classes.joblib".format(source.dataset_code, model_name)
    model_file = os.path.join(save_folder, file_name)
    model_states = joblib.load(model_file)['model_states']
    models[i].load_state_dict(copy.deepcopy(model_states[0]))

weights = np.stack(
    [np.squeeze(models[i].step1.time_conv.weight.detach().numpy()) for i in range(len(datasets))
    ])

Rs = []
for i in range(len(weights)):
    R = []
    for j in range(len(weights)):
        R.append(np.corrcoef(weights[i], weights[j])[:len(weights[i]), len(weights[i]):len(weights[i])+len(weights[j])])
    R = np.concatenate(R, axis=-1)
    Rs.append(R)

Rs = np.concatenate(Rs, axis=0)

level = 0.5
Rs[np.logical_and(Rs<level, Rs>-level)] = np.NaN
ix = np.diag_indices(len(Rs))
Rs[ix[0], ix[1]] = np.NaN

In [26]:
import matplotlib.patches as patches

with sns.plotting_context('paper', font_scale=1.7):
    with sns.axes_style(None):
        f, ax = plt.subplots(figsize=(10, 8))
        sns.heatmap(Rs, vmin=-1, vmax=1, 
            annot=False, fmt=".2f", linewidths=0.5, linecolor='#dbdbdb', ax=ax, square=True,
            cmap=sns.color_palette('RdBu_r', n_colors=20), cbar=True)

        ax.xaxis.tick_top()
        labels = ['' for _ in range(32)]
        labels[4] = 'BNCI2014001'
        labels[12] = 'PhysionetMI'
        labels[20] = 'Weibo2014'
        labels[28] = 'Cho2017'
        plt.xticks(ticks=np.arange(0, 32), labels=labels, rotation=0)
        plt.yticks(ticks=np.arange(0, 32), labels=labels, rotation=0)
        ax.tick_params(axis='both', which='both', length=0)
        ax.plot([8, 8], [0, 32], color='k')
        ax.plot([16, 16], [0, 32], color='k')
        ax.plot([24, 24], [0, 32], color='k')

        ax.plot([0, 32], [8, 8], color='k')
        ax.plot([0, 32], [16, 16], color='k')
        ax.plot([0, 32], [24, 24], color='k')
        b, t = plt.ylim()
        b += 1
        t -= 0.2
        plt.ylim(b, t)
#         ax.xaxis.set_label_position('top')
        plt.tight_layout()  
    
        rect = patches.Rectangle((7.8, 0), 24.2, 1, linewidth=1.5, edgecolor='none', facecolor='r', alpha=0.2)
        ax.add_patch(rect)
        rect = patches.Rectangle((7.8, 2), 24.2, 1, linewidth=1.5, edgecolor='none', facecolor='r', alpha=0.2)
        ax.add_patch(rect)
        rect = patches.Rectangle((7.8, 3), 24.2, 1, linewidth=1.5, edgecolor='none', facecolor='r', alpha=0.2)
        ax.add_patch(rect)
    
        plt.savefig("images/eegnet_euclid_temporal_weight_correlation.jpg", format='jpg', dpi=300)

In [27]:
from scipy.signal import freqz

inds = [
    [0, 7, 4, 7],
    [2, 2, None, 1],
    [3, 4, None, 2]]

srate = 128
selected_channels = ['FZ', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'P1', 'PZ', 'P2', 'POZ']
names = ['BNCI2014001', 'PhysioneMI', 'Weibo2014', 'Cho2017']

fig, axes = plt.subplots(len(inds), 1,figsize=(12, 9))
for iind in range(len(inds)):
    for i, ix in enumerate(inds[iind]):
        if ix is None:
            continue
        w, h = freqz(weights[i][ix], fs=srate)
        ia, ib = np.argmin(np.abs(w-4)), np.argmin(np.abs(w-40))
        axes[iind].plot(w[ia:ib], 20 * np.log10(abs(h[ia:ib])), label='{}-{}'.format(names[i], ix), alpha=0.8, linewidth=2)
    axes[iind].set_ylabel('Amplitude [dB]')
    axes[iind].set_xlabel('Frequency [Hz]')
    for item in ([axes[iind].title, axes[iind].xaxis.label, axes[iind].yaxis.label] +
                 axes[iind].get_xticklabels() + axes[iind].get_yticklabels()):
        item.set_fontsize(15)
    axes[iind].legend(fontsize=12, loc='lower left')
    axes[iind].set_xlim((4, 40))
    axes[iind].set_ylim((-60, 30))
    plt.tight_layout()
plt.savefig("images/eegnet_euclid_freqz.jpg", format='jpg', dpi=300)

In [28]:
for iind in range(len(inds)):
    montage = mne.channels.make_standard_montage('standard_1005')
    info = mne.create_info(ch_names=selected_channels, sfreq=250, ch_types='eeg')
    info.set_montage(montage, match_case=False)

    plt.rcParams.update({'font.size': 10})

    fig, axes = plt.subplots(2, 4, figsize=(12, 4))

    for i, ix in enumerate(inds[iind]):
        if ix is None:
            axes[0, i].axis('off')
            axes[1, i].axis('off')
            continue
        spatial_weights = np.squeeze(models[i].step2.depthwise_conv.weight.detach().numpy())   
        axes[0, i].set_title("{}-{}".format(names[i], ix), fontsize=16)
        mne.viz.plot_topomap(
            spatial_weights[2*ix], info, show=True, names=selected_channels, show_names=True, 
            contours=False, extrapolate='head', axes=axes[0, i])
        mne.viz.plot_topomap(
            spatial_weights[2*ix+1], info, show=True, names=selected_channels, show_names=True, 
            contours=False, extrapolate='head', axes=axes[1, i])
        plt.tight_layout()
        plt.savefig("images/eegnet_euclid_topomap_{}.jpg".format(iind), format='jpg', dpi=300)

## MEKT and Networks

In [29]:
for source_dataset in datasets:
    for target_dataset in datasets:
        if source_dataset.dataset_code == target_dataset.dataset_code:
            continue
        save_folder = 'mekt'
        save_file = os.path.join(
            save_folder,
            "{}->{}-{}-{}classes.joblib".format(
                source_dataset.dataset_code,
                target_dataset.dataset_code,
                'mekt',
                len(events)))    
        kfold_accs = joblib.load(save_file)['kfold_accs']
        sub_accs = np.mean(kfold_accs, axis=0)  

        results = results.append(pd.DataFrame.from_dict({
            'source_dataset': source_dataset.dataset_code,
            'target_dataset': target_dataset.dataset_code,
            'subject': target_dataset.subjects,
            'accuracy': sub_accs,
            'model': 'mekt',
            'method': '',
        }), ignore_index=True)

In [30]:
score_mats = []
# model_names = ['mekt', 'shallownet', 'shallowfbcspnet', 'eegnet', 'eegnetv4', 'tidnet']
model_names = ['mekt', 'shallownet', 'shallowfbcspnet', 'eegnet', 'eegnetv4']

for model_name in model_names:
    score_mat = np.zeros((len(datasets), len(datasets)))
    for i, source_dataset in enumerate(datasets):
        for j, target_dataset in enumerate(datasets):
            if source_dataset.dataset_code == target_dataset.dataset_code:
                score_mat[i, j] = np.NaN
                
            if model_name != 'mekt':
                score = results[
                    (results['source_dataset']==source_dataset.dataset_code) 
                    & (results['target_dataset']==target_dataset.dataset_code)
                    & (results['model']==model_name) 
                    & (results['method']=='euclid-adabn')].accuracy.to_numpy()
                score_mat[i, j] = np.mean(score)
            else:
                score = results[
                    (results['source_dataset']==source_dataset.dataset_code) 
                    & (results['target_dataset']==target_dataset.dataset_code) 
                    & (results['model']=='mekt')].accuracy.to_numpy()
                score_mat[i, j] = np.mean(score)
    score_mats.append(score_mat)
                
for model_name in model_names[1:]:
    score_mat = np.zeros((len(datasets), len(datasets)))
    for i, source_dataset in enumerate(datasets):
        for j, target_dataset in enumerate(datasets):
            if source_dataset.dataset_code == target_dataset.dataset_code:
                score_mat[i, j] = np.NaN
                
            if model_name != 'mekt':
                score = results[
                    (results['source_dataset']==source_dataset.dataset_code) 
                    & (results['target_dataset']==target_dataset.dataset_code)
                    & (results['model']==model_name) 
                    & (results['method']=='riemann-adabn')].accuracy.to_numpy()
                score_mat[i, j] = np.mean(score)
            else:
                score = results[
                    (results['source_dataset']==source_dataset.dataset_code) 
                    & (results['target_dataset']==target_dataset.dataset_code) 
                    & (results['model']=='mekt')].accuracy.to_numpy()
                score_mat[i, j] = np.mean(score)
    score_mats.append(score_mat)

score_mats = np.array(score_mats)

In [31]:
names = ['BNCI2014001', 'PhysionetMI', 'Weibo2014', 'Cho2017']
# model_name_strs = ['ShallowNet', 'ShallowFBCSPNet', 'EEGNet', 'EEGNetv4', 'TIDNet']
model_name_strs = ['ShallowNet', 'ShallowFBCSPNet', 'EEGNet', 'EEGNetv4']
with sns.plotting_context('paper', font_scale=2):
    with sns.axes_style('darkgrid'):
        fig, axs = plt.subplots(1, 1, figsize=(6, 6))
        sns.heatmap(score_mats[0]*100, vmin=50, vmax=100, 
            annot=True, fmt=".1f", linewidths=0.1, ax=axs, 
            cmap=sns.color_palette('Blues', n_colors=10), cbar=False, square=True)
        axs.xaxis.tick_top()
#         axs.set_xticklabels(names)
#         axs.set_yticklabels(names)
        axs.tick_params(axis='both', which='both', length=0)
        plt.xticks(axs.get_xticks(), names, rotation=45, ha='left')
        plt.yticks(axs.get_yticks(), names, rotation=0)
#         plt.setp(axs.get_xticklabels(), rotation=90)
#         plt.setp(axs.get_yticklabels(), rotation=0)
        b, t = axs.get_ylim()
        b += 0.5
        t -= 0.5
        axs.set_ylim(b, t)
        axs.set_title('MEKT', loc='center')
        plt.tight_layout()
        plt.savefig("images/mekt_cross_dataset_accuracy.jpg", format='jpg', dpi=300)
        
        fig, axs = plt.subplots(2, 4, figsize=(16, 8))
#         axs[0, 0].axis('off')
#         axs[0, 1].axis('off')
#         axs[0, 3].axis('off')
#         axs[0, 4].axis('off')
        
        # mekt
#         sns.heatmap(score_mats[0]*100, vmin=50, vmax=100, 
#             annot=True, fmt=".1f", linewidths=0.1, ax=axs[0, 2], 
#             cmap=sns.color_palette('Blues', n_colors=10), cbar=False, square=True)
#         axs[0, 2].xaxis.tick_top()
#         axs[0, 2].set_xticklabels(names)
#         axs[0, 2].set_yticklabels(names)
#         axs[0, 2].tick_params(axis='both', which='both', length=0)
#         plt.setp(axs[0, 2].get_xticklabels(), rotation=90)
#         plt.setp(axs[0, 2].get_yticklabels(), rotation=0)
#         b, t = axs[0, 2].get_ylim()
#         b += 0.5
#         t -= 0.5
#         axs[0, 2].set_ylim(b, t)
#         axs[0, 2].set_title('MEKT', loc='center')
        
        vmax = 10
        for i in range(len(model_names[1:])):
            sns.heatmap(
                score_mats[i+1]*100-score_mats[0]*100, 
                vmin=-vmax, vmax=vmax, 
                annot=True, fmt=".1f", linewidths=0.1, ax=axs[0, i], 
                cmap=sns.color_palette('RdBu_r', n_colors=10), cbar=False, square=True)
#             axs[1, i].axis('off')
            b, t = axs[0, i].get_ylim()
            b += 0.5
            t -= 0.5
            axs[0, i].set_ylim(b, t)
            axs[0, i].set_title(model_name_strs[i], loc='center')
            axs[0, i].get_xaxis().set_visible(False)
            axs[0, i].get_yaxis().set_visible(False)
            if i == 0:
                axs[0, i].get_yaxis().set_visible(True)
                axs[0, i].set_ylabel('EA')
                axs[0, i].set_yticks([])
                
            sns.heatmap(
                score_mats[i+len(model_names)]*100-score_mats[0]*100, 
                vmin=-vmax, vmax=vmax, 
                annot=True, fmt=".1f", linewidths=0.1, ax=axs[1, i], 
                cmap=sns.color_palette('RdBu_r', n_colors=10), cbar=False, square=True)
#             axs[2, i].axis('off')
            b, t = axs[1, i].get_ylim()
            b += 0.5
            t -= 0.5
            axs[1, i].set_ylim(b, t)
            axs[1, i].get_xaxis().set_visible(False)
            axs[1, i].get_yaxis().set_visible(False)
            if i == 0:
                axs[1, i].get_yaxis().set_visible(True)
                axs[1, i].set_ylabel('RA')
                axs[1, i].set_yticks([])
        plt.tight_layout()
        plt.savefig("images/network_cross_dataset_accuracy.jpg", format='jpg', dpi=300)

In [33]:
scores = []
# model_names = ['mekt', 'shallownet', 'shallowfbcspnet', 'eegnet', 'eegnetv4', 'tidnet']
model_names = ['mekt', 'shallownet', 'shallowfbcspnet', 'eegnet', 'eegnetv4']
methods = ['euclid-adabn', 'riemann-adabn']

for dataset in datasets:
    inner_scores = []
    for model_name in model_names:
        if model_name != 'mekt':
            for method in methods:
                score = results[
                    (results['source_dataset']==dataset.dataset_code) 
                    & (results['model']==model_name) 
                    & (results['method']==method)].accuracy.to_numpy()
                inner_scores.append(score)
        else:
            score = results[
                (results['source_dataset']==dataset.dataset_code) 
                & (results['model']=='mekt')].accuracy.to_numpy()
            inner_scores.append(score)
    scores.append(inner_scores)

P = show_p(scores, pval=pval)

In [34]:
# names = [
#     'MEKT',
#     'EA-ShallowNet', 'RA-ShallowNet',
#     'EA-ShallowFBCSPNet', 'RA-ShallowFBCSPNet',
#     'EA-EEGNet', 'RA-EEGNet',
#     'EA-EEGNetv4', 'RA-EEGNetv4',
#     'EA-TIDNet', 'RA-TIDNet']

names = [
    'MEKT',
    'EA-ShallowNet', 'RA-ShallowNet',
    'EA-ShallowFBCSPNet', 'RA-ShallowFBCSPNet',
    'EA-EEGNet', 'RA-EEGNet',
    'EA-EEGNetv4', 'RA-EEGNetv4']

with sns.plotting_context('paper', font_scale=2):
    with sns.axes_style('darkgrid'):
        f, ax = plt.subplots(figsize=(9, 9))
        sns.heatmap(P, vmin=0, vmax=1, 
            annot=False, fmt="1.0e", linewidths=0.1, ax=ax, linecolor='k', square=True,
            cmap=sns.color_palette('Reds_r'), cbar=False)

        ax.xaxis.tick_top()
        xticks = ax.get_xticks()
        plt.xticks(xticks, names, rotation=45, ha='left')
        plt.yticks(ax.get_yticks(), names, rotation=0)
        ax.tick_params(axis='both', which='both', length=0)

        b, t = plt.ylim()
        b += 0.5
        t -= 0.5
        plt.ylim(b, t)
        ax.xaxis.set_label_position('top')
        plt.tight_layout()
        plt.savefig("images/network_cross_dataset_pvalues_single_tail.jpg", format='jpg', dpi=300)

### Fusing MEKT and Networks

In [35]:
model_names = ['eegnet', 'eegnetv4']
methods = ['euclid', 'riemann']
for model_name in model_names:
    for method in methods:
        for source_dataset in datasets:
            for target_dataset in datasets:
                if source_dataset.dataset_code == target_dataset.dataset_code:
                    continue
                save_folder = model_name
                model_name_str = '{}-{}-adabn-mekt'.format(model_name, method)

                save_file = os.path.join(
                    save_folder,
                    "{}->{}-{}-{}classes.joblib".format(
                        source_dataset.dataset_code,
                        target_dataset.dataset_code,
                        model_name_str,
                        len(events)))    
                kfold_accs = joblib.load(save_file)['kfold_accs']
                sub_accs = np.mean(kfold_accs, axis=0)  
                
                results = results.append(pd.DataFrame.from_dict({
                    'source_dataset': source_dataset.dataset_code,
                    'target_dataset': target_dataset.dataset_code,
                    'subject': target_dataset.subjects,
                    'accuracy': sub_accs,
                    'model': model_name,
                    'method': '{}-adabn-mekt'.format(method),
                }), ignore_index=True)

In [38]:
model_name = 'eegnet'
source_dataset = 'weibo2014'

tmp = results[(results['model']==model_name)&(results['source_dataset']==source_dataset)]
tmp.groupby(['target_dataset', 'method']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,accuracy
target_dataset,method,Unnamed: 2_level_1
bnci2014001,cnorm,0.639815
bnci2014001,cnorm-adabn,0.650849
bnci2014001,euclid,0.691127
bnci2014001,euclid-adabn,0.696451
bnci2014001,euclid-adabn-mekt,0.703009
bnci2014001,raw,0.649846
bnci2014001,raw-adabn,0.667515
bnci2014001,riemann,0.690586
bnci2014001,riemann-adabn,0.694985
bnci2014001,riemann-adabn-mekt,0.704861


In [39]:
scores = []
# model_names = ['mekt', 'shallownet', 'shallowfbcspnet', 'eegnet', 'eegnetv4', 'tidnet']
model_names = ['mekt', 'eegnet']
methods = ['euclid-adabn', 'riemann-adabn', 'euclid-adabn-mekt', 'riemann-adabn-mekt']

for dataset in datasets:
    inner_scores = []
    for model_name in model_names:
        if model_name != 'mekt':
            for method in methods:
                score = results[
                    (results['source_dataset']==dataset.dataset_code) 
                    & (results['model']==model_name) 
                    & (results['method']==method)].accuracy.to_numpy()
                inner_scores.append(score)
        else:
            score = results[
                (results['source_dataset']==dataset.dataset_code) 
                & (results['model']=='mekt')].accuracy.to_numpy()
            inner_scores.append(score)
    scores.append(inner_scores)

P = show_p(scores, pval=pval)

In [40]:
# names = [
#     'MEKT',
#     'EA-ShallowNet', 'RA-ShallowNet',
#     'EA-ShallowFBCSPNet', 'RA-ShallowFBCSPNet',
#     'EA-EEGNet', 'RA-EEGNet',
#     'EA-EEGNetv4', 'RA-EEGNetv4',
#     'EA-TIDNet', 'RA-TIDNet']

names = [
    'MEKT',
    'EA', 'RA',
    'EA-MEKT', 'RA-MEKT']

with sns.plotting_context('paper', font_scale=2):
    with sns.axes_style('darkgrid'):
        f, ax = plt.subplots(figsize=(9, 9))
        sns.heatmap(P, vmin=0, vmax=1, 
            annot=False, fmt="1.0e", linewidths=0.1, ax=ax, linecolor='k',
            cmap=sns.color_palette('Reds_r'), cbar=False, square=True)

        ax.xaxis.tick_top()
        xticks = ax.get_xticks()
        plt.xticks(xticks, names, rotation=45, ha='left')
        plt.yticks(ax.get_yticks(), names, rotation=0)
        ax.tick_params(axis='both', which='both', length=0)

        b, t = plt.ylim()
        b += 0.5
        t -= 0.5
        plt.ylim(b, t)
        ax.xaxis.set_label_position('top')
        plt.tight_layout()
        plt.savefig("images/eegnet_mekt_cross_dataset_pvalues_single_tail.jpg", format='jpg', dpi=300)

In [41]:
scores = []
# model_names = ['mekt', 'shallownet', 'shallowfbcspnet', 'eegnet', 'eegnetv4', 'tidnet']
model_names = ['mekt', 'eegnetv4']
methods = ['euclid-adabn', 'riemann-adabn', 'euclid-adabn-mekt', 'riemann-adabn-mekt']

for dataset in datasets:
    inner_scores = []
    for model_name in model_names:
        if model_name != 'mekt':
            for method in methods:
                score = results[
                    (results['source_dataset']==dataset.dataset_code) 
                    & (results['model']==model_name) 
                    & (results['method']==method)].accuracy.to_numpy()
                inner_scores.append(score)
        else:
            score = results[
                (results['source_dataset']==dataset.dataset_code) 
                & (results['model']=='mekt')].accuracy.to_numpy()
            inner_scores.append(score)
    scores.append(inner_scores)

P = show_p(scores, pval=pval)

In [42]:
# names = [
#     'MEKT',
#     'EA-ShallowNet', 'RA-ShallowNet',
#     'EA-ShallowFBCSPNet', 'RA-ShallowFBCSPNet',
#     'EA-EEGNet', 'RA-EEGNet',
#     'EA-EEGNetv4', 'RA-EEGNetv4',
#     'EA-TIDNet', 'RA-TIDNet']

names = [
    'MEKT',
    'EA', 'RA',
    'EA-MEKT', 'RA-MEKT']

with sns.plotting_context('paper', font_scale=2):
    with sns.axes_style('darkgrid'):
        f, ax = plt.subplots(figsize=(9, 9))
        sns.heatmap(P, vmin=0, vmax=1, 
            annot=False, fmt="1.0e", linewidths=0.1, ax=ax, linecolor='k',
            cmap=sns.color_palette('Reds_r'), cbar=False, square=True)

        ax.xaxis.tick_top()
        xticks = ax.get_xticks()
        plt.xticks(xticks, names, rotation=45, ha='left')
        plt.yticks(ax.get_yticks(), names, rotation=0)
        ax.tick_params(axis='both', which='both', length=0)

        b, t = plt.ylim()
        b += 0.5
        t -= 0.5
        plt.ylim(b, t)
        ax.xaxis.set_label_position('top')
        plt.tight_layout()
        plt.savefig("images/eegnetv4_mekt_cross_dataset_pvalues_single_tail.jpg", format='jpg', dpi=300)