### Figures for within-participant regression analysis 
1) Results Figure 2 (average timeseries, snapshot roseplots (removed in revision), snapshot topoplots, example timeseries)
2) Supplementary Figure S1 (timeseries for all dimensions)
3) One example timeseries
4) Dynamic roseplot
5) Dynamic topoplot

In [None]:
"""
Author: linateichmann
Email: lina.teichmann@nih.gov

    Created on 2023-03-30 12:40:16
    Modified on 2023-03-30 12:40:16
"""


import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import mne, os
import seaborn as sns
from matplotlib.patches import ConnectionPatch
import matplotlib.animation as animation
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Rectangle
import matplotlib.patches as mpatches

%matplotlib qt

# bids_dir = '/System/Volumes/Data/misc/data16/teichmanna2/2020_MEG_things/THINGS_biowulf/THINGS-MEG-bids/'
bids_dir = '/Volumes/THINGS-MEG/THINGS-MEG-bids'
res_folder = f'{bids_dir}/derivatives/meg_paper/output/regression'
font = 'Arial'
text_size = 12
text_size_big = 16
text_size_small = 6

plt.rcParams['font.size'] = text_size
plt.rcParams['font.family'] = font

colours = pd.read_csv(f'{bids_dir}/sourcedata/meg_paper/colors66.txt',header=None)
labels = pd.read_csv(f'{bids_dir}/sourcedata/meg_paper/labels_super_short.txt',header=None)


# load data
def load_epochs(bids_dir,participant,behav):
    epochs = mne.read_epochs(f'{bids_dir}/derivatives/preprocessed/preprocessed_P{participant}-epo.fif',preload=False)
    # THINGS-category number & image number start at 1 (matlab based) so subtracting 1 to use for indexing
    epochs.metadata['things_image_nr'] = epochs.metadata['things_image_nr']-1
    epochs.metadata['things_category_nr'] = epochs.metadata['things_category_nr']-1
    # adding dimensional weights to metadata
    epochs.metadata[['dim'+ str(d+1) for d in range(66)]] = np.nan
    for i in range(len(epochs.metadata)):
        if not np.isnan(epochs.metadata.loc[i,'things_image_nr']):
            epochs.metadata.loc[i,['dim'+ str(d+1) for d in range(66)]] = behav[int(epochs.metadata.loc[i,'things_image_nr']),:]
    return epochs


behav = np.loadtxt(f'{bids_dir}/sourcedata/meg_paper/predictions_66d_elastic_clip-ViT-B-32_visual_THINGS.txt')
epochs = load_epochs(bids_dir,1,behav)

epochs_exp = epochs[(epochs.metadata['trial_type']=='exp')]   
epochs_exp.metadata.reset_index(inplace=True,drop=True)


# load results
all_dat_dims=[]
all_dat_sens=[]

for p in range(1,5):
    corr_ridge= []
    for cv in range(1,13):
        print(f'loading results from participant {p} cv {cv}')
        y_pred = np.load(f'{res_folder}/P{p}_ridge-reg_within_predict-dims_ypred_cv{cv}.npy')
        y_true = np.load(f'{res_folder}/P{p}_ridge-reg_within_predict-dims_ytrue_cv{cv}.npy')
        corr_t = []
        for t in range(y_pred.shape[0]):
            corr_t.append([np.corrcoef(y_pred[t,:,d],y_true[:,d])[0,1] for d in range(66)])
        corr_ridge.append(np.array(corr_t))
        
    all_dat_dims.append(np.mean(corr_ridge,axis=0))
    all_dat_sens.append(pd.read_csv(f'{res_folder}/P{p}_linreg_within_predict-sens.csv',index_col=0))

# change data order based on peak amplitude
avg_dim_data = np.mean(all_dat_dims,axis=0)
sorted_idx = np.flipud(np.argsort(np.max(avg_dim_data,axis=0)))

avg_dim_data_ranked = avg_dim_data[:,sorted_idx]
dim_data_ranked = np.array(all_dat_dims)[:,:,sorted_idx]
colours_ranked = colours.to_numpy()[sorted_idx,:]
labels_ranked = labels.loc[sorted_idx,:]
labels_ranked.reset_index(inplace=True,drop=True)

