In [24]:
%matplotlib qt5
import warnings, copy, os, itertools
warnings.filterwarnings("ignore")
from collections import OrderedDict

import numpy as np
import pandas as pd
import joblib
from scipy.signal import welch

import matplotlib
matplotlib.use("Qt5Agg")
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid", context="paper")
import matplotlib.patches as patches
from matplotlib.legend_handler import HandlerPatch


from brainda.datasets import Nakanishi2015, Wang2016, BETA
from brainda.paradigms import SSVEP
from brainda.algorithms.utils.model_selection import set_random_seeds
from utils import *
from models import *

In [2]:
def make_file(
    dataset, model_name, channels, srate, duration, events, 
    preprocess=None, 
    n_bands=None,
    augment=False, loo=False, fixed_dtn_template=False):
    file = "{:s}-{:s}-{ch:d}-{srate:d}-{nt:d}-{event:d}".format(
        dataset.dataset_code,
        model_name,
        ch=len(channels), 
        srate=srate, 
        nt=int(duration*srate),
        event=len(events))
    if n_bands is not None:
        file += '-{:d}'.format(n_bands)
    if preprocess is not None:
        file += '-{:s}'.format(preprocess)
    if augment:
        file += '-augment'
    if fixed_dtn_template:
        file += '-fixed'
    if loo:
        file += '-loo'
    file += '.joblib'
    return file

def get_within_subject_scores(
    dataset, model_name, channels, srate, duration, events,
    n_bands=None,
    augment=False, fixed_dtn_template=False):
    score_file = make_file(
        dataset, model_name, channels, srate, duration, events, 
        n_bands=n_bands, augment=False, fixed_dtn_template=fixed_dtn_template)
    if model_name in ['eegnet-ssvep', 'dtn', 'ftn']:
        save_folder = 'neural_networks'
    else:
        save_folder = 'matrix_decomposition'
    
    score_file = os.path.join(
        save_folder, score_file)
    scores = joblib.load(score_file)
    
    if model_name in ['eegnet-ssvep', 'dtn', 'ftn']:

        if augment:
            if fixed_dtn_template:
                sub_accs = scores['ft_aug_sub_accs_fixed']
            else:
                sub_accs = scores['ft_aug_sub_accs']
        else:
            sub_accs = scores['ft_sub_accs']
    else:
        sub_accs = scores['sub_accs']
    return sub_accs

class HandlerRect(HandlerPatch):
    def create_artists(self, legend, orig_handle,
                       xdescent, ydescent, width, height,
                       fontsize, trans):
        x = width//2
        y = 0
        w = h = 10

        # create
        p = patches.Rectangle(xy=(x, y), width=w, height=h)

        # update with data from oryginal object
        self.update_prop(p, orig_handle, legend)

        # move xy to legend
        p.set_transform(trans)
        return [p]

In [3]:
datasets = [
    Nakanishi2015(), 
    Wang2016(), 
    BETA()
]
delays = [
    0.135, 
    0.14, 
    0.13
]
channels = [
    ['PO7', 'PO3', 'POZ', 'PO4', 'PO8', 'O1', 'OZ', 'O2'],
    ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2'],
    ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2']
]

srate = 250
durations = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

n_bands = 3
n_harmonics = 5

save_folder = 'figures'
os.makedirs(save_folder, exist_ok=True)

In [4]:
decomp_models = ['fbdsp', 'fbtrca', 'fbtrcar', 'fbtdca']
decomp_model_names = ['FBDSP', 'FBeTRCA', 'FBeTRCA-R', 'FBTDCA']

