### Figures for within-participant regression analysis 
1) Results Figure 2 (average timeseries, snapshot roseplots, snapshot topoplots, example timeseries)
2) Supplementary Figure S1 (timeseries for all dimensions)
3) One example timeseries
4) Dynamic roseplot
5) Dynamic topoplot

In [None]:
"""
Author: linateichmann
Email: lina.teichmann@nih.gov

    Created on 2023-03-30 12:40:16
    Modified on 2023-03-30 12:40:16
"""


import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import mne, os
import seaborn as sns
from matplotlib.patches import ConnectionPatch
import matplotlib.animation as animation
from matplotlib.animation import FuncAnimation
%matplotlib qt

bids_dir = '/System/Volumes/Data/misc/data16/teichmanna2/2020_MEG_things/THINGS_biowulf/THINGS-MEG-bids/'
res_folder = f'{bids_dir}/derivatives/meg_paper/output/regression'
font = 'Arial'
text_size = 12
text_size_big = 16
text_size_small = 8

plt.rcParams['font.size'] = text_size
plt.rcParams['font.family'] = font

colours = pd.read_csv(f'{bids_dir}/sourcedata/meg_paper/colors66.txt',header=None)
labels = pd.read_csv(f'{bids_dir}/sourcedata/meg_paper/labels_super_short.txt',header=None)


# load data
def load_epochs(bids_dir,participant,behav):
    epochs = mne.read_epochs(f'{bids_dir}/derivatives/preprocessed/preprocessed_P{participant}-epo.fif',preload=False)
    # THINGS-category number & image number start at 1 (matlab based) so subtracting 1 to use for indexing
    epochs.metadata['things_image_nr'] = epochs.metadata['things_image_nr']-1
    epochs.metadata['things_category_nr'] = epochs.metadata['things_category_nr']-1
    # adding dimensional weights to metadata
    epochs.metadata[['dim'+ str(d+1) for d in range(66)]] = np.nan
    for i in range(len(epochs.metadata)):
        if not np.isnan(epochs.metadata.loc[i,'things_image_nr']):
            epochs.metadata.loc[i,['dim'+ str(d+1) for d in range(66)]] = behav[int(epochs.metadata.loc[i,'things_image_nr']),:]
    return epochs


behav = np.loadtxt(f'{bids_dir}/sourcedata/meg_paper/predictions_66d_elastic_clip-ViT-B-32_visual_THINGS.txt')
epochs = load_epochs(bids_dir,1,behav)

epochs_exp = epochs[(epochs.metadata['trial_type']=='exp')]   
epochs_exp.metadata.reset_index(inplace=True,drop=True)


# load results
all_dat_dims=[]
all_dat_sens=[]

for p in range(1,5):
    all_dat_dims.append(pd.read_csv(f'{res_folder}/P{p}_linreg_within.csv',index_col=0))
    all_dat_sens.append(pd.read_csv(f'{res_folder}/P{p}_linreg_within_predict-sens.csv',index_col=0))

# change data order based on peak amplitude
avg_dim_data = np.mean(all_dat_dims,axis=0)
sorted_idx = np.flipud(np.argsort(np.max(avg_dim_data,axis=0)))

avg_dim_data_ranked = avg_dim_data[:,sorted_idx]
dim_data_ranked = np.array(all_dat_dims)[:,:,sorted_idx]
colours_ranked = colours.to_numpy()[sorted_idx,:]
labels_ranked = labels.loc[sorted_idx,:]
labels_ranked.reset_index(inplace=True,drop=True)

filter_col = [col for col in epochs_exp.metadata.columns if col.startswith('dim')] 
weights_sorted_dims = epochs_exp.metadata.loc[:,filter_col].to_numpy()[:,sorted_idx]


labels_ranked_list = labels_ranked.loc[:,0].to_list()
labels_ranked_list = [item.replace('/', '') for item in labels_ranked_list]

In [None]:
# MAKE RESULTS FIGURE
dim_ranks = [int(i) for i in np.round(np.linspace(0,65,6))]