filter_col = [col for col in epochs_exp.metadata.columns if col.startswith('dim')] 
weights_sorted_dims = epochs_exp.metadata.loc[:,filter_col].to_numpy()[:,sorted_idx]


labels_ranked_list = labels_ranked.loc[:,0].to_list()
labels_ranked_list = [item.replace('/', '') for item in labels_ranked_list]

In [2]:
## load permutation results and check for significance at an individual level
significance_mat,significance_mat_avg = [],[]
for participant in range(4):
    # load permutations
    corr_ridge_perm_all = np.load(f'{res_folder}/P{participant+1}_ridge-reg_within_predict-dims_permutations.npy') 
    # average over cross-validation splits
    corr_ridge_perm_avgcv = np.mean(corr_ridge_perm_all,axis=0)
    # get 95%ile for each timepoint and each dimension null-distribution
    thresholds = np.array([[np.percentile(corr_ridge_perm_avgcv[:,t,d],99) for t in range(281)] for d in range(66)])
    # threshold defined as the max 95%ile across time and dimensions
    threshold = np.max(thresholds)

    significance_mat.append(np.array(all_dat_dims)[participant,:,:]>threshold)
    significance_mat_avg.append(np.mean(np.array(all_dat_dims)[participant,:,:],axis=1)>threshold)


# generate one matrix that is participants x dimensions x time that tells us when stuff is significant    
significance_mat = np.array(significance_mat)
significance_mat_avg = np.array(significance_mat_avg)
significance_mat_rank = significance_mat[:,:,sorted_idx]

In [3]:
# MAKE RESULTS FIGURE
dim_ranks = [int(i) for i in np.round(np.linspace(0,65,6))]

plt.close('all')
axd = plt.figure(constrained_layout=True, figsize=(8.25, 11.75)).subplot_mosaic(
    """
    .AAAAAA
    .AAAAAA
    .BBBCCC
    .BBBCCC
    .DDDEEE
    .DDDEEE
    .FFFGGG
    .FFFGGG
    .......
    """
)

fig = plt.gcf()
#### make figure
axd['A'].plot(epochs.times*1000, epochs.times*0, 'grey', lw=1, linestyle='--')

# plot individual data and mean 
for p in range(4):
    axd['A'].plot(epochs.times*1000,all_dat_dims[p].mean(axis=1),'k',alpha=0.3,lw=1)
    sig_tp = (epochs.times*1000)[significance_mat_avg[p,:]]
    axd['A'].plot(sig_tp,np.repeat(-.02-(p*1/100),len(sig_tp==1)),'k.',alpha=0.3,markersize=3)
    axd['A'].text(1310,-.025-(p*1/100),f'S0{p+1}: p<0.01',c='k',fontsize=5.8,ha='left')
    
    
axd['A'].plot(epochs.times*1000,np.mean(np.mean(all_dat_dims,axis=0),axis=1),'k',alpha=1,lw=3)


axd['A'].set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])
axd['A'].set_ylim([-0.07,0.15])
axd['A'].set_xlabel('time (ms)')
axd['A'].set_ylabel('Correlation:\n true vs predicted label')

axd['A'].spines['top'].set_visible(False)
axd['A'].spines['right'].set_visible(False)


ts = np.arange(-100,1400,200)
axd['A'].set_xticks(ts)

# add topoplots
for t in ts:
    t_idx = np.where(epochs.times*1000 == t)[0][0]
    height_topo = .2
    width_topo = 150
    axin2 = axd['A'].inset_axes([t-width_topo/2,-0.32, width_topo,height_topo], transform= axd['A'].transData)

    tix = np.where(t==epochs.times*1000)[0][0]
    topo,cm=mne.viz.plot_topomap(np.mean(all_dat_sens,axis=0)[tix,:], epochs.info,vlim=[0,0.2],cmap='RdPu',axes=axin2,
                                    sensors=False,contours=False)

    if t == ts[0]:
        cbar_ax = fig.add_axes([0.04,0.72,0.008,0.05])
        clb = fig.colorbar(topo, cax=cbar_ax)

    # add arrows
    con = ConnectionPatch(xyA=(t,axd['A'].get_ylim()[0]), xyB=(0,axin2.get_ylim()[1]), 
                            coordsA="data", coordsB="data",
                            axesA=axd['A'], axesB=axin2, color="black", linestyle='--',arrowstyle="-|>")
    axd['A'].add_artist(con)

    if tix!=0:
        ypoint = np.max([all_dat_dims[p].mean(axis=1)[tix] for p in range(4)])+0.01
    else:
        ypoint = axd['A'].get_ylim()[1]+0.01

    plt.setp(axd['A'].get_xticklabels(), backgroundcolor="white")