for dataset, dataset_channels, delay in zip(datasets, channels, delays):
    dataset_events = sorted(list(dataset.events.keys()))
    freqs = [dataset.get_freq(event) for event in dataset_events]
    phases = [dataset.get_phase(event) for event in dataset_events]    
    
    results = pd.DataFrame(columns=['Nt', 'BA', 'k', 'model'])
    for i, model in enumerate(decomp_models):
        for duration in durations:
            scores = get_within_subject_scores(
                dataset, model, dataset_channels, srate, duration, dataset_events, 
                n_bands=n_bands)
            acc = np.mean(scores, 0)
            
            results = results.append(pd.DataFrame.from_dict({
                 'Nt': duration,
                 'BA': acc*100,
                 'k': np.arange(len(acc)),
                 'model': decomp_model_names[i]
             }), ignore_index=True)
            
    with sns.plotting_context('paper', font_scale=2.5):
        with sns.axes_style('whitegrid'):
            fig, ax = plt.subplots(figsize=(10, 8))
            sns.barplot(
                x='Nt', 
                y='BA', 
                hue='model', 
                hue_order=decomp_model_names,
                data=results, 
                ci='sd', 
                palette=sns.color_palette('tab10', 4), errcolor='#858585',
                edgecolor=".2", lw=1,capsize=0.02, 
                zorder=10, ax=ax)
            ax.set_xlabel('$N_t$(s)')
            ax.set_ylabel('BA(%)')
            if dataset.dataset_code == 'nakanishi2015':
                min_val = 65
            elif dataset.dataset_code == 'wang2016':
                min_val = 30
            else:
                min_val = 20
            ax.set_ylim(min_val, 100)
            step = 5 if dataset.dataset_code == 'nakanishi2015' else 10 
            yticks = np.arange(min_val, 101, step)
            ax.set_yticks(yticks)
            ax.legend_.remove()
            cmaps = sns.color_palette('tab10', 4)
            handles = [
                patches.Rectangle((0, 0), 0, 0, linewidth=1, edgecolor=cmap, facecolor=cmap) for cmap in cmaps]
            ax.legend(
                handles=handles, labels=decomp_model_names, 
                bbox_to_anchor=(1.25, 0.13), loc='center', 
                handler_map={
                   patches.Rectangle: HandlerRect()})
            plt.setp(ax.legend_.get_texts(), fontsize='18')
            fig.tight_layout()
            fig.savefig(
                os.path.join(
                    save_folder, "within_subject_decomp_{}.jpg".format(dataset.dataset_code)),
                format='jpg', 
                dpi=300)
            plt.close(fig)

In [5]:
scores = []
for dataset, dataset_channels, delay in zip(datasets, channels, delays):
    dataset_events = sorted(list(dataset.events.keys()))
    freqs = [dataset.get_freq(event) for event in dataset_events]
    phases = [dataset.get_phase(event) for event in dataset_events] 
    for duration in durations:
        inner_scores = []
        for i, model in enumerate(decomp_models):
            sub_accs = get_within_subject_scores(
                dataset, model, dataset_channels, srate, duration, dataset_events, 
                n_bands=n_bands)
            inner_scores.append(np.mean(sub_accs, -1))
        scores.append(inner_scores)
set_random_seeds(64)
P, dataP = show_p(
    scores, bonferroni_correcton=True)
colors = sns.color_palette('Reds')
cmaps = [colors[0], colors[2], colors[5]]
with sns.plotting_context('paper', font_scale=2.2):
    with sns.axes_style('darkgrid'):
        fig, ax = plt.subplots(figsize=(9, 7))
        sns.heatmap(P, vmin=0, vmax=1, 
            annot=dataP, fmt="1.0e", linewidths=0.1, ax=ax, linecolor='k',
            cmap=cmaps, cbar=False, square=True)

        ax.xaxis.tick_top()
        xticks = ax.get_xticks()
        plt.xticks(xticks, decomp_model_names, rotation=45, ha='left')
        plt.yticks(ax.get_yticks(), decomp_model_names, rotation=0)
        ax.tick_params(axis='both', which='both', length=0)

        b, t = plt.ylim()
        # b += 0.5
        # t -= 0.5
        plt.ylim(b, t)
        ax.xaxis.set_label_position('top')
        rect1 = patches.Rectangle(
            (0, 0), 0, 0, linewidth=1, edgecolor=cmaps[0], facecolor=cmaps[0])
        rect2 = patches.Rectangle(
            (0, 0), 0, 0, linewidth=1, edgecolor=cmaps[1], facecolor=cmaps[1])
        rect3 = patches.Rectangle(
            (0, 0), 0, 0, linewidth=1, edgecolor=cmaps[2], facecolor=cmaps[2])
        ax.legend(
            handles=[rect1, rect2, rect3], labels=['P<0.05', 'P<0.01', 'P<0.001'], 
            bbox_to_anchor=(1.17, 0.12), loc='center', 
            handler_map={
               patches.Rectangle: HandlerRect()})
        plt.setp(ax.legend_.get_texts(), fontsize='14')
        fig.tight_layout()
        fig.savefig(
            os.path.join(
                save_folder,"within_subject_decomp_significance.jpg"), 
            format='jpg', 
            dpi=300)
        plt.close(fig)

