In [None]:
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import pandas as pd
import seaborn as sns
import numpy as np
import pickle
from scipy import stats
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
%matplotlib widget

In [None]:
def load_corr(modalities, table_name, table_parent='tables'):
    df_list = []
    for modality in modalities:
        df = pd.read_csv(f'{table_parent}/Overlay/{modality}/{table_name}')
        cc1 = df[df['Component'] == 'CC 1'].copy()
        cc1['subject_num'] = cc1['Subject ID'].str.extract('(\d+)', expand=False).astype(int)
        cc1 = cc1.sort_values(by='subject_num').drop(columns='subject_num')
        cc2 = df[df['Component'] == 'CC 2'].copy()
        cc2['subject_num'] = cc2['Subject ID'].str.extract('(\d+)', expand=False).astype(int)
        cc2 = cc2.sort_values(by='subject_num').drop(columns='subject_num')
        cc1_cc2 = pd.concat([cc1, cc2], ignore_index=True)
        cc1_cc2['Modality'] = modality
        df_list.append(cc1_cc2[['Subject ID', 'Sig Level', 'Component', 'Att', 'Unatt', 'Modality']])
    corr_test = pd.concat(df_list, ignore_index=True)
    corr_test = corr_test.melt(id_vars=['Subject ID', 'Sig Level', 'Component', 'Modality'], value_vars=['Att', 'Unatt'], var_name='Attention', value_name='Corr')
    return corr_test

def load_acc(modalities, table_name, column='Trial_len=30', table_parent='tables'):
    acc = []
    for modality in modalities:
        table = pd.read_csv(f'{table_parent}/Overlay/{modality}/{table_name}')
        table['Modality'] = modality
        table = table.rename(columns={column: 'Accuracy'})
        table['subject_num'] = table['Subject ID'].str.extract('(\d+)', expand=False).astype(int)
        table = table.sort_values(by='subject_num').drop(columns='subject_num')
        acc.append(table[['Subject ID', 'Accuracy', 'Modality']])
    acc_df = pd.concat(acc, ignore_index=True)
    return acc_df

def load_acc_permu_test(modality, file_name, table_parent='tables'):
    acc_file = f"{table_parent}/Overlay/{modality}/{file_name}.pkl"
    with open(acc_file, 'rb') as f:
        acc = pickle.load(f)
    acc_all = []
    for key in acc.keys():
        acc_all += acc[key]
    acc_all = np.array(acc_all)
    acc_all = np.array(acc_all).flatten()
    acc_sorted = np.sort(acc_all)
    nb_tests = len(acc_sorted)
    significance_level = 0.025
    threshold = acc_sorted[int((1-significance_level)*nb_tests)]
    return threshold

def load_isc(modalities, nb_comp=3, OVERLAY=True, table_parent='tables'):
    isc = []
    for modality in modalities:
        table_path = f'{table_parent}/Overlay/GCCA/OL_Single_Mod.csv' if OVERLAY else f'{table_parent}/Overlay/GCCA/SO_Single_Mod.csv'
        df_isc = pd.read_csv(table_path)
        df_isc = df_isc[df_isc['Modality'] == modality]
        comp_list = [f'CC {i}' for i in range(1, nb_comp+1)]
        isc.append(df_isc[df_isc['Component'].isin(comp_list)])
    isc_df = pd.concat(isc)
    return isc_df

def load_isc_folds(modalities, xth_comp=1, OVERLAY=True, table_parent='tables'):
    isc = []
    for modality in modalities:
        table_path = f'{table_parent}/Overlay/GCCA/OL_Folds.csv' if OVERLAY else f'{table_parent}/Overlay/GCCA/SO_Folds.csv'
        df_isc = pd.read_csv(table_path)
        df_isc = df_isc[df_isc['Modality'] == modality]
        df_isc = df_isc[df_isc['Component'] == f'CC {xth_comp}']
        isc.append(df_isc)
    isc_df = pd.concat(isc)
    return isc_df

def plot_corr_mod(df_long, mod, ax):
    df_long = df_long[df_long['Modality'] == mod]
    sig_level_pool = df_long['Sig Level'].mean()
    palette = dict(zip(df_long['Attention'].unique(), sns.color_palette()))
    boxprops = {'edgecolor': 'black', 'alpha': 0.5}
    sns.boxplot(x='Component', y='Corr', hue='Attention', data=df_long, showfliers=False, dodge=True, ax=ax, boxprops=boxprops, palette=palette)
    sns.stripplot(x='Component', y='Corr', hue='Attention', data=df_long, size=4, jitter=True, dodge=True, ax=ax, legend=False)
    ax.hlines(sig_level_pool, -0.5, 1.5, color='grey', linestyle='--')
    ax.set_title(mod)
    ax.set_ylabel('')
    ax.set_xlabel('')
    return ax.get_legend_handles_labels()