plt.close('all')
axd = plt.figure(constrained_layout=True, figsize=(8.25, 11.75)).subplot_mosaic(
    """
    .AAAAAA
    .......
    .BBBCCC
    .BBBCCC
    .DDDEEE
    .DDDEEE
    .FFFGGG
    .FFFGGG
    .......
    """
)

fig = plt.gcf()
#### make figure
axd['A'].plot(epochs.times*1000, epochs.times*0, 'grey', lw=1, linestyle='--')

# plot individual data and mean 
for p in range(4):
    axd['A'].plot(epochs.times*1000,all_dat_dims[p].mean(axis=1),'k',alpha=0.3,lw=1)
axd['A'].plot(epochs.times*1000,np.mean(np.mean(all_dat_dims,axis=0),axis=1),'k',alpha=1,lw=3)


axd['A'].set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])
axd['A'].set_ylim([-0.01,0.15])
axd['A'].set_xlabel('time (ms)')
axd['A'].set_ylabel('Correlation:\n true vs predicted label')

axd['A'].spines['top'].set_visible(False)
axd['A'].spines['right'].set_visible(False)


ts = np.arange(-100,1400,200)
axd['A'].set_xticks(ts)

# add polar plot
for t in ts:
    t_idx = np.where(epochs.times*1000 == t)[0][0]
    width_polar = 800
    height_polar = .25

    axin1 = axd['A'].inset_axes([t-width_polar/2,0.12, width_polar,height_polar], transform= axd['A'].transData,polar=True)
    axin1.set_aspect('equal')

    n_dims = np.array(dim_data_ranked).shape[2]
    theta = np.linspace(0.0, 2 * np.pi, n_dims, endpoint=False)
    radii = np.mean(dim_data_ranked,axis=0)[t_idx,:]
    width = np.pi/n_dims

    for ii,dim in enumerate(range(n_dims)):
        if ii in dim_ranks:
            bar = axin1.bar(theta[ii], radii[dim], width=width, bottom=0, color=colours_ranked[ii,:], alpha=1)
        else:
            bar = axin1.bar(theta[ii], radii[dim], width=width, bottom=0, color='k', alpha=.2)
        
    axin1.set_ylim(0,0.2)

    axin1.get_xaxis().set_visible(False)
    axin1.get_yaxis().set_visible(False)

    axin1.spines['polar'].set_visible(False)

    axin1.patch.set_alpha(0)


    # add topoplots
    height_topo = .2
    width_topo = 150
    axin2 = axd['A'].inset_axes([t-width_topo/2,-0.3, width_topo,height_topo], transform= axd['A'].transData)

    tix = np.where(t==epochs.times*1000)[0][0]
    topo,cm=mne.viz.plot_topomap(np.mean(all_dat_sens,axis=0)[tix,:], epochs.info,vlim=[0,0.2],cmap='RdPu',axes=axin2,
                                    sensors=False,contours=False)

    if t == ts[0]:
        cbar_ax = fig.add_axes([0.04,0.74,0.008,0.05])
        clb = fig.colorbar(topo, cax=cbar_ax)


    # add arrows
    con = ConnectionPatch(xyA=(t,axd['A'].get_ylim()[0]), xyB=(0,axin2.get_ylim()[1]), 
                            coordsA="data", coordsB="data",
                            axesA=axd['A'], axesB=axin2, color="black", linestyle='--',arrowstyle="-|>")
    axd['A'].add_artist(con)


    if tix!=0:
        ypoint = np.max([all_dat_dims[p].mean(axis=1)[tix] for p in range(4)])+0.01
    else:
        ypoint = axd['A'].get_ylim()[1]+0.01

    con = ConnectionPatch(xyA=(t,ypoint), xyB=(0.5,0.3), 
                            coordsA="data", coordsB="axes fraction",
                            axesA=axd['A'], axesB=axin1, color="black", linestyle='--',arrowstyle="-|>")
    axd['A'].add_artist(con)

    plt.setp(axd['A'].get_xticklabels(), backgroundcolor="white")