In [6]:
durations = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
decomp_models = ['fbtdca', 'fbtrcar']
decomp_model_names = ['FBTDCA', 'FBeTRCA-R']
dl_models = ['eegnet-ssvep', 'dtn', 'ftn']
dl_model_names = ['EEGNet', 'DTN', 'FTN']

with sns.plotting_context('paper', font_scale=2.6):
    with sns.axes_style('whitegrid'):
        fig, axs = plt.subplots(1, 3, figsize=(7*3.2, 7))

        for idata, dataset, dataset_channels, delay in zip([0,1, 2], datasets, channels, delays):
            
            dataset_events = sorted(list(dataset.events.keys()))
            freqs = [dataset.get_freq(event) for event in dataset_events]
            phases = [dataset.get_phase(event) for event in dataset_events] 
            
            results = pd.DataFrame(columns=['Nt', 'BA', 'k', 'model'])
            for i, model in enumerate(decomp_models):
                for duration in durations:
                    sub_accs = get_within_subject_scores(
                        dataset, model, dataset_channels, srate, duration, dataset_events, 
                        n_bands=n_bands)
                    acc = np.mean(sub_accs, 0)

                    results = results.append(pd.DataFrame.from_dict({
                         'Nt': duration,
                         'BA': acc*100,
                         'k': np.arange(len(acc)),
                         'model': decomp_model_names[i]
                     }), ignore_index=True)

            for i, model in enumerate(dl_models):
                for duration in durations:
                    sub_accs = get_within_subject_scores(
                        dataset, model, dataset_channels, srate, duration, dataset_events, 
                        n_bands=n_bands, augment=False)
                    acc = np.mean(sub_accs, 0)

                    results = results.append(pd.DataFrame.from_dict({
                         'Nt': duration,
                         'BA': acc*100,
                         'k': np.arange(len(acc)),
                         'model': dl_model_names[i]
                     }), ignore_index=True)
            
            colors = sns.color_palette('tab10', n_colors=10)
            cmaps=[colors[0], colors[4], colors[2], colors[1], colors[3]]
            sns.lineplot(
                x='Nt',
                y='BA',
                hue='model',
                hue_order=decomp_model_names+dl_model_names,
                data=results,
                ci='sd', err_style='bars', 
                err_kws={'capsize': 2},
                palette=cmaps, 
                linewidth=3.5, ax=axs[idata])
            
            axs[idata].set_xlabel('$N_t$(s)')
            axs[idata].set_ylabel('BA(%)')
            min_val = 60 if dataset.dataset_code == 'nakanishi2015' else 20
            axs[idata].set_ylim(min_val, 100)
            step = 5 if dataset.dataset_code == 'nakanishi2015' else 10
            yticks = np.arange(min_val, 101, step)
            axs[idata].set_yticks(yticks)
            axs[idata].legend_.remove()
            if idata == 0:
                 axs[idata].text(0.6, 102, 'Nakanishi2015', va='bottom', ha='center')
            elif idata == 1:
                axs[idata].text(0.6, 102, 'Wang2016', va='bottom', ha='center')
            elif idata == 2:
                axs[idata].text(0.6, 102, 'BETA', va='bottom', ha='center')
            if idata == 2:
                axs[idata].legend(bbox_to_anchor=(1.24, 0.25), loc='center')
                plt.setp(axs[idata].legend_.get_texts(), fontsize='16')
        fig.tight_layout()
        fig.savefig(
            os.path.join(
                save_folder, "within_subject_decomp_dl_line.jpg"),
            format='jpg', 
            dpi=300)
        plt.close(fig)