def wilcoxon_pvalue(comparisons, alternative, BH_correction=False, PAIRED=True):
    # comparisons is a list of tuples, where each tuple contains two lists of values to compare
    p_values = []
    for group_1, group_2 in comparisons:
        p_value = stats.wilcoxon(group_1, group_2, alternative=alternative).pvalue if PAIRED else stats.mannwhitneyu(group_1, group_2, alternative=alternative).pvalue
        p_values.append(p_value)
    if BH_correction:
        p_values = stats.false_discovery_control(p_values, method='bh')
    return p_values

def plot_pvalue(df, val_name, comparisons, alternative, BH_correction, ax, PAIRED=True, hpara=[1.5, 0]):
    modalities = df['Modality'].unique()
    comp_groups = []
    comp_groups = [(df[df['Modality'] == mod_pair[0]][val_name], df[df['Modality'] == mod_pair[1]][val_name]) for mod_pair in comparisons]
    p_values = wilcoxon_pvalue(comp_groups, alternative, BH_correction, PAIRED)
    for i in range(len(comparisons)):
        mod1, mod2 = comparisons[i]
        x1 = list(modalities).index(mod1)
        x2 = list(modalities).index(mod2)
        y, h, col = df[val_name].max() + hpara[0]*(x1+x2) - hpara[1], 0.5, 'k'
        ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1.5, c=col, clip_on=False)
        p_value = p_values[i]
        ax.text((x1+x2)*.5, y+h, f"p = {p_value:.3f}", ha='center', va='bottom', color=col)

def plot_acc(df, title, ax, comparisons=None, alternative=None, BH_correction=None, xlabel=True, hpara=[1.5, 0], top=100, sig_levels=None):
    # Create a copy of the dataframe to avoid modifying the original
    df_copy = df.copy()
    # Convert Accuracy to percentage
    df_copy['Accuracy'] = df_copy['Accuracy'] * 100
    sns.boxplot(x='Modality', y='Accuracy', data=df_copy, showfliers=False, color='lightgray', ax=ax)
    sns.stripplot(x='Modality', y='Accuracy', hue='Subject ID', data=df_copy, jitter=True, ax=ax, legend=False)
    if comparisons is not None:
        plot_pvalue(df_copy, 'Accuracy', comparisons, alternative, BH_correction, ax, hpara=hpara)
    # Add horizontal dotted lines for significance values if provided
    if sig_levels is not None:
        modalities = df_copy['Modality'].unique()
        for i, sig_val in enumerate(sig_levels):
            if i < len(modalities):
                x_pos = i
                ax.plot([x_pos-0.4, x_pos+0.4], [sig_val*100, sig_val*100], linestyle='--', color='grey')
    if title is not None:
        ax.set_title(title)
    ax.set_ylabel('Accuracy (%)')
    ax.set_xlabel('Modality') if xlabel else ax.set_xlabel('')
    ax.set_ylim(top=top)   # Set y-axis limit from 0% to 100%
    # Remove the top and right spines (boundaries)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.legend()

def plot_corr(df, comp, att, title, ax, comparisons=None, alternative=None, BH_correction=None):
    df = df[df['Component'] == comp]
    df = df[df['Attention'] == att]
    sns.boxplot(x='Modality', y='Corr', data=df, showfliers=False, color='lightgray', ax=ax)
    sns.stripplot(x='Modality', y='Corr', hue='Subject ID', data=df, jitter=True, ax=ax, legend=False)
    if comparisons is not None:
        plot_pvalue(df, 'Corr', comparisons, alternative, BH_correction, ax)
    if title is not None:
        ax.set_title(title)
    ax.set_ylabel('Accuracy')
    ax.set_xlabel('Modality')
    ax.legend()