# add individual timeseries
for rank,ax in zip(dim_ranks,[axd['B'],axd['C'],axd['D'],axd['E'],axd['F'],axd['G']]):
    ax.plot(epochs.times*1000, epochs.times*0, 'grey', lw=1, linestyle='--')
    
    for p in range(4):
        ax.plot(epochs.times*1000,dim_data_ranked[p,:,rank],c=colours_ranked[rank,:],alpha=0.3,lw=1)
        sig_tp = (epochs.times*1000)[significance_mat_rank[p,:,rank]]
        ax.plot(sig_tp,np.repeat(-.02-(p*1/50),len(sig_tp==1)),'k.',mfc=colours_ranked[rank,:],mec=colours_ranked[rank,:],alpha=0.3,markersize=3)
        # ax.text(1310,-.025-(p*1/50),f'S0{p+1}: p<0.01',c='k',fontsize=4.5,ha='left')
        
    ax.plot(epochs.times*1000,avg_dim_data_ranked[:,rank],c=colours_ranked[rank,:],lw=2)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylim([-0.09,0.5])
    ax.set_ylabel('Correlation')
    ax.set_xlabel('time (ms)')
    ax.set_title(' ')
    ax.text(700,0.5,f'{labels_ranked.loc[rank,0]} ({sorted_idx[rank]+1})',c='k',fontsize=text_size,ha='center')

    ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])



# add example images
labels_ranked_list
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

example_im_path = f'{bids_dir}/sourcedata/meg_paper/example_images/'

for ii,x in enumerate(np.arange(0.1,0.9,0.15)):
    for i,dim in enumerate(dim_ranks):
        image = plt.imread(f'{example_im_path}/{labels_ranked_list[dim]}/{ii}.jpg')
        imagebox = OffsetImage(image, zoom=0.08)
        ab = AnnotationBbox(imagebox, (x, .73), frameon=False, xycoords='axes fraction', boxcoords="offset points", pad=0)
        axd[chr(ord('B')+i)].add_artist(ab)

# labels
axd['B'].text(-400,0.3,' ',fontsize=text_size_big*3)
fig.text(0.01,0.97,'A',fontsize=text_size_big*2)
fig.text(0.01,0.67,'B',fontsize=text_size_big*2)

fig.savefig(f'{bids_dir}/derivatives/meg_paper/figures/Figure2.pdf')





In [4]:
# supplementary figure showing all dimension timecourses
plt.close('all')
fig,axs = plt.subplots(figsize=(8.25, 11.75),num=2,ncols=6,nrows=11,sharex=True,sharey=True)

axs=axs.flatten()

for i,ax in enumerate(axs):
    ax.plot(epochs.times*1000,epochs.times*0,'grey',linestyle='--')
    for p in range(4):
        ax.plot(epochs.times*1000,np.array(dim_data_ranked)[p,:,i],c=colours_ranked[i,:],alpha=0.3,lw=1)
        sig_tp = (epochs.times*1000)[significance_mat_rank[p,:,i]]
        ax.plot(sig_tp,np.repeat(-.02-(p*1/50),len(sig_tp==1)),'k.',mfc=colours_ranked[i,:],mec=colours_ranked[i,:],alpha=0.3,markersize=.8)

    ax.plot(epochs.times*1000,avg_dim_data_ranked[:,i],c=colours_ranked[i,:],lw=2)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylim([-0.11,0.3])

    ax.set_title(f'{labels_ranked.loc[i,0]}\n({sorted_idx[i]+1})',c='k',fontsize=text_size_small)
    ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])


fig.supylabel('Correlation: true vs predicted label')
fig.supxlabel('time (ms)')
fig.tight_layout()
fig.savefig(f'{bids_dir}/derivatives/meg_paper/figures/Supplementary_within.pdf')