In [7]:
durations = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
decomp_models = ['fbtdca', 'fbtrcar']
decomp_model_names = ['FBTDCA', 'FBeTRCA-R']
dl_models = ['dtn', 'ftn']
dl_model_names = ['DTN', 'FTN']

for dataset, dataset_channels, delay in zip(datasets, channels, delays):
    dataset_events = sorted(list(dataset.events.keys()))
    freqs = [dataset.get_freq(event) for event in dataset_events]
    phases = [dataset.get_phase(event) for event in dataset_events] 
    
    results = pd.DataFrame(columns=['Nt', 'BA', 'k', 'model'])
    for i, model in enumerate(decomp_models):
        for duration in durations:
            sub_accs = get_within_subject_scores(
                dataset, model, dataset_channels, srate, duration, dataset_events, 
                n_bands=n_bands)
            acc = np.mean(sub_accs, 0)
            
            results = results.append(pd.DataFrame.from_dict({
                 'Nt': duration,
                 'BA': acc*100,
                 'k': np.arange(len(acc)),
                 'model': decomp_model_names[i]
             }), ignore_index=True)
            
    for i, model in enumerate(dl_models):
        for duration in durations:
            sub_accs = get_within_subject_scores(
                dataset, model, dataset_channels, srate, duration, dataset_events, 
                n_bands=n_bands, augment=False)
            acc = np.mean(sub_accs, 0)

            results = results.append(pd.DataFrame.from_dict({
                 'Nt': duration,
                 'BA': acc*100,
                 'k': np.arange(len(acc)),
                 'model': dl_model_names[i]
             }), ignore_index=True)

    for i, model in enumerate(dl_models):
        for duration in durations:
            sub_accs = get_within_subject_scores(
                dataset, model, dataset_channels, srate, duration, dataset_events, 
                n_bands=n_bands, augment=True)
            acc = np.mean(sub_accs, 0)
                
            results = results.append(pd.DataFrame.from_dict({
                 'Nt': duration,
                 'BA': acc*100,
                 'k': np.arange(len(acc)),
                 'model': dl_model_names[i]+'-A'
             }), ignore_index=True)
            
    
    with sns.plotting_context('paper', font_scale=2.6):
        with sns.axes_style('whitegrid'):
            colors = sns.color_palette('Paired', 10)
            cmaps=[colors[1], colors[2], colors[6], colors[7], colors[4], colors[5]]
            fig, ax = plt.subplots(figsize=(10, 8))
            sns.barplot(
                x='Nt', 
                y='BA', 
                hue='model', 
                hue_order=['FBTDCA', 'FBeTRCA-R', 'DTN', 'DTN-A', 'FTN', 'FTN-A'],
                data=results, 
                ci='sd', 
                palette=cmaps, errcolor='#858585',
                edgecolor=".2", lw=1,capsize=0.02, 
                zorder=10, ax=ax)
            ax.set_xlabel('$N_t$(s)')
            ax.set_ylabel('BA(%)')
            
            if dataset.dataset_code == 'nakanishi2015':
                min_val = 60
            elif dataset.dataset_code == 'wang2016':
                min_val = 20
            else:
                min_val = 10
            ax.set_ylim(min_val, 100)
            step = 5 if dataset.dataset_code == 'nakanishi2015' else 10 
            yticks = np.arange(min_val, 101, step)
            ax.set_yticks(yticks)
            ax.legend_.remove()
            handles = [
                patches.Rectangle((0, 0), 0, 0, linewidth=1, edgecolor=cmap, facecolor=cmap) for cmap in cmaps]
            ax.legend(
                handles=handles, labels=['FBTDCA', 'FBeTRCA-R', 'DTN', 'DTN-A', 'FTN', 'FTN-A'], 
                bbox_to_anchor=(1.25, 0.2), loc='center', 
                handler_map={
                   patches.Rectangle: HandlerRect()})
            plt.setp(ax.legend_.get_texts(), fontsize='16')
            fig.tight_layout()
            fig.savefig(
                os.path.join(
                    save_folder, "within_subject_decomp_dl_{}.jpg".format(dataset.dataset_code)),
                format='jpg', 
                dpi=300)
            plt.close(fig)