def plot_compare(df_long, df_long_cp, df_mode, df_cp_mode, comp, mode, title, ax, alternative, BH_correction, PAIRED=True, Mod_to_Compare='All', SELEMOD=None, sig_levels=None, sig_levels_cp=None):
    df_long = df_long[df_long['Component'] == comp] if comp is not None else df_long
    df_long = df_long[df_long['Attention'] == mode] if mode is not None else df_long
    df_long_cp = df_long_cp[df_long_cp['Component'] == comp] if comp is not None else df_long_cp
    df_long_cp = df_long_cp[df_long_cp['Attention'] == mode] if mode is not None else df_long_cp
    MODE = 'CC' if comp is not None else 'Acc'
    # Add a new column to distinguish between masked and unmasked modalities
    df_long['CP'] = df_mode
    df_long_cp['CP'] = df_cp_mode
    # Concatenate the dataframes
    df_combined = pd.concat([df_long, df_long_cp])
    if MODE == 'Acc':
        df_combined['Value'] = df_combined['Value'] * 100
    modalities = df_combined['Modality'].unique() if Mod_to_Compare == 'All' else Mod_to_Compare
    comparisons = [(df_combined[(df_combined['Modality'] == modality) & (df_combined['CP'] == df_mode)]['Value'], df_combined[(df_combined['Modality'] == modality) & (df_combined['CP'] == df_cp_mode)]['Value']) for modality in modalities]
    p_values = wilcoxon_pvalue(comparisons, alternative, BH_correction, PAIRED)
    if SELEMOD is not None:
        selected_mod_idx = [i for i, mod in enumerate(modalities) if mod in SELEMOD]
        p_values = p_values[selected_mod_idx]
        modalities = modalities[selected_mod_idx]
        df_combined = df_combined[df_combined['Modality'].isin(SELEMOD)] 
    # Create a color palette that maps 'Modality' levels to colors
    boxprops = dict(edgecolor='black', alpha=0.5)
    sns.boxplot(x='Modality', y='Value', hue='CP', data=df_combined, showfliers=False, dodge=True, ax=ax, boxprops=boxprops)
    sns.stripplot(x='Modality', y='Value', hue='CP', data=df_combined, size=4, jitter=True, dodge=True, ax=ax, legend=False)
    # Get dodge width for proper horizontal line placement
    # For boxplots with two groups, the dodge width is usually 0.4
    dodge_width = 0.4
    # Add significance values and horizontal lines
    for i, modality in enumerate(modalities):
        # Display p-value text
        p_value = p_values[i]
        ax.text(i, df_combined['Value'].max(), f"p = {p_value:.3f}", ha='center', va='bottom')
        # Add significance level lines for first condition (df_mode)
        if sig_levels is not None and i < len(sig_levels):
            sig_val = sig_levels[i]
            # Position for first group is (i - dodge_width/2)
            x_pos = i - dodge_width/2
            ax.plot([x_pos - 0.2, x_pos + 0.2], [sig_val*100, sig_val*100], linestyle='--', color='grey')
        # Add significance level lines for second condition (df_cp_mode)
        if sig_levels_cp is not None and i < len(sig_levels_cp):
            sig_val_cp = sig_levels_cp[i]
            x_pos = i + dodge_width/2
            ax.plot([x_pos - 0.2, x_pos + 0.2], [sig_val_cp*100, sig_val_cp*100], linestyle='--', color='grey')
    ax.set_title(title) if title is not None else ax.set_title('')
    ax.set_ylabel('Correlation Coefficient') if MODE=='CC' else ax.set_ylabel('Accuracy (%)')
    ax.set_xlabel('Modality')

def adjust_legend(axes, loc='best'):
    for ax in axes:
        handles, labels = ax.get_legend_handles_labels()
        unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
        ax.legend(*zip(*unique), loc=loc)

def plot_CC_pairs(df_long, mode='Att'):
    df_long = df_long[df_long['Attention'] == mode]
    modalities = df_long['Modality'].unique()
    fig, axes = plt.subplots(1, len(modalities), sharey=True, figsize=(10, 4))
    for mod in modalities:
        ax = axes[modalities.tolist().index(mod)]
        df_mod = df_long.loc[df_long['Modality'] == mod].copy()  # Create a copy to avoid SettingWithCopyWarning
        df_mod.loc[:, 'Component'] = (df_mod['Component'] != 'CC 1').astype(float)
        # Create a new column 'Above Sig Level' that indicates whether the correlation for 'CC1' or 'CC2' is above the significance level
        df_mod.loc[:, 'Above Sig Level'] = df_mod.groupby('Subject ID')['Corr'].transform(lambda x: any(x > df_mod['Sig Level'].mean()))
        subjects = df_mod['Subject ID'].unique()
        num_sig_subjects = sum(df_mod.groupby('Subject ID')['Above Sig Level'].any())
        for subject in subjects:
            df_subject = df_mod[df_mod['Subject ID'] == subject]
            color = 'red' if df_subject['Above Sig Level'].any() else '#3C74BC'
            sns.lineplot(x='Component', y='Corr', data=df_subject, ax=ax, marker='o', color=color, legend=False)
        ax.set_xlabel(mod)
        ax.set_ylabel('Correlation Coefficient')
        ax.set_xticks([0, 1], ['CC 1', 'CC 2'])
        ax.legend().remove()
        # Add the number of significant subjects over the total number of subjects to the top of the plot
        ax.text(0.5, 1.02, f'#sig_subj/#subj\n{num_sig_subjects}/{len(subjects)}', transform=ax.transAxes, ha='center', va='bottom')