In [4]:
# figure for example timeseries
fig,ax = plt.subplots(1,1,figsize=(4,3))

i = 7
ax.plot(epochs.times*1000,epochs.times*0,'grey',linestyle='--')
ax.plot(epochs.times*1000,avg_dim_data_ranked[:,i],c=colours_ranked[i,:],alpha=1,lw=2)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_ylim([-0.05,0.3])
ax.set_title('Dimension 11: ' + str(labels_ranked.loc[i,0]),c=colours_ranked[i,:],fontsize=text_size)
ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])


ax.set_ylabel('Correlation:\ntrue vs predicted label')
ax.set_xlabel('time (ms)')

fig.tight_layout()
fig.savefig(f'{bids_dir}/sourcedata/meg_paper/example_timeseries.png', transparent=True)


In [27]:
# roseplot animation
plt.close('all')
n_dims = 66
theta = np.linspace(0.0, 2 * np.pi, n_dims, endpoint=False)
n_dims = np.array(dim_data_ranked).shape[2]
width = np.pi/n_dims

fig= plt.figure(num=1, figsize=(10,7))
ax = plt.subplot(projection='polar')
title = ax.text(0.5,0.85, "", bbox={'facecolor':'w', 'alpha':0.5, 'pad':5},
                transform=ax.transAxes, ha="center")

def init(t=0):
    radii = np.mean(dim_data_ranked,axis=0)[t,:]
    bars = ax.bar(theta, radii, width=width, bottom=0, color=colours_ranked, alpha=1)
    ax.set_ylim(0,0.3)
    ax.set_rlabel_position(315)
    title.set_text(str(epochs.times[t]*1000) +' ms')
    return bars

def update(t):
    ax.cla()
    ax.set_ylim(0,0.3)
    ax.text(np.radians(ax.get_rlabel_position())-.1,ax.get_rmax()/2,'Correlation',
        rotation=ax.get_rlabel_position(),ha='center',va='center')
    ax.axes.set_xticks([np.radians(315)])
    ax.axes.set_xticklabels('')
    ax.yaxis.grid(True,color='k',linestyle=':',alpha=0.5)


    radii = np.mean(dim_data_ranked,axis=0)[t,:]
    gets_a_label = np.where(radii>0.1)[0]

    bars = ax.bar(theta, radii, width=width, bottom=0, color=colours_ranked, alpha=1)

    if np.any(gets_a_label):
        for dim in gets_a_label:
            rotations = np.rad2deg(theta[dim])
            if rotations>90 and rotations<270:
                rotations = rotations-180
                label = labels_ranked.loc[dim,0] +' ' 
                ha = 'right'
            else:
                ha='left'
                label = ' ' + labels_ranked.loc[dim,0]
            lab = plt.text(theta[dim],radii[dim],label, ha=ha, va='center', rotation=rotations, rotation_mode="anchor",fontsize=8) 
    
    plt.title(str(epochs.times[t]*1000) +' ms')
    title.set_text(str(int(epochs.times[t]*1000)) +' ms')
    
    return bars

ani = FuncAnimation(fig, update, frames=len(epochs.times)-80,init_func=init,blit=False,cache_frame_data=False,repeat=False)
writer = animation.PillowWriter(fps=3)

ani.save(f'{bids_dir}/derivatives/meg_paper/figures/roseplot_short.gif',writer=writer, dpi=600)


In [5]:
# topoplot animation
plt.close('all')
fig,ax = plt.subplots(num=1, figsize=(5,5))
title = ax.text(0.5,0.85, "", bbox={'facecolor':'w', 'alpha':0.5, 'pad':5},
                transform=ax.transAxes, ha="center")

def init(t=0):
    topo,cm = mne.viz.plot_topomap(np.mean(all_dat_sens,axis=0)[t,:], epochs.info,vlim=[0,0.2],cmap='RdPu',axes=ax,
                                        sensors=True,contours=False)
    cbar = fig.colorbar(topo,fraction=0.03, pad=0.04)
    cbar.set_label('Correlation')
    fig.tight_layout()
    title.set_text(str(epochs.times[t]*1000) +' ms')
    return topo