In [8]:
scores = []
for dataset, dataset_channels, delay in zip(datasets, channels, delays):
    dataset_events = sorted(list(dataset.events.keys()))
    freqs = [dataset.get_freq(event) for event in dataset_events]
    phases = [dataset.get_phase(event) for event in dataset_events] 
    
    for duration in durations:
        inner_scores = []
        for i, model in enumerate(decomp_models):
            sub_accs = get_within_subject_scores(
                dataset, model, dataset_channels, srate, duration, dataset_events, 
                n_bands=n_bands)
            inner_scores.append(np.mean(sub_accs, -1))
            
    
        for i, model in enumerate(dl_models):
            sub_accs = get_within_subject_scores(
                dataset, model, dataset_channels, srate, duration, dataset_events, 
                n_bands=n_bands, augment=True)
            inner_scores.append(np.mean(sub_accs, -1))      
        scores.append(inner_scores)
        
set_random_seeds(64)
P, dataP = show_p(
    scores, bonferroni_correcton=True)
colors = sns.color_palette('Reds')
cmaps = [colors[0], colors[2], colors[5]] 
with sns.plotting_context('paper', font_scale=2.2):
    with sns.axes_style('darkgrid'):
        fig, ax = plt.subplots(figsize=(10, 7))
        sns.heatmap(P, vmin=0, vmax=1, 
            annot=dataP, fmt="1.0e", linewidths=0.1, ax=ax, linecolor='k',
            cmap=cmaps, cbar=False, square=True)

        ax.xaxis.tick_top()
        xticks = ax.get_xticks()
        names = ['FBTDCA', 'FBeTRCA-R', 'DTN-A', 'FTN-A']
        plt.xticks(xticks, names, rotation=45, ha='left')
        plt.yticks(ax.get_yticks(), names, rotation=0)
        ax.tick_params(axis='both', which='both', length=0)

        b, t = plt.ylim()
        # b += 0.5
        # t -= 0.5
        plt.ylim(b, t)
        ax.xaxis.set_label_position('top')
        fig.tight_layout()
        rect1 = patches.Rectangle(
            (0, 0), 0, 0, linewidth=1, edgecolor=cmaps[0], facecolor=cmaps[0])
        rect2 = patches.Rectangle(
            (0, 0), 0, 0, linewidth=1, edgecolor=cmaps[1], facecolor=cmaps[1])
        rect3 = patches.Rectangle(
            (0, 0), 0, 0, linewidth=1, edgecolor=cmaps[2], facecolor=cmaps[2])
        ax.legend(
            handles=[rect1, rect2, rect3], labels=['P<0.05', 'P<0.01', 'P<0.001'], 
            bbox_to_anchor=(1.17, 0.12), loc='center', 
            handler_map={
               patches.Rectangle: HandlerRect()})
        plt.setp(ax.legend_.get_texts(), fontsize='14')
        
        fig.savefig(
            os.path.join(
                save_folder,"within_subject_decomp_dl_significance.jpg"), 
            format='jpg', 
            dpi=300)
        plt.close(fig)

direct template replacement

In [80]:
datasets = [
    Nakanishi2015(), 
    Wang2016(), 
    BETA()
]
delays = [
    0.135, 
    0.14, 
    0.13
]
channels = [
    ['PO7', 'PO3', 'POZ', 'PO4', 'PO8', 'O1', 'OZ', 'O2'],
    ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2'],
    ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2']
]

srate = 250
durations = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

models = ['dtn']

