In [28]:
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import pandas as pd
import seaborn as sns
import numpy as np
import pickle
import copy
from scipy import stats
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
from scipy.stats import wilcoxon
%matplotlib widget

In [29]:
def load_res(TRUELABEL_PERCENT, table_path, nb_subj, SEEDS, CROSSVIEW, track_resolu, label_resolu, compete_resolu, coe, MAX_ITER, LWCOV, SAMEWEIGHT, KEEP_TRAIN_PERCENT, BOOTSTRAP, SVAD, fold=None):
    rt = np.zeros((nb_subj, MAX_ITER))
    corr_sum_att = np.zeros((nb_subj, MAX_ITER))
    corr_sum_unatt = np.zeros((nb_subj, MAX_ITER))
    acc = np.zeros((nb_subj, MAX_ITER))
    rt_list = []
    corr_sum_att_list = []
    corr_sum_unatt_list = []
    acc_list = []
    for SEED in SEEDS:
        if TRUELABEL_PERCENT == 0.0 or TRUELABEL_PERCENT == 1.0:
            SEED = SEEDS[0]
        for Subj_ID in range(nb_subj):
            file_name = f'{table_path}{Subj_ID}_{'crossview' if CROSSVIEW else 'featview'}_trackresolu{track_resolu}_truelabelpct_{TRUELABEL_PERCENT}_labelresolu_{label_resolu}_{'svadresolu' if SVAD else 'mmresolu'}_{compete_resolu}coe{coe}_nbiter{MAX_ITER}_seed{SEED}{'_lwcov' if LWCOV else ''}{'_samew' if SAMEWEIGHT else ''}{KEEP_TRAIN_PERCENT if KEEP_TRAIN_PERCENT is not None else ''}{'_bootstrap' if BOOTSTRAP else ''}.pkl'
            with open(file_name, 'rb') as f:
                res = pickle.load(f)
                if not SVAD:
                    rt[Subj_ID,:] = res['rt_list']
                    corr_sum_att[Subj_ID,:] = res['corr_sum_list']
                    acc[Subj_ID,:] = res['acc_list']
                else:
                    corr_sum_att[Subj_ID,:] = res['corr_sum_att'][fold,:] if fold is not None else np.mean(res['corr_sum_att'], axis=0)
                    corr_sum_unatt[Subj_ID,:] = res['corr_sum_unatt'][fold,:] if fold is not None else np.mean(res['corr_sum_unatt'], axis=0)
                    acc[Subj_ID,:] = res['acc'][fold,:] if fold is not None else np.mean(res['acc'], axis=0)
        corr_sum_att_list.append(copy.deepcopy(corr_sum_att))
        acc_list.append(copy.deepcopy(acc))
        if not SVAD:
            rt_list.append(copy.deepcopy(rt))
        else:
            corr_sum_unatt_list.append(copy.deepcopy(corr_sum_unatt))
    rt_avg_over_seeds = np.mean(rt_list, axis=0) if not SVAD else None
    corr_avg_over_seeds_att = np.mean(corr_sum_att_list, axis=0)
    corr_avg_over_seeds_unatt = np.mean(corr_sum_unatt_list, axis=0) if SVAD else None
    acc_avg_over_seeds = np.mean(acc_list, axis=0)   
    rt_all = np.concatenate(tuple(rt_list), axis=0) if not SVAD else None
    corr_all_att = np.concatenate(tuple(corr_sum_att_list), axis=0)
    corr_all_unatt = np.concatenate(tuple(corr_sum_unatt_list), axis=0) if SVAD else None
    acc_all = np.concatenate(tuple(acc_list), axis=0)
    return rt_avg_over_seeds, corr_avg_over_seeds_att, corr_avg_over_seeds_unatt, acc_avg_over_seeds, rt_all, corr_all_att, corr_all_unatt, acc_all

In [30]:
SEEDS = [2, 4, 8, 16, 32]
# SEEDS = [4]
label_resolu = 60
track_resolu = 60
compete_resolu = 60
LWCOV = False
BOOTSTRAP = True
SVAD = True
fold = 0
MAX_ITER = 8 if not SVAD else 6
SAMEWEIGHT = False
KEEP_TRAIN_PERCENT = None
coe = None