def update(t):
    ax.cla()
    topo,cm = mne.viz.plot_topomap(np.mean(all_dat_sens,axis=0)[t,:], epochs.info,vlim=[0,0.2],cmap='RdPu',axes=ax,
                                        sensors=True,contours=False)
    fig.tight_layout()
    plt.title(str(int(epochs.times[t]*1000)) +' ms')
    return topo


ani = FuncAnimation(fig, update, frames=len(epochs.times)-80,init_func=init,blit=False,cache_frame_data=False,repeat=False)

writer = animation.PillowWriter(fps=3)

ani.save(f'{bids_dir}/derivatives/meg_paper/figures/topoplot_short.gif',writer=writer, dpi=600)



In [5]:
## load cross-participant results
all_dat_dims_cross = []
for test_participant in range(1,5):
    train_participants = [i for i in np.arange(1,5) if i !=test_participant]
    res_ppt  =[]
    for train_participant in train_participants:
        for cv in range(1,13):
            res_ppt.append(pd.read_csv(f'{res_folder}/Ridge-reg_cross_trainP{train_participant}_testP{test_participant}_cv{cv}.csv',index_col=0).to_numpy())

    all_dat_dims_cross.append(np.mean(res_ppt,axis=0))
all_dat_dims_cross = np.array(all_dat_dims_cross)

dim_data_cross_ranked = all_dat_dims_cross[:,:,sorted_idx]
dim_data_cross_ranked_avg = np.mean(dim_data_cross_ranked,axis=0)


In [6]:
# load permutations and check whether cross-decoding & difference between within & across is significant 
significance_mat_cross,significance_mat_avg_cross = [],[]
significance_mat_diff,significance_mat_avg_diff = [],[]

for test_participant in range(1,5):
    corr_ridge_perm_avgcv_cross,corr_ridge_perm_avgcv_within,corr_ridge_perm_avgcv_diff = [],[],[]
    for train_participant in range(1,5):
        if test_participant!=train_participant:
            corr_ridge_perm_cross_all = np.load(f'{res_folder}/Ridge-reg_cross_trainP{train_participant}_testP{test_participant}_permutations.npy') 
            corr_ridge_perm_within_all = np.load(f'{res_folder}/P{test_participant}_ridge-reg_within_predict-dims_permutations.npy') 
            diff_perm = corr_ridge_perm_cross_all-corr_ridge_perm_within_all

            corr_ridge_perm_avgcv_cross.append(np.mean(corr_ridge_perm_cross_all,axis=0))
            corr_ridge_perm_avgcv_within.append(np.mean(corr_ridge_perm_within_all,axis=0))
            corr_ridge_perm_avgcv_diff.append(np.mean(diff_perm,axis=0))

    # average over train participants
    corr_ridge_perm_avg_cross = np.mean(corr_ridge_perm_avgcv_cross,axis=0)
    corr_ridge_perm_avg_diff = np.mean(corr_ridge_perm_avgcv_diff,axis=0)
    # get 95%ile for each timepoint and each dimension null-distribution
    thresholds_cross = np.array([[np.percentile(corr_ridge_perm_avg_cross[:,t,d],99) for t in range(281)] for d in range(66)])
    thresholds_diff = np.array([[np.percentile(corr_ridge_perm_avg_diff[:,t,d],99) for t in range(281)] for d in range(66)])
    
    # threshold defined as the max 95%ile across time and dimensions
    threshold_cross = np.max(thresholds_cross)
    threshold_diff = np.max(thresholds_diff)
    
    significance_mat_cross.append(all_dat_dims_cross[test_participant-1,:,:]>threshold_cross)
    significance_mat_avg_cross.append(np.mean(all_dat_dims_cross[test_participant-1,:,:],axis=1)>threshold_cross)
    
    all_dat_dims_diff = all_dat_dims_cross-np.array(all_dat_dims)
    significance_mat_diff.append(np.array(all_dat_dims_diff)[test_participant-1,:,:]>threshold_diff)
    significance_mat_avg_diff.append(np.mean(all_dat_dims_diff[test_participant-1,:,:],axis=1)>threshold_diff)