In [87]:
for dataset, dataset_channels, delay in zip(datasets, channels, delays):
    dataset_events = sorted(list(dataset.events.keys()))
    freqs = [dataset.get_freq(event) for event in dataset_events]
    phases = [dataset.get_phase(event) for event in dataset_events]    
    
    results = pd.DataFrame(columns=['Nt', 'BA', 'k', 'model'])
    for i, model in enumerate(models):
        for duration in durations:
            scores = get_within_subject_scores(
                dataset, model, dataset_channels, srate, duration, dataset_events, 
                n_bands=n_bands, 
                augment=True,
                fixed_dtn_template=False)
            acc = np.mean(scores, 0)
            
            results = results.append(pd.DataFrame.from_dict({
                 'Nt': duration,
                 'BA': acc*100,
                 'k': np.arange(len(acc)),
                 'model': 'w/o average templates'
             }), ignore_index=True)
            
    for i, model in enumerate(models):
        for duration in durations:
            scores = get_within_subject_scores(
                dataset, model, dataset_channels, srate, duration, dataset_events, 
                n_bands=n_bands, 
                augment=True,
                fixed_dtn_template=True)
            acc = np.mean(scores, 0)
            
            results = results.append(pd.DataFrame.from_dict({
                 'Nt': duration,
                 'BA': acc*100,
                 'k': np.arange(len(acc)),
                 'model': 'w/ average templates'
             }), ignore_index=True)
            
    with sns.plotting_context('paper', font_scale=2.5):
        with sns.axes_style('whitegrid'):
            colors = sns.color_palette('Paired', 10)
            cmaps=[colors[6], colors[7]]
            fig, ax = plt.subplots(figsize=(10, 8))
            sns.barplot(
                x='Nt', 
                y='BA', 
                hue='model', 
                hue_order=['w/o average templates', 'w/ average templates'],
                data=results, 
                ci='sd', 
                palette=cmaps, errcolor='#858585',
                edgecolor=".2", lw=1,capsize=0.02, 
                zorder=10, ax=ax)
            ax.set_xlabel('$N_t$(s)')
            ax.set_ylabel('BA(%)')
            if dataset.dataset_code == 'nakanishi2015':
                min_val = 65
            elif dataset.dataset_code == 'wang2016':
                min_val = 30
            else:
                min_val = 20
            ax.set_ylim(min_val, 100)
            step = 5 if dataset.dataset_code == 'nakanishi2015' else 10 
            yticks = np.arange(min_val, 101, step)
            ax.set_yticks(yticks)
            ax.legend_.remove()
            handles = [
                patches.Rectangle((0, 0), 0, 0, linewidth=1, edgecolor=cmap, facecolor=cmap) for cmap in cmaps]
            ax.legend(
                handles=handles, labels=['w/o average\n templates', 'w/ average\n templates'], 
                bbox_to_anchor=(1.25, 0.13), loc='center', 
                handler_map={
                   patches.Rectangle: HandlerRect()})
            plt.setp(ax.legend_.get_texts(), fontsize='18')
            fig.tight_layout()
            fig.savefig(
                os.path.join(
                    save_folder, "within_subject_dtn_ave_template_{}.jpg".format(dataset.dataset_code)),
                format='jpg', 
                dpi=300)
            plt.close(fig)

analyses of DTN templates

In [82]:
datasets = [
    Nakanishi2015(), 
    Wang2016(), 
    BETA()
]
delays = [
    0.135, 
    0.14, 
    0.13
]
channels = [
    ['PO7', 'PO3', 'POZ', 'PO4', 'PO8', 'O1', 'OZ', 'O2'],
    ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2'],
    ['PZ', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'O1', 'OZ', 'O2']
]

srate = 250
durations = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

n_bands = 3
n_harmonics = 5

save_folder = 'figures'
os.makedirs(save_folder, exist_ok=True)

In [10]:
dtn_scores = np.mean(joblib.load('./neural_networks/beta-dtn-9-250-125-40-3.joblib')['ft_aug_sub_accs'], -1)
ftn_scores = np.mean(joblib.load('./neural_networks/beta-ftn-9-250-125-40-3.joblib')['ft_aug_sub_accs'], -1)
base_scores = np.mean(joblib.load('./matrix_decomposition/beta-fbtdca-9-250-125-40-3.joblib')['sub_accs'], -1)