# add individual timeseries
for rank,ax in zip(dim_ranks,[axd['B'],axd['C'],axd['D'],axd['E'],axd['F'],axd['G']]):

    ax.plot(epochs.times*1000, epochs.times*0, 'grey', lw=1, linestyle='--')

    for p in range(4):
        ax.plot(epochs.times*1000,dim_data_ranked[p,:,rank],c=colours_ranked[rank,:],alpha=0.3,lw=1)
    ax.plot(epochs.times*1000,avg_dim_data_ranked[:,rank],c=colours_ranked[rank,:],lw=2)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylim([-0.05,0.5])
    ax.set_ylabel('Correlation')
    ax.set_xlabel('time (ms)')
    ax.set_title(' ')
    ax.text(700,0.5,f'{labels_ranked.loc[rank,0]} ({sorted_idx[rank]+1})',c='k',fontsize=text_size,ha='center')

    ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])



# add example images
labels_ranked_list
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

example_im_path = f'{bids_dir}/sourcedata/meg_paper/example_images/'

for ii,x in enumerate(np.arange(0.1,0.9,0.15)):
    for i,dim in enumerate(dim_ranks):
        image = plt.imread(f'{example_im_path}/{labels_ranked_list[dim]}/{ii}.jpg')
        imagebox = OffsetImage(image, zoom=0.08)
        ab = AnnotationBbox(imagebox, (x, .73), frameon=False, xycoords='axes fraction', boxcoords="offset points", pad=0)
        axd[chr(ord('B')+i)].add_artist(ab)

# labels
axd['B'].text(-400,0.3,' ',fontsize=text_size_big*3)
fig.text(0.01,0.95,'A',fontsize=text_size_big*2)
fig.text(0.01,0.63,'B',fontsize=text_size_big*2)

fig.savefig(f'{bids_dir}/derivatives/meg_paper/figures/Figure2.pdf')





In [None]:
# supplementary figure showing all dimension timecourses
plt.close('all')
fig,axs = plt.subplots(figsize=(8.25, 11.75),num=2,ncols=6,nrows=11,sharex=True,sharey=True)

axs=axs.flatten()

for i,ax in enumerate(axs):
    ax.plot(epochs.times*1000,epochs.times*0,'grey',linestyle='--')
    for p in range(4):
        ax.plot(epochs.times*1000,np.array(dim_data_ranked)[p,:,i],c=colours_ranked[i,:],alpha=0.3,lw=1)
    ax.plot(epochs.times*1000,avg_dim_data_ranked[:,i],c=colours_ranked[i,:],lw=2)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylim([-0.05,0.3])


    # ax.set_title(labels_ranked.loc[i,0],c=colours_ranked[i,:],fontsize=text_size_small)
    ax.set_title(f'{labels_ranked.loc[i,0]}\n({sorted_idx[i]+1})',c='k',fontsize=text_size_small)


    ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])


fig.supylabel('Correlation: true vs predicted label')
fig.supxlabel('time (ms)')
fig.tight_layout()
fig.savefig(f'{bids_dir}/derivatives/meg_paper/figures/Supplementary_within.pdf')



In [None]:
fig,ax = plt.subplots(1,1,figsize=(4,3))

i = 7
ax.plot(epochs.times*1000,epochs.times*0,'grey',linestyle='--')
ax.plot(epochs.times*1000,avg_dim_data_ranked[:,i],c=colours_ranked[i,:],alpha=1,lw=2)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_ylim([-0.05,0.3])
ax.set_title('Dimension 11: ' + str(labels_ranked.loc[i,0]),c=colours_ranked[i,:],fontsize=text_size)
ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])


ax.set_ylabel('Correlation:\ntrue vs predicted label')
ax.set_xlabel('time (ms)')

fig.tight_layout()
fig.savefig(f'{bids_dir}/sourcedata/meg_paper/example_timeseries.png', transparent=True)