# generate one matrix that is participants x dimensions x time that tells us when stuff is significant    
significance_mat_cross = np.array(significance_mat_cross)
significance_mat_avg_cross = np.array(significance_mat_avg_cross)
significance_mat_rank_cross = significance_mat_cross[:,:,sorted_idx]

significance_mat_diff = np.array(significance_mat_diff)
significance_mat_avg_diff = np.array(significance_mat_avg_diff)
significance_mat_rank_diff = significance_mat_diff[:,:,sorted_idx]




In [54]:
# MAKE RESULTS FIGURE
plt.close('all')
axd = plt.figure(constrained_layout=True,figsize=(8.25, 11.75)).subplot_mosaic(

    """
    ....
    WXYZ
    ....
    CDEF
    GHIJ
    KLMN
    OPQR
    ....
    ....
    ....
    """)

fig = plt.gcf()
for p,ax in enumerate([axd['W'],axd['X'],axd['Y'],axd['Z']]):
    # within
    ax.plot(epochs.times*1000,np.mean(dim_data_ranked[p,:,:],axis=1),'gray',alpha=0.3,label='within')
    sig_tp = (epochs.times*1000)[significance_mat_avg[p,:]]
    ax.plot(sig_tp,np.repeat(0,len(sig_tp==1)),marker='.',color='gray',alpha=0.1,markersize=1)
    
    #cross
    ax.plot(epochs.times*1000,np.mean(dim_data_cross_ranked[p,:,:],axis=1),'gray',label='across')
    sig_tp = (epochs.times*1000)[significance_mat_avg_cross[p,:]]
    ax.plot(sig_tp,np.repeat(-.007,len(sig_tp==1)),marker='.',color='gray',alpha=1,markersize=1)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])
    ax.set_ylim([-0.01,0.15])
    ax.legend(frameon=False,fontsize=8)

    axin2 = ax.inset_axes([-100,-0.15, 1400,0.13], transform= ax.transData)

    # Boxplots at timepoints of interest
    x = epochs.times*1000
    y = np.mean(avg_dim_data_ranked-dim_data_cross_ranked, axis=1)    
    points_of_interest_ms = [100, 200, 300, 400, 500, 600 ]
    points_of_interest=np.array([np.where(x==v)[0][0] for v in points_of_interest_ms])
    
    bp_data = np.mean([all_dat_dims_diff[p,i:i2,:] for i,i2 in zip(points_of_interest-2,points_of_interest+3)],axis=1)
    bp = axin2.boxplot(bp_data.T, patch_artist=True,positions=points_of_interest_ms,widths=80,showfliers=False)
    [i.set_color('grey') for i in bp['boxes']]
    [i.set_color('red') for i in bp['medians']]

    [i.set_alpha(0.6) for i in bp['boxes']]

    axin2.scatter(np.repeat(points_of_interest_ms,66),all_dat_dims_diff[p,points_of_interest,:],1,'k')
    
    axin2.set_xticks([0,500,1000])
    axin2.set_xticklabels([0,500,1000])

    axin2.set_xlim([-100,1400])
    axin2.set_ylim([-0.14,0.02])

    # stylize axis 
    axin2.spines['top'].set_visible(False)
    axin2.spines['right'].set_visible(False)
    axin2.set_title('')
    axin2.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])

    if p==0:
        axin2.set_ylabel('Difference',fontsize=8)
        ax.set_ylabel('Correlation', labelpad=10,fontsize=8)
    if p>0:
        ax.set_yticks([])
        axin2.set_yticks([])
    
    ax.set_title(f'S0{p+1}')
    axin2.set_xlabel('time (ms)')


### B
# add individual timeseries
dim_ranks = [int(i) for i in np.round(np.linspace(0,65,6))]
dim_ranks = np.arange(0,65,7)
dim_ranks = [0,3,5,7,8,9,14,21,28,31]
dim_ranks = [0,4,5,7,8,9,14,22,25,30]
dim_ranks = [0,4,7,8]


all_ax = [[axd['C'],axd['D'],axd['E'],axd['F']],
            [axd['G'],axd['H'],axd['I'],axd['J']],
            [axd['K'],axd['L'],axd['M'],axd['N']],
            [axd['O'],axd['P'],axd['Q'],axd['R']]]