In [31]:
MOD = 'EEG-EOG'
CROSSVIEW = True
nb_subj = 19
L_data = 3
L_feats = 15
w_data = np.zeros((nb_subj, 64*L_data, 5))
w_feats = np.zeros((nb_subj, L_feats, 5)) 
table_path = f'tables/{MOD}/'

## Plot the initial results for each TRUELABEL_PERCENT

In [32]:
TRUELABEL_PERCENT_list = [0.0, 0.5, 1.0] # [0.0, 0.25, 0.5, 0.75, 1.0]
rt_init = np.zeros((nb_subj*len(SEEDS), len(TRUELABEL_PERCENT_list))) if not SVAD else None
corr_att_init = np.zeros((nb_subj*len(SEEDS), len(TRUELABEL_PERCENT_list)))
corr_unatt_init = np.zeros((nb_subj*len(SEEDS), len(TRUELABEL_PERCENT_list))) if SVAD else None
acc_init = np.zeros((nb_subj*len(SEEDS), len(TRUELABEL_PERCENT_list)))
for i, TRUELABEL_PERCENT in enumerate(TRUELABEL_PERCENT_list):
    rt_avg_over_seeds, corr_avg_over_seeds_att, corr_avg_over_seeds_unatt, acc_avg_over_seeds, rt_all, corr_all_att, corr_all_unatt, acc_all =  load_res(TRUELABEL_PERCENT, table_path, nb_subj, SEEDS, CROSSVIEW, track_resolu, label_resolu, compete_resolu, coe, MAX_ITER, LWCOV, SAMEWEIGHT, KEEP_TRAIN_PERCENT, BOOTSTRAP, SVAD)
    corr_att_init[:,i] = corr_all_att[:,0]
    acc_init[:,i] = acc_all[:,0]
    if not SVAD:
        rt_init[:,i] = rt_all[:,0]
    else:
        corr_unatt_init[:,i] = corr_all_unatt[:,0]

In [None]:
value_to_plot = acc_init
plt.close('all')
plt.figure()
sns.boxplot(data=value_to_plot, fill=None)
sns.stripplot(data=value_to_plot)
plt.plot(np.mean(value_to_plot, axis=0), color='black')
plt.fill_between(range(len(TRUELABEL_PERCENT_list)), np.mean(value_to_plot, axis=0) - np.std(value_to_plot, axis=0), np.mean(value_to_plot, axis=0) + np.std(value_to_plot, axis=0), color='gray', alpha=0.2)
plt.xticks(ticks=range(len(TRUELABEL_PERCENT_list)), labels=[f'{TRUELABEL_PERCENT}' for TRUELABEL_PERCENT in TRUELABEL_PERCENT_list])
plt.xlabel('Initial True Label Percentage')
plt.ylabel('CC1+CC2')
plt.show()

In [None]:
np.mean(value_to_plot, axis=0)

In [None]:
wilcoxon(acc_init[:,1], acc_init[:,0], alternative='greater')

## Plot the results of each iteration for a given initial TRUELABEL_PERCENT

In [83]:
TRUELABEL_PERCENT = 0.0
rt_avg_over_seeds, corr_avg_over_seeds_att, corr_avg_over_seeds_unatt, acc_avg_over_seeds, rt_all, corr_all_att, corr_all_unatt, acc_all =  load_res(TRUELABEL_PERCENT, table_path, nb_subj, SEEDS, CROSSVIEW, track_resolu, label_resolu, compete_resolu, coe, MAX_ITER, LWCOV, SAMEWEIGHT, KEEP_TRAIN_PERCENT, BOOTSTRAP, SVAD, fold=6)

In [None]:
value_to_plot = acc_all # Choose between rt_all, corr_all, acc_all
plt.close('all')
plt.figure()
sns.boxplot(data=value_to_plot, fill=None)
sns.stripplot(data=value_to_plot)
plt.plot(np.mean(value_to_plot, axis=0), color='black')
plt.fill_between(range(MAX_ITER), np.mean(value_to_plot, axis=0) - np.std(value_to_plot, axis=0), np.mean(value_to_plot, axis=0) + np.std(value_to_plot, axis=0), color='gray', alpha=0.2)
plt.xlabel('Number of Iterations')
plt.ylabel('CC1+CC2')
plt.show()

In [None]:
np.mean(value_to_plot, axis=0)

In [None]:
wilcoxon(value_to_plot[:,1], value_to_plot[:,0], alternative='greater')