def plot_acc_pairs(df, cond1, cond2, condition='Modality', title=None, ax=None, show_legend=True, sig_levels=None):
    # Create a figure
    # Get unique subjects that appear in both modalities
    if ax is None:
        fig = plt.figure(figsize=(5, 5))
        ax = fig.add_subplot(111)
    subjects = []
    for subject in df["Subject ID"].unique():
        if len(df[(df["Subject ID"] == subject) & (df[condition] == cond1)]) > 0 and \
        len(df[(df["Subject ID"] == subject) & (df[condition] == cond2)]) > 0:
            subjects.append(subject)
    # Set up x-positions for each modality
    x_mod1 = 1
    x_mod2 = 2
    x_positions = [x_mod1, x_mod2]
    x_labels = [cond1, cond2]
    colors = plt.cm.tab20(np.linspace(0, 1, 19))
    # Plot data points and connecting lines for each subject
    for i, subject in enumerate(subjects):
        acc_1 = df[(df["Subject ID"] == subject) & (df[condition] == cond1)]["Accuracy"].values
        acc_2 = df[(df["Subject ID"] == subject) & (df[condition] == cond2)]["Accuracy"].values
        
        if len(acc_1) > 0 and len(acc_2) > 0:
            # Plot the points
            ax.plot([x_mod1, x_mod2], [acc_1[0]*100, acc_2[0]*100], 'o-', label=subject, color=colors[i % 19])

    # mark the mean accuracy for each modality, connected by a line
    mean_acc_1 = df[(df[condition] == cond1)]['Accuracy'].mean() * 100
    mean_acc_2 = df[(df[condition] == cond2)]['Accuracy'].mean() * 100
    ax.plot([x_mod1, x_mod2], [mean_acc_1, mean_acc_2], '^--', color='black', label='Mean')

    if sig_levels is not None:
        ax.hlines(sig_levels[0]*100, x_mod1-0.2, x_mod1+0.2, colors='grey', linestyles='--')
        ax.hlines(sig_levels[1]*100, x_mod2-0.2, x_mod2+0.2, colors='grey', linestyles='--')

    # Add some jitter to prevent overlapping points
    ax.set_xticks(x_positions)
    ax.set_xticklabels(x_labels)
    ax.set_xlabel(condition)
    ax.set_ylabel("Accuracy (%)")
    if show_legend:
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    if title is not None:
        ax.set_title(title)
    # set the y-axis limit from 0% to 100%
    ax.set_xlim(0.75, 2.25)

def plot_isc(modalities, ax, SELEMOD=None):
    isc_df_sg_folds = load_isc_folds(modalities, xth_comp=1, OVERLAY=False, table_parent='tables')
    isc_df_sg_folds = isc_df_sg_folds.rename(columns={'ISC': 'Value'})
    isc_df_ol_folds = load_isc_folds(modalities, xth_comp=1, OVERLAY=True, table_parent='tables')
    isc_df_ol_folds = isc_df_ol_folds.rename(columns={'ISC': 'Value'})

    plot_compare(isc_df_sg_folds, isc_df_ol_folds, 'Single-object dataset', 'Superimposed-object dataset', 'CC 1', None, None, ax, 'greater', True, PAIRED=True, SELEMOD=SELEMOD)
    if SELEMOD is None:
        modalities = modalities
    else:
        modalities = SELEMOD
    for i, mod in enumerate(modalities):
        sig_level_SO = isc_df_sg_folds[isc_df_sg_folds['Modality'] == mod]['Sig Level (ISC)'].mean()
        sig_level_OL = isc_df_ol_folds[isc_df_ol_folds['Modality'] == mod]['Sig Level (ISC)'].mean()
        ax.hlines(y=sig_level_SO, xmin=i-0.45, xmax=i-0.03, linestyle='--', color='grey')
        ax.hlines(y=sig_level_OL, xmin=i+0.03, xmax=i+0.45, linestyle='--', color='grey')
    ax.set_xlabel('')
    ax.set_ylabel('ISC')
    h, l = ax.get_legend_handles_labels()
    ax.legend().remove() 
    significance_line = Line2D([0], [0], color='grey', linestyle='--')
    h.append(significance_line)
    l.append('Significance Level')
    return h, l