for p in range(4):
    for rank,ax in zip(dim_ranks,np.array(all_ax)[:,p]):
        ax.plot(epochs.times*1000, epochs.times*0, 'grey', lw=1, linestyle='--')
        ax.plot(epochs.times*1000,dim_data_ranked[p,:,rank],c=colours_ranked[rank,:],linestyle='-',lw=1,alpha=0.4,label='within')
        ax.plot(epochs.times*1000,dim_data_cross_ranked[p,:,rank],c=colours_ranked[rank,:],lw=2,label='across')
        
        sig_tp = (epochs.times*1000)[significance_mat_rank_cross[p,:,rank]]
        ax.plot(sig_tp,np.repeat(-.02,len(sig_tp==1)),marker='.',c=colours_ranked[rank,:],alpha=1,markersize=1)

        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_ylim([-0.05,0.37])
        ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])
        ax.legend(frameon=False,fontsize=8)

        if p ==0:
            label_text = f'{labels_ranked.loc[rank,0]} ({sorted_idx[rank]+1})'
            slash_index = label_text.rfind("/")
            bracket_index = label_text.find("(")
            
            # Insert a line break before the first "/"
            if slash_index != -1:  # Check if "/" is found
                label_text = f'{label_text[:slash_index+1]}\n{label_text[slash_index+1:]}'
            else:
                label_text = f'{label_text[:bracket_index]}\n{label_text[bracket_index:]}'
            # title = ax.set_title(label_text, fontsize=text_size_small, y=0.9)
            ax.set_ylabel(f'{label_text}', labelpad=10,fontsize=8)
        if p>0:
            ax.set_yticks([])
        
        if rank == dim_ranks[-1]:
            ax.set_xlabel('time (ms)')
        else:
            ax.set_xticklabels('')
        if rank ==dim_ranks[0]:
            ax.set_title(f'S0{p+1}', y=0.9)

# add A and B labels
fig.text(0.01,0.92,'A',fontsize=text_size_big*2)
fig.text(0.01,0.63,'B',fontsize=text_size_big*2)


fig.savefig(f'{bids_dir}/derivatives/meg_paper/figures/Figure3.pdf')

In [45]:
# supplementary with table with differences and stats for diff timewindows
from scipy.stats import ttest_rel
from statsmodels.stats.multitest import multipletests
plt.close('all')
axd = plt.figure(constrained_layout=True,figsize=(8.25, 11.75)).subplot_mosaic(
    """
    HA.
    IB.
    JC.
    KD.
    ...
    ...
    ...
    ...
    """)

fig = plt.gcf()


for p,ax in enumerate([axd['H'],axd['I'],axd['J'],axd['K']]):
    # Boxplots at timepoints of interest
    x = epochs.times*1000
    y = np.mean(avg_dim_data_ranked-dim_data_cross_ranked, axis=1)    
    points_of_interest_ms = [100, 200, 300, 400, 500, 600 ]
    points_of_interest=np.array([np.where(x==v)[0][0] for v in points_of_interest_ms])
    
    bp_data = np.mean([all_dat_dims_diff[p,i:i2,:] for i,i2 in zip(points_of_interest-2,points_of_interest+3)],axis=1)
    bp = ax.boxplot(bp_data.T, patch_artist=True,positions=points_of_interest_ms,widths=80,showfliers=False)
    [i.set_color('grey') for i in bp['boxes']]
    [i.set_color('red') for i in bp['medians']]

    [i.set_alpha(0.6) for i in bp['boxes']]

    ax.scatter(np.repeat(points_of_interest_ms,66),all_dat_dims_diff[p,points_of_interest,:],1,'k')
    
    ax.set_xticks([0,500,1000])
    ax.set_xticklabels([0,500,1000])

    ax.set_xlim([-100,1400])
    ax.set_ylim([-0.14,0.02])

    # stylize axis 
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_title('')
    ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])

    ax.set_ylabel('Difference')

    
    ax.set_title(f'S0{p+1}')
    ax.set_xlabel('time (ms)')
    
    