plt.plot(ftn_scores-base_scores, label='ftn')
plt.plot(dtn_scores-base_scores, label='dtn')
plt.legend()

In [56]:
templates_data = joblib.load('sub_templates.joblib')
dtn_sub_templates1 = templates_data['dtn_sub_templates1']
dtn_sub_templates2 = templates_data['dtn_sub_templates2']

better = [8, 31, 43, 59]
worse = [14, 25, 40, 56]
subjects = better + worse
n_subjects = len(dtn_sub_templates1)
n_classes = 40

sub_corr = []
for i_sub in range(n_subjects):
    a = dtn_sub_templates1[i_sub].reshape((n_classes, -1))
    b = dtn_sub_templates2[i_sub].reshape((n_classes, -1))
    a = a - np.mean(a, axis=-1, keepdims=True)
    b = b - np.mean(b, axis=-1, keepdims=True)
    R = a@b.T / (np.sqrt(np.diag(a@a.T))[:, np.newaxis]@np.sqrt(np.diag(b@b.T))[np.newaxis, :])
    sub_corr.append(R)

sub_power = []
for i_sub in range(n_subjects):
    f1, p1 = welch(dtn_sub_templates1[i_sub], fs=125)
    p1 = np.mean(p1, axis=1)
    f2, p2 = welch(dtn_sub_templates2[i_sub], fs=125)
    p2 = np.mean(p2, axis=1)
    sub_power.append((f1, p1, p2))

In [98]:
select_i_sub = 7

ix = np.argsort(freqs)
with sns.plotting_context('paper', font_scale=2.2):
    with sns.axes_style('darkgrid'):
        fig, ax = plt.subplots(figsize=(12, 7))
        sns.heatmap(sub_corr[select_i_sub][ix][:, ix], vmin=-1, vmax=1, 
            annot=False, fmt="1.0e", linewidths=0.1, ax=ax, linecolor='k',
            cmap='jet', cbar=True, square=True)

        ax.xaxis.tick_top()
        xticks = np.arange(0, 40, 5)
        # xticks = ax.get_xticks()
        plt.xticks(xticks, np.array(freqs)[ix][xticks], rotation=45, ha='left')
        plt.yticks(xticks, np.array(freqs)[ix][xticks], rotation=0)
        ax.set_xlabel('Average Templates')
        ax.set_ylabel('DTN Templates')

        b, t = plt.ylim()
        # b += 0.5
        # t -= 0.5
        plt.ylim(b, t)
        ax.xaxis.set_label_position('top')
        fig.tight_layout()
        
        fig.savefig(
            os.path.join(
                save_folder,"dtn_subject_{}_corr.jpg".format(subjects[select_i_sub])), 
            format='jpg', 
            dpi=300)
        plt.close(fig)

e_p = 9
with sns.plotting_context('paper', font_scale=1):
    with sns.axes_style('darkgrid'):
        n_r, n_c= 4, 10
        fig, axs = plt.subplots(n_r, n_c, figsize=(16, 7))
        f, p1, p2 = sub_power[select_i_sub]
        for i in range(n_r*n_c):
            axs[i//n_c, i%n_c].plot(f[:e_p], p1[ix][i][:e_p], label='DTN Templates')
            axs[i//n_c, i%n_c].plot(f[:e_p], p2[ix][i][:e_p], label='Average Templates')
            axs[i//n_c, i%n_c].set_xticks(
                [0, 8, 12, 16], [0, 8, 12, 16])
            if i == 0:
                axs[i//n_c, i%n_c].set_xlabel('freq(Hz)')
            axs[i//n_c, i%n_c].set_title('{:.1f}'.format(np.array(freqs)[ix][i]))
        fig.tight_layout()
        plt.legend(loc='upper right')
        fig.savefig(
            os.path.join(save_folder, "dtn_subject_{}_power.jpg".format(subjects[select_i_sub])),
            format='jpg',
            dpi=300)
        plt.close(fig)