In [None]:
table_parent='tables'
fig_save_path = '../../Manuscript/2nd/Submission_final/latex-files/images/'
BOOTSTRAP = True

## Correlations are modulated by attention

In [None]:
modalities = ['EEG', 'EOG', 'GAZE', 'GAZE_V', 'EOG_V', 'SACC']
table_name = 'ObjFlow_Corr_Train_Att_Mask_False_REG_False.csv'
corr_df = load_corr(modalities, table_name, table_parent=table_parent)

plt.close('all')
fig, axes = plt.subplots(1, len(modalities), figsize=(12, 5), sharey=True)
for i, ax in enumerate(axes):
    h, l = plot_corr_mod(corr_df, modalities[i], ax)
    ax.legend().remove()  # Remove the legend from each subplot
axes[0].set_ylabel('Correlation')
significance_line = mlines.Line2D([], [], color='grey', linestyle='--', label='Significance Level')
h.append(significance_line)
l.append('Significance Level')
fig.legend(h, l, loc='upper center', ncol=len(h))
plt.show()
# plt.savefig(fig_save_path + 'corr_mod.pdf', dpi=600)

## Performance of visual attention decoding tasks

In [None]:
modalities = ['EEG', 'EOG', 'GAZE', 'GAZE_V', 'EOG_V', 'SACC']
sig_levels = []
for mod in modalities:
    sig_levels.append(load_acc_permu_test(mod, 'ObjFlow_Acc_Permu_Train_Att_Mask_False_trial_len30', table_parent=table_parent))
table_name = f"ObjFlow_SVAD_Train_Att_Mask_False_REG_False{'_noBOOTSTRAP' if not BOOTSTRAP else ''}.csv"
acc_df = load_acc(modalities, table_name, table_parent=table_parent, column='Trial_len=30')
comparisons = [('EEG', 'GAZE_V'), ('GAZE_V', 'SACC'), ('EEG', 'SACC')]
alternative = 'two-sided' # Whether the compared distributions are significantly different
BH_correction = True
plt.close('all')
fig, ax = plt.figure(figsize=(6, 3.5)), plt.gca()
plot_acc(acc_df, None, ax, comparisons, alternative, BH_correction, top=85, sig_levels=sig_levels)
plt.tight_layout()
plt.show()
plt.savefig(fig_save_path + 'acc_vad_mod.pdf', dpi=600, bbox_inches='tight', pad_inches=0)

In [None]:
modalities = ['EEG', 'GAZE_V', 'SACC']
sig_levels = []
sig_levels_cp = []
for mod in modalities:
    sig_levels.append(load_acc_permu_test(mod, 'ObjFlow_Acc_Permu_Train_Att_Mask_False_trial_len30', table_parent=table_parent))
    sig_levels_cp.append(load_acc_permu_test(mod, 'ObjFlow_Acc_Permu_Train_SO_Mask_False_trial_len30', table_parent=table_parent))
table_name = f"ObjFlow_SVAD_Indpd_Train_SO_Mask_False_REG_False{'_noBOOTSTRAP' if not BOOTSTRAP else ''}.csv"
acc_df_SO = load_acc(modalities, table_name, table_parent=table_parent, column='Trial_len=30')
acc_df_SO = acc_df_SO.rename(columns={'Accuracy': 'Value'})

table_name = f"ObjFlow_SVAD_Train_Att_Mask_False_REG_False{'_noBOOTSTRAP' if not BOOTSTRAP else ''}.csv"
acc_df = load_acc(modalities, table_name, table_parent=table_parent, column='Trial_len=30')
acc_df = acc_df.rename(columns={'Accuracy': 'Value'})

plt.close('all')
fig, ax = plt.figure(figsize=(8, 5)), plt.gca()
plot_compare(acc_df, acc_df_SO, 'Train with superimposed objects', 'Train with single object', None, None, None, ax, 'two-sided', True, sig_levels=sig_levels, sig_levels_cp=sig_levels_cp)
adjust_legend([ax])