for p,ax in enumerate([axd['A'],axd['B'],axd['C'],axd['D']]):
    
    bp_data = np.mean([all_dat_dims_diff[p,i:i2,:] for i,i2 in zip(points_of_interest-2,points_of_interest+3)],axis=1)
    n_timepoints = bp_data.shape[0]
    p_values,mean_differences = [],[]
    
    # Calculate p-values for paired comparisons between time points
    for i in range(n_timepoints):
        for j in range(i + 1, n_timepoints):
            stat, p_val = ttest_rel(bp_data[i], bp_data[j])  # Using paired t-test
            p_values.append(p_val)
            mean_diff = np.mean(bp_data[i]) - np.mean(bp_data[j])
            mean_differences.append(mean_diff)
    rejections, corrected_p_values = multipletests(p_values, alpha=0.01,method='bonferroni')[:2]

    # Table
    table_data = np.full((n_timepoints, n_timepoints), '', dtype=object)
    for idx, (i, j) in enumerate(((i, j) for i in range(n_timepoints) for j in range(n_timepoints) if i < j)):
        formatted = f"{mean_differences[idx]:.3f}"
        if formatted.startswith('0'):
            formatted = formatted[1:] 
        if formatted.startswith('-0'):
            formatted =  formatted.replace('-0', '-')
        if rejections[idx]:
            table_data[i, j] = f"{formatted}*"
        else:
            table_data[i, j] = f"{formatted}"
        
    colors = np.full((n_timepoints, n_timepoints), 'white', dtype=object)
    for idx, (i, j) in enumerate(((i, j) for i in range(n_timepoints) for j in range(n_timepoints) if i < j)):
        if rejections[idx]:
            colors[i, j] = 'green'  # Reject null
            colors[j, i] = 'white' 
        else:
            colors[i, j] = 'red'    # Do not reject null
            colors[j, i] = 'white' 

    ax.axis('tight')
    ax.axis('off')

    # Create the table
    table = ax.table(cellText=table_data, cellColours=colors, colLabels=[f'{i}ms' for i in points_of_interest_ms],
                    rowLabels=[f' {i}ms ' for i in points_of_interest_ms], cellLoc='center', loc='center')

    # # Set lower triangle cells and first column cells to be empty 
    for i in range(n_timepoints):
        for j in range(n_timepoints):
            if j < i:  # Lower triangle
                table[i + 1, j + 1].set_text_props(text='')
            if j == 0:  # First column
                table[i + 1, j].set_text_props(text='')

    # Adjust font size and table properties
    for key, cell in table.get_celld().items():
        cell.set_fontsize(12)  # Adjust font size as needed
        cell.set_edgecolor('k')  # Optional: set edge color for clarity

    # Adjust the axes limits to fit the table better
    ax.set_xlim([-100,1400])
    ax.set_ylim([-0.14,0.02])

fig.savefig(f'{bids_dir}/derivatives/meg_paper/figures/Supplementary_timewindows_stats.pdf')


In [402]:
# supplementary figure showing all dimension timecourses
plt.close('all')
fig,axs = plt.subplots(figsize=(8.25, 11.75),num=2,ncols=6,nrows=11,sharex=True,sharey=True)
text_size_small = 7

axs=axs.flatten()

for i,ax in enumerate(axs):
    ax.plot(epochs.times*1000,epochs.times*0,'grey',linestyle='--')
    for p in range(4):
        ax.plot(epochs.times*1000,dim_data_cross_ranked[p,:,i],c=colours_ranked[i,:],alpha=0.3,lw=1)
        sig_tp = (epochs.times*1000)[significance_mat_rank_cross[p,:,i]]
        ax.plot(sig_tp,np.repeat(-.02-(p*1/50),len(sig_tp==1)),'k.',mfc=colours_ranked[i,:],mec=colours_ranked[i,:],alpha=0.3,markersize=.8)

    ax.plot(epochs.times*1000,dim_data_cross_ranked_avg[:,i],c=colours_ranked[i,:],lw=2)
 

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylim([-0.11,0.3])

    ax.set_title(f'{labels_ranked.loc[i,0]}\n({sorted_idx[i]+1})',c='k',fontsize=text_size_small)
    ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])


fig.supylabel('Correlation: true vs predicted label')
fig.supxlabel('time (ms)')
fig.tight_layout()

fig.savefig(f'{bids_dir}/derivatives/meg_paper/figures/Supplementary_cross.pdf')