In [None]:
# roseplot animation
plt.close('all')
theta = np.linspace(0.0, 2 * np.pi, n_dims, endpoint=False)
n_dims = np.array(dim_data_ranked).shape[2]
width = np.pi/n_dims

fig= plt.figure(num=1, figsize=(10,7))
ax = plt.subplot(projection='polar')
title = ax.text(0.5,0.85, "", bbox={'facecolor':'w', 'alpha':0.5, 'pad':5},
                transform=ax.transAxes, ha="center")

def init(t=0):
    radii = np.mean(dim_data_ranked,axis=0)[t,:]
    bars = ax.bar(theta, radii, width=width, bottom=0, color=colours_ranked, alpha=1)
    ax.set_ylim(0,0.3)
    ax.set_rlabel_position(315)
    title.set_text(str(epochs.times[t]*1000) +' ms')
    return bars

def update(t):
    ax.cla()
    ax.set_ylim(0,0.3)
    ax.text(np.radians(ax.get_rlabel_position())-.1,ax.get_rmax()/2,'Correlation',
        rotation=ax.get_rlabel_position(),ha='center',va='center')
    ax.axes.set_xticks([np.radians(315)])
    ax.axes.set_xticklabels('')
    ax.yaxis.grid(True,color='k',linestyle=':',alpha=0.5)


    radii = np.mean(dim_data_ranked,axis=0)[t,:]
    gets_a_label = np.where(radii>0.1)[0]

    bars = ax.bar(theta, radii, width=width, bottom=0, color=colours_ranked, alpha=1)

    if np.any(gets_a_label):
        for dim in gets_a_label:
            rotations = np.rad2deg(theta[dim])
            if rotations>90 and rotations<270:
                rotations = rotations-180
                label = labels_ranked.loc[dim,0] +' ' 
                ha = 'right'
            else:
                ha='left'
                label = ' ' + labels_ranked.loc[dim,0]
            lab = plt.text(theta[dim],radii[dim],label, ha=ha, va='center', rotation=rotations, rotation_mode="anchor",fontsize=8) 
    
    plt.title(str(epochs.times[t]*1000) +' ms')
    title.set_text(str(int(epochs.times[t]*1000)) +' ms')
    
    return bars

ani = FuncAnimation(fig, update, frames=len(epochs.times)-80,init_func=init,blit=False,cache_frame_data=False,repeat=False)
writer = animation.PillowWriter(fps=3)

ani.save(f'{bids_dir}/derivatives/meg_paper/figures/roseplot_short.gif',writer=writer, dpi=600)


In [None]:
# topoplot animation
plt.close('all')
fig,ax = plt.subplots(num=1, figsize=(5,5))
title = ax.text(0.5,0.85, "", bbox={'facecolor':'w', 'alpha':0.5, 'pad':5},
                transform=ax.transAxes, ha="center")

def init(t=0):
    topo,cm = mne.viz.plot_topomap(np.mean(all_dat_sens,axis=0)[t,:], epochs.info,vlim=[0,0.2],cmap='RdPu',axes=ax,
                                        sensors=True,contours=False)
    cbar = fig.colorbar(topo,fraction=0.03, pad=0.04)
    cbar.set_label('Correlation')
    fig.tight_layout()
    title.set_text(str(epochs.times[t]*1000) +' ms')
    return topo

def update(t):
    ax.cla()
    topo,cm = mne.viz.plot_topomap(np.mean(all_dat_sens,axis=0)[t,:], epochs.info,vlim=[0,0.2],cmap='RdPu',axes=ax,
                                        sensors=True,contours=False)
    fig.tight_layout()
    plt.title(str(int(epochs.times[t]*1000)) +' ms')
    return topo


ani = FuncAnimation(fig, update, frames=len(epochs.times)-80,init_func=init,blit=False,cache_frame_data=False,repeat=False)

writer = animation.PillowWriter(fps=3)

ani.save(f'{bids_dir}/derivatives/meg_paper/figures/topoplot_short.gif',writer=writer, dpi=600)