plt.savefig(fig_save_path + 'acc_vad_SO.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)

## Performance when regressing out eye movements from EEG

In [None]:
table_name = f"ObjFlow_SVAD_Train_Att_Mask_False_REG_False{'_noBOOTSTRAP' if not BOOTSTRAP else ''}.csv"
modalities = ['EEG', 'EEG-EOG&GAZE_V']
sig_levels = []
for mod in modalities:
    sig_levels.append(load_acc_permu_test(mod, 'ObjFlow_Acc_Permu_Train_Att_Mask_False_trial_len30', table_parent=table_parent))
acc_reg_df = load_acc(modalities, table_name, table_parent=table_parent, column='Trial_len=30')
comparisons = [('EEG', 'EEG-EOG&GAZE_V')]
alternative = 'greater' # Whether the compared distributions are significantly different
BH_correction = True
plt.close('all')
fig, ax = plt.figure(figsize=(5, 5)), plt.gca()
# plot_acc(acc_reg_df, None, ax, comparisons, alternative, BH_correction)
# plot_acc(acc_reg_df, None, ax, top=85, sig_levels=sig_levels)
plot_acc_pairs(acc_reg_df, 'EEG', 'EEG-EOG&GAZE_V', title=None, ax=ax, show_legend=True, sig_levels=sig_levels)
plt.tight_layout()
plt.show()
plt.savefig(fig_save_path + 'acc_reg_eye.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)

## Performance when using a subset of electrodes in the visual cortex

In [None]:
regions = ['frontal', 'frontal_central', 'temporal', 'central_parietal', 'parietal_occipital']
# load and concatenate the dataframes of different regions
sig_levels = []
sig_levels.append(load_acc_permu_test('EEG', f'ObjFlow_Acc_Permu_Train_Att_Mask_False_trial_len30', table_parent=table_parent))
acc_df = []
ifREG = False
modalities = ['EEG']
table_name = f"ObjFlow_SVAD_Train_Att_Mask_False_REG_{ifREG}{'_noBOOTSTRAP' if not BOOTSTRAP else ''}.csv"
acc = load_acc(modalities, table_name, table_parent=table_parent, column='Trial_len=30')
acc['Modality'] = 'whole brain'
acc_df.append(acc)

for region in reversed(regions):
    table_name = f"{region}-OF_SVAD_Indpd_Train_Att_Mask_False_REG_{ifREG}{'_noBOOTSTRAP' if not BOOTSTRAP else ''}.csv"
    acc = load_acc(modalities, table_name, table_parent=table_parent, column='Trial_len=30')
    acc['Modality'] = region
    acc_df.append(acc)
    sig_levels.append(load_acc_permu_test('EEG', f'{region}-OF_Acc_Permu_Train_Att_Mask_False_trial_len30', table_parent=table_parent))
acc_df = pd.concat(acc_df)

plt.close('all')
fig, ax = plt.figure(figsize=(9, 4)), plt.gca()
comparisons = [('whole brain', 'parietal_occipital'), ('whole brain', 'frontal_central'), ('whole brain', 'frontal'), ('whole brain', 'temporal'), ('whole brain', 'central_parietal')]
alternative = 'greater' # Whether the compared distributions are significantly different
BH_correction = True
plot_acc(acc_df, None, ax, comparisons, alternative, BH_correction, xlabel=False, hpara=[2.8, 0], top=85, sig_levels=sig_levels)
plt.show()
plt.tight_layout()
plt.savefig(fig_save_path + 'acc_regions.pdf', dpi=600, bbox_inches='tight', pad_inches=0)

## Performance after removing saccades

In [None]:
modalities = ['EEG']
acc_df = []
table_name = f"ObjFlow_SVAD_Train_Aug_Multisubj_Syn_False_REG_False{'_noBOOTSTRAP' if not BOOTSTRAP else ''}.csv"
acc = load_acc(modalities, table_name, column='Trial_len=30', table_parent=table_parent)
acc['Modality'] = 'No SACC'
acc_df.append(acc)
for run in range(10):
    table_name = f"Run{run+1}-ObjFlow_SVAD_Train_Aug_Multisubj_Syn_True_REG_False{'_noBOOTSTRAP' if not BOOTSTRAP else ''}.csv"
    acc = load_acc(modalities, table_name, table_parent=table_parent)
    acc['Modality'] = f'Control-{run+1}'
    acc_df.append(acc)
acc_df = pd.concat(acc_df)

plt.close('all')
fig, ax = plt.figure(figsize=(12, 6)), plt.gca()
comparisons = [('No SACC', 'Control-1'), ('No SACC', 'Control-2'), ('No SACC', 'Control-3'), ('No SACC', 'Control-4'), ('No SACC', 'Control-5'), ('No SACC', 'Control-6'), ('No SACC', 'Control-7'), ('No SACC', 'Control-8'), ('No SACC', 'Control-9'), ('No SACC', 'Control-10')]
alternative = 'less' # Whether the compared distributions are significantly different
BH_correction = True
plot_acc(acc_df, None, ax, comparisons, alternative, BH_correction, xlabel=False, hpara=[2.5, 3])
plt.show()
# plt.savefig(fig_save_path + 'acc_wo_sacc.jpg', dpi=300)


## MM vs VAD

In [None]:
modalities = ['EEG', 'EEG-EOG&GAZE_V']
sig_levels = []
sig_levels_cp = []
for mod in modalities:
    sig_levels.append(load_acc_permu_test(mod, 'ObjFlow_Acc_Permu_Train_Att_Mask_False_trial_len30', table_parent=table_parent))
    sig_levels_cp.append(load_acc_permu_test(mod, 'ObjFlow_Acc_Permu_Train_Att_Mask_False_trial_len30', table_parent=table_parent))
ifREG = False
table_name = f"ObjFlow_MM_Train_Att_Mask_False_REG_{ifREG}{'_noBOOTSTRAP' if not BOOTSTRAP else ''}.csv"
acc_mm_df = load_acc(modalities, table_name, table_parent=table_parent, column='Trial_len=30')
acc_mm_df = acc_mm_df.rename(columns={'Accuracy': 'Value'})
table_name = f"ObjFlow_SVAD_Train_Att_Mask_False_REG_{ifREG}{'_noBOOTSTRAP' if not BOOTSTRAP else ''}.csv"
acc_df = load_acc(modalities, table_name, table_parent=table_parent, column='Trial_len=30')
acc_df = acc_df.rename(columns={'Accuracy': 'Value'})

plt.close('all')
fig, ax = plt.figure(figsize=(6, 4)), plt.gca()
plot_compare(acc_df, acc_mm_df, 'SVAD', 'MM', None, None, None, ax, 'less', True, sig_levels=sig_levels, sig_levels_cp=sig_levels_cp)
adjust_legend([ax])
plt.savefig(fig_save_path + 'acc_vad_mm.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)

In [None]:
modalities = ['EEG', 'EEG-EOG&GAZE_V']
nb_subj = 19
# table_parent='tables'
# create a dataframe to store the correlation coefficients
df = pd.DataFrame(columns=['Subject ID', 'Modality', 'Correlation', 'Type'])
i = 0
for modality in modalities:
    for Subj_ID in range(nb_subj):
        file = f"{table_parent}/Overlay/{modality}/ObjFlow_{Subj_ID}_Corrdict_Len_30_REG_False.pkl"
        with open(file, 'rb') as f:
            corr_dict = pickle.load(f)
        eeg_att = corr_dict['EEG_Att_VAD'][:,:2]
        eeg_att_avg = np.sum(np.mean(eeg_att, axis=0))
        df.loc[i] = ['Subj '+str(Subj_ID+1), modality, eeg_att_avg, 'Data_Att(Match)']
        eeg_unatt = corr_dict['EEG_Unatt'][:,:2]
        eeg_unatt_avg = np.sum(np.mean(eeg_unatt, axis=0))
        df.loc[i+1] = ['Subj '+str(Subj_ID+1), modality, eeg_unatt_avg, 'Data_Unatt']
        eeg_mismatch = corr_dict['EEG_MM'][:,:2]
        eeg_mismatch_avg = np.sum(np.mean(eeg_mismatch, axis=0))
        df.loc[i+2] = ['Subj '+str(Subj_ID+1), modality, eeg_mismatch_avg, 'Data_Mismatch']
        i += 3

sig_levels = []
for mod in modalities:
    sig_levels.append(load_acc_permu_test(mod, 'ObjFlow_AvgCorr_Permu_Train_Att_Mask_False_trial_len30', table_parent=table_parent))
plt.close('all')
fig, ax = plt.figure(figsize=(7, 4)), plt.gca()
corr_plot = []
for i, mod in enumerate(modalities):
    table = df[df['Modality'] == mod].copy()
    table['subject_num'] = table['Subject ID'].str.extract('(\d+)', expand=False).astype(int)
    table = table.sort_values(by='subject_num').drop(columns='subject_num')
    corr_plot.append(table)
df_plot = pd.concat(corr_plot, ignore_index=True)
boxprops = dict(edgecolor='black', alpha=0.5)
sns.boxplot(x='Modality', y='Correlation', data=df_plot, hue='Type', dodge=True, ax=ax, showfliers=False, boxprops=boxprops)
sns.stripplot(x='Modality', y='Correlation', data=df_plot, hue='Type', jitter=True, dodge=True, ax=ax, legend=False)
# plot significance levels
for i, mod in enumerate(modalities):
    sig_val = sig_levels[i]
    ax.plot([i-0.4, i+0.4], [sig_val, sig_val], linestyle='--', color='grey')
plt.show()
plt.savefig(fig_save_path + 'eeg_att_unatt_mm.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)

In [None]:
Subj_ID = 0
modality = 'EEG'
feat_name = 'ObjFlow'
REGFEATS = False

trial_len = 30
file = f'{table_parent}/Overlay/{modality}/{feat_name}_{Subj_ID}_Corrdict_Len_{trial_len}_REG_{REGFEATS}.pkl'
with open(file, 'rb') as f:
    corr_dict = pickle.load(f)
Att_Unatt_30 = corr_dict['Att_Unatt']
Att_MM_30 = corr_dict['Att_MM']

trial_len = 5
file = f'{table_parent}/Overlay/{modality}/{feat_name}_{Subj_ID}_Corrdict_Len_{trial_len}_REG_{REGFEATS}.pkl'
with open(file, 'rb') as f:
    corr_dict = pickle.load(f)
Att_Unatt_5 = corr_dict['Att_Unatt']
Att_MM_5 = corr_dict['Att_MM']

fig, axes = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(6, 4))
sns.kdeplot(Att_Unatt_30, color='blue', label='Att-Unatt', ax=axes[0])
sns.kdeplot(Att_MM_30, color='red', label='Att-Mismatch', ax=axes[0], linestyle='--')
axes[0].set_title('30-second segments')
axes[0].set_xlabel('Correlation Coefficient') 
axes[0].set_xlim(-1, 1)
sns.kdeplot(Att_Unatt_5, color='blue', label='Att-Unatt', ax=axes[1])
sns.kdeplot(Att_MM_5, color='red', label='Att-Mismatch', ax=axes[1], linestyle='--')
axes[1].set_title('5-second segments')
axes[1].set_xlabel('Correlation Coefficient')
axes[1].set_xlim(-1, 1) 
plt.legend()
plt.show()
# plt.savefig(fig_save_path + 'att_unatt_mm.jpg', dpi=300, bbox_inches='tight', pad_inches=0.05)

## Combining EEG and GAZE_V

In [None]:
modalities = ['EEG', 'GAZE_V', 'EEG+GAZE_V']
sig_levels = []
for mod in modalities:
    sig_levels.append(load_acc_permu_test(mod, 'ObjFlow_Acc_Permu_Train_Att_Mask_False_trial_len30', table_parent=table_parent))
table_name = f"ObjFlow_SVAD_Train_Att_Mask_False_REG_False{'_noBOOTSTRAP' if not BOOTSTRAP else ''}.csv"
acc_comb_df = load_acc(modalities, table_name, table_parent=table_parent, column='Trial_len=30')
table_name = f"ObjFlow_MM_Train_Att_Mask_False_REG_False{'_noBOOTSTRAP' if not BOOTSTRAP else ''}.csv"
acc_comb_mm_df = load_acc(modalities, table_name, table_parent=table_parent, column='Trial_len=30')

plt.close('all')
fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharey=True)
comparisons = [('EEG', 'EEG+GAZE_V'), ('GAZE_V', 'EEG+GAZE_V')]
alternative = 'less'
BH_correction = True
plot_acc(acc_comb_df, 'Selective visual attention decoding', axes[0], comparisons, alternative, BH_correction, top=95, hpara=[4, 5.5], sig_levels=sig_levels)
plot_acc(acc_comb_mm_df, 'Match-mismatch', axes[1], comparisons, alternative, BH_correction, top=95, hpara=[4, 6.5], sig_levels=sig_levels)
plt.show()
plt.savefig(fig_save_path + 'acc_vad_mm_comb.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)

## GCCA

In [None]:
modalities = ['EEG-EOG&GAZE_V', 'EOG', 'GAZE', 'GAZE_V', 'EOG_V', 'SACC']
SELEMOD_1 = ['EEG-EOG&GAZE_V', 'GAZE_V', 'EOG_V', 'SACC']
SELEMOD_2 = ['EOG', 'GAZE']

fig = plt.figure(figsize=(10, 4))

# Create a GridSpec with 1 row and 2 columns
gs = GridSpec(1, 2, width_ratios=[2, 1])  # Adjust width_ratios to set the size of each subplot
# Create subplots with different sizes
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1])
h, l = plot_isc(modalities, ax1, SELEMOD_1)
h, l = plot_isc(modalities, ax2, SELEMOD_2)
fig.legend(h, l, loc='upper center', ncol=len(h))
plt.show()
plt.savefig(fig_save_path + 'isc_fold